| | from langchain_openai import ChatOpenAI |
| | from langchain.chains import LLMChain |
| |
|
| | from langchain.agents import ZeroShotAgent, Tool, AgentExecutor |
| | from langchain_community.utilities import SerpAPIWrapper |
| |
|
| | from typing import List, Dict, Callable |
| | from langchain.chains import ConversationChain |
| | from .config import ConversationConfig |
| | |
| |
|
| | from langchain.memory import ConversationBufferMemory |
| | from langchain.prompts.prompt import PromptTemplate |
| | from langchain.schema import ( |
| | AIMessage, |
| | HumanMessage, |
| | SystemMessage, |
| | BaseMessage, |
| | ) |
| |
|
| | from typing import Optional |
| |
|
| |
|
| | |
| | class DialougeAgent: |
| | def __init__( |
| | self, |
| | name: str, |
| | system_message: SystemMessage, |
| | model: ChatOpenAI, |
| | ) -> None: |
| | self.name = name |
| | self.system_message = system_message |
| | self.model = model |
| | self.prefix = f"{self.name}: " |
| | self.reset() |
| | |
| | def reset(self): |
| | self.message_history = ["Here is the conversation so far."] |
| |
|
| |
|
| | def send(self) -> str: |
| | """ |
| | Applies the chatmodel to the message history |
| | and returns the message string |
| | """ |
| | message = self.model.invoke( |
| | [ |
| | self.system_message, |
| | HumanMessage(content="\n".join(self.message_history + [self.prefix])), |
| | ] |
| | ) |
| |
|
| | return message.content |
| |
|
| | def receive(self, name: str, message: str) -> None: |
| | """ |
| | Concatenates {message} spoken by {name} into message history |
| | """ |
| | self.message_history.append(f"{name}: {message}") |
| |
|
| | |
| | class JudgeAgent(DialougeAgent): |
| | def __init__( |
| | self, |
| | name: str, |
| | system_message: str, |
| | model: ChatOpenAI = ChatOpenAI()) -> None: |
| | system_message = SystemMessage(content=system_message) |
| | super().__init__(name, system_message, model) |
| | |
| | def reset(self): |
| | self.message_history = [] |
| |
|
| | |
| | class DialogueSimulator: |
| | def __init__( |
| | self, |
| | agents: List[DialougeAgent], |
| | selection_function: Callable[[int, List[DialougeAgent]], int], |
| | ) -> None: |
| | self.agents = agents |
| | self._step = 0 |
| | self.select_next_speaker = selection_function |
| |
|
| | def reset(self): |
| | for agent in self.agents: |
| | agent.reset() |
| |
|
| | def inject(self, name: str, message: str): |
| | """ |
| | Initiates the conversation with a {message} from {name} |
| | """ |
| | for agent in self.agents: |
| | agent.receive(name, message) |
| |
|
| | |
| | self._step += 1 |
| |
|
| | def step(self) -> tuple[str, str]: |
| | |
| | speaker_idx = self.select_next_speaker(self._step, self.agents) |
| | speaker = self.agents[speaker_idx] |
| |
|
| | |
| | message = speaker.send() |
| |
|
| | |
| | for receiver in self.agents: |
| | receiver.receive(speaker.name, message) |
| |
|
| | |
| | self._step += 1 |
| |
|
| | return speaker.name, message |
| |
|
| |
|
| | class SalesSimulator(DialogueSimulator): |
| | def __init__(self, sales_sys_message: str, customer_sys_message: str, sales_first: bool =True, |
| | data_path: Optional[str] = None, date: Optional[str] = None): |
| | |
| | sales_agent = DialougeAgent( |
| | name="Sales", |
| | system_message=SystemMessage(content=sales_sys_message), |
| | model=ChatOpenAI(), |
| | ) |
| | |
| | customer_agent = DialougeAgent( |
| | name="Customer", |
| | system_message=SystemMessage(content=customer_sys_message), |
| | model=ChatOpenAI(), |
| | ) |
| |
|
| | talk_in_turns = lambda step, agents: step % len(agents) if sales_first else (step + 1) % len(agents) |
| |
|
| | super().__init__([sales_agent, customer_agent], talk_in_turns) |
| |
|
| | self.data_path = data_path |
| | self.name = 'SalesSimulator_'+date |
| |
|
| | @property |
| | def lattest_utterance(self) -> str: |
| | return self.agents[1].message_history[-1] |
| | |
| | @property |
| | def conversation_history(self) -> List[str]: |
| | return self.agents[1].message_history[1:] |
| | |
| | from typing import List, Callable, Optional |
| |
|
| | def _check_purchase(self, message: str) -> bool: |
| | """ |
| | Checks if the customer purchased the product |
| | """ |
| | if "<PURCHASE>" in message: |
| | return True |
| | elif "<NO_PURCHASE>" in message: |
| | return False |
| | else: |
| | return None |
| | |
| |
|
| | def simulate(self, n: int = 20, print_conversation: bool = False, return_result: bool = False) -> List[str]: |
| | self.reset() |
| | res = None |
| | for i in range(int(n * len(self.agents))): |
| | name, message = self.step() |
| | if print_conversation: |
| | if i % 2 == 0: |
| | print('---'*4 + 'Round ' + str(i//2) + '---'*4) |
| | print(f'{name}: {message}') |
| | if not isinstance(res, bool): |
| | res = self._check_purchase(message) |
| | if return_result: |
| | return self.agents[1].message_history[1:], res |
| | else: |
| | return self.agents[1].message_history[1:] |
| |
|
| | @classmethod |
| | def make(cls, conversation_config: Optional[ConversationConfig] = None) -> 'SalesSimulator': |
| | if conversation_config is None: |
| | conversation_config = ConversationConfig.make() |
| | return cls(sales_sys_message = conversation_config.sales_system_prompt, |
| | customer_sys_message = conversation_config.customer_system_prompt, |
| | sales_first = conversation_config.sales_first, |
| | data_path = conversation_config.data_path, |
| | date = conversation_config.date) |
| |
|
| | def store_conversation(self): |
| | import json |
| | with open(self.data_path, 'w') as f: |
| | json.dump(self.conversation_history, f) |
| |
|
| |
|
| |
|
| |
|
| | |
| | from .prompt import CUSTOMER_SYSTEM_PROMPT, SALES_SYSTEM_PROMPT |
| | def test_simsales(): |
| | print('------Initializing Test function for SalesSimulator------') |
| | simsales = SalesSimulator.make() |
| | simsales.simulate(10, True) |
| |
|
| | print('Conversation History: \n', simsales._get_conservation_history()) |
| |
|
| |
|