DocUA commited on
Commit
2e0e95f
·
1 Parent(s): 4367366

Add per-session prompt override functionality across interfaces and AI client

Browse files
src/core/ai_client.py CHANGED
@@ -380,6 +380,10 @@ class AIClientManager:
380
  # Expected shape: {agent_name: model_string}
381
  self.model_overrides: Dict[str, str] = {}
382
 
 
 
 
 
383
  def set_model_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
384
  """Set per-session model overrides.
385
 
@@ -387,6 +391,22 @@ class AIClientManager:
387
  (chat / manual input / file upload) can share the same mechanism.
388
  """
389
  self.model_overrides = dict(overrides or {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
  # NEW: Enhanced client management for medical AI optimization
392
  self.provider_performance_metrics = {}
@@ -491,6 +511,9 @@ class AIClientManager:
491
  if model_override is None and self.model_overrides:
492
  model_override = self.model_overrides.get("SpiritualDistressAnalyzer")
493
 
 
 
 
494
  return self.generate_response(
495
  system_prompt=system_prompt,
496
  user_prompt=user_prompt,
@@ -499,6 +522,28 @@ class AIClientManager:
499
  agent_name="SpiritualDistressAnalyzer",
500
  model_override=model_override,
501
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
  def call_medical_api(self, system_prompt: str, user_prompt: str,
504
  temperature: float = 0.3,
@@ -519,6 +564,9 @@ class AIClientManager:
519
  if model_override is None and self.model_overrides:
520
  model_override = self.model_overrides.get("SoftMedicalTriage")
521
 
 
 
 
522
  return self.generate_response(
523
  system_prompt=system_prompt,
524
  user_prompt=user_prompt,
@@ -528,6 +576,63 @@ class AIClientManager:
528
  model_override=model_override,
529
  )
530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  # Factory function for easy client creation
532
  def create_ai_client(agent_name: str, model_override: Optional[str] = None) -> UniversalAIClient:
533
  """
 
380
  # Expected shape: {agent_name: model_string}
381
  self.model_overrides: Dict[str, str] = {}
382
 
383
+ # Optional per-session prompt overrides.
384
+ # Expected shape: {agent_name: system_prompt_string}
385
+ self.prompt_overrides = {}
386
+
387
  def set_model_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
388
  """Set per-session model overrides.
389
 
 
391
  (chat / manual input / file upload) can share the same mechanism.
392
  """
393
  self.model_overrides = dict(overrides or {})
394
+
395
+ def set_prompt_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
396
+ """Set per-session prompt overrides.
397
+
398
+ This avoids mutating module-level prompt constants and prevents
399
+ cross-session leakage.
400
+
401
+ Expected keys are agent names, e.g.:
402
+ - SpiritualDistressAnalyzer
403
+ - SoftSpiritualTriage
404
+ - TriageResponseEvaluator
405
+ - MedicalAssistant
406
+ - SoftMedicalTriage
407
+ - EntryClassifier
408
+ """
409
+ self.prompt_overrides = dict(overrides or {})
410
 
411
  # NEW: Enhanced client management for medical AI optimization
412
  self.provider_performance_metrics = {}
 
511
  if model_override is None and self.model_overrides:
512
  model_override = self.model_overrides.get("SpiritualDistressAnalyzer")
513
 
514
+ if self.prompt_overrides and "SpiritualDistressAnalyzer" in self.prompt_overrides:
515
+ system_prompt = self.prompt_overrides["SpiritualDistressAnalyzer"]
516
+
517
  return self.generate_response(
518
  system_prompt=system_prompt,
519
  user_prompt=user_prompt,
 
522
  agent_name="SpiritualDistressAnalyzer",
523
  model_override=model_override,
524
  )
525
+
526
+ def call_entry_classifier_api(self, system_prompt: str, user_prompt: str,
527
+ temperature: float = 0.3,
528
+ model_override: Optional[str] = None) -> str:
529
+ """Call AI API for entry classification.
530
+
531
+ This is used by Enhanced Verification manual input / file upload modes.
532
+ """
533
+ if model_override is None and self.model_overrides:
534
+ model_override = self.model_overrides.get("EntryClassifier")
535
+
536
+ if self.prompt_overrides and "EntryClassifier" in self.prompt_overrides:
537
+ system_prompt = self.prompt_overrides["EntryClassifier"]
538
+
539
+ return self.generate_response(
540
+ system_prompt=system_prompt,
541
+ user_prompt=user_prompt,
542
+ temperature=temperature,
543
+ call_type="entry_classification",
544
+ agent_name="EntryClassifier",
545
+ model_override=model_override,
546
+ )
547
 
548
  def call_medical_api(self, system_prompt: str, user_prompt: str,
549
  temperature: float = 0.3,
 
564
  if model_override is None and self.model_overrides:
565
  model_override = self.model_overrides.get("SoftMedicalTriage")
566
 
567
+ if self.prompt_overrides and "SoftMedicalTriage" in self.prompt_overrides:
568
+ system_prompt = self.prompt_overrides["SoftMedicalTriage"]
569
+
570
  return self.generate_response(
571
  system_prompt=system_prompt,
572
  user_prompt=user_prompt,
 
576
  model_override=model_override,
577
  )
578
 
579
+ def call_soft_spiritual_triage_api(self, system_prompt: str, user_prompt: str,
580
+ temperature: float = 0.3,
581
+ model_override: Optional[str] = None) -> str:
582
+ """Call AI API for soft spiritual triage question generation."""
583
+ if model_override is None and self.model_overrides:
584
+ model_override = self.model_overrides.get("SoftSpiritualTriage")
585
+
586
+ if self.prompt_overrides and "SoftSpiritualTriage" in self.prompt_overrides:
587
+ system_prompt = self.prompt_overrides["SoftSpiritualTriage"]
588
+
589
+ return self.generate_response(
590
+ system_prompt=system_prompt,
591
+ user_prompt=user_prompt,
592
+ temperature=temperature,
593
+ call_type="soft_spiritual_triage",
594
+ agent_name="SoftSpiritualTriage",
595
+ model_override=model_override,
596
+ )
597
+
598
+ def call_triage_response_evaluator_api(self, system_prompt: str, user_prompt: str,
599
+ temperature: float = 0.3,
600
+ model_override: Optional[str] = None) -> str:
601
+ """Call AI API for triage response evaluation."""
602
+ if model_override is None and self.model_overrides:
603
+ model_override = self.model_overrides.get("TriageResponseEvaluator")
604
+
605
+ if self.prompt_overrides and "TriageResponseEvaluator" in self.prompt_overrides:
606
+ system_prompt = self.prompt_overrides["TriageResponseEvaluator"]
607
+
608
+ return self.generate_response(
609
+ system_prompt=system_prompt,
610
+ user_prompt=user_prompt,
611
+ temperature=temperature,
612
+ call_type="triage_response_evaluator",
613
+ agent_name="TriageResponseEvaluator",
614
+ model_override=model_override,
615
+ )
616
+
617
+ def call_medical_assistant_api(self, system_prompt: str, user_prompt: str,
618
+ temperature: float = 0.3,
619
+ model_override: Optional[str] = None) -> str:
620
+ """Call AI API for medical assistant responses."""
621
+ if model_override is None and self.model_overrides:
622
+ model_override = self.model_overrides.get("MedicalAssistant")
623
+
624
+ if self.prompt_overrides and "MedicalAssistant" in self.prompt_overrides:
625
+ system_prompt = self.prompt_overrides["MedicalAssistant"]
626
+
627
+ return self.generate_response(
628
+ system_prompt=system_prompt,
629
+ user_prompt=user_prompt,
630
+ temperature=temperature,
631
+ call_type="medical_assistant",
632
+ agent_name="MedicalAssistant",
633
+ model_override=model_override,
634
+ )
635
+
636
  # Factory function for easy client creation
637
  def create_ai_client(agent_name: str, model_override: Optional[str] = None) -> UniversalAIClient:
638
  """
src/core/simplified_medical_app.py CHANGED
@@ -66,6 +66,10 @@ class SimplifiedMedicalApp:
66
  self.api = AIClientManager()
67
  # Optional per-session model overrides (set by the UI Model Settings tab)
68
  self.model_overrides = {}
 
 
 
 
69
 
70
  # Medical components
71
  self.medical_assistant = MedicalAssistant(self.api)
@@ -81,7 +85,7 @@ class SimplifiedMedicalApp:
81
  logger.info(f"✅ Loaded patient: {self.clinical_background.patient_name}")
82
 
83
  # Session state
84
- self.chat_history: List[ChatMessage] = []
85
  self.spiritual_state = SessionSpiritualState()
86
  self.session_active = False
87
 
@@ -109,6 +113,16 @@ class SimplifiedMedicalApp:
109
  if hasattr(self, "api") and hasattr(self.api, "model_overrides"):
110
  self.api.model_overrides = dict(self.model_overrides)
111
 
 
 
 
 
 
 
 
 
 
 
112
  def _get_model_override(self, agent_name: str) -> Optional[str]:
113
  """Return the model override for an agent, if any."""
114
  if not getattr(self, "model_overrides", None):
 
66
  self.api = AIClientManager()
67
  # Optional per-session model overrides (set by the UI Model Settings tab)
68
  self.model_overrides = {}
69
+
70
+ # Optional per-session prompt overrides (set by the UI Edit Prompts tab)
71
+ # Expected: {agent_name: system_prompt_text}
72
+ self.prompt_overrides = {}
73
 
74
  # Medical components
75
  self.medical_assistant = MedicalAssistant(self.api)
 
85
  logger.info(f"✅ Loaded patient: {self.clinical_background.patient_name}")
86
 
87
  # Session state
88
+ self.chat_history = []
89
  self.spiritual_state = SessionSpiritualState()
90
  self.session_active = False
91
 
 
113
  if hasattr(self, "api") and hasattr(self.api, "model_overrides"):
114
  self.api.model_overrides = dict(self.model_overrides)
115
 
116
+ def set_prompt_overrides(self, overrides: Optional[dict] = None) -> None:
117
+ """Set per-session prompt overrides.
118
+
119
+ `overrides` is expected to be a mapping of agent_name -> system_prompt string.
120
+ """
121
+ self.prompt_overrides = dict(overrides or {})
122
+
123
+ if hasattr(self, "api") and hasattr(self.api, "set_prompt_overrides"):
124
+ self.api.set_prompt_overrides(self.prompt_overrides)
125
+
126
  def _get_model_override(self, agent_name: str) -> Optional[str]:
127
  """Return the model override for an agent, if any."""
128
  if not getattr(self, "model_overrides", None):
src/interface/file_upload_interface.py CHANGED
@@ -48,6 +48,9 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
48
  # Optional per-session model overrides (UI Model Settings tab)
49
  self.model_overrides = {}
50
  self.ai_client.set_model_overrides(self.model_overrides)
 
 
 
51
  self.current_session = None
52
  self.current_file_result = None
53
  self.current_message_index = 0
@@ -57,6 +60,11 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
57
  """Set per-session model overrides from the UI."""
58
  self.model_overrides = dict(overrides or {})
59
  self.ai_client.set_model_overrides(self.model_overrides)
 
 
 
 
 
60
 
61
  def process_uploaded_file(self, file_path: str) -> Tuple[bool, str, Optional[FileUploadResult], str]:
62
  """
@@ -298,10 +306,10 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
298
  # Call AI classifier using the same approach as manual input
299
  user_prompt = f"Please analyze this patient message for spiritual distress:\n\n{test_message.text}"
300
 
301
- response = self.ai_client.call_spiritual_api(
302
  system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
303
  user_prompt=user_prompt,
304
- temperature=0.3
305
  )
306
 
307
  # Parse the response to extract classification details
@@ -407,11 +415,11 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
407
  "Please analyze this patient message for spiritual distress:\n\n"
408
  f"{test_message.text}"
409
  )
410
- raw_response = self.ai_client.call_spiritual_api(
411
  system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
412
  user_prompt=user_prompt,
413
  temperature=0.3,
414
- model_override=self.model_overrides.get("SpiritualDistressAnalyzer"),
415
  )
416
 
417
  classification_result = self._parse_classification_response(raw_response)
 
48
  # Optional per-session model overrides (UI Model Settings tab)
49
  self.model_overrides = {}
50
  self.ai_client.set_model_overrides(self.model_overrides)
51
+ # Optional per-session prompt overrides (UI Edit Prompts tab)
52
+ self.prompt_overrides = {}
53
+ self.ai_client.set_prompt_overrides(self.prompt_overrides)
54
  self.current_session = None
55
  self.current_file_result = None
56
  self.current_message_index = 0
 
60
  """Set per-session model overrides from the UI."""
61
  self.model_overrides = dict(overrides or {})
62
  self.ai_client.set_model_overrides(self.model_overrides)
63
+
64
+ def set_prompt_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
65
+ """Set per-session prompt overrides from the UI."""
66
+ self.prompt_overrides = dict(overrides or {})
67
+ self.ai_client.set_prompt_overrides(self.prompt_overrides)
68
 
69
  def process_uploaded_file(self, file_path: str) -> Tuple[bool, str, Optional[FileUploadResult], str]:
70
  """
 
306
  # Call AI classifier using the same approach as manual input
307
  user_prompt = f"Please analyze this patient message for spiritual distress:\n\n{test_message.text}"
308
 
309
+ response = self.ai_client.call_entry_classifier_api(
310
  system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
311
  user_prompt=user_prompt,
312
+ temperature=0.3,
313
  )
314
 
315
  # Parse the response to extract classification details
 
415
  "Please analyze this patient message for spiritual distress:\n\n"
416
  f"{test_message.text}"
417
  )
418
+ raw_response = self.ai_client.call_entry_classifier_api(
419
  system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
420
  user_prompt=user_prompt,
421
  temperature=0.3,
422
+ model_override=self.model_overrides.get("EntryClassifier"),
423
  )
424
 
425
  classification_result = self._parse_classification_response(raw_response)
src/interface/manual_input_interface.py CHANGED
@@ -60,16 +60,23 @@ class ManualInputController(ProgressTrackingMixin):
60
  self.store = JSONVerificationStore()
61
  self.ai_client = AIClientManager()
62
  self.model_overrides = {}
 
63
  self.state = ManualInputState()
64
  self.classification_start_time = None
65
 
66
  # Ensure the underlying AI client manager sees our overrides.
67
  self.ai_client.set_model_overrides(self.model_overrides)
 
68
 
69
  def set_model_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
70
  """Set per-session model overrides from the UI."""
71
  self.model_overrides = dict(overrides or {})
72
  self.ai_client.set_model_overrides(self.model_overrides)
 
 
 
 
 
73
 
74
  def start_new_session(self, verifier_name: str) -> Tuple[bool, str, Optional[EnhancedVerificationSession]]:
75
  """
@@ -139,11 +146,11 @@ class ManualInputController(ProgressTrackingMixin):
139
 
140
  # Call AI classifier
141
  user_prompt = f"Please analyze this patient message for spiritual distress:\n\n{message_text.strip()}"
142
-
143
- response = self.ai_client.call_spiritual_api(
144
  system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
145
  user_prompt=user_prompt,
146
- temperature=0.3
147
  )
148
  # Parse the response to extract classification details
149
  classification_result = self._parse_classification_response(response)
 
60
  self.store = JSONVerificationStore()
61
  self.ai_client = AIClientManager()
62
  self.model_overrides = {}
63
+ self.prompt_overrides = {}
64
  self.state = ManualInputState()
65
  self.classification_start_time = None
66
 
67
  # Ensure the underlying AI client manager sees our overrides.
68
  self.ai_client.set_model_overrides(self.model_overrides)
69
+ self.ai_client.set_prompt_overrides(self.prompt_overrides)
70
 
71
  def set_model_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
72
  """Set per-session model overrides from the UI."""
73
  self.model_overrides = dict(overrides or {})
74
  self.ai_client.set_model_overrides(self.model_overrides)
75
+
76
+ def set_prompt_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
77
+ """Set per-session prompt overrides from the UI."""
78
+ self.prompt_overrides = dict(overrides or {})
79
+ self.ai_client.set_prompt_overrides(self.prompt_overrides)
80
 
81
  def start_new_session(self, verifier_name: str) -> Tuple[bool, str, Optional[EnhancedVerificationSession]]:
82
  """
 
146
 
147
  # Call AI classifier
148
  user_prompt = f"Please analyze this patient message for spiritual distress:\n\n{message_text.strip()}"
149
+
150
+ response = self.ai_client.call_entry_classifier_api(
151
  system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
152
  user_prompt=user_prompt,
153
+ temperature=0.3,
154
  )
155
  # Parse the response to extract classification details
156
  classification_result = self._parse_classification_response(response)
src/interface/simplified_gradio_app.py CHANGED
@@ -604,6 +604,13 @@ Changes apply only to your current session.
604
  session.app_instance.set_model_overrides(custom_models)
605
  else:
606
  session.app_instance.set_model_overrides({})
 
 
 
 
 
 
 
607
  new_history, status = session.app_instance.process_message(message, history)
608
 
609
  # Get updated conversation stats
@@ -866,8 +873,22 @@ Changes apply only to your current session.
866
 
867
  return formatted
868
 
869
- def load_prompt(prompt_name: str):
870
- """Load selected prompt for editing."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871
  from src.core.spiritual_monitor import SYSTEM_PROMPT_SPIRITUAL_MONITOR
872
  from src.core.soft_triage_manager import (
873
  SYSTEM_PROMPT_TRIAGE_QUESTION,
@@ -886,7 +907,12 @@ Changes apply only to your current session.
886
  "🩺 Soft Medical Triage": SYSTEM_PROMPT_SOFT_MEDICAL_TRIAGE
887
  }
888
 
 
889
  prompt_text = prompts.get(prompt_name, "")
 
 
 
 
890
 
891
  # Format with HTML for display
892
  formatted_html = format_prompt_with_html(prompt_text)
@@ -924,32 +950,18 @@ Changes apply only to your current session.
924
  </div>"""
925
  return error_html, session
926
 
927
- # Store custom prompt in session
928
  if not hasattr(session, 'custom_prompts'):
929
  session.custom_prompts = {}
930
 
931
- session.custom_prompts[prompt_name] = prompt_text
932
-
933
- # Apply to app components
934
- try:
935
- if prompt_name == "🔍 Spiritual Monitor (Classifier)":
936
- # Update spiritual monitor prompt
937
- import src.core.spiritual_monitor as sm
938
- sm.SYSTEM_PROMPT_SPIRITUAL_MONITOR = prompt_text
939
- elif prompt_name == "🟡 Soft Spiritual Triage":
940
- import src.core.soft_triage_manager as stm
941
- stm.SYSTEM_PROMPT_TRIAGE_QUESTION = prompt_text
942
- elif prompt_name == "📊 Triage Response Evaluator":
943
- import src.core.soft_triage_manager as stm
944
- stm.SYSTEM_PROMPT_TRIAGE_EVALUATE = prompt_text
945
- elif prompt_name == "🏥 Medical Assistant":
946
- import src.config.prompts as p
947
- p.SYSTEM_PROMPT_MEDICAL_ASSISTANT = prompt_text
948
- elif prompt_name == "🩺 Soft Medical Triage":
949
- import src.config.prompts as p
950
- p.SYSTEM_PROMPT_SOFT_MEDICAL_TRIAGE = prompt_text
951
-
952
- status = f"""<div style="padding: 1em; background-color: #ecfdf5; border-left: 4px solid #10b981; border-radius: 4px;">
953
  <h4 style="color: #059669; margin-top: 0;">✅ Prompt Applied Successfully</h4>
954
 
955
  <p><strong>Prompt:</strong> {prompt_name}</p>
@@ -961,15 +973,8 @@ Changes apply only to your current session.
961
  To revert, use "Reset to Default" button.
962
  </p>
963
  </div>"""
964
-
965
- return status, session
966
-
967
- except Exception as e:
968
- error_html = f"""<div style="padding: 1em; background-color: #fef2f2; border-left: 4px solid #dc2626; border-radius: 4px;">
969
- <h4 style="color: #dc2626; margin-top: 0;">❌ Error Applying Prompt</h4>
970
- <p style="margin-bottom: 0;"><code>{str(e)}</code></p>
971
- </div>"""
972
- return error_html, session
973
 
974
  def reset_prompt(prompt_name: str, session: SimplifiedSessionData):
975
  """Reset prompt to default."""
@@ -977,14 +982,16 @@ To revert, use "Reset to Default" button.
977
  session = SimplifiedSessionData()
978
 
979
  # Remove from custom prompts
980
- if hasattr(session, 'custom_prompts') and prompt_name in session.custom_prompts:
981
- del session.custom_prompts[prompt_name]
 
 
 
 
 
982
 
983
  # Reload default
984
- prompt_text, info, status = load_prompt(prompt_name)
985
-
986
- # Reapply default
987
- apply_status, session = apply_prompt_changes(prompt_name, prompt_text, session)
988
 
989
  reset_status = """<div style="padding: 1em; background-color: #eff6ff; border-left: 4px solid #3b82f6; border-radius: 4px;">
990
  <h4 style="color: #2563eb; margin-top: 0;">🔄 Reset to Default</h4>
@@ -2235,7 +2242,7 @@ To revert, use "Reset to Default" button.
2235
  # Prompt editing events
2236
  load_prompt_btn.click(
2237
  load_prompt,
2238
- inputs=[prompt_selector],
2239
  outputs=[prompt_editor, prompt_info_display, prompt_status]
2240
  )
2241
 
@@ -2254,7 +2261,7 @@ To revert, use "Reset to Default" button.
2254
  # Auto-load prompt when selector changes
2255
  prompt_selector.change(
2256
  load_prompt,
2257
- inputs=[prompt_selector],
2258
  outputs=[prompt_editor, prompt_info_display, prompt_status]
2259
  )
2260
 
 
604
  session.app_instance.set_model_overrides(custom_models)
605
  else:
606
  session.app_instance.set_model_overrides({})
607
+
608
+ # Apply per-session prompt overrides (if configured in Edit Prompts)
609
+ custom_prompts = getattr(session, 'custom_prompts', None)
610
+ if custom_prompts:
611
+ session.app_instance.set_prompt_overrides(custom_prompts)
612
+ else:
613
+ session.app_instance.set_prompt_overrides({})
614
  new_history, status = session.app_instance.process_message(message, history)
615
 
616
  # Get updated conversation stats
 
873
 
874
  return formatted
875
 
876
+ def _prompt_name_to_agent(prompt_name: str) -> str:
877
+ """Map UI prompt selection to internal agent/prompt key."""
878
+ mapping = {
879
+ "🔍 Spiritual Monitor (Classifier)": "SpiritualDistressAnalyzer",
880
+ "🟡 Soft Spiritual Triage": "SoftSpiritualTriage",
881
+ "📊 Triage Response Evaluator": "TriageResponseEvaluator",
882
+ "🏥 Medical Assistant": "MedicalAssistant",
883
+ "🩺 Soft Medical Triage": "SoftMedicalTriage",
884
+ }
885
+ return mapping.get(prompt_name, prompt_name)
886
+
887
+ def load_prompt(prompt_name: str, session: Optional[SimplifiedSessionData] = None):
888
+ """Load selected prompt for editing.
889
+
890
+ If a session override exists, show it instead of the default.
891
+ """
892
  from src.core.spiritual_monitor import SYSTEM_PROMPT_SPIRITUAL_MONITOR
893
  from src.core.soft_triage_manager import (
894
  SYSTEM_PROMPT_TRIAGE_QUESTION,
 
907
  "🩺 Soft Medical Triage": SYSTEM_PROMPT_SOFT_MEDICAL_TRIAGE
908
  }
909
 
910
+ agent_key = _prompt_name_to_agent(prompt_name)
911
  prompt_text = prompts.get(prompt_name, "")
912
+
913
+ # Prefer session override (true session-scoped behavior)
914
+ if session is not None and hasattr(session, 'custom_prompts'):
915
+ prompt_text = session.custom_prompts.get(agent_key, prompt_text)
916
 
917
  # Format with HTML for display
918
  formatted_html = format_prompt_with_html(prompt_text)
 
950
  </div>"""
951
  return error_html, session
952
 
953
+ # Store custom prompt in session (session-scoped)
954
  if not hasattr(session, 'custom_prompts'):
955
  session.custom_prompts = {}
956
 
957
+ agent_key = _prompt_name_to_agent(prompt_name)
958
+ session.custom_prompts[agent_key] = prompt_text
959
+
960
+ # Apply into the current session app instance (no global mutation)
961
+ if hasattr(session, 'app_instance') and hasattr(session.app_instance, 'set_prompt_overrides'):
962
+ session.app_instance.set_prompt_overrides(session.custom_prompts)
963
+
964
+ status = f"""<div style="padding: 1em; background-color: #ecfdf5; border-left: 4px solid #10b981; border-radius: 4px;">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965
  <h4 style="color: #059669; margin-top: 0;">✅ Prompt Applied Successfully</h4>
966
 
967
  <p><strong>Prompt:</strong> {prompt_name}</p>
 
973
  To revert, use "Reset to Default" button.
974
  </p>
975
  </div>"""
976
+
977
+ return status, session
 
 
 
 
 
 
 
978
 
979
  def reset_prompt(prompt_name: str, session: SimplifiedSessionData):
980
  """Reset prompt to default."""
 
982
  session = SimplifiedSessionData()
983
 
984
  # Remove from custom prompts
985
+ agent_key = _prompt_name_to_agent(prompt_name)
986
+ if hasattr(session, 'custom_prompts') and agent_key in session.custom_prompts:
987
+ del session.custom_prompts[agent_key]
988
+
989
+ # Apply into current session app instance
990
+ if hasattr(session, 'app_instance') and hasattr(session.app_instance, 'set_prompt_overrides'):
991
+ session.app_instance.set_prompt_overrides(getattr(session, 'custom_prompts', {}))
992
 
993
  # Reload default
994
+ prompt_text, info, status = load_prompt(prompt_name, session)
 
 
 
995
 
996
  reset_status = """<div style="padding: 1em; background-color: #eff6ff; border-left: 4px solid #3b82f6; border-radius: 4px;">
997
  <h4 style="color: #2563eb; margin-top: 0;">🔄 Reset to Default</h4>
 
2242
  # Prompt editing events
2243
  load_prompt_btn.click(
2244
  load_prompt,
2245
+ inputs=[prompt_selector, session_data],
2246
  outputs=[prompt_editor, prompt_info_display, prompt_status]
2247
  )
2248
 
 
2261
  # Auto-load prompt when selector changes
2262
  prompt_selector.change(
2263
  load_prompt,
2264
+ inputs=[prompt_selector, session_data],
2265
  outputs=[prompt_editor, prompt_info_display, prompt_status]
2266
  )
2267
 
tests/test_model_overrides_cross_mode.py CHANGED
@@ -26,7 +26,7 @@ def test_file_upload_controller_applies_overrides_to_ai_client():
26
 
27
  def test_file_upload_batch_passes_override_into_call(monkeypatch):
28
  controller = FileUploadInterfaceController()
29
- controller.set_model_overrides({"SpiritualDistressAnalyzer": "gemini-batch-override"})
30
 
31
  # Minimal session + dataset state so the batch method can run without UI.
32
  controller.current_session = types.SimpleNamespace(
@@ -60,13 +60,13 @@ def test_file_upload_batch_passes_override_into_call(monkeypatch):
60
 
61
  captured = {}
62
 
63
- def fake_call_spiritual_api(*, system_prompt=None, user_prompt=None, model_override=None, **kwargs):
64
  captured["system_prompt"] = system_prompt
65
  captured["user_prompt"] = user_prompt
66
  captured["model_override"] = model_override
67
  return "green"
68
 
69
- monkeypatch.setattr(controller.ai_client, "call_spiritual_api", fake_call_spiritual_api)
70
 
71
  # Run the simplest possible classification path.
72
  result = controller.run_batch_classification()
 
26
 
27
  def test_file_upload_batch_passes_override_into_call(monkeypatch):
28
  controller = FileUploadInterfaceController()
29
+ controller.set_model_overrides({"EntryClassifier": "gemini-batch-override"})
30
 
31
  # Minimal session + dataset state so the batch method can run without UI.
32
  controller.current_session = types.SimpleNamespace(
 
60
 
61
  captured = {}
62
 
63
+ def fake_call_entry_classifier_api(*, system_prompt=None, user_prompt=None, model_override=None, **kwargs):
64
  captured["system_prompt"] = system_prompt
65
  captured["user_prompt"] = user_prompt
66
  captured["model_override"] = model_override
67
  return "green"
68
 
69
+ monkeypatch.setattr(controller.ai_client, "call_entry_classifier_api", fake_call_entry_classifier_api)
70
 
71
  # Run the simplest possible classification path.
72
  result = controller.run_batch_classification()
tests/verification_mode/test_properties_persistence.py CHANGED
@@ -213,7 +213,7 @@ class TestSessionStatePersistence:
213
  assert restored_session.is_complete == session.is_complete
214
 
215
  @given(verification_session_strategy())
216
- @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
217
  def test_session_with_multiple_records_persists(
218
  self, verification_store, session
219
  ):
 
213
  assert restored_session.is_complete == session.is_complete
214
 
215
  @given(verification_session_strategy())
216
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
217
  def test_session_with_multiple_records_persists(
218
  self, verification_store, session
219
  ):