Spaces:
Running
Running
| import os | |
| from typing import Optional | |
| from .registry.persona import * | |
| from .utils import colorstr, log | |
| from .utils.common_utils import set_seed | |
| from .client import GeminiClient, GeminiVertexClient, GPTClient, GPTAzureClient | |
| _PROMPT_DIR = os.path.join(os.path.dirname(__file__), "assets", "prompt") | |
| class DoctorAgent: | |
| def __init__(self, | |
| model: str, | |
| max_inferences: int = 15, | |
| top_k_diagnosis: int = 5, | |
| api_key: Optional[str] = None, | |
| use_azure: bool = False, | |
| use_vertex: bool = False, | |
| azure_endpoint: Optional[str] = None, | |
| system_prompt_path: Optional[str] = None, | |
| **kwargs) -> None: | |
| self.current_inference = 0 | |
| self.max_inferences = max_inferences | |
| self.top_k_diagnosis = top_k_diagnosis | |
| self._init_env(**kwargs) | |
| self.model = model | |
| self._init_model( | |
| model=self.model, | |
| api_key=api_key, | |
| use_azure=use_azure, | |
| use_vertex=use_vertex, | |
| azure_endpoint=azure_endpoint, | |
| ) | |
| self._system_prompt_template = self._init_prompt(system_prompt_path) | |
| self.build_prompt() | |
| log("DoctorAgent initialized successfully", color=True) | |
| def _init_env(self, **kwargs) -> None: | |
| self.random_seed = kwargs.get('random_seed', None) | |
| self.temperature = kwargs.get('temperature', 0.2) | |
| self.doctor_greet = kwargs.get('doctor_greet', "Hello, how can I help you?") | |
| self.patient_conditions = { | |
| 'age': kwargs.get('age', 'N/A'), | |
| 'gender': kwargs.get('gender', 'N/A'), | |
| 'arrival_transport': kwargs.get('arrival_transport', 'N/A'), | |
| } | |
| missing_conditions = [k for k, v in self.patient_conditions.items() if v == 'N/A'] | |
| if missing_conditions: | |
| log(f"Required patient information missing for the doctor agent: {', '.join(missing_conditions)}. Using default values.", level="warning") | |
| if self.random_seed: | |
| set_seed(self.random_seed) | |
| def _init_model(self, | |
| model: str, | |
| api_key: Optional[str] = None, | |
| use_azure: bool = False, | |
| use_vertex: bool = False, | |
| azure_endpoint: Optional[str] = None) -> None: | |
| if 'gemini' in self.model.lower(): | |
| self.client = GeminiVertexClient(model, api_key) if use_vertex else GeminiClient(model, api_key) | |
| elif 'gpt' in self.model.lower(): | |
| self.client = GPTAzureClient(model, api_key, azure_endpoint) if use_azure else GPTClient(model, api_key) | |
| else: | |
| raise ValueError(colorstr("red", f"Unsupported model: {self.model}. Supported models are 'gemini' and 'gpt'.")) | |
| def _init_prompt(self, system_prompt_path: Optional[str] = None) -> str: | |
| if not system_prompt_path: | |
| with open(os.path.join(_PROMPT_DIR, "ed_doctor_sys.txt"), 'r') as f: | |
| return f.read() | |
| else: | |
| if not os.path.exists(system_prompt_path): | |
| raise FileNotFoundError(colorstr("red", f"System prompt file not found: {system_prompt_path}")) | |
| with open(system_prompt_path, 'r') as f: | |
| return f.read() | |
| def reset_history(self, verbose: bool = True) -> None: | |
| self.client.reset_history(verbose=verbose) | |
| def build_prompt(self) -> None: | |
| self.system_prompt = self._system_prompt_template.format( | |
| total_idx=self.max_inferences, | |
| curr_idx=self.current_inference, | |
| remain_idx=self.max_inferences - self.current_inference, | |
| top_k_diagnosis=self.top_k_diagnosis, | |
| **self.patient_conditions | |
| ) | |
| def update_system_prompt(self): | |
| self.current_inference = max(1, len(list(filter( | |
| lambda x: (not isinstance(x, dict) and x.role == 'model') or | |
| (isinstance(x, dict) and x.get('role') == 'assistant'), | |
| self.client.histories | |
| )))) | |
| self.build_prompt() | |
| if len(self.client.histories) and isinstance(self.client.histories[0], dict) and self.client.histories[0].get('role') == 'system': | |
| self.client.histories[0]['content'] = self.system_prompt | |
| def __call__(self, | |
| user_prompt: str, | |
| using_multi_turn: bool = True, | |
| verbose: bool = True, | |
| **kwargs) -> str: | |
| self.update_system_prompt() | |
| response = self.client( | |
| user_prompt=user_prompt, | |
| system_prompt=self.system_prompt, | |
| using_multi_turn=using_multi_turn, | |
| greeting=self.doctor_greet, | |
| verbose=verbose, | |
| temperature=self.temperature, | |
| seed=self.random_seed, | |
| **kwargs | |
| ) | |
| return response | |