PatientSim / patientsim /environment /ed_simulation.py
dek924's picture
feat: patientsim update
e7069ae
import time
from typing import Optional
from ..doctor import DoctorAgent
from ..patient import PatientAgent
from ..checker import CheckerAgent
from ..utils import log, colorstr
from ..utils.common_utils import detect_ed_termination
class EDSimulation:
def __init__(self,
patient_agent: PatientAgent,
doctor_agent: DoctorAgent,
checker_agent: Optional[CheckerAgent] = None,
max_inferences: int = 15):
self.patient_agent = patient_agent
self.doctor_agent = doctor_agent
self.checker_agent = checker_agent
self.max_inferences = max_inferences
self.current_inference = 0
self._sanity_check()
def _sanity_check(self):
if not self.doctor_agent.max_inferences == self.max_inferences:
log("The maximum number of inferences between the Doctor agent and the ED simulation does not match.", level="warning")
log(f"The simulation will start with the value ({self.max_inferences}) configured in the ED simulation, and the Doctor agent system prompt will be updated accordingly.", level="warning")
self.doctor_agent.max_inferences = self.max_inferences
self.doctor_agent.build_prompt()
if self.checker_agent:
assert self.checker_agent.visit_type == self.patient_agent.visit_type, \
log(colorstr("red", f"The visit type between the Checker agent ({self.checker_agent.visit_type}) and the Patient agent ({self.patient_agent.visit_type}) must be the same."))
def _init_agents(self, verbose: bool = True) -> None:
self.patient_agent.reset_history(verbose=verbose)
self.doctor_agent.reset_history(verbose=verbose)
def simulate(self,
verbose: bool = True,
patient_kwargs: dict = {},
doctor_kwargs: dict = {},
**kwargs) -> dict:
self._init_agents(verbose=verbose)
doctor_greet = self.doctor_agent.doctor_greet
dialog_history = [{"role": "Doctor", "content": doctor_greet}]
for inference_idx in range(self.max_inferences):
patient_kwargs.update(kwargs)
patient_response = self.patient_agent(
user_prompt=dialog_history[-1]["content"],
using_multi_turn=True,
verbose=verbose,
**patient_kwargs
)
dialog_history.append({"role": "Patient", "content": patient_response})
doctor_kwargs.update(kwargs)
doctor_response = self.doctor_agent(
user_prompt=dialog_history[-1]["content"] + "\nThis is the final turn. Now, you must provide your top5 differential diagnosis."
if inference_idx == self.max_inferences - 1 else dialog_history[-1]["content"],
using_multi_turn=True,
verbose=verbose,
**doctor_kwargs
)
dialog_history.append({"role": "Doctor", "content": doctor_response})
if detect_ed_termination(doctor_response):
break
elif self.checker_agent:
termination_check = self.checker_agent(response=doctor_response).strip().upper()
if termination_check == "Y":
log("Consultation termination detected by the checker agent.", level="warning")
break
time.sleep(1.0)
log("Simulation completed.", color=True)
return {
"dialog_history": dialog_history,
"patient_token_usage": self.patient_agent.client.token_usages,
"doctor_token_usage": self.doctor_agent.client.token_usages,
}