Spaces:
Running
Running
Implement model override functionality for AI clients and enhance session management in the simplified medical app
Browse files
src/config/ai_providers_config.py
CHANGED
|
@@ -97,12 +97,13 @@ AGENT_CONFIGURATIONS = {
|
|
| 97 |
"reasoning": "Gentle triage requires empathy and nuanced understanding"
|
| 98 |
},
|
| 99 |
|
| 100 |
-
# Spiritual Distress Analyzer uses
|
|
|
|
| 101 |
"SpiritualDistressAnalyzer": {
|
| 102 |
-
"provider": AIProvider.
|
| 103 |
-
"model": AIModel.
|
| 104 |
-
"temperature": 0.
|
| 105 |
-
"reasoning": "
|
| 106 |
},
|
| 107 |
|
| 108 |
# Referral Message Generator uses Anthropic for compassionate communication
|
|
|
|
| 97 |
"reasoning": "Gentle triage requires empathy and nuanced understanding"
|
| 98 |
},
|
| 99 |
|
| 100 |
+
# Spiritual Distress Analyzer uses Google Gemini by default for speed/throughput
|
| 101 |
+
# (This matches the baseline UI defaults in Model Settings).
|
| 102 |
"SpiritualDistressAnalyzer": {
|
| 103 |
+
"provider": AIProvider.GEMINI,
|
| 104 |
+
"model": AIModel.GEMINI_2_5_FLASH,
|
| 105 |
+
"temperature": 0.1,
|
| 106 |
+
"reasoning": "Fast classification task, optimized for speed"
|
| 107 |
},
|
| 108 |
|
| 109 |
# Referral Message Generator uses Anthropic for compassionate communication
|
src/core/ai_client.py
CHANGED
|
@@ -230,19 +230,51 @@ class UniversalAIClient:
|
|
| 230 |
based on agent configuration and availability
|
| 231 |
"""
|
| 232 |
|
| 233 |
-
def __init__(self, agent_name: str):
|
| 234 |
self.agent_name = agent_name
|
|
|
|
| 235 |
self.config = get_agent_config(agent_name)
|
| 236 |
self.client = None
|
| 237 |
self.fallback_client = None
|
| 238 |
|
| 239 |
self._initialize_clients()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
def _initialize_clients(self):
|
| 242 |
"""Initialize primary and fallback clients"""
|
| 243 |
primary_provider = self.config["provider"]
|
| 244 |
primary_model = self.config["model"]
|
| 245 |
temperature = self.config.get("temperature", 0.3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
# Try to initialize primary client
|
| 248 |
try:
|
|
@@ -343,14 +375,25 @@ class AIClientManager:
|
|
| 343 |
def __init__(self):
|
| 344 |
self._clients = {} # Cache for AI clients
|
| 345 |
self.call_counter = 0 # Backward compatibility
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
# NEW: Enhanced client management for medical AI optimization
|
| 348 |
self.provider_performance_metrics = {}
|
| 349 |
self.medical_context_routing = {}
|
| 350 |
|
| 351 |
# Enhanced client retrieval with performance tracking
|
| 352 |
-
def get_client(self, agent_name: str):
|
| 353 |
-
"""Get or create an AI client for the specified agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
if agent_name not in self._clients:
|
| 355 |
self._clients[agent_name] = create_ai_client(agent_name)
|
| 356 |
return self._clients[agent_name]
|
|
@@ -358,7 +401,8 @@ class AIClientManager:
|
|
| 358 |
def generate_response(self, system_prompt: str, user_prompt: str,
|
| 359 |
temperature: float = None, call_type: str = "",
|
| 360 |
agent_name: str = "DefaultAgent",
|
| 361 |
-
medical_context: Optional[Dict] = None
|
|
|
|
| 362 |
"""
|
| 363 |
Enhanced response generation with medical context awareness
|
| 364 |
|
|
@@ -369,7 +413,7 @@ class AIClientManager:
|
|
| 369 |
- Maintain full backward compatibility
|
| 370 |
"""
|
| 371 |
try:
|
| 372 |
-
client = self.get_client(agent_name)
|
| 373 |
response = client.generate_response(
|
| 374 |
system_prompt=system_prompt,
|
| 375 |
user_prompt=user_prompt,
|
|
@@ -421,7 +465,8 @@ class AIClientManager:
|
|
| 421 |
return {name: self.get_client_info(name) for name in self._clients}
|
| 422 |
|
| 423 |
def call_spiritual_api(self, system_prompt: str, user_prompt: str,
|
| 424 |
-
temperature: float = 0.7
|
|
|
|
| 425 |
"""
|
| 426 |
Call AI API for spiritual/emotional analysis.
|
| 427 |
|
|
@@ -435,16 +480,21 @@ class AIClientManager:
|
|
| 435 |
Returns:
|
| 436 |
AI response as string
|
| 437 |
"""
|
|
|
|
|
|
|
|
|
|
| 438 |
return self.generate_response(
|
| 439 |
system_prompt=system_prompt,
|
| 440 |
user_prompt=user_prompt,
|
| 441 |
temperature=temperature,
|
| 442 |
call_type="spiritual_analysis",
|
| 443 |
-
agent_name="SpiritualDistressAnalyzer"
|
|
|
|
| 444 |
)
|
| 445 |
|
| 446 |
def call_medical_api(self, system_prompt: str, user_prompt: str,
|
| 447 |
-
temperature: float = 0.3
|
|
|
|
| 448 |
"""
|
| 449 |
Call AI API for medical assistance.
|
| 450 |
|
|
@@ -458,16 +508,20 @@ class AIClientManager:
|
|
| 458 |
Returns:
|
| 459 |
AI response as string
|
| 460 |
"""
|
|
|
|
|
|
|
|
|
|
| 461 |
return self.generate_response(
|
| 462 |
system_prompt=system_prompt,
|
| 463 |
user_prompt=user_prompt,
|
| 464 |
temperature=temperature,
|
| 465 |
call_type="medical_assistance",
|
| 466 |
-
agent_name="SoftMedicalTriage"
|
|
|
|
| 467 |
)
|
| 468 |
|
| 469 |
# Factory function for easy client creation
|
| 470 |
-
def create_ai_client(agent_name: str) -> UniversalAIClient:
|
| 471 |
"""
|
| 472 |
Create an AI client for a specific agent
|
| 473 |
|
|
@@ -477,7 +531,7 @@ def create_ai_client(agent_name: str) -> UniversalAIClient:
|
|
| 477 |
Returns:
|
| 478 |
Configured UniversalAIClient instance
|
| 479 |
"""
|
| 480 |
-
return UniversalAIClient(agent_name)
|
| 481 |
|
| 482 |
if __name__ == "__main__":
|
| 483 |
print(" AI Client Test")
|
|
|
|
| 230 |
based on agent configuration and availability
|
| 231 |
"""
|
| 232 |
|
| 233 |
+
def __init__(self, agent_name: str, model_override: Optional[str] = None):
|
| 234 |
self.agent_name = agent_name
|
| 235 |
+
self.model_override = model_override
|
| 236 |
self.config = get_agent_config(agent_name)
|
| 237 |
self.client = None
|
| 238 |
self.fallback_client = None
|
| 239 |
|
| 240 |
self._initialize_clients()
|
| 241 |
+
|
| 242 |
+
@staticmethod
|
| 243 |
+
def _resolve_override_model(model_override: str) -> tuple[Optional[AIProvider], Optional[AIModel]]:
|
| 244 |
+
"""Resolve a UI-provided model string into provider+AIModel.
|
| 245 |
+
|
| 246 |
+
Expected strings (from UI dropdowns):
|
| 247 |
+
- gemini-2.5-flash / gemini-2.0-flash / gemini-flash-latest
|
| 248 |
+
- claude-sonnet-4-5-20250929 / ...
|
| 249 |
+
"""
|
| 250 |
+
if not model_override:
|
| 251 |
+
return None, None
|
| 252 |
+
override = model_override.strip()
|
| 253 |
+
if not override:
|
| 254 |
+
return None, None
|
| 255 |
+
|
| 256 |
+
try:
|
| 257 |
+
if override.startswith("gemini"):
|
| 258 |
+
return AIProvider.GEMINI, AIModel(override)
|
| 259 |
+
if override.startswith("claude"):
|
| 260 |
+
return AIProvider.ANTHROPIC, AIModel(override)
|
| 261 |
+
except Exception:
|
| 262 |
+
return None, None
|
| 263 |
+
|
| 264 |
+
return None, None
|
| 265 |
|
| 266 |
def _initialize_clients(self):
|
| 267 |
"""Initialize primary and fallback clients"""
|
| 268 |
primary_provider = self.config["provider"]
|
| 269 |
primary_model = self.config["model"]
|
| 270 |
temperature = self.config.get("temperature", 0.3)
|
| 271 |
+
|
| 272 |
+
# Optional: override model/provider (session-level setting from UI)
|
| 273 |
+
if self.model_override:
|
| 274 |
+
override_provider, override_model = self._resolve_override_model(self.model_override)
|
| 275 |
+
if override_provider is not None and override_model is not None:
|
| 276 |
+
primary_provider = override_provider
|
| 277 |
+
primary_model = override_model
|
| 278 |
|
| 279 |
# Try to initialize primary client
|
| 280 |
try:
|
|
|
|
| 375 |
def __init__(self):
|
| 376 |
self._clients = {} # Cache for AI clients
|
| 377 |
self.call_counter = 0 # Backward compatibility
|
| 378 |
+
|
| 379 |
+
# Optional: allow an owning session/app to attach per-session overrides.
|
| 380 |
+
# Expected shape: {agent_name: model_string}
|
| 381 |
+
self.model_overrides: Dict[str, str] = {}
|
| 382 |
|
| 383 |
# NEW: Enhanced client management for medical AI optimization
|
| 384 |
self.provider_performance_metrics = {}
|
| 385 |
self.medical_context_routing = {}
|
| 386 |
|
| 387 |
# Enhanced client retrieval with performance tracking
|
| 388 |
+
def get_client(self, agent_name: str, model_override: Optional[str] = None):
|
| 389 |
+
"""Get or create an AI client for the specified agent.
|
| 390 |
+
|
| 391 |
+
If `model_override` is provided, a new (non-cached) client is returned
|
| 392 |
+
to avoid cross-session leakage.
|
| 393 |
+
"""
|
| 394 |
+
if model_override:
|
| 395 |
+
return create_ai_client(agent_name, model_override=model_override)
|
| 396 |
+
|
| 397 |
if agent_name not in self._clients:
|
| 398 |
self._clients[agent_name] = create_ai_client(agent_name)
|
| 399 |
return self._clients[agent_name]
|
|
|
|
| 401 |
def generate_response(self, system_prompt: str, user_prompt: str,
|
| 402 |
temperature: float = None, call_type: str = "",
|
| 403 |
agent_name: str = "DefaultAgent",
|
| 404 |
+
medical_context: Optional[Dict] = None,
|
| 405 |
+
model_override: Optional[str] = None):
|
| 406 |
"""
|
| 407 |
Enhanced response generation with medical context awareness
|
| 408 |
|
|
|
|
| 413 |
- Maintain full backward compatibility
|
| 414 |
"""
|
| 415 |
try:
|
| 416 |
+
client = self.get_client(agent_name, model_override=model_override)
|
| 417 |
response = client.generate_response(
|
| 418 |
system_prompt=system_prompt,
|
| 419 |
user_prompt=user_prompt,
|
|
|
|
| 465 |
return {name: self.get_client_info(name) for name in self._clients}
|
| 466 |
|
| 467 |
def call_spiritual_api(self, system_prompt: str, user_prompt: str,
|
| 468 |
+
temperature: float = 0.7,
|
| 469 |
+
model_override: Optional[str] = None) -> str:
|
| 470 |
"""
|
| 471 |
Call AI API for spiritual/emotional analysis.
|
| 472 |
|
|
|
|
| 480 |
Returns:
|
| 481 |
AI response as string
|
| 482 |
"""
|
| 483 |
+
if model_override is None and self.model_overrides:
|
| 484 |
+
model_override = self.model_overrides.get("SpiritualDistressAnalyzer")
|
| 485 |
+
|
| 486 |
return self.generate_response(
|
| 487 |
system_prompt=system_prompt,
|
| 488 |
user_prompt=user_prompt,
|
| 489 |
temperature=temperature,
|
| 490 |
call_type="spiritual_analysis",
|
| 491 |
+
agent_name="SpiritualDistressAnalyzer",
|
| 492 |
+
model_override=model_override,
|
| 493 |
)
|
| 494 |
|
| 495 |
def call_medical_api(self, system_prompt: str, user_prompt: str,
|
| 496 |
+
temperature: float = 0.3,
|
| 497 |
+
model_override: Optional[str] = None) -> str:
|
| 498 |
"""
|
| 499 |
Call AI API for medical assistance.
|
| 500 |
|
|
|
|
| 508 |
Returns:
|
| 509 |
AI response as string
|
| 510 |
"""
|
| 511 |
+
if model_override is None and self.model_overrides:
|
| 512 |
+
model_override = self.model_overrides.get("SoftMedicalTriage")
|
| 513 |
+
|
| 514 |
return self.generate_response(
|
| 515 |
system_prompt=system_prompt,
|
| 516 |
user_prompt=user_prompt,
|
| 517 |
temperature=temperature,
|
| 518 |
call_type="medical_assistance",
|
| 519 |
+
agent_name="SoftMedicalTriage",
|
| 520 |
+
model_override=model_override,
|
| 521 |
)
|
| 522 |
|
| 523 |
# Factory function for easy client creation
|
| 524 |
+
def create_ai_client(agent_name: str, model_override: Optional[str] = None) -> UniversalAIClient:
|
| 525 |
"""
|
| 526 |
Create an AI client for a specific agent
|
| 527 |
|
|
|
|
| 531 |
Returns:
|
| 532 |
Configured UniversalAIClient instance
|
| 533 |
"""
|
| 534 |
+
return UniversalAIClient(agent_name, model_override=model_override)
|
| 535 |
|
| 536 |
if __name__ == "__main__":
|
| 537 |
print(" AI Client Test")
|
src/core/simplified_medical_app.py
CHANGED
|
@@ -64,6 +64,8 @@ class SimplifiedMedicalApp:
|
|
| 64 |
|
| 65 |
# AI client
|
| 66 |
self.api = AIClientManager()
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# Medical components
|
| 69 |
self.medical_assistant = MedicalAssistant(self.api)
|
|
@@ -89,6 +91,29 @@ class SimplifiedMedicalApp:
|
|
| 89 |
)
|
| 90 |
|
| 91 |
logger.info("✅ SimplifiedMedicalApp initialized")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
def process_message(
|
| 94 |
self,
|
|
|
|
| 64 |
|
| 65 |
# AI client
|
| 66 |
self.api = AIClientManager()
|
| 67 |
+
# Optional per-session model overrides (set by the UI Model Settings tab)
|
| 68 |
+
self.model_overrides = {}
|
| 69 |
|
| 70 |
# Medical components
|
| 71 |
self.medical_assistant = MedicalAssistant(self.api)
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
logger.info("✅ SimplifiedMedicalApp initialized")
|
| 94 |
+
|
| 95 |
+
def set_model_overrides(self, overrides: Optional[dict] = None) -> None:
|
| 96 |
+
"""Set per-session model overrides.
|
| 97 |
+
|
| 98 |
+
`overrides` is expected to be a mapping of agent_name -> model string.
|
| 99 |
+
Example keys used by the UI:
|
| 100 |
+
- SpiritualDistressAnalyzer
|
| 101 |
+
- SoftSpiritualTriage
|
| 102 |
+
- TriageResponseEvaluator
|
| 103 |
+
- MedicalAssistant
|
| 104 |
+
- SoftMedicalTriage
|
| 105 |
+
"""
|
| 106 |
+
self.model_overrides = dict(overrides or {})
|
| 107 |
+
|
| 108 |
+
# Propagate to AI manager so core components can read overrides.
|
| 109 |
+
if hasattr(self, "api") and hasattr(self.api, "model_overrides"):
|
| 110 |
+
self.api.model_overrides = dict(self.model_overrides)
|
| 111 |
+
|
| 112 |
+
def _get_model_override(self, agent_name: str) -> Optional[str]:
|
| 113 |
+
"""Return the model override for an agent, if any."""
|
| 114 |
+
if not getattr(self, "model_overrides", None):
|
| 115 |
+
return None
|
| 116 |
+
return self.model_overrides.get(agent_name)
|
| 117 |
|
| 118 |
def process_message(
|
| 119 |
self,
|
src/interface/simplified_gradio_app.py
CHANGED
|
@@ -597,6 +597,13 @@ Changes apply only to your current session.
|
|
| 597 |
session = SimplifiedSessionData()
|
| 598 |
|
| 599 |
session.update_activity()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
new_history, status = session.app_instance.process_message(message, history)
|
| 601 |
|
| 602 |
# Get updated conversation stats
|
|
|
|
| 597 |
session = SimplifiedSessionData()
|
| 598 |
|
| 599 |
session.update_activity()
|
| 600 |
+
|
| 601 |
+
# Apply per-session model overrides (if configured in Model Settings)
|
| 602 |
+
custom_models = getattr(session, 'custom_models', None)
|
| 603 |
+
if custom_models:
|
| 604 |
+
session.app_instance.set_model_overrides(custom_models)
|
| 605 |
+
else:
|
| 606 |
+
session.app_instance.set_model_overrides({})
|
| 607 |
new_history, status = session.app_instance.process_message(message, history)
|
| 608 |
|
| 609 |
# Get updated conversation stats
|