PatientSim / patientsim /doctor.py
dek924's picture
feat: patientsim update
e7069ae
import os
from typing import Optional
from .registry.persona import *
from .utils import colorstr, log
from .utils.common_utils import set_seed
from .client import GeminiClient, GeminiVertexClient, GPTClient, GPTAzureClient
_PROMPT_DIR = os.path.join(os.path.dirname(__file__), "assets", "prompt")
class DoctorAgent:
def __init__(self,
model: str,
max_inferences: int = 15,
top_k_diagnosis: int = 5,
api_key: Optional[str] = None,
use_azure: bool = False,
use_vertex: bool = False,
azure_endpoint: Optional[str] = None,
system_prompt_path: Optional[str] = None,
**kwargs) -> None:
self.current_inference = 0
self.max_inferences = max_inferences
self.top_k_diagnosis = top_k_diagnosis
self._init_env(**kwargs)
self.model = model
self._init_model(
model=self.model,
api_key=api_key,
use_azure=use_azure,
use_vertex=use_vertex,
azure_endpoint=azure_endpoint,
)
self._system_prompt_template = self._init_prompt(system_prompt_path)
self.build_prompt()
log("DoctorAgent initialized successfully", color=True)
def _init_env(self, **kwargs) -> None:
self.random_seed = kwargs.get('random_seed', None)
self.temperature = kwargs.get('temperature', 0.2)
self.doctor_greet = kwargs.get('doctor_greet', "Hello, how can I help you?")
self.patient_conditions = {
'age': kwargs.get('age', 'N/A'),
'gender': kwargs.get('gender', 'N/A'),
'arrival_transport': kwargs.get('arrival_transport', 'N/A'),
}
missing_conditions = [k for k, v in self.patient_conditions.items() if v == 'N/A']
if missing_conditions:
log(f"Required patient information missing for the doctor agent: {', '.join(missing_conditions)}. Using default values.", level="warning")
if self.random_seed:
set_seed(self.random_seed)
def _init_model(self,
model: str,
api_key: Optional[str] = None,
use_azure: bool = False,
use_vertex: bool = False,
azure_endpoint: Optional[str] = None) -> None:
if 'gemini' in self.model.lower():
self.client = GeminiVertexClient(model, api_key) if use_vertex else GeminiClient(model, api_key)
elif 'gpt' in self.model.lower():
self.client = GPTAzureClient(model, api_key, azure_endpoint) if use_azure else GPTClient(model, api_key)
else:
raise ValueError(colorstr("red", f"Unsupported model: {self.model}. Supported models are 'gemini' and 'gpt'."))
def _init_prompt(self, system_prompt_path: Optional[str] = None) -> str:
if not system_prompt_path:
with open(os.path.join(_PROMPT_DIR, "ed_doctor_sys.txt"), 'r') as f:
return f.read()
else:
if not os.path.exists(system_prompt_path):
raise FileNotFoundError(colorstr("red", f"System prompt file not found: {system_prompt_path}"))
with open(system_prompt_path, 'r') as f:
return f.read()
def reset_history(self, verbose: bool = True) -> None:
self.client.reset_history(verbose=verbose)
def build_prompt(self) -> None:
self.system_prompt = self._system_prompt_template.format(
total_idx=self.max_inferences,
curr_idx=self.current_inference,
remain_idx=self.max_inferences - self.current_inference,
top_k_diagnosis=self.top_k_diagnosis,
**self.patient_conditions
)
def update_system_prompt(self):
self.current_inference = max(1, len(list(filter(
lambda x: (not isinstance(x, dict) and x.role == 'model') or
(isinstance(x, dict) and x.get('role') == 'assistant'),
self.client.histories
))))
self.build_prompt()
if len(self.client.histories) and isinstance(self.client.histories[0], dict) and self.client.histories[0].get('role') == 'system':
self.client.histories[0]['content'] = self.system_prompt
def __call__(self,
user_prompt: str,
using_multi_turn: bool = True,
verbose: bool = True,
**kwargs) -> str:
self.update_system_prompt()
response = self.client(
user_prompt=user_prompt,
system_prompt=self.system_prompt,
using_multi_turn=using_multi_turn,
greeting=self.doctor_greet,
verbose=verbose,
temperature=self.temperature,
seed=self.random_seed,
**kwargs
)
return response