import time from typing import Optional from ..doctor import DoctorAgent from ..patient import PatientAgent from ..checker import CheckerAgent from ..utils import log, colorstr from ..utils.common_utils import detect_ed_termination class EDSimulation: def __init__(self, patient_agent: PatientAgent, doctor_agent: DoctorAgent, checker_agent: Optional[CheckerAgent] = None, max_inferences: int = 15): self.patient_agent = patient_agent self.doctor_agent = doctor_agent self.checker_agent = checker_agent self.max_inferences = max_inferences self.current_inference = 0 self._sanity_check() def _sanity_check(self): if not self.doctor_agent.max_inferences == self.max_inferences: log("The maximum number of inferences between the Doctor agent and the ED simulation does not match.", level="warning") log(f"The simulation will start with the value ({self.max_inferences}) configured in the ED simulation, and the Doctor agent system prompt will be updated accordingly.", level="warning") self.doctor_agent.max_inferences = self.max_inferences self.doctor_agent.build_prompt() if self.checker_agent: assert self.checker_agent.visit_type == self.patient_agent.visit_type, \ log(colorstr("red", f"The visit type between the Checker agent ({self.checker_agent.visit_type}) and the Patient agent ({self.patient_agent.visit_type}) must be the same.")) def _init_agents(self, verbose: bool = True) -> None: self.patient_agent.reset_history(verbose=verbose) self.doctor_agent.reset_history(verbose=verbose) def simulate(self, verbose: bool = True, patient_kwargs: dict = {}, doctor_kwargs: dict = {}, **kwargs) -> dict: self._init_agents(verbose=verbose) doctor_greet = self.doctor_agent.doctor_greet dialog_history = [{"role": "Doctor", "content": doctor_greet}] for inference_idx in range(self.max_inferences): patient_kwargs.update(kwargs) patient_response = self.patient_agent( user_prompt=dialog_history[-1]["content"], using_multi_turn=True, verbose=verbose, **patient_kwargs ) dialog_history.append({"role": "Patient", "content": patient_response}) doctor_kwargs.update(kwargs) doctor_response = self.doctor_agent( user_prompt=dialog_history[-1]["content"] + "\nThis is the final turn. Now, you must provide your top5 differential diagnosis." if inference_idx == self.max_inferences - 1 else dialog_history[-1]["content"], using_multi_turn=True, verbose=verbose, **doctor_kwargs ) dialog_history.append({"role": "Doctor", "content": doctor_response}) if detect_ed_termination(doctor_response): break elif self.checker_agent: termination_check = self.checker_agent(response=doctor_response).strip().upper() if termination_check == "Y": log("Consultation termination detected by the checker agent.", level="warning") break time.sleep(1.0) log("Simulation completed.", color=True) return { "dialog_history": dialog_history, "patient_token_usage": self.patient_agent.client.token_usages, "doctor_token_usage": self.doctor_agent.client.token_usages, }