Spaces:
Sleeping
Sleeping
| """ | |
| Configuration management for MedGemma AI Medical Triage System. | |
| """ | |
| import os | |
| from pathlib import Path | |
| from typing import Optional | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Project Root | |
| PROJECT_ROOT = Path(__file__).parent | |
| SRC_DIR = PROJECT_ROOT / "src" | |
| DATA_DIR = PROJECT_ROOT / "data" | |
| LOGS_DIR = PROJECT_ROOT / "logs" | |
| MODELS_DIR = PROJECT_ROOT / "models" | |
| # Create directories if they don't exist | |
| LOGS_DIR.mkdir(exist_ok=True) | |
| MODELS_DIR.mkdir(exist_ok=True) | |
| (MODELS_DIR / "cache").mkdir(exist_ok=True) | |
| class ModelConfig: | |
| """Configuration for MedGemma model.""" | |
| # Model selection | |
| MODEL_NAME: str = os.getenv("MODEL_NAME", "google/medgemma-1.5-4b-it") | |
| MODEL_CACHE_DIR: str = os.getenv("MODEL_CACHE_DIR", str(MODELS_DIR / "cache")) | |
| # Optional: load from S3 instead of Hugging Face (e.g. s3://bucket/medgemma-1.5-4b-it/) | |
| MODEL_S3_URI: str = os.getenv("MODEL_S3_URI", "") | |
| # Hugging Face | |
| HF_TOKEN: Optional[str] = os.getenv("HF_TOKEN") | |
| HF_HOME: str = os.getenv("HF_HOME", str(MODELS_DIR / "cache")) | |
| # Model parameters | |
| USE_GPU: bool = os.getenv("USE_GPU", "true").lower() == "true" | |
| MAX_LENGTH: int = int(os.getenv("MAX_LENGTH", "2048")) | |
| MAX_NEW_TOKENS: int = int(os.getenv("MAX_NEW_TOKENS", "384")) | |
| TEMPERATURE: float = float(os.getenv("TEMPERATURE", "0.7")) | |
| TOP_P: float = 0.9 | |
| TOP_K: int = 50 | |
| # Performance: 4-bit greatly speeds up inference on GPU (e.g. HF Spaces). Default on when SPACE_ID is set. | |
| _default_4bit = "true" if os.getenv("SPACE_ID") else "false" | |
| LOAD_IN_8BIT: bool = os.getenv("LOAD_IN_8BIT", "false").lower() == "true" | |
| LOAD_IN_4BIT: bool = os.getenv("LOAD_IN_4BIT", _default_4bit).lower() == "true" | |
| # Attention: "eager" (default), "sdpa" (faster on GPU), "flash_attention_2" (fastest, needs flash-attn; Gemma can be flaky) | |
| ATTN_IMPLEMENTATION: str = os.getenv("ATTN_IMPLEMENTATION", "eager").lower() | |
| def get_device(cls) -> str: | |
| """Get the device to use for inference.""" | |
| if cls.USE_GPU: | |
| import torch | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| return "cpu" | |
| class AgentConfig: | |
| """Configuration for agent behavior.""" | |
| MAX_ITERATIONS: int = int(os.getenv("MAX_AGENT_ITERATIONS", "10")) | |
| TIMEOUT: int = int(os.getenv("AGENT_TIMEOUT", "300")) | |
| RETRY_ATTEMPTS: int = 3 | |
| RETRY_DELAY: float = 1.0 | |
| class TriageConfig: | |
| """Configuration for triage workflow.""" | |
| # Conversation settings | |
| ENABLE_HISTORY: bool = os.getenv("ENABLE_CONVERSATION_HISTORY", "true").lower() == "true" | |
| MAX_TURNS: int = int(os.getenv("MAX_CONVERSATION_TURNS", "20")) | |
| # Urgency levels | |
| URGENCY_LEVELS = [ | |
| "EMERGENCY", # Immediate life-threatening | |
| "URGENT", # Needs attention within hours | |
| "SEMI-URGENT", # Needs attention within 1-2 days | |
| "NON-URGENT" # Can wait for routine appointment | |
| ] | |
| # Care settings | |
| CARE_SETTINGS = [ | |
| "Emergency Department (ER)", | |
| "Urgent Care Center", | |
| "Primary Care Physician", | |
| "Telemedicine Consultation", | |
| "Self-Care at Home" | |
| ] | |
| # Red flag symptoms (require immediate attention) | |
| # Legacy list - kept for backward compatibility | |
| RED_FLAGS = [ | |
| "chest pain", | |
| "difficulty breathing", | |
| "severe bleeding", | |
| "altered consciousness", | |
| "severe head injury", | |
| "stroke symptoms", | |
| "severe allergic reaction", | |
| "suicidal thoughts", | |
| "seizure" | |
| ] | |
| # Critical red flags - Always EMERGENCY (life-threatening) | |
| # These are more specific to reduce false positives | |
| CRITICAL_RED_FLAGS = [ | |
| "severe chest pain", | |
| "crushing chest pain", | |
| "crushing chest pressure", | |
| "chest pain radiating", | |
| "severe difficulty breathing", | |
| "cannot breathe", | |
| "blue lips", | |
| "cyanosis", | |
| "gasping for air", | |
| "stroke symptoms", | |
| "face drooping", | |
| "facial droop", | |
| "arm weakness", | |
| "arm paralysis", | |
| "speech difficulty", | |
| "slurred speech", | |
| "altered consciousness", | |
| "unresponsive", | |
| "severe confusion", | |
| "severe bleeding", | |
| "hemorrhage", | |
| "bleeding heavily", | |
| "severe allergic reaction", | |
| "anaphylaxis", | |
| "throat swelling", | |
| "severe head injury", | |
| "worst headache of life", | |
| "thunderclap headache", | |
| "suicidal thoughts", | |
| "seizure" | |
| ] | |
| # Warning flags - Concerning but context-dependent | |
| WARNING_FLAGS = [ | |
| "chest discomfort", | |
| "chest pain", | |
| "high fever", | |
| "fever over 103", | |
| "103°f", | |
| "103 degrees", | |
| "fever 103", | |
| "severe abdominal pain", | |
| "moderate breathing difficulty", | |
| "shortness of breath", | |
| "severe injury", | |
| "severe pain", | |
| "persistent vomiting", | |
| "cannot keep fluids down", | |
| "can't keep food down", | |
| "can't keep fluids down", | |
| "extreme swelling", | |
| "extremely swollen", | |
| "purple discoloration", | |
| "turning purple", | |
| "can't bear weight", | |
| "can't put weight", | |
| "can't walk", | |
| "unable to walk", | |
| "pain getting worse", | |
| "progressively worse" | |
| ] | |
| # Severity keywords for context analysis | |
| # Enhanced with more specific phrases and negative indicators | |
| SEVERITY_KEYWORDS = { | |
| "critical": [ | |
| "severe", "crushing", "worst", "sudden", "radiating", | |
| "intense", "unbearable", "excruciating", "life-threatening", | |
| "blue", "cyanosis", "unresponsive", "cannot breathe", | |
| "drooping", "paralyzed", "cannot speak", "hemorrhaging", | |
| "gasping", "worst ever", "cannot move", "won't stop bleeding", | |
| "throat swelling", "cannot swallow", "extreme", "acute" | |
| ], | |
| "high": [ | |
| "significant", "extreme", "very painful", "getting worse", | |
| "rapidly", "spreading", "persistent", "worsening", | |
| "progressively worse", "increasingly", "very severe", | |
| "extremely swollen", "turning purple", "can't walk", | |
| "can't keep down" | |
| ], | |
| "moderate": [ | |
| "moderate", "noticeable", "concerning", "uncomfortable", | |
| "increasing", "frequent", "persistent", "ongoing", | |
| "bothersome", "troublesome", "annoying" | |
| ], | |
| "low": [ | |
| "mild", "slight", "minor", "occasional", "intermittent", | |
| "tolerable", "manageable", "slight", "little", | |
| "somewhat", "a bit", "goes away", "comes and goes", | |
| "improving", "better", "healing" | |
| ] | |
| } | |
| # Negative indicators that reduce severity | |
| NEGATION_KEYWORDS = [ | |
| "no", "not", "without", "denies", "deny", "absent", | |
| "never", "neither", "none", "nothing" | |
| ] | |
| # Temporal indicators that suggest non-acute/historical | |
| TEMPORAL_KEYWORDS = [ | |
| "history of", "previously", "past", "used to", | |
| "last year", "months ago", "years ago", "chronic" | |
| ] | |
| class DemoConfig: | |
| """Configuration for demo application.""" | |
| PORT: int = int(os.getenv("DEMO_PORT", "7860")) | |
| SHARE: bool = os.getenv("DEMO_SHARE", "false").lower() == "true" | |
| # Bind to 0.0.0.0 so the app is reachable from outside the container (e.g. ECS host network) | |
| SERVER_NAME: str = os.getenv("DEMO_SERVER_NAME", "0.0.0.0") | |
| # UI settings | |
| THEME: str = "soft" | |
| TITLE: str = "MedGemma AI Medical Triage System" | |
| DESCRIPTION: str = """ | |
| This intelligent triage system uses multiple AI agents powered by MedGemma | |
| to assess your symptoms and recommend appropriate care. | |
| ⚠️ **Disclaimer**: This is a demonstration system for research purposes only. | |
| It does NOT replace professional medical advice. For emergencies, call 911. | |
| """ | |
| class APIConfig: | |
| """Configuration for API (if enabled).""" | |
| HOST: str = os.getenv("API_HOST", "0.0.0.0") | |
| PORT: int = int(os.getenv("API_PORT", "8000")) | |
| RELOAD: bool = os.getenv("DEBUG", "false").lower() == "true" | |
| class LogConfig: | |
| """Configuration for logging.""" | |
| LEVEL: str = os.getenv("LOG_LEVEL", "INFO") | |
| LOG_FILE: str = os.getenv("LOG_FILE", str(LOGS_DIR / "app.log")) | |
| FORMAT: str = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan> - <level>{message}</level>" | |
| # Log rotation | |
| ROTATION: str = "100 MB" | |
| RETENTION: str = "30 days" | |
| COMPRESSION: str = "zip" | |
| # Export all configs | |
| __all__ = [ | |
| "ModelConfig", | |
| "AgentConfig", | |
| "TriageConfig", | |
| "DemoConfig", | |
| "APIConfig", | |
| "LogConfig", | |
| "PROJECT_ROOT", | |
| "SRC_DIR", | |
| "DATA_DIR", | |
| "LOGS_DIR", | |
| "MODELS_DIR" | |
| ] | |