Spaces:
Running
Running
| import time | |
| from typing import Optional | |
| from ..doctor import DoctorAgent | |
| from ..patient import PatientAgent | |
| from ..checker import CheckerAgent | |
| from ..utils import log, colorstr | |
| from ..utils.common_utils import detect_ed_termination | |
| class EDSimulation: | |
| def __init__(self, | |
| patient_agent: PatientAgent, | |
| doctor_agent: DoctorAgent, | |
| checker_agent: Optional[CheckerAgent] = None, | |
| max_inferences: int = 15): | |
| self.patient_agent = patient_agent | |
| self.doctor_agent = doctor_agent | |
| self.checker_agent = checker_agent | |
| self.max_inferences = max_inferences | |
| self.current_inference = 0 | |
| self._sanity_check() | |
| def _sanity_check(self): | |
| if not self.doctor_agent.max_inferences == self.max_inferences: | |
| log("The maximum number of inferences between the Doctor agent and the ED simulation does not match.", level="warning") | |
| log(f"The simulation will start with the value ({self.max_inferences}) configured in the ED simulation, and the Doctor agent system prompt will be updated accordingly.", level="warning") | |
| self.doctor_agent.max_inferences = self.max_inferences | |
| self.doctor_agent.build_prompt() | |
| if self.checker_agent: | |
| assert self.checker_agent.visit_type == self.patient_agent.visit_type, \ | |
| log(colorstr("red", f"The visit type between the Checker agent ({self.checker_agent.visit_type}) and the Patient agent ({self.patient_agent.visit_type}) must be the same.")) | |
| def _init_agents(self, verbose: bool = True) -> None: | |
| self.patient_agent.reset_history(verbose=verbose) | |
| self.doctor_agent.reset_history(verbose=verbose) | |
| def simulate(self, | |
| verbose: bool = True, | |
| patient_kwargs: dict = {}, | |
| doctor_kwargs: dict = {}, | |
| **kwargs) -> dict: | |
| self._init_agents(verbose=verbose) | |
| doctor_greet = self.doctor_agent.doctor_greet | |
| dialog_history = [{"role": "Doctor", "content": doctor_greet}] | |
| for inference_idx in range(self.max_inferences): | |
| patient_kwargs.update(kwargs) | |
| patient_response = self.patient_agent( | |
| user_prompt=dialog_history[-1]["content"], | |
| using_multi_turn=True, | |
| verbose=verbose, | |
| **patient_kwargs | |
| ) | |
| dialog_history.append({"role": "Patient", "content": patient_response}) | |
| doctor_kwargs.update(kwargs) | |
| doctor_response = self.doctor_agent( | |
| user_prompt=dialog_history[-1]["content"] + "\nThis is the final turn. Now, you must provide your top5 differential diagnosis." | |
| if inference_idx == self.max_inferences - 1 else dialog_history[-1]["content"], | |
| using_multi_turn=True, | |
| verbose=verbose, | |
| **doctor_kwargs | |
| ) | |
| dialog_history.append({"role": "Doctor", "content": doctor_response}) | |
| if detect_ed_termination(doctor_response): | |
| break | |
| elif self.checker_agent: | |
| termination_check = self.checker_agent(response=doctor_response).strip().upper() | |
| if termination_check == "Y": | |
| log("Consultation termination detected by the checker agent.", level="warning") | |
| break | |
| time.sleep(1.0) | |
| log("Simulation completed.", color=True) | |
| return { | |
| "dialog_history": dialog_history, | |
| "patient_token_usage": self.patient_agent.client.token_usages, | |
| "doctor_token_usage": self.doctor_agent.client.token_usages, | |
| } | |