PatientSim / patientsim /utils /common_utils.py
dek924's picture
feat: sanitize response by escaping braces to prevent formatting errors & add exception
e42bc71
import re
import random
import numpy as np
from typing import Union
from datetime import datetime, timedelta
# NOTE: torch removed — set_seed only uses random + numpy for reproducibility.
# This is sufficient for API-based agents (GPT, Gemini) that don't run local models.
from . import colorstr
from ..registry.detection_key import DDX_DETECT_KEYS
def set_seed(seed: int) -> None:
"""
Set the random seed for reproducibility (CPU-only, no torch required).
"""
random.seed(seed)
np.random.seed(seed)
def split_string(string: Union[str, list], delimiter: str = ",") -> list:
if isinstance(string, str):
return [s.strip() for s in string.split(delimiter)]
elif isinstance(string, list):
return [s.strip() for s in string]
else:
raise ValueError(colorstr("red", "Input must be a string or a list of strings."))
def prompt_valid_check(prompt: str, data_dict: dict) -> None:
keys = re.findall(r'\{(.*?)\}', prompt)
missing_keys = [key for key in keys if key not in data_dict]
if missing_keys:
raise ValueError(colorstr("red", f"Missing keys in the prompt: {missing_keys}. Please ensure all required keys are present in the data dictionary."))
def detect_ed_termination(text: str) -> bool:
pattern = re.compile(r'\[ddx\]:\s*\d+\.\s*.+', re.IGNORECASE)
end_flag = any(key.lower() in text.lower() for key in DDX_DETECT_KEYS)
return bool(pattern.search(text.lower())) or end_flag
def detect_op_termination(text: str) -> bool:
try:
pattern = re.compile(r'Answer:\s*\d+\.\s*(.+)')
return bool(pattern.search(text))
except Exception:
return False
def str_to_datetime(iso_time: Union[str, datetime]) -> datetime:
try:
if isinstance(iso_time, str):
return datetime.fromisoformat(iso_time)
return iso_time
except Exception:
raise ValueError(colorstr("red", f"`iso_time` must be str or date format, but got {type(iso_time)}"))
def datetime_to_str(iso_time: Union[str, datetime], format: str) -> str:
try:
if not isinstance(iso_time, str):
return iso_time.strftime(format)
return iso_time
except Exception:
raise ValueError(colorstr("red", f"`iso_time` must be str or date format, but got {type(iso_time)}"))
def generate_random_date(start_date: Union[str, datetime] = '1960-01-01',
end_date: Union[str, datetime] = '2000-12-31') -> str:
start = str_to_datetime(start_date)
end = str_to_datetime(end_date)
delta = (end - start).days
random_days = random.randint(0, delta)
random_date = start + timedelta(days=random_days)
return datetime_to_str(random_date, '%Y-%m-%d')
def exponential_backoff(retry_count: int,
base_delay: int = 5,
max_delay: int = 65,
jitter: bool = True) -> float:
delay = min(base_delay * (2 ** retry_count), max_delay)
if jitter:
delay = random.uniform(delay * 0.8, delay * 1.2)
return delay