DocUA commited on
Commit
37184d4
·
1 Parent(s): ba5a70c

Implement model overrides for AI clients and enhance interfaces for manual input and file upload

Browse files
src/config/ai_providers_config.py CHANGED
@@ -91,10 +91,10 @@ AGENT_CONFIGURATIONS = {
91
  },
92
 
93
  "SoftMedicalTriage": {
94
- "provider": AIProvider.ANTHROPIC,
95
- "model": AIModel.CLAUDE_SONNET_4_5,
96
  "temperature": 0.3,
97
- "reasoning": "Gentle triage requires empathy and nuanced understanding"
98
  },
99
 
100
  # Spiritual Distress Analyzer uses Google Gemini by default for speed/throughput
 
91
  },
92
 
93
  "SoftMedicalTriage": {
94
+ "provider": AIProvider.GEMINI,
95
+ "model": AIModel.GEMINI_2_5_FLASH,
96
  "temperature": 0.3,
97
+ "reasoning": "Batch and soft triage prioritize throughput; Gemini is the default unless overridden"
98
  },
99
 
100
  # Spiritual Distress Analyzer uses Google Gemini by default for speed/throughput
src/core/ai_client.py CHANGED
@@ -379,6 +379,14 @@ class AIClientManager:
379
  # Optional: allow an owning session/app to attach per-session overrides.
380
  # Expected shape: {agent_name: model_string}
381
  self.model_overrides: Dict[str, str] = {}
 
 
 
 
 
 
 
 
382
 
383
  # NEW: Enhanced client management for medical AI optimization
384
  self.provider_performance_metrics = {}
 
379
  # Optional: allow an owning session/app to attach per-session overrides.
380
  # Expected shape: {agent_name: model_string}
381
  self.model_overrides: Dict[str, str] = {}
382
+
383
+ def set_model_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
384
+ """Set per-session model overrides.
385
+
386
+ This is intentionally a thin setter so multiple UI controllers
387
+ (chat / manual input / file upload) can share the same mechanism.
388
+ """
389
+ self.model_overrides = dict(overrides or {})
390
 
391
  # NEW: Enhanced client management for medical AI optimization
392
  self.provider_performance_metrics = {}
src/interface/enhanced_verification_interface.py CHANGED
@@ -79,6 +79,9 @@ class EnhancedVerificationInterface:
79
  incomplete_sessions_state = gr.State(value=[])
80
  pending_mode_switch_state = gr.State(value=None)
81
  selected_session_state = gr.State(value=None)
 
 
 
82
 
83
  # Main container
84
  with gr.Column():
@@ -354,7 +357,7 @@ class EnhancedVerificationInterface:
354
  back_from_manual_btn = gr.Button("← Back to Mode Selection", size="sm")
355
  gr.Markdown("---")
356
  # Embed the actual interface
357
- manual_input_ui = EnhancedVerificationUIComponents.create_manual_input_interface()
358
 
359
  file_upload_interface = gr.Row(visible=False)
360
  with file_upload_interface:
@@ -364,7 +367,7 @@ class EnhancedVerificationInterface:
364
  back_from_file_btn = gr.Button("← Back to Mode Selection", size="sm")
365
  gr.Markdown("---")
366
  # Embed the actual interface
367
- file_upload_ui = EnhancedVerificationUIComponents.create_file_upload_interface()
368
 
369
  # Event handlers
370
  def initialize_interface():
 
79
  incomplete_sessions_state = gr.State(value=[])
80
  pending_mode_switch_state = gr.State(value=None)
81
  selected_session_state = gr.State(value=None)
82
+
83
+ # Model overrides (populated by the main app's AI Model Configuration, if wired)
84
+ model_overrides_state = gr.State(value={})
85
 
86
  # Main container
87
  with gr.Column():
 
357
  back_from_manual_btn = gr.Button("← Back to Mode Selection", size="sm")
358
  gr.Markdown("---")
359
  # Embed the actual interface
360
+ manual_input_ui = EnhancedVerificationUIComponents.create_manual_input_interface(model_overrides_state)
361
 
362
  file_upload_interface = gr.Row(visible=False)
363
  with file_upload_interface:
 
367
  back_from_file_btn = gr.Button("← Back to Mode Selection", size="sm")
368
  gr.Markdown("---")
369
  # Embed the actual interface
370
+ file_upload_ui = EnhancedVerificationUIComponents.create_file_upload_interface(model_overrides_state)
371
 
372
  # Event handlers
373
  def initialize_interface():
src/interface/enhanced_verification_ui.py CHANGED
@@ -854,7 +854,7 @@ class EnhancedVerificationUIComponents:
854
  return enhanced_dataset_interface
855
 
856
  @staticmethod
857
- def create_manual_input_interface() -> gr.Blocks:
858
  """
859
  Create manual input mode interface.
860
 
@@ -863,10 +863,10 @@ class EnhancedVerificationUIComponents:
863
  """
864
  # Import the complete manual input interface
865
  from src.interface.manual_input_interface import create_manual_input_interface
866
- return create_manual_input_interface()
867
 
868
  @staticmethod
869
- def create_file_upload_interface() -> gr.Blocks:
870
  """
871
  Create file upload mode interface.
872
 
@@ -875,7 +875,7 @@ class EnhancedVerificationUIComponents:
875
  """
876
  # Import the complete file upload interface
877
  from src.interface.file_upload_interface import create_file_upload_interface
878
- return create_file_upload_interface()
879
 
880
 
881
  def create_enhanced_verification_app() -> gr.Blocks:
 
854
  return enhanced_dataset_interface
855
 
856
  @staticmethod
857
+ def create_manual_input_interface(model_overrides_state: Optional[gr.State] = None) -> gr.Blocks:
858
  """
859
  Create manual input mode interface.
860
 
 
863
  """
864
  # Import the complete manual input interface
865
  from src.interface.manual_input_interface import create_manual_input_interface
866
+ return create_manual_input_interface(model_overrides_state=model_overrides_state)
867
 
868
  @staticmethod
869
+ def create_file_upload_interface(model_overrides_state: Optional[gr.State] = None) -> gr.Blocks:
870
  """
871
  Create file upload mode interface.
872
 
 
875
  """
876
  # Import the complete file upload interface
877
  from src.interface.file_upload_interface import create_file_upload_interface
878
+ return create_file_upload_interface(model_overrides_state=model_overrides_state)
879
 
880
 
881
  def create_enhanced_verification_app() -> gr.Blocks:
src/interface/file_upload_interface.py CHANGED
@@ -45,10 +45,18 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
45
  self.file_processor = FileProcessingService()
46
  self.store = JSONVerificationStore()
47
  self.ai_client = AIClientManager()
48
- self.current_session: Optional[EnhancedVerificationSession] = None
49
- self.current_file_result: Optional[FileUploadResult] = None
50
- self.current_message_index: int = 0
 
 
 
51
  self.batch_processing_start_time = None
 
 
 
 
 
52
 
53
  def process_uploaded_file(self, file_path: str) -> Tuple[bool, str, Optional[FileUploadResult], str]:
54
  """
@@ -403,6 +411,7 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
403
  system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
404
  user_prompt=user_prompt,
405
  temperature=0.3,
 
406
  )
407
 
408
  classification_result = self._parse_classification_response(raw_response)
@@ -691,7 +700,7 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
691
  return csv_content, xlsx_bytes
692
 
693
 
694
- def create_file_upload_interface() -> gr.Blocks:
695
  """
696
  Create the complete file upload mode interface.
697
 
@@ -699,6 +708,16 @@ def create_file_upload_interface() -> gr.Blocks:
699
  Gradio Blocks component for file upload mode
700
  """
701
  controller = FileUploadInterfaceController()
 
 
 
 
 
 
 
 
 
 
702
 
703
  with gr.Blocks() as file_upload_interface:
704
  # Headers and back button are in parent interface
 
45
  self.file_processor = FileProcessingService()
46
  self.store = JSONVerificationStore()
47
  self.ai_client = AIClientManager()
48
+ # Optional per-session model overrides (UI Model Settings tab)
49
+ self.model_overrides = {}
50
+ self.ai_client.set_model_overrides(self.model_overrides)
51
+ self.current_session = None
52
+ self.current_file_result = None
53
+ self.current_message_index = 0
54
  self.batch_processing_start_time = None
55
+
56
+ def set_model_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
57
+ """Set per-session model overrides from the UI."""
58
+ self.model_overrides = dict(overrides or {})
59
+ self.ai_client.set_model_overrides(self.model_overrides)
60
 
61
  def process_uploaded_file(self, file_path: str) -> Tuple[bool, str, Optional[FileUploadResult], str]:
62
  """
 
411
  system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
412
  user_prompt=user_prompt,
413
  temperature=0.3,
414
+ model_override=self.model_overrides.get("SpiritualDistressAnalyzer"),
415
  )
416
 
417
  classification_result = self._parse_classification_response(raw_response)
 
700
  return csv_content, xlsx_bytes
701
 
702
 
703
+ def create_file_upload_interface(model_overrides_state: Optional[gr.State] = None) -> gr.Blocks:
704
  """
705
  Create the complete file upload mode interface.
706
 
 
708
  Gradio Blocks component for file upload mode
709
  """
710
  controller = FileUploadInterfaceController()
711
+
712
+ # Apply any provided model overrides at build time.
713
+ # Note: this is safe even if the state is mutated later, because the click
714
+ # handlers also refresh overrides before calls.
715
+ if model_overrides_state is not None:
716
+ try:
717
+ controller.set_model_overrides(model_overrides_state.value or {})
718
+ except Exception:
719
+ # Don't fail UI creation if state isn't initialized yet
720
+ pass
721
 
722
  with gr.Blocks() as file_upload_interface:
723
  # Headers and back button are in parent interface
src/interface/manual_input_interface.py CHANGED
@@ -59,8 +59,17 @@ class ManualInputController(ProgressTrackingMixin):
59
  super().__init__(VerificationMode.MANUAL_INPUT)
60
  self.store = JSONVerificationStore()
61
  self.ai_client = AIClientManager()
 
62
  self.state = ManualInputState()
63
  self.classification_start_time = None
 
 
 
 
 
 
 
 
64
 
65
  def start_new_session(self, verifier_name: str) -> Tuple[bool, str, Optional[EnhancedVerificationSession]]:
66
  """
@@ -136,7 +145,6 @@ class ManualInputController(ProgressTrackingMixin):
136
  user_prompt=user_prompt,
137
  temperature=0.3
138
  )
139
-
140
  # Parse the response to extract classification details
141
  classification_result = self._parse_classification_response(response)
142
 
@@ -479,7 +487,7 @@ class ManualInputController(ProgressTrackingMixin):
479
  return False, f"❌ Error completing session: {str(e)}"
480
 
481
 
482
- def create_manual_input_interface() -> gr.Blocks:
483
  """
484
  Create the complete manual input mode interface.
485
 
@@ -487,6 +495,8 @@ def create_manual_input_interface() -> gr.Blocks:
487
  Gradio Blocks component for manual input mode
488
  """
489
  controller = ManualInputController()
 
 
490
 
491
  with gr.Blocks() as manual_input_interface:
492
  # Headers and back button are in parent interface
@@ -507,13 +517,13 @@ def create_manual_input_interface() -> gr.Blocks:
507
  "🚀",
508
  "lg"
509
  )
510
-
511
  # Session info display
512
  session_info_display = gr.Markdown(
513
  "Enter your name and click 'Start Session' to begin",
514
  label="Session Status"
515
  )
516
-
517
  # Manual input section (initially hidden)
518
  manual_input_section = gr.Row(visible=False)
519
  with manual_input_section:
@@ -534,6 +544,11 @@ def create_manual_input_interface() -> gr.Blocks:
534
  "🎯",
535
  "lg"
536
  )
 
 
 
 
 
537
 
538
  # Classification results (initially hidden)
539
  classification_results_section = gr.Row(visible=False)
@@ -667,8 +682,9 @@ def create_manual_input_interface() -> gr.Blocks:
667
  message # status_message
668
  )
669
 
670
- def on_classify_message(message_text):
671
  """Handle message classification."""
 
672
  success, message, classification = controller.classify_message(message_text)
673
 
674
  if success:
@@ -831,7 +847,7 @@ def create_manual_input_interface() -> gr.Blocks:
831
 
832
  classify_btn.click(
833
  on_classify_message,
834
- inputs=[message_input],
835
  outputs=[
836
  classification_results_section,
837
  classifier_decision_display,
 
59
  super().__init__(VerificationMode.MANUAL_INPUT)
60
  self.store = JSONVerificationStore()
61
  self.ai_client = AIClientManager()
62
+ self.model_overrides = {}
63
  self.state = ManualInputState()
64
  self.classification_start_time = None
65
+
66
+ # Ensure the underlying AI client manager sees our overrides.
67
+ self.ai_client.set_model_overrides(self.model_overrides)
68
+
69
+ def set_model_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
70
+ """Set per-session model overrides from the UI."""
71
+ self.model_overrides = dict(overrides or {})
72
+ self.ai_client.set_model_overrides(self.model_overrides)
73
 
74
  def start_new_session(self, verifier_name: str) -> Tuple[bool, str, Optional[EnhancedVerificationSession]]:
75
  """
 
145
  user_prompt=user_prompt,
146
  temperature=0.3
147
  )
 
148
  # Parse the response to extract classification details
149
  classification_result = self._parse_classification_response(response)
150
 
 
487
  return False, f"❌ Error completing session: {str(e)}"
488
 
489
 
490
+ def create_manual_input_interface(model_overrides_state: Optional[gr.State] = None) -> gr.Blocks:
491
  """
492
  Create the complete manual input mode interface.
493
 
 
495
  Gradio Blocks component for manual input mode
496
  """
497
  controller = ManualInputController()
498
+ if model_overrides_state is None:
499
+ model_overrides_state = gr.State(value={})
500
 
501
  with gr.Blocks() as manual_input_interface:
502
  # Headers and back button are in parent interface
 
517
  "🚀",
518
  "lg"
519
  )
520
+
521
  # Session info display
522
  session_info_display = gr.Markdown(
523
  "Enter your name and click 'Start Session' to begin",
524
  label="Session Status"
525
  )
526
+
527
  # Manual input section (initially hidden)
528
  manual_input_section = gr.Row(visible=False)
529
  with manual_input_section:
 
544
  "🎯",
545
  "lg"
546
  )
547
+
548
+ # Apply model overrides right before classification
549
+ def _classify_with_overrides(message_text: str, overrides: Dict[str, str]):
550
+ controller.set_model_overrides(overrides or {})
551
+ return controller.classify_message(message_text)
552
 
553
  # Classification results (initially hidden)
554
  classification_results_section = gr.Row(visible=False)
 
682
  message # status_message
683
  )
684
 
685
+ def on_classify_message(message_text, overrides):
686
  """Handle message classification."""
687
+ controller.set_model_overrides(overrides or {})
688
  success, message, classification = controller.classify_message(message_text)
689
 
690
  if success:
 
847
 
848
  classify_btn.click(
849
  on_classify_message,
850
+ inputs=[message_input, model_overrides_state],
851
  outputs=[
852
  classification_results_section,
853
  classifier_decision_display,
tests/test_model_overrides.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from src.core.ai_client import AIClientManager
4
+ from src.config.ai_providers_config import get_agent_config
5
+
6
+
7
+ def test_default_spiritual_distress_analyzer_is_gemini():
8
+ config = get_agent_config("SpiritualDistressAnalyzer")
9
+ assert config["provider"].value == "gemini"
10
+ assert config["model"].value.startswith("gemini")
11
+
12
+
13
+ def test_model_override_creates_non_cached_client_and_does_not_leak():
14
+ mgr = AIClientManager()
15
+
16
+ # Baseline (cached) client should be created once
17
+ base = mgr.get_client("SpiritualDistressAnalyzer")
18
+
19
+ # Override should return a fresh client each time (non-cached)
20
+ override_1 = mgr.get_client("SpiritualDistressAnalyzer", model_override="claude-sonnet-4-5-20250929")
21
+ override_2 = mgr.get_client("SpiritualDistressAnalyzer", model_override="claude-sonnet-4-5-20250929")
22
+
23
+ assert override_1 is not override_2
24
+
25
+ # And it should not replace the cached base client
26
+ base_again = mgr.get_client("SpiritualDistressAnalyzer")
27
+ assert base_again is base
28
+
29
+
30
+ def test_manager_attached_model_overrides_apply_to_spiritual_and_medical_calls_without_args(monkeypatch):
31
+ mgr = AIClientManager()
32
+
33
+ # Attach session-level overrides as done by SimplifiedMedicalApp
34
+ mgr.model_overrides = {
35
+ "SpiritualDistressAnalyzer": "claude-sonnet-4-5-20250929",
36
+ "SoftMedicalTriage": "gemini-2.5-flash",
37
+ }
38
+
39
+ captured = []
40
+
41
+ def fake_generate_response(*, system_prompt, user_prompt, temperature=None, call_type="", agent_name="DefaultAgent", medical_context=None, model_override=None):
42
+ captured.append({
43
+ "agent_name": agent_name,
44
+ "call_type": call_type,
45
+ "model_override": model_override,
46
+ })
47
+ return "ok"
48
+
49
+ monkeypatch.setattr(mgr, "generate_response", fake_generate_response)
50
+
51
+ mgr.call_spiritual_api("sys", "user")
52
+ mgr.call_medical_api("sys", "user")
53
+
54
+ assert captured[0]["agent_name"] == "SpiritualDistressAnalyzer"
55
+ assert captured[0]["model_override"] == "claude-sonnet-4-5-20250929"
56
+
57
+ assert captured[1]["agent_name"] == "SoftMedicalTriage"
58
+ assert captured[1]["model_override"] == "gemini-2.5-flash"
59
+
60
+
61
+ def test_two_managers_do_not_share_model_overrides():
62
+ mgr_a = AIClientManager()
63
+ mgr_b = AIClientManager()
64
+
65
+ mgr_a.model_overrides = {"SpiritualDistressAnalyzer": "claude-sonnet-4-5-20250929"}
66
+ mgr_b.model_overrides = {"SpiritualDistressAnalyzer": "gemini-2.5-flash"}
67
+
68
+ assert mgr_a.model_overrides["SpiritualDistressAnalyzer"] != mgr_b.model_overrides["SpiritualDistressAnalyzer"]