Spaces:
Sleeping
Sleeping
Add per-session prompt override functionality across interfaces and AI client
Browse files- src/core/ai_client.py +105 -0
- src/core/simplified_medical_app.py +15 -1
- src/interface/file_upload_interface.py +12 -4
- src/interface/manual_input_interface.py +10 -3
- src/interface/simplified_gradio_app.py +49 -42
- tests/test_model_overrides_cross_mode.py +3 -3
- tests/verification_mode/test_properties_persistence.py +1 -1
src/core/ai_client.py
CHANGED
|
@@ -380,6 +380,10 @@ class AIClientManager:
|
|
| 380 |
# Expected shape: {agent_name: model_string}
|
| 381 |
self.model_overrides: Dict[str, str] = {}
|
| 382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
def set_model_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
|
| 384 |
"""Set per-session model overrides.
|
| 385 |
|
|
@@ -387,6 +391,22 @@ class AIClientManager:
|
|
| 387 |
(chat / manual input / file upload) can share the same mechanism.
|
| 388 |
"""
|
| 389 |
self.model_overrides = dict(overrides or {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
# NEW: Enhanced client management for medical AI optimization
|
| 392 |
self.provider_performance_metrics = {}
|
|
@@ -491,6 +511,9 @@ class AIClientManager:
|
|
| 491 |
if model_override is None and self.model_overrides:
|
| 492 |
model_override = self.model_overrides.get("SpiritualDistressAnalyzer")
|
| 493 |
|
|
|
|
|
|
|
|
|
|
| 494 |
return self.generate_response(
|
| 495 |
system_prompt=system_prompt,
|
| 496 |
user_prompt=user_prompt,
|
|
@@ -499,6 +522,28 @@ class AIClientManager:
|
|
| 499 |
agent_name="SpiritualDistressAnalyzer",
|
| 500 |
model_override=model_override,
|
| 501 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
|
| 503 |
def call_medical_api(self, system_prompt: str, user_prompt: str,
|
| 504 |
temperature: float = 0.3,
|
|
@@ -519,6 +564,9 @@ class AIClientManager:
|
|
| 519 |
if model_override is None and self.model_overrides:
|
| 520 |
model_override = self.model_overrides.get("SoftMedicalTriage")
|
| 521 |
|
|
|
|
|
|
|
|
|
|
| 522 |
return self.generate_response(
|
| 523 |
system_prompt=system_prompt,
|
| 524 |
user_prompt=user_prompt,
|
|
@@ -528,6 +576,63 @@ class AIClientManager:
|
|
| 528 |
model_override=model_override,
|
| 529 |
)
|
| 530 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
# Factory function for easy client creation
|
| 532 |
def create_ai_client(agent_name: str, model_override: Optional[str] = None) -> UniversalAIClient:
|
| 533 |
"""
|
|
|
|
| 380 |
# Expected shape: {agent_name: model_string}
|
| 381 |
self.model_overrides: Dict[str, str] = {}
|
| 382 |
|
| 383 |
+
# Optional per-session prompt overrides.
|
| 384 |
+
# Expected shape: {agent_name: system_prompt_string}
|
| 385 |
+
self.prompt_overrides = {}
|
| 386 |
+
|
| 387 |
def set_model_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
|
| 388 |
"""Set per-session model overrides.
|
| 389 |
|
|
|
|
| 391 |
(chat / manual input / file upload) can share the same mechanism.
|
| 392 |
"""
|
| 393 |
self.model_overrides = dict(overrides or {})
|
| 394 |
+
|
| 395 |
+
def set_prompt_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
|
| 396 |
+
"""Set per-session prompt overrides.
|
| 397 |
+
|
| 398 |
+
This avoids mutating module-level prompt constants and prevents
|
| 399 |
+
cross-session leakage.
|
| 400 |
+
|
| 401 |
+
Expected keys are agent names, e.g.:
|
| 402 |
+
- SpiritualDistressAnalyzer
|
| 403 |
+
- SoftSpiritualTriage
|
| 404 |
+
- TriageResponseEvaluator
|
| 405 |
+
- MedicalAssistant
|
| 406 |
+
- SoftMedicalTriage
|
| 407 |
+
- EntryClassifier
|
| 408 |
+
"""
|
| 409 |
+
self.prompt_overrides = dict(overrides or {})
|
| 410 |
|
| 411 |
# NEW: Enhanced client management for medical AI optimization
|
| 412 |
self.provider_performance_metrics = {}
|
|
|
|
| 511 |
if model_override is None and self.model_overrides:
|
| 512 |
model_override = self.model_overrides.get("SpiritualDistressAnalyzer")
|
| 513 |
|
| 514 |
+
if self.prompt_overrides and "SpiritualDistressAnalyzer" in self.prompt_overrides:
|
| 515 |
+
system_prompt = self.prompt_overrides["SpiritualDistressAnalyzer"]
|
| 516 |
+
|
| 517 |
return self.generate_response(
|
| 518 |
system_prompt=system_prompt,
|
| 519 |
user_prompt=user_prompt,
|
|
|
|
| 522 |
agent_name="SpiritualDistressAnalyzer",
|
| 523 |
model_override=model_override,
|
| 524 |
)
|
| 525 |
+
|
| 526 |
+
def call_entry_classifier_api(self, system_prompt: str, user_prompt: str,
|
| 527 |
+
temperature: float = 0.3,
|
| 528 |
+
model_override: Optional[str] = None) -> str:
|
| 529 |
+
"""Call AI API for entry classification.
|
| 530 |
+
|
| 531 |
+
This is used by Enhanced Verification manual input / file upload modes.
|
| 532 |
+
"""
|
| 533 |
+
if model_override is None and self.model_overrides:
|
| 534 |
+
model_override = self.model_overrides.get("EntryClassifier")
|
| 535 |
+
|
| 536 |
+
if self.prompt_overrides and "EntryClassifier" in self.prompt_overrides:
|
| 537 |
+
system_prompt = self.prompt_overrides["EntryClassifier"]
|
| 538 |
+
|
| 539 |
+
return self.generate_response(
|
| 540 |
+
system_prompt=system_prompt,
|
| 541 |
+
user_prompt=user_prompt,
|
| 542 |
+
temperature=temperature,
|
| 543 |
+
call_type="entry_classification",
|
| 544 |
+
agent_name="EntryClassifier",
|
| 545 |
+
model_override=model_override,
|
| 546 |
+
)
|
| 547 |
|
| 548 |
def call_medical_api(self, system_prompt: str, user_prompt: str,
|
| 549 |
temperature: float = 0.3,
|
|
|
|
| 564 |
if model_override is None and self.model_overrides:
|
| 565 |
model_override = self.model_overrides.get("SoftMedicalTriage")
|
| 566 |
|
| 567 |
+
if self.prompt_overrides and "SoftMedicalTriage" in self.prompt_overrides:
|
| 568 |
+
system_prompt = self.prompt_overrides["SoftMedicalTriage"]
|
| 569 |
+
|
| 570 |
return self.generate_response(
|
| 571 |
system_prompt=system_prompt,
|
| 572 |
user_prompt=user_prompt,
|
|
|
|
| 576 |
model_override=model_override,
|
| 577 |
)
|
| 578 |
|
| 579 |
+
def call_soft_spiritual_triage_api(self, system_prompt: str, user_prompt: str,
|
| 580 |
+
temperature: float = 0.3,
|
| 581 |
+
model_override: Optional[str] = None) -> str:
|
| 582 |
+
"""Call AI API for soft spiritual triage question generation."""
|
| 583 |
+
if model_override is None and self.model_overrides:
|
| 584 |
+
model_override = self.model_overrides.get("SoftSpiritualTriage")
|
| 585 |
+
|
| 586 |
+
if self.prompt_overrides and "SoftSpiritualTriage" in self.prompt_overrides:
|
| 587 |
+
system_prompt = self.prompt_overrides["SoftSpiritualTriage"]
|
| 588 |
+
|
| 589 |
+
return self.generate_response(
|
| 590 |
+
system_prompt=system_prompt,
|
| 591 |
+
user_prompt=user_prompt,
|
| 592 |
+
temperature=temperature,
|
| 593 |
+
call_type="soft_spiritual_triage",
|
| 594 |
+
agent_name="SoftSpiritualTriage",
|
| 595 |
+
model_override=model_override,
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
def call_triage_response_evaluator_api(self, system_prompt: str, user_prompt: str,
|
| 599 |
+
temperature: float = 0.3,
|
| 600 |
+
model_override: Optional[str] = None) -> str:
|
| 601 |
+
"""Call AI API for triage response evaluation."""
|
| 602 |
+
if model_override is None and self.model_overrides:
|
| 603 |
+
model_override = self.model_overrides.get("TriageResponseEvaluator")
|
| 604 |
+
|
| 605 |
+
if self.prompt_overrides and "TriageResponseEvaluator" in self.prompt_overrides:
|
| 606 |
+
system_prompt = self.prompt_overrides["TriageResponseEvaluator"]
|
| 607 |
+
|
| 608 |
+
return self.generate_response(
|
| 609 |
+
system_prompt=system_prompt,
|
| 610 |
+
user_prompt=user_prompt,
|
| 611 |
+
temperature=temperature,
|
| 612 |
+
call_type="triage_response_evaluator",
|
| 613 |
+
agent_name="TriageResponseEvaluator",
|
| 614 |
+
model_override=model_override,
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
def call_medical_assistant_api(self, system_prompt: str, user_prompt: str,
|
| 618 |
+
temperature: float = 0.3,
|
| 619 |
+
model_override: Optional[str] = None) -> str:
|
| 620 |
+
"""Call AI API for medical assistant responses."""
|
| 621 |
+
if model_override is None and self.model_overrides:
|
| 622 |
+
model_override = self.model_overrides.get("MedicalAssistant")
|
| 623 |
+
|
| 624 |
+
if self.prompt_overrides and "MedicalAssistant" in self.prompt_overrides:
|
| 625 |
+
system_prompt = self.prompt_overrides["MedicalAssistant"]
|
| 626 |
+
|
| 627 |
+
return self.generate_response(
|
| 628 |
+
system_prompt=system_prompt,
|
| 629 |
+
user_prompt=user_prompt,
|
| 630 |
+
temperature=temperature,
|
| 631 |
+
call_type="medical_assistant",
|
| 632 |
+
agent_name="MedicalAssistant",
|
| 633 |
+
model_override=model_override,
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
# Factory function for easy client creation
|
| 637 |
def create_ai_client(agent_name: str, model_override: Optional[str] = None) -> UniversalAIClient:
|
| 638 |
"""
|
src/core/simplified_medical_app.py
CHANGED
|
@@ -66,6 +66,10 @@ class SimplifiedMedicalApp:
|
|
| 66 |
self.api = AIClientManager()
|
| 67 |
# Optional per-session model overrides (set by the UI Model Settings tab)
|
| 68 |
self.model_overrides = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
# Medical components
|
| 71 |
self.medical_assistant = MedicalAssistant(self.api)
|
|
@@ -81,7 +85,7 @@ class SimplifiedMedicalApp:
|
|
| 81 |
logger.info(f"✅ Loaded patient: {self.clinical_background.patient_name}")
|
| 82 |
|
| 83 |
# Session state
|
| 84 |
-
self.chat_history
|
| 85 |
self.spiritual_state = SessionSpiritualState()
|
| 86 |
self.session_active = False
|
| 87 |
|
|
@@ -109,6 +113,16 @@ class SimplifiedMedicalApp:
|
|
| 109 |
if hasattr(self, "api") and hasattr(self.api, "model_overrides"):
|
| 110 |
self.api.model_overrides = dict(self.model_overrides)
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
def _get_model_override(self, agent_name: str) -> Optional[str]:
|
| 113 |
"""Return the model override for an agent, if any."""
|
| 114 |
if not getattr(self, "model_overrides", None):
|
|
|
|
| 66 |
self.api = AIClientManager()
|
| 67 |
# Optional per-session model overrides (set by the UI Model Settings tab)
|
| 68 |
self.model_overrides = {}
|
| 69 |
+
|
| 70 |
+
# Optional per-session prompt overrides (set by the UI Edit Prompts tab)
|
| 71 |
+
# Expected: {agent_name: system_prompt_text}
|
| 72 |
+
self.prompt_overrides = {}
|
| 73 |
|
| 74 |
# Medical components
|
| 75 |
self.medical_assistant = MedicalAssistant(self.api)
|
|
|
|
| 85 |
logger.info(f"✅ Loaded patient: {self.clinical_background.patient_name}")
|
| 86 |
|
| 87 |
# Session state
|
| 88 |
+
self.chat_history = []
|
| 89 |
self.spiritual_state = SessionSpiritualState()
|
| 90 |
self.session_active = False
|
| 91 |
|
|
|
|
| 113 |
if hasattr(self, "api") and hasattr(self.api, "model_overrides"):
|
| 114 |
self.api.model_overrides = dict(self.model_overrides)
|
| 115 |
|
| 116 |
+
def set_prompt_overrides(self, overrides: Optional[dict] = None) -> None:
|
| 117 |
+
"""Set per-session prompt overrides.
|
| 118 |
+
|
| 119 |
+
`overrides` is expected to be a mapping of agent_name -> system_prompt string.
|
| 120 |
+
"""
|
| 121 |
+
self.prompt_overrides = dict(overrides or {})
|
| 122 |
+
|
| 123 |
+
if hasattr(self, "api") and hasattr(self.api, "set_prompt_overrides"):
|
| 124 |
+
self.api.set_prompt_overrides(self.prompt_overrides)
|
| 125 |
+
|
| 126 |
def _get_model_override(self, agent_name: str) -> Optional[str]:
|
| 127 |
"""Return the model override for an agent, if any."""
|
| 128 |
if not getattr(self, "model_overrides", None):
|
src/interface/file_upload_interface.py
CHANGED
|
@@ -48,6 +48,9 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
|
|
| 48 |
# Optional per-session model overrides (UI Model Settings tab)
|
| 49 |
self.model_overrides = {}
|
| 50 |
self.ai_client.set_model_overrides(self.model_overrides)
|
|
|
|
|
|
|
|
|
|
| 51 |
self.current_session = None
|
| 52 |
self.current_file_result = None
|
| 53 |
self.current_message_index = 0
|
|
@@ -57,6 +60,11 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
|
|
| 57 |
"""Set per-session model overrides from the UI."""
|
| 58 |
self.model_overrides = dict(overrides or {})
|
| 59 |
self.ai_client.set_model_overrides(self.model_overrides)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
def process_uploaded_file(self, file_path: str) -> Tuple[bool, str, Optional[FileUploadResult], str]:
|
| 62 |
"""
|
|
@@ -298,10 +306,10 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
|
|
| 298 |
# Call AI classifier using the same approach as manual input
|
| 299 |
user_prompt = f"Please analyze this patient message for spiritual distress:\n\n{test_message.text}"
|
| 300 |
|
| 301 |
-
response = self.ai_client.
|
| 302 |
system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
|
| 303 |
user_prompt=user_prompt,
|
| 304 |
-
temperature=0.3
|
| 305 |
)
|
| 306 |
|
| 307 |
# Parse the response to extract classification details
|
|
@@ -407,11 +415,11 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
|
|
| 407 |
"Please analyze this patient message for spiritual distress:\n\n"
|
| 408 |
f"{test_message.text}"
|
| 409 |
)
|
| 410 |
-
raw_response = self.ai_client.
|
| 411 |
system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
|
| 412 |
user_prompt=user_prompt,
|
| 413 |
temperature=0.3,
|
| 414 |
-
model_override=self.model_overrides.get("
|
| 415 |
)
|
| 416 |
|
| 417 |
classification_result = self._parse_classification_response(raw_response)
|
|
|
|
| 48 |
# Optional per-session model overrides (UI Model Settings tab)
|
| 49 |
self.model_overrides = {}
|
| 50 |
self.ai_client.set_model_overrides(self.model_overrides)
|
| 51 |
+
# Optional per-session prompt overrides (UI Edit Prompts tab)
|
| 52 |
+
self.prompt_overrides = {}
|
| 53 |
+
self.ai_client.set_prompt_overrides(self.prompt_overrides)
|
| 54 |
self.current_session = None
|
| 55 |
self.current_file_result = None
|
| 56 |
self.current_message_index = 0
|
|
|
|
| 60 |
"""Set per-session model overrides from the UI."""
|
| 61 |
self.model_overrides = dict(overrides or {})
|
| 62 |
self.ai_client.set_model_overrides(self.model_overrides)
|
| 63 |
+
|
| 64 |
+
def set_prompt_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
|
| 65 |
+
"""Set per-session prompt overrides from the UI."""
|
| 66 |
+
self.prompt_overrides = dict(overrides or {})
|
| 67 |
+
self.ai_client.set_prompt_overrides(self.prompt_overrides)
|
| 68 |
|
| 69 |
def process_uploaded_file(self, file_path: str) -> Tuple[bool, str, Optional[FileUploadResult], str]:
|
| 70 |
"""
|
|
|
|
| 306 |
# Call AI classifier using the same approach as manual input
|
| 307 |
user_prompt = f"Please analyze this patient message for spiritual distress:\n\n{test_message.text}"
|
| 308 |
|
| 309 |
+
response = self.ai_client.call_entry_classifier_api(
|
| 310 |
system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
|
| 311 |
user_prompt=user_prompt,
|
| 312 |
+
temperature=0.3,
|
| 313 |
)
|
| 314 |
|
| 315 |
# Parse the response to extract classification details
|
|
|
|
| 415 |
"Please analyze this patient message for spiritual distress:\n\n"
|
| 416 |
f"{test_message.text}"
|
| 417 |
)
|
| 418 |
+
raw_response = self.ai_client.call_entry_classifier_api(
|
| 419 |
system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
|
| 420 |
user_prompt=user_prompt,
|
| 421 |
temperature=0.3,
|
| 422 |
+
model_override=self.model_overrides.get("EntryClassifier"),
|
| 423 |
)
|
| 424 |
|
| 425 |
classification_result = self._parse_classification_response(raw_response)
|
src/interface/manual_input_interface.py
CHANGED
|
@@ -60,16 +60,23 @@ class ManualInputController(ProgressTrackingMixin):
|
|
| 60 |
self.store = JSONVerificationStore()
|
| 61 |
self.ai_client = AIClientManager()
|
| 62 |
self.model_overrides = {}
|
|
|
|
| 63 |
self.state = ManualInputState()
|
| 64 |
self.classification_start_time = None
|
| 65 |
|
| 66 |
# Ensure the underlying AI client manager sees our overrides.
|
| 67 |
self.ai_client.set_model_overrides(self.model_overrides)
|
|
|
|
| 68 |
|
| 69 |
def set_model_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
|
| 70 |
"""Set per-session model overrides from the UI."""
|
| 71 |
self.model_overrides = dict(overrides or {})
|
| 72 |
self.ai_client.set_model_overrides(self.model_overrides)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
def start_new_session(self, verifier_name: str) -> Tuple[bool, str, Optional[EnhancedVerificationSession]]:
|
| 75 |
"""
|
|
@@ -139,11 +146,11 @@ class ManualInputController(ProgressTrackingMixin):
|
|
| 139 |
|
| 140 |
# Call AI classifier
|
| 141 |
user_prompt = f"Please analyze this patient message for spiritual distress:\n\n{message_text.strip()}"
|
| 142 |
-
|
| 143 |
-
response = self.ai_client.
|
| 144 |
system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
|
| 145 |
user_prompt=user_prompt,
|
| 146 |
-
temperature=0.3
|
| 147 |
)
|
| 148 |
# Parse the response to extract classification details
|
| 149 |
classification_result = self._parse_classification_response(response)
|
|
|
|
| 60 |
self.store = JSONVerificationStore()
|
| 61 |
self.ai_client = AIClientManager()
|
| 62 |
self.model_overrides = {}
|
| 63 |
+
self.prompt_overrides = {}
|
| 64 |
self.state = ManualInputState()
|
| 65 |
self.classification_start_time = None
|
| 66 |
|
| 67 |
# Ensure the underlying AI client manager sees our overrides.
|
| 68 |
self.ai_client.set_model_overrides(self.model_overrides)
|
| 69 |
+
self.ai_client.set_prompt_overrides(self.prompt_overrides)
|
| 70 |
|
| 71 |
def set_model_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
|
| 72 |
"""Set per-session model overrides from the UI."""
|
| 73 |
self.model_overrides = dict(overrides or {})
|
| 74 |
self.ai_client.set_model_overrides(self.model_overrides)
|
| 75 |
+
|
| 76 |
+
def set_prompt_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None:
|
| 77 |
+
"""Set per-session prompt overrides from the UI."""
|
| 78 |
+
self.prompt_overrides = dict(overrides or {})
|
| 79 |
+
self.ai_client.set_prompt_overrides(self.prompt_overrides)
|
| 80 |
|
| 81 |
def start_new_session(self, verifier_name: str) -> Tuple[bool, str, Optional[EnhancedVerificationSession]]:
|
| 82 |
"""
|
|
|
|
| 146 |
|
| 147 |
# Call AI classifier
|
| 148 |
user_prompt = f"Please analyze this patient message for spiritual distress:\n\n{message_text.strip()}"
|
| 149 |
+
|
| 150 |
+
response = self.ai_client.call_entry_classifier_api(
|
| 151 |
system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
|
| 152 |
user_prompt=user_prompt,
|
| 153 |
+
temperature=0.3,
|
| 154 |
)
|
| 155 |
# Parse the response to extract classification details
|
| 156 |
classification_result = self._parse_classification_response(response)
|
src/interface/simplified_gradio_app.py
CHANGED
|
@@ -604,6 +604,13 @@ Changes apply only to your current session.
|
|
| 604 |
session.app_instance.set_model_overrides(custom_models)
|
| 605 |
else:
|
| 606 |
session.app_instance.set_model_overrides({})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
new_history, status = session.app_instance.process_message(message, history)
|
| 608 |
|
| 609 |
# Get updated conversation stats
|
|
@@ -866,8 +873,22 @@ Changes apply only to your current session.
|
|
| 866 |
|
| 867 |
return formatted
|
| 868 |
|
| 869 |
-
def
|
| 870 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
from src.core.spiritual_monitor import SYSTEM_PROMPT_SPIRITUAL_MONITOR
|
| 872 |
from src.core.soft_triage_manager import (
|
| 873 |
SYSTEM_PROMPT_TRIAGE_QUESTION,
|
|
@@ -886,7 +907,12 @@ Changes apply only to your current session.
|
|
| 886 |
"🩺 Soft Medical Triage": SYSTEM_PROMPT_SOFT_MEDICAL_TRIAGE
|
| 887 |
}
|
| 888 |
|
|
|
|
| 889 |
prompt_text = prompts.get(prompt_name, "")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 890 |
|
| 891 |
# Format with HTML for display
|
| 892 |
formatted_html = format_prompt_with_html(prompt_text)
|
|
@@ -924,32 +950,18 @@ Changes apply only to your current session.
|
|
| 924 |
</div>"""
|
| 925 |
return error_html, session
|
| 926 |
|
| 927 |
-
# Store custom prompt in session
|
| 928 |
if not hasattr(session, 'custom_prompts'):
|
| 929 |
session.custom_prompts = {}
|
| 930 |
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
elif prompt_name == "🟡 Soft Spiritual Triage":
|
| 940 |
-
import src.core.soft_triage_manager as stm
|
| 941 |
-
stm.SYSTEM_PROMPT_TRIAGE_QUESTION = prompt_text
|
| 942 |
-
elif prompt_name == "📊 Triage Response Evaluator":
|
| 943 |
-
import src.core.soft_triage_manager as stm
|
| 944 |
-
stm.SYSTEM_PROMPT_TRIAGE_EVALUATE = prompt_text
|
| 945 |
-
elif prompt_name == "🏥 Medical Assistant":
|
| 946 |
-
import src.config.prompts as p
|
| 947 |
-
p.SYSTEM_PROMPT_MEDICAL_ASSISTANT = prompt_text
|
| 948 |
-
elif prompt_name == "🩺 Soft Medical Triage":
|
| 949 |
-
import src.config.prompts as p
|
| 950 |
-
p.SYSTEM_PROMPT_SOFT_MEDICAL_TRIAGE = prompt_text
|
| 951 |
-
|
| 952 |
-
status = f"""<div style="padding: 1em; background-color: #ecfdf5; border-left: 4px solid #10b981; border-radius: 4px;">
|
| 953 |
<h4 style="color: #059669; margin-top: 0;">✅ Prompt Applied Successfully</h4>
|
| 954 |
|
| 955 |
<p><strong>Prompt:</strong> {prompt_name}</p>
|
|
@@ -961,15 +973,8 @@ Changes apply only to your current session.
|
|
| 961 |
To revert, use "Reset to Default" button.
|
| 962 |
</p>
|
| 963 |
</div>"""
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
except Exception as e:
|
| 968 |
-
error_html = f"""<div style="padding: 1em; background-color: #fef2f2; border-left: 4px solid #dc2626; border-radius: 4px;">
|
| 969 |
-
<h4 style="color: #dc2626; margin-top: 0;">❌ Error Applying Prompt</h4>
|
| 970 |
-
<p style="margin-bottom: 0;"><code>{str(e)}</code></p>
|
| 971 |
-
</div>"""
|
| 972 |
-
return error_html, session
|
| 973 |
|
| 974 |
def reset_prompt(prompt_name: str, session: SimplifiedSessionData):
|
| 975 |
"""Reset prompt to default."""
|
|
@@ -977,14 +982,16 @@ To revert, use "Reset to Default" button.
|
|
| 977 |
session = SimplifiedSessionData()
|
| 978 |
|
| 979 |
# Remove from custom prompts
|
| 980 |
-
|
| 981 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 982 |
|
| 983 |
# Reload default
|
| 984 |
-
prompt_text, info, status = load_prompt(prompt_name)
|
| 985 |
-
|
| 986 |
-
# Reapply default
|
| 987 |
-
apply_status, session = apply_prompt_changes(prompt_name, prompt_text, session)
|
| 988 |
|
| 989 |
reset_status = """<div style="padding: 1em; background-color: #eff6ff; border-left: 4px solid #3b82f6; border-radius: 4px;">
|
| 990 |
<h4 style="color: #2563eb; margin-top: 0;">🔄 Reset to Default</h4>
|
|
@@ -2235,7 +2242,7 @@ To revert, use "Reset to Default" button.
|
|
| 2235 |
# Prompt editing events
|
| 2236 |
load_prompt_btn.click(
|
| 2237 |
load_prompt,
|
| 2238 |
-
inputs=[prompt_selector],
|
| 2239 |
outputs=[prompt_editor, prompt_info_display, prompt_status]
|
| 2240 |
)
|
| 2241 |
|
|
@@ -2254,7 +2261,7 @@ To revert, use "Reset to Default" button.
|
|
| 2254 |
# Auto-load prompt when selector changes
|
| 2255 |
prompt_selector.change(
|
| 2256 |
load_prompt,
|
| 2257 |
-
inputs=[prompt_selector],
|
| 2258 |
outputs=[prompt_editor, prompt_info_display, prompt_status]
|
| 2259 |
)
|
| 2260 |
|
|
|
|
| 604 |
session.app_instance.set_model_overrides(custom_models)
|
| 605 |
else:
|
| 606 |
session.app_instance.set_model_overrides({})
|
| 607 |
+
|
| 608 |
+
# Apply per-session prompt overrides (if configured in Edit Prompts)
|
| 609 |
+
custom_prompts = getattr(session, 'custom_prompts', None)
|
| 610 |
+
if custom_prompts:
|
| 611 |
+
session.app_instance.set_prompt_overrides(custom_prompts)
|
| 612 |
+
else:
|
| 613 |
+
session.app_instance.set_prompt_overrides({})
|
| 614 |
new_history, status = session.app_instance.process_message(message, history)
|
| 615 |
|
| 616 |
# Get updated conversation stats
|
|
|
|
| 873 |
|
| 874 |
return formatted
|
| 875 |
|
| 876 |
+
def _prompt_name_to_agent(prompt_name: str) -> str:
|
| 877 |
+
"""Map UI prompt selection to internal agent/prompt key."""
|
| 878 |
+
mapping = {
|
| 879 |
+
"🔍 Spiritual Monitor (Classifier)": "SpiritualDistressAnalyzer",
|
| 880 |
+
"🟡 Soft Spiritual Triage": "SoftSpiritualTriage",
|
| 881 |
+
"📊 Triage Response Evaluator": "TriageResponseEvaluator",
|
| 882 |
+
"🏥 Medical Assistant": "MedicalAssistant",
|
| 883 |
+
"🩺 Soft Medical Triage": "SoftMedicalTriage",
|
| 884 |
+
}
|
| 885 |
+
return mapping.get(prompt_name, prompt_name)
|
| 886 |
+
|
| 887 |
+
def load_prompt(prompt_name: str, session: Optional[SimplifiedSessionData] = None):
|
| 888 |
+
"""Load selected prompt for editing.
|
| 889 |
+
|
| 890 |
+
If a session override exists, show it instead of the default.
|
| 891 |
+
"""
|
| 892 |
from src.core.spiritual_monitor import SYSTEM_PROMPT_SPIRITUAL_MONITOR
|
| 893 |
from src.core.soft_triage_manager import (
|
| 894 |
SYSTEM_PROMPT_TRIAGE_QUESTION,
|
|
|
|
| 907 |
"🩺 Soft Medical Triage": SYSTEM_PROMPT_SOFT_MEDICAL_TRIAGE
|
| 908 |
}
|
| 909 |
|
| 910 |
+
agent_key = _prompt_name_to_agent(prompt_name)
|
| 911 |
prompt_text = prompts.get(prompt_name, "")
|
| 912 |
+
|
| 913 |
+
# Prefer session override (true session-scoped behavior)
|
| 914 |
+
if session is not None and hasattr(session, 'custom_prompts'):
|
| 915 |
+
prompt_text = session.custom_prompts.get(agent_key, prompt_text)
|
| 916 |
|
| 917 |
# Format with HTML for display
|
| 918 |
formatted_html = format_prompt_with_html(prompt_text)
|
|
|
|
| 950 |
</div>"""
|
| 951 |
return error_html, session
|
| 952 |
|
| 953 |
+
# Store custom prompt in session (session-scoped)
|
| 954 |
if not hasattr(session, 'custom_prompts'):
|
| 955 |
session.custom_prompts = {}
|
| 956 |
|
| 957 |
+
agent_key = _prompt_name_to_agent(prompt_name)
|
| 958 |
+
session.custom_prompts[agent_key] = prompt_text
|
| 959 |
+
|
| 960 |
+
# Apply into the current session app instance (no global mutation)
|
| 961 |
+
if hasattr(session, 'app_instance') and hasattr(session.app_instance, 'set_prompt_overrides'):
|
| 962 |
+
session.app_instance.set_prompt_overrides(session.custom_prompts)
|
| 963 |
+
|
| 964 |
+
status = f"""<div style="padding: 1em; background-color: #ecfdf5; border-left: 4px solid #10b981; border-radius: 4px;">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 965 |
<h4 style="color: #059669; margin-top: 0;">✅ Prompt Applied Successfully</h4>
|
| 966 |
|
| 967 |
<p><strong>Prompt:</strong> {prompt_name}</p>
|
|
|
|
| 973 |
To revert, use "Reset to Default" button.
|
| 974 |
</p>
|
| 975 |
</div>"""
|
| 976 |
+
|
| 977 |
+
return status, session
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 978 |
|
| 979 |
def reset_prompt(prompt_name: str, session: SimplifiedSessionData):
|
| 980 |
"""Reset prompt to default."""
|
|
|
|
| 982 |
session = SimplifiedSessionData()
|
| 983 |
|
| 984 |
# Remove from custom prompts
|
| 985 |
+
agent_key = _prompt_name_to_agent(prompt_name)
|
| 986 |
+
if hasattr(session, 'custom_prompts') and agent_key in session.custom_prompts:
|
| 987 |
+
del session.custom_prompts[agent_key]
|
| 988 |
+
|
| 989 |
+
# Apply into current session app instance
|
| 990 |
+
if hasattr(session, 'app_instance') and hasattr(session.app_instance, 'set_prompt_overrides'):
|
| 991 |
+
session.app_instance.set_prompt_overrides(getattr(session, 'custom_prompts', {}))
|
| 992 |
|
| 993 |
# Reload default
|
| 994 |
+
prompt_text, info, status = load_prompt(prompt_name, session)
|
|
|
|
|
|
|
|
|
|
| 995 |
|
| 996 |
reset_status = """<div style="padding: 1em; background-color: #eff6ff; border-left: 4px solid #3b82f6; border-radius: 4px;">
|
| 997 |
<h4 style="color: #2563eb; margin-top: 0;">🔄 Reset to Default</h4>
|
|
|
|
| 2242 |
# Prompt editing events
|
| 2243 |
load_prompt_btn.click(
|
| 2244 |
load_prompt,
|
| 2245 |
+
inputs=[prompt_selector, session_data],
|
| 2246 |
outputs=[prompt_editor, prompt_info_display, prompt_status]
|
| 2247 |
)
|
| 2248 |
|
|
|
|
| 2261 |
# Auto-load prompt when selector changes
|
| 2262 |
prompt_selector.change(
|
| 2263 |
load_prompt,
|
| 2264 |
+
inputs=[prompt_selector, session_data],
|
| 2265 |
outputs=[prompt_editor, prompt_info_display, prompt_status]
|
| 2266 |
)
|
| 2267 |
|
tests/test_model_overrides_cross_mode.py
CHANGED
|
@@ -26,7 +26,7 @@ def test_file_upload_controller_applies_overrides_to_ai_client():
|
|
| 26 |
|
| 27 |
def test_file_upload_batch_passes_override_into_call(monkeypatch):
|
| 28 |
controller = FileUploadInterfaceController()
|
| 29 |
-
controller.set_model_overrides({"
|
| 30 |
|
| 31 |
# Minimal session + dataset state so the batch method can run without UI.
|
| 32 |
controller.current_session = types.SimpleNamespace(
|
|
@@ -60,13 +60,13 @@ def test_file_upload_batch_passes_override_into_call(monkeypatch):
|
|
| 60 |
|
| 61 |
captured = {}
|
| 62 |
|
| 63 |
-
def
|
| 64 |
captured["system_prompt"] = system_prompt
|
| 65 |
captured["user_prompt"] = user_prompt
|
| 66 |
captured["model_override"] = model_override
|
| 67 |
return "green"
|
| 68 |
|
| 69 |
-
monkeypatch.setattr(controller.ai_client, "
|
| 70 |
|
| 71 |
# Run the simplest possible classification path.
|
| 72 |
result = controller.run_batch_classification()
|
|
|
|
| 26 |
|
| 27 |
def test_file_upload_batch_passes_override_into_call(monkeypatch):
|
| 28 |
controller = FileUploadInterfaceController()
|
| 29 |
+
controller.set_model_overrides({"EntryClassifier": "gemini-batch-override"})
|
| 30 |
|
| 31 |
# Minimal session + dataset state so the batch method can run without UI.
|
| 32 |
controller.current_session = types.SimpleNamespace(
|
|
|
|
| 60 |
|
| 61 |
captured = {}
|
| 62 |
|
| 63 |
+
def fake_call_entry_classifier_api(*, system_prompt=None, user_prompt=None, model_override=None, **kwargs):
|
| 64 |
captured["system_prompt"] = system_prompt
|
| 65 |
captured["user_prompt"] = user_prompt
|
| 66 |
captured["model_override"] = model_override
|
| 67 |
return "green"
|
| 68 |
|
| 69 |
+
monkeypatch.setattr(controller.ai_client, "call_entry_classifier_api", fake_call_entry_classifier_api)
|
| 70 |
|
| 71 |
# Run the simplest possible classification path.
|
| 72 |
result = controller.run_batch_classification()
|
tests/verification_mode/test_properties_persistence.py
CHANGED
|
@@ -213,7 +213,7 @@ class TestSessionStatePersistence:
|
|
| 213 |
assert restored_session.is_complete == session.is_complete
|
| 214 |
|
| 215 |
@given(verification_session_strategy())
|
| 216 |
-
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
|
| 217 |
def test_session_with_multiple_records_persists(
|
| 218 |
self, verification_store, session
|
| 219 |
):
|
|
|
|
| 213 |
assert restored_session.is_complete == session.is_complete
|
| 214 |
|
| 215 |
@given(verification_session_strategy())
|
| 216 |
+
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
|
| 217 |
def test_session_with_multiple_records_persists(
|
| 218 |
self, verification_store, session
|
| 219 |
):
|