DocUA commited on
Commit
ba5a70c
·
1 Parent(s): 82ca3ff

Implement model override functionality for AI clients and enhance session management in the simplified medical app

Browse files
src/config/ai_providers_config.py CHANGED
@@ -97,12 +97,13 @@ AGENT_CONFIGURATIONS = {
97
  "reasoning": "Gentle triage requires empathy and nuanced understanding"
98
  },
99
 
100
- # Spiritual Distress Analyzer uses Anthropic for empathy and safety
 
101
  "SpiritualDistressAnalyzer": {
102
- "provider": AIProvider.ANTHROPIC,
103
- "model": AIModel.CLAUDE_SONNET_4_5,
104
- "temperature": 0.2,
105
- "reasoning": "Spiritual distress assessment requires empathy, safety, and nuanced understanding"
106
  },
107
 
108
  # Referral Message Generator uses Anthropic for compassionate communication
 
97
  "reasoning": "Gentle triage requires empathy and nuanced understanding"
98
  },
99
 
100
+ # Spiritual Distress Analyzer uses Google Gemini by default for speed/throughput
101
+ # (This matches the baseline UI defaults in Model Settings).
102
  "SpiritualDistressAnalyzer": {
103
+ "provider": AIProvider.GEMINI,
104
+ "model": AIModel.GEMINI_2_5_FLASH,
105
+ "temperature": 0.1,
106
+ "reasoning": "Fast classification task, optimized for speed"
107
  },
108
 
109
  # Referral Message Generator uses Anthropic for compassionate communication
src/core/ai_client.py CHANGED
@@ -230,19 +230,51 @@ class UniversalAIClient:
230
  based on agent configuration and availability
231
  """
232
 
233
- def __init__(self, agent_name: str):
234
  self.agent_name = agent_name
 
235
  self.config = get_agent_config(agent_name)
236
  self.client = None
237
  self.fallback_client = None
238
 
239
  self._initialize_clients()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  def _initialize_clients(self):
242
  """Initialize primary and fallback clients"""
243
  primary_provider = self.config["provider"]
244
  primary_model = self.config["model"]
245
  temperature = self.config.get("temperature", 0.3)
 
 
 
 
 
 
 
246
 
247
  # Try to initialize primary client
248
  try:
@@ -343,14 +375,25 @@ class AIClientManager:
343
  def __init__(self):
344
  self._clients = {} # Cache for AI clients
345
  self.call_counter = 0 # Backward compatibility
 
 
 
 
346
 
347
  # NEW: Enhanced client management for medical AI optimization
348
  self.provider_performance_metrics = {}
349
  self.medical_context_routing = {}
350
 
351
  # Enhanced client retrieval with performance tracking
352
- def get_client(self, agent_name: str):
353
- """Get or create an AI client for the specified agent"""
 
 
 
 
 
 
 
354
  if agent_name not in self._clients:
355
  self._clients[agent_name] = create_ai_client(agent_name)
356
  return self._clients[agent_name]
@@ -358,7 +401,8 @@ class AIClientManager:
358
  def generate_response(self, system_prompt: str, user_prompt: str,
359
  temperature: float = None, call_type: str = "",
360
  agent_name: str = "DefaultAgent",
361
- medical_context: Optional[Dict] = None):
 
362
  """
363
  Enhanced response generation with medical context awareness
364
 
@@ -369,7 +413,7 @@ class AIClientManager:
369
  - Maintain full backward compatibility
370
  """
371
  try:
372
- client = self.get_client(agent_name)
373
  response = client.generate_response(
374
  system_prompt=system_prompt,
375
  user_prompt=user_prompt,
@@ -421,7 +465,8 @@ class AIClientManager:
421
  return {name: self.get_client_info(name) for name in self._clients}
422
 
423
  def call_spiritual_api(self, system_prompt: str, user_prompt: str,
424
- temperature: float = 0.7) -> str:
 
425
  """
426
  Call AI API for spiritual/emotional analysis.
427
 
@@ -435,16 +480,21 @@ class AIClientManager:
435
  Returns:
436
  AI response as string
437
  """
 
 
 
438
  return self.generate_response(
439
  system_prompt=system_prompt,
440
  user_prompt=user_prompt,
441
  temperature=temperature,
442
  call_type="spiritual_analysis",
443
- agent_name="SpiritualDistressAnalyzer"
 
444
  )
445
 
446
  def call_medical_api(self, system_prompt: str, user_prompt: str,
447
- temperature: float = 0.3) -> str:
 
448
  """
449
  Call AI API for medical assistance.
450
 
@@ -458,16 +508,20 @@ class AIClientManager:
458
  Returns:
459
  AI response as string
460
  """
 
 
 
461
  return self.generate_response(
462
  system_prompt=system_prompt,
463
  user_prompt=user_prompt,
464
  temperature=temperature,
465
  call_type="medical_assistance",
466
- agent_name="SoftMedicalTriage"
 
467
  )
468
 
469
  # Factory function for easy client creation
470
- def create_ai_client(agent_name: str) -> UniversalAIClient:
471
  """
472
  Create an AI client for a specific agent
473
 
@@ -477,7 +531,7 @@ def create_ai_client(agent_name: str) -> UniversalAIClient:
477
  Returns:
478
  Configured UniversalAIClient instance
479
  """
480
- return UniversalAIClient(agent_name)
481
 
482
  if __name__ == "__main__":
483
  print(" AI Client Test")
 
230
  based on agent configuration and availability
231
  """
232
 
233
+ def __init__(self, agent_name: str, model_override: Optional[str] = None):
234
  self.agent_name = agent_name
235
+ self.model_override = model_override
236
  self.config = get_agent_config(agent_name)
237
  self.client = None
238
  self.fallback_client = None
239
 
240
  self._initialize_clients()
241
+
242
+ @staticmethod
243
+ def _resolve_override_model(model_override: str) -> tuple[Optional[AIProvider], Optional[AIModel]]:
244
+ """Resolve a UI-provided model string into provider+AIModel.
245
+
246
+ Expected strings (from UI dropdowns):
247
+ - gemini-2.5-flash / gemini-2.0-flash / gemini-flash-latest
248
+ - claude-sonnet-4-5-20250929 / ...
249
+ """
250
+ if not model_override:
251
+ return None, None
252
+ override = model_override.strip()
253
+ if not override:
254
+ return None, None
255
+
256
+ try:
257
+ if override.startswith("gemini"):
258
+ return AIProvider.GEMINI, AIModel(override)
259
+ if override.startswith("claude"):
260
+ return AIProvider.ANTHROPIC, AIModel(override)
261
+ except Exception:
262
+ return None, None
263
+
264
+ return None, None
265
 
266
  def _initialize_clients(self):
267
  """Initialize primary and fallback clients"""
268
  primary_provider = self.config["provider"]
269
  primary_model = self.config["model"]
270
  temperature = self.config.get("temperature", 0.3)
271
+
272
+ # Optional: override model/provider (session-level setting from UI)
273
+ if self.model_override:
274
+ override_provider, override_model = self._resolve_override_model(self.model_override)
275
+ if override_provider is not None and override_model is not None:
276
+ primary_provider = override_provider
277
+ primary_model = override_model
278
 
279
  # Try to initialize primary client
280
  try:
 
375
  def __init__(self):
376
  self._clients = {} # Cache for AI clients
377
  self.call_counter = 0 # Backward compatibility
378
+
379
+ # Optional: allow an owning session/app to attach per-session overrides.
380
+ # Expected shape: {agent_name: model_string}
381
+ self.model_overrides: Dict[str, str] = {}
382
 
383
  # NEW: Enhanced client management for medical AI optimization
384
  self.provider_performance_metrics = {}
385
  self.medical_context_routing = {}
386
 
387
  # Enhanced client retrieval with performance tracking
388
+ def get_client(self, agent_name: str, model_override: Optional[str] = None):
389
+ """Get or create an AI client for the specified agent.
390
+
391
+ If `model_override` is provided, a new (non-cached) client is returned
392
+ to avoid cross-session leakage.
393
+ """
394
+ if model_override:
395
+ return create_ai_client(agent_name, model_override=model_override)
396
+
397
  if agent_name not in self._clients:
398
  self._clients[agent_name] = create_ai_client(agent_name)
399
  return self._clients[agent_name]
 
401
  def generate_response(self, system_prompt: str, user_prompt: str,
402
  temperature: float = None, call_type: str = "",
403
  agent_name: str = "DefaultAgent",
404
+ medical_context: Optional[Dict] = None,
405
+ model_override: Optional[str] = None):
406
  """
407
  Enhanced response generation with medical context awareness
408
 
 
413
  - Maintain full backward compatibility
414
  """
415
  try:
416
+ client = self.get_client(agent_name, model_override=model_override)
417
  response = client.generate_response(
418
  system_prompt=system_prompt,
419
  user_prompt=user_prompt,
 
465
  return {name: self.get_client_info(name) for name in self._clients}
466
 
467
  def call_spiritual_api(self, system_prompt: str, user_prompt: str,
468
+ temperature: float = 0.7,
469
+ model_override: Optional[str] = None) -> str:
470
  """
471
  Call AI API for spiritual/emotional analysis.
472
 
 
480
  Returns:
481
  AI response as string
482
  """
483
+ if model_override is None and self.model_overrides:
484
+ model_override = self.model_overrides.get("SpiritualDistressAnalyzer")
485
+
486
  return self.generate_response(
487
  system_prompt=system_prompt,
488
  user_prompt=user_prompt,
489
  temperature=temperature,
490
  call_type="spiritual_analysis",
491
+ agent_name="SpiritualDistressAnalyzer",
492
+ model_override=model_override,
493
  )
494
 
495
  def call_medical_api(self, system_prompt: str, user_prompt: str,
496
+ temperature: float = 0.3,
497
+ model_override: Optional[str] = None) -> str:
498
  """
499
  Call AI API for medical assistance.
500
 
 
508
  Returns:
509
  AI response as string
510
  """
511
+ if model_override is None and self.model_overrides:
512
+ model_override = self.model_overrides.get("SoftMedicalTriage")
513
+
514
  return self.generate_response(
515
  system_prompt=system_prompt,
516
  user_prompt=user_prompt,
517
  temperature=temperature,
518
  call_type="medical_assistance",
519
+ agent_name="SoftMedicalTriage",
520
+ model_override=model_override,
521
  )
522
 
523
  # Factory function for easy client creation
524
+ def create_ai_client(agent_name: str, model_override: Optional[str] = None) -> UniversalAIClient:
525
  """
526
  Create an AI client for a specific agent
527
 
 
531
  Returns:
532
  Configured UniversalAIClient instance
533
  """
534
+ return UniversalAIClient(agent_name, model_override=model_override)
535
 
536
  if __name__ == "__main__":
537
  print(" AI Client Test")
src/core/simplified_medical_app.py CHANGED
@@ -64,6 +64,8 @@ class SimplifiedMedicalApp:
64
 
65
  # AI client
66
  self.api = AIClientManager()
 
 
67
 
68
  # Medical components
69
  self.medical_assistant = MedicalAssistant(self.api)
@@ -89,6 +91,29 @@ class SimplifiedMedicalApp:
89
  )
90
 
91
  logger.info("✅ SimplifiedMedicalApp initialized")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def process_message(
94
  self,
 
64
 
65
  # AI client
66
  self.api = AIClientManager()
67
+ # Optional per-session model overrides (set by the UI Model Settings tab)
68
+ self.model_overrides = {}
69
 
70
  # Medical components
71
  self.medical_assistant = MedicalAssistant(self.api)
 
91
  )
92
 
93
  logger.info("✅ SimplifiedMedicalApp initialized")
94
+
95
+ def set_model_overrides(self, overrides: Optional[dict] = None) -> None:
96
+ """Set per-session model overrides.
97
+
98
+ `overrides` is expected to be a mapping of agent_name -> model string.
99
+ Example keys used by the UI:
100
+ - SpiritualDistressAnalyzer
101
+ - SoftSpiritualTriage
102
+ - TriageResponseEvaluator
103
+ - MedicalAssistant
104
+ - SoftMedicalTriage
105
+ """
106
+ self.model_overrides = dict(overrides or {})
107
+
108
+ # Propagate to AI manager so core components can read overrides.
109
+ if hasattr(self, "api") and hasattr(self.api, "model_overrides"):
110
+ self.api.model_overrides = dict(self.model_overrides)
111
+
112
+ def _get_model_override(self, agent_name: str) -> Optional[str]:
113
+ """Return the model override for an agent, if any."""
114
+ if not getattr(self, "model_overrides", None):
115
+ return None
116
+ return self.model_overrides.get(agent_name)
117
 
118
  def process_message(
119
  self,
src/interface/simplified_gradio_app.py CHANGED
@@ -597,6 +597,13 @@ Changes apply only to your current session.
597
  session = SimplifiedSessionData()
598
 
599
  session.update_activity()
 
 
 
 
 
 
 
600
  new_history, status = session.app_instance.process_message(message, history)
601
 
602
  # Get updated conversation stats
 
597
  session = SimplifiedSessionData()
598
 
599
  session.update_activity()
600
+
601
+ # Apply per-session model overrides (if configured in Model Settings)
602
+ custom_models = getattr(session, 'custom_models', None)
603
+ if custom_models:
604
+ session.app_instance.set_model_overrides(custom_models)
605
+ else:
606
+ session.app_instance.set_model_overrides({})
607
  new_history, status = session.app_instance.process_message(message, history)
608
 
609
  # Get updated conversation stats