Eric Botti commited on
Commit ·
a92f249
1
Parent(s): 760a529
added Kani agents
Browse files- src/agents.py +12 -30
- src/parser.py +3 -2
- src/player.py +21 -21
src/agents.py
CHANGED
|
@@ -1,34 +1,16 @@
|
|
| 1 |
-
from
|
| 2 |
-
from langchain.agents import AgentExecutor, create_openai_functions_agent
|
| 3 |
-
from langchain_openai import ChatOpenAI
|
| 4 |
|
| 5 |
-
from langchain.prompts import PromptTemplate
|
| 6 |
-
from reasoning_tools import animal_tools, extract_vote
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
'temperature': 1
|
| 13 |
-
},
|
| 14 |
-
"herd": {
|
| 15 |
-
'model': 'gpt-3.5-turbo',
|
| 16 |
-
'temperature': 1
|
| 17 |
-
},
|
| 18 |
-
"judge": {
|
| 19 |
-
'model': 'gpt-3.5-turbo',
|
| 20 |
-
'temperature': 1
|
| 21 |
-
}
|
| 22 |
-
}
|
| 23 |
|
| 24 |
-
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
agent = create_openai_functions_agent(llm, animal_tools, prompt)
|
| 33 |
-
|
| 34 |
-
super().__init__(agent=agent, tools=animal_tools, verbose=True, return_intermediate_steps=True)
|
|
|
|
| 1 |
+
from kani import Kani
|
|
|
|
|
|
|
| 2 |
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
class LogMessagesKani(Kani):
|
| 5 |
+
def __init__(self, engine, log_filepath: str = None, *args, **kwargs):
|
| 6 |
+
super().__init__(engine, *args, **kwargs)
|
| 7 |
+
self.log_filepath = log_filepath
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
async def add_to_history(self, message, *args, **kwargs):
|
| 10 |
+
await super().add_to_history(message, *args, **kwargs)
|
| 11 |
|
| 12 |
+
# Logs Message to File
|
| 13 |
+
if self.log_filepath:
|
| 14 |
+
with open(self.log_filepath, "a") as log_file:
|
| 15 |
+
log_file.write(message.model_dump_json())
|
| 16 |
+
log_file.write("\n")
|
|
|
|
|
|
|
|
|
|
|
|
src/parser.py
CHANGED
|
@@ -2,10 +2,11 @@ from typing import Type
|
|
| 2 |
import asyncio
|
| 3 |
import json
|
| 4 |
|
| 5 |
-
from kani import Kani
|
| 6 |
from kani.engines.openai import OpenAIEngine
|
| 7 |
from pydantic import BaseModel, ValidationError
|
| 8 |
|
|
|
|
|
|
|
| 9 |
FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
|
| 10 |
Here is the output schema:
|
| 11 |
```
|
|
@@ -24,7 +25,7 @@ Output:
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
|
| 27 |
-
class ParserKani(
|
| 28 |
def __init__(self, engine, *args, **kwargs):
|
| 29 |
super().__init__(engine, *args, **kwargs)
|
| 30 |
|
|
|
|
| 2 |
import asyncio
|
| 3 |
import json
|
| 4 |
|
|
|
|
| 5 |
from kani.engines.openai import OpenAIEngine
|
| 6 |
from pydantic import BaseModel, ValidationError
|
| 7 |
|
| 8 |
+
from agents import LogMessagesKani
|
| 9 |
+
|
| 10 |
FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
|
| 11 |
Here is the output schema:
|
| 12 |
```
|
|
|
|
| 25 |
"""
|
| 26 |
|
| 27 |
|
| 28 |
+
class ParserKani(LogMessagesKani):
|
| 29 |
def __init__(self, engine, *args, **kwargs):
|
| 30 |
super().__init__(engine, *args, **kwargs)
|
| 31 |
|
src/player.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
import openai
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
# Using TGI Inference Endpoints from Hugging Face
|
| 5 |
# api_type = "tgi"
|
|
@@ -15,40 +20,35 @@ else:
|
|
| 15 |
model_name = "gpt-3.5-turbo"
|
| 16 |
client = openai.Client()
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
class Player:
|
| 19 |
-
def __init__(self, name: str,
|
| 20 |
self.name = name
|
| 21 |
-
self.controller =
|
|
|
|
|
|
|
|
|
|
| 22 |
self.role = role
|
| 23 |
self.messages = []
|
| 24 |
|
| 25 |
-
def
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
output = self.
|
| 29 |
-
|
| 30 |
return output
|
| 31 |
|
| 32 |
-
def
|
| 33 |
if self.controller == "human":
|
| 34 |
print(prompt)
|
| 35 |
return input()
|
| 36 |
|
| 37 |
elif self.controller == "ai":
|
| 38 |
-
|
| 39 |
-
model=model_name,
|
| 40 |
-
messages=self.messages,
|
| 41 |
-
stream=False,
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
return chat_completion.choices[0].message.content
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def add_message(self, message: str):
|
| 48 |
-
"""Add a message to the messages list. No response required."""
|
| 49 |
-
self.messages.append({"role": "user", "content": message})
|
| 50 |
-
|
| 51 |
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
|
|
|
|
| 1 |
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
|
| 4 |
import openai
|
| 5 |
+
from agents import LogMessagesKani
|
| 6 |
+
from kani.engines.openai import OpenAIEngine
|
| 7 |
+
|
| 8 |
|
| 9 |
# Using TGI Inference Endpoints from Hugging Face
|
| 10 |
# api_type = "tgi"
|
|
|
|
| 20 |
model_name = "gpt-3.5-turbo"
|
| 21 |
client = openai.Client()
|
| 22 |
|
| 23 |
+
openai_engine = OpenAIEngine(model="gpt-3.5-turbo")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
class Player:
|
| 27 |
+
def __init__(self, name: str, controller_type: str, role: str, log_filepath: str = None):
|
| 28 |
self.name = name
|
| 29 |
+
self.controller = controller_type
|
| 30 |
+
if controller_type == "ai":
|
| 31 |
+
self.kani = LogMessagesKani(openai_engine, log_filepath=log_filepath)
|
| 32 |
+
|
| 33 |
self.role = role
|
| 34 |
self.messages = []
|
| 35 |
|
| 36 |
+
async def respond_to(self, prompt: str) -> str:
|
| 37 |
+
"""Makes the player respond to a prompt. Returns the response."""
|
| 38 |
+
# Generate a response from the controller
|
| 39 |
+
output = await self.__generate(prompt)
|
| 40 |
+
|
| 41 |
return output
|
| 42 |
|
| 43 |
+
async def __generate(self, prompt: str) -> str:
|
| 44 |
if self.controller == "human":
|
| 45 |
print(prompt)
|
| 46 |
return input()
|
| 47 |
|
| 48 |
elif self.controller == "ai":
|
| 49 |
+
output = await self.kani.chat_round_str(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
return output
|
| 52 |
|
| 53 |
|
| 54 |
|