Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import os | |
| import uuid | |
| import csv | |
| import traceback | |
| from datetime import datetime | |
| from typing import List, Optional, Any, Dict | |
| from src.interface.session_manager import SimplifiedSessionData | |
| from src.core.verification_models import VerificationSession, VerificationRecord, TestMessage | |
| from src.core.verification_store import JSONVerificationStore | |
| from src.core.chaplain_models import ClassificationFlowResult, DistressIndicator, FollowUpQuestion, TaggingRecord | |
| from src.core.verification_csv_exporter import VerificationCSVExporter | |
| from src.core.test_datasets import TestDatasetManager | |
| from src.interface.verification_ui import VerificationUIComponents | |
| from src.interface.chaplain_feedback_ui import ChaplainFeedbackUIComponents | |
| from src.core.conversation_verification import ( | |
| ConversationVerificationManager, | |
| VerificationRecord as ConvVerificationRecord, | |
| VerificationSession as ConvVerificationSession | |
| ) | |
| def open_verification_window(session: SimplifiedSessionData): | |
| """Open verification window for current conversation.""" | |
| if session is None or not hasattr(session.app_instance, 'conversation_logger'): | |
| return """<div style="background-color: #f8d7da; padding: 0.75em; border-radius: 4px; margin: 0.5em 0;"> | |
| β <strong>No conversation to verify</strong><br> | |
| <small>Start a conversation first</small> | |
| </div>""" | |
| try: | |
| # Check if conversation has any entries | |
| if not session.app_instance.conversation_logger.entries: | |
| return """<div style="background-color: #fff3cd; padding: 0.75em; border-radius: 4px; margin: 0.5em 0;"> | |
| β οΈ <strong>No conversation exchanges to verify</strong><br> | |
| <small>Send some messages in the chat first</small> | |
| </div>""" | |
| print(f"π Opening verification for {len(session.app_instance.conversation_logger.entries)} exchanges...") | |
| manager = ConversationVerificationManager() | |
| verification_session = manager.create_verification_session( | |
| session.app_instance.conversation_logger, | |
| "Medical Professional" | |
| ) | |
| print(f"β Created verification session: {verification_session.session_id}") | |
| # HF Spaces / Gradio limitation: | |
| # Launching a *second* Gradio server from inside a running Gradio app is unreliable | |
| # and is currently causing the Button._id error in Spaces. | |
| # Instead, export the verification session to a JSON file that the user can download. | |
| export_dir = os.path.join(os.getcwd(), "verification_sessions") | |
| os.makedirs(export_dir, exist_ok=True) | |
| export_filename = f"verification_session_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{verification_session.session_id}.json" | |
| export_path = os.path.join(export_dir, export_filename) | |
| # Serialize to JSON in a resilient way (dataclasses / pydantic / plain python). | |
| def _to_dict(obj): | |
| if hasattr(obj, "model_dump"): | |
| return obj.model_dump() | |
| if hasattr(obj, "dict") and callable(getattr(obj, "dict")): | |
| return obj.dict() | |
| if hasattr(obj, "__dict__"): | |
| return obj.__dict__ | |
| return str(obj) | |
| payload = { | |
| "session_id": verification_session.session_id, | |
| "patient_name": verification_session.patient_name, | |
| "verifier_name": verification_session.verifier_name, | |
| "start_time": verification_session.start_time.isoformat() if hasattr(verification_session, "start_time") else None, | |
| "verification_records": [ | |
| { | |
| # Conversation verification records use `exchange_id`. | |
| # Keep a `record_id` alias for backward compatibility with older exports. | |
| "exchange_id": getattr(r, "exchange_id", None), | |
| "record_id": getattr(r, "exchange_id", None), | |
| "timestamp": r.timestamp.isoformat() if hasattr(r, "timestamp") else None, | |
| "user_message": r.user_message, | |
| "assistant_response": r.assistant_response, | |
| "original_classification": r.original_classification, | |
| "original_confidence": r.original_confidence, | |
| "original_indicators": r.original_indicators, | |
| "original_reasoning": r.original_reasoning, | |
| "is_correct": r.is_correct, | |
| "correct_classification": r.correct_classification, | |
| "correction_reason": r.correction_reason, | |
| "verifier_notes": r.verifier_notes, | |
| } | |
| for r in verification_session.verification_records | |
| ], | |
| } | |
| with open(export_path, "w", encoding="utf-8") as f: | |
| json.dump(payload, f, ensure_ascii=False, indent=2, default=_to_dict) | |
| print(f"β Verification session exported: {export_path}") | |
| return f"""<div style="background-color: #d4edda; padding: 0.75em; border-radius: 4px; margin: 0.5em 0;"> | |
| β <strong>Verification session exported</strong><br> | |
| <small>Exchanges: {len(verification_session.verification_records)}</small><br> | |
| <small>Download JSON from the app's files panel (or add a dedicated download button).</small> | |
| </div>""" | |
| except Exception as e: | |
| print(f"β Error opening verification: {str(e)}") | |
| traceback.print_exc() | |
| return f"""<div style="background-color: #f8d7da; padding: 0.75em; border-radius: 4px; margin: 0.5em 0;"> | |
| β <strong>Error opening verification</strong><br> | |
| <small>{str(e)}</small> | |
| </div>""" | |
| def load_verification_dataset(dataset_name: str, store: JSONVerificationStore): | |
| """Load a verification dataset.""" | |
| try: | |
| # Find dataset ID from name | |
| datasets = TestDatasetManager.get_dataset_list() | |
| dataset_id = None | |
| for d in datasets: | |
| if d['name'] in dataset_name: | |
| dataset_id = d['dataset_id'] | |
| break | |
| if not dataset_id: | |
| return ( | |
| None, # verification_session | |
| "β Dataset not found", # dataset_info | |
| "", "", "", "", # message_text, decision_badge, confidence, indicators | |
| "", # progress_display | |
| "β Dataset not found", # error_message | |
| 0, # current_message_index | |
| None, # current_dataset_id | |
| [], # message_queue | |
| [], # verification_records | |
| ) | |
| # Load dataset | |
| dataset = TestDatasetManager.load_dataset(dataset_id) | |
| # Create new verification session | |
| new_session = VerificationSession( | |
| session_id=str(uuid.uuid4()), | |
| verifier_name="Medical Professional", | |
| dataset_id=dataset_id, | |
| dataset_name=dataset.name, | |
| total_messages=dataset.message_count, | |
| message_queue=[m.message_id for m in dataset.messages], | |
| ) | |
| # Save session | |
| store.save_session(new_session) | |
| # Get first message | |
| if dataset.messages: | |
| first_message = dataset.messages[0] | |
| message_text, decision_badge, confidence, indicators = VerificationUIComponents.render_message_review( | |
| first_message, | |
| first_message.pre_classified_label, | |
| 0.85, # Default confidence | |
| ["Distress indicator 1", "Distress indicator 2"] | |
| ) | |
| progress = VerificationUIComponents.update_progress_display(0, dataset.message_count) | |
| dataset_info_text = f"**{dataset.name}**\n\n{dataset.description}\n\nπ {dataset.message_count} messages to review" | |
| return ( | |
| new_session, # verification_session | |
| dataset_info_text, # dataset_info | |
| message_text, # message_text | |
| decision_badge, # decision_badge | |
| confidence, # confidence | |
| indicators, # indicators | |
| progress, # progress_display | |
| "", # error_message (empty = no error) | |
| 0, # current_message_index | |
| dataset_id, # current_dataset_id | |
| [m.message_id for m in dataset.messages], # message_queue | |
| [], # verification_records | |
| ) | |
| else: | |
| return ( | |
| None, # verification_session | |
| "β Dataset is empty", # dataset_info | |
| "", "", "", "", # message_text, decision_badge, confidence, indicators | |
| "", # progress_display | |
| "β Dataset is empty", # error_message | |
| 0, # current_message_index | |
| dataset_id, # current_dataset_id | |
| [], # message_queue | |
| [], # verification_records | |
| ) | |
| except Exception as e: | |
| return ( | |
| None, # verification_session | |
| f"β Error loading dataset: {str(e)}", # dataset_info | |
| "", "", "", "", # message_text, decision_badge, confidence, indicators | |
| "", # progress_display | |
| f"β Error: {str(e)}", # error_message | |
| 0, # current_message_index | |
| None, # current_dataset_id | |
| [], # message_queue | |
| [], # verification_records | |
| ) | |
| def handle_correct_feedback(session: VerificationSession, current_idx: int, dataset_id: str, message_queue: List[str], records: List[dict], store: JSONVerificationStore): | |
| """Handle correct feedback.""" | |
| try: | |
| if not session or current_idx >= len(message_queue): | |
| return ( | |
| session, | |
| "β Error: Invalid session state", | |
| "", "", "", "", | |
| "", | |
| "β Correct: 0", | |
| "β Incorrect: 0", | |
| "π Accuracy: 0%", | |
| current_idx, | |
| records, | |
| ) | |
| # Get current message | |
| dataset = TestDatasetManager.load_dataset(dataset_id) | |
| current_message_id = message_queue[current_idx] | |
| current_message = next((m for m in dataset.messages if m.message_id == current_message_id), None) | |
| if not current_message: | |
| return ( | |
| session, | |
| "β Error: Message not found", | |
| "", "", "", "", | |
| "", | |
| "β Correct: 0", | |
| "β Incorrect: 0", | |
| "π Accuracy: 0%", | |
| current_idx, | |
| records, | |
| ) | |
| # Create verification record | |
| record = VerificationRecord( | |
| message_id=current_message.message_id, | |
| original_message=current_message.text, | |
| classifier_decision=current_message.pre_classified_label, | |
| classifier_confidence=0.85, | |
| classifier_indicators=["Distress indicator 1", "Distress indicator 2"], | |
| ground_truth_label=current_message.pre_classified_label, | |
| verifier_notes="", | |
| is_correct=True, | |
| ) | |
| # Add to session | |
| session.verifications.append(record) | |
| session.verified_count += 1 | |
| session.correct_count += 1 | |
| # Save session | |
| store.save_session(session) | |
| # Move to next message | |
| next_idx = current_idx + 1 | |
| if next_idx >= len(message_queue): | |
| # Session complete | |
| session.is_complete = True | |
| session.completed_at = datetime.now() | |
| store.save_session(session) | |
| correct_str, incorrect_str, accuracy_str = VerificationUIComponents.update_statistics_display( | |
| session.correct_count, | |
| session.incorrect_count | |
| ) | |
| return ( | |
| session, | |
| "β Verification complete!", | |
| "", "", "", "", | |
| "", | |
| correct_str, | |
| incorrect_str, | |
| accuracy_str, | |
| next_idx, | |
| [r.to_dict() for r in session.verifications], | |
| ) | |
| else: | |
| # Load next message | |
| next_message = next((m for m in dataset.messages if m.message_id == message_queue[next_idx]), None) | |
| if next_message: | |
| message_text, decision_badge, confidence, indicators = VerificationUIComponents.render_message_review( | |
| next_message, | |
| next_message.pre_classified_label, | |
| 0.85, | |
| ["Distress indicator 1", "Distress indicator 2"] | |
| ) | |
| progress = VerificationUIComponents.update_progress_display(next_idx, len(message_queue)) | |
| correct_str, incorrect_str, accuracy_str = VerificationUIComponents.update_statistics_display( | |
| session.correct_count, | |
| session.incorrect_count | |
| ) | |
| return ( | |
| session, | |
| "", | |
| message_text, | |
| decision_badge, | |
| confidence, | |
| indicators, | |
| progress, | |
| correct_str, | |
| incorrect_str, | |
| accuracy_str, | |
| next_idx, | |
| [r.to_dict() for r in session.verifications], | |
| ) | |
| return ( | |
| session, | |
| "β Error processing feedback", | |
| "", "", "", "", | |
| "", | |
| "β Correct: 0", | |
| "β Incorrect: 0", | |
| "π Accuracy: 0%", | |
| current_idx, | |
| records, | |
| ) | |
| except Exception as e: | |
| return ( | |
| session, | |
| f"β Error: {str(e)}", | |
| "", "", "", "", | |
| "", | |
| "β Correct: 0", | |
| "β Incorrect: 0", | |
| "π Accuracy: 0%", | |
| current_idx, | |
| records, | |
| ) | |
| def handle_incorrect_feedback(session: VerificationSession, current_idx: int, dataset_id: str, message_queue: List[str], records: List[dict]): | |
| """Show correction selector.""" | |
| return "β Please select the correct classification below" | |
| def handle_submit_correction(session: VerificationSession, current_idx: int, dataset_id: str, message_queue: List[str], records: List[dict], correction: str, notes: str, store: JSONVerificationStore): | |
| """Handle correction submission.""" | |
| try: | |
| if not correction: | |
| return ( | |
| "β Please select a correction before submitting", | |
| session, | |
| current_idx, | |
| dataset_id, | |
| message_queue, | |
| records, | |
| "", "", "", "", | |
| "", | |
| "β Correct: 0", | |
| "β Incorrect: 0", | |
| "π Accuracy: 0%", | |
| "", | |
| "", | |
| ) | |
| # Get current message | |
| dataset = TestDatasetManager.load_dataset(dataset_id) | |
| current_message_id = message_queue[current_idx] | |
| current_message = next((m for m in dataset.messages if m.message_id == current_message_id), None) | |
| if not current_message: | |
| return ( | |
| "β Error: Message not found", | |
| session, | |
| current_idx, | |
| dataset_id, | |
| message_queue, | |
| records, | |
| "", "", "", "", | |
| "", | |
| "β Correct: 0", | |
| "β Incorrect: 0", | |
| "π Accuracy: 0%", | |
| "", | |
| "", | |
| ) | |
| # Create verification record | |
| record = VerificationRecord( | |
| message_id=current_message.message_id, | |
| original_message=current_message.text, | |
| classifier_decision=current_message.pre_classified_label, | |
| classifier_confidence=0.85, | |
| classifier_indicators=["Distress indicator 1", "Distress indicator 2"], | |
| ground_truth_label=correction, | |
| verifier_notes=notes, | |
| is_correct=current_message.pre_classified_label == correction, | |
| ) | |
| # Add to session | |
| session.verifications.append(record) | |
| session.verified_count += 1 | |
| if record.is_correct: | |
| session.correct_count += 1 | |
| else: | |
| session.incorrect_count += 1 | |
| # Save session | |
| store.save_session(session) | |
| # Move to next message | |
| next_idx = current_idx + 1 | |
| if next_idx >= len(message_queue): | |
| # Session complete | |
| session.is_complete = True | |
| session.completed_at = datetime.now() | |
| store.save_session(session) | |
| correct_str, incorrect_str, accuracy_str = VerificationUIComponents.update_statistics_display( | |
| session.correct_count, | |
| session.incorrect_count | |
| ) | |
| summary = VerificationUIComponents.render_summary_card(session, session.verifications) | |
| return ( | |
| "β Verification complete!", | |
| session, | |
| next_idx, | |
| dataset_id, | |
| message_queue, | |
| [r.to_dict() for r in session.verifications], | |
| "", "", "", "", | |
| "", | |
| correct_str, | |
| incorrect_str, | |
| accuracy_str, | |
| "", | |
| summary, | |
| ) | |
| else: | |
| # Load next message | |
| next_message = next((m for m in dataset.messages if m.message_id == message_queue[next_idx]), None) | |
| if next_message: | |
| message_text, decision_badge, confidence, indicators = VerificationUIComponents.render_message_review( | |
| next_message, | |
| next_message.pre_classified_label, | |
| 0.85, | |
| ["Distress indicator 1", "Distress indicator 2"] | |
| ) | |
| progress = VerificationUIComponents.update_progress_display(next_idx, len(message_queue)) | |
| correct_str, incorrect_str, accuracy_str = VerificationUIComponents.update_statistics_display( | |
| session.correct_count, | |
| session.incorrect_count | |
| ) | |
| return ( | |
| "", | |
| session, | |
| next_idx, | |
| dataset_id, | |
| message_queue, | |
| [r.to_dict() for r in session.verifications], | |
| message_text, | |
| decision_badge, | |
| confidence, | |
| indicators, | |
| progress, | |
| correct_str, | |
| incorrect_str, | |
| accuracy_str, | |
| "", | |
| "", | |
| ) | |
| return ( | |
| "β Error processing correction", | |
| session, | |
| current_idx, | |
| dataset_id, | |
| message_queue, | |
| records, | |
| "", "", "", "", | |
| "", | |
| "β Correct: 0", | |
| "β Incorrect: 0", | |
| "π Accuracy: 0%", | |
| "", | |
| "", | |
| ) | |
| except Exception as e: | |
| return ( | |
| f"β Error: {str(e)}", | |
| session, | |
| current_idx, | |
| dataset_id, | |
| message_queue, | |
| records, | |
| "", "", "", "", | |
| "", | |
| "β Correct: 0", | |
| "β Incorrect: 0", | |
| "π Accuracy: 0%", | |
| "", | |
| "", | |
| ) | |
| def handle_download_csv(session: VerificationSession, store: JSONVerificationStore): | |
| """Handle CSV download - returns file path for DownloadButton.""" | |
| try: | |
| if not session or session.verified_count == 0: | |
| return None | |
| csv_content = VerificationCSVExporter.generate_csv_content(session) | |
| filename = VerificationCSVExporter.generate_csv_filename() | |
| import tempfile | |
| # Use temp directory for Hugging Face compatibility | |
| temp_dir = tempfile.gettempdir() | |
| file_path = os.path.join(temp_dir, filename) | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| f.write(csv_content) | |
| return file_path | |
| except Exception as e: | |
| print(f"CSV Export Error: {traceback.format_exc()}") | |
| return None | |
| def handle_next_message(session: VerificationSession, current_idx: int, dataset_id: str, message_queue: List[str], records: List[dict]): | |
| """Move to next message.""" | |
| if not session or current_idx >= len(message_queue) - 1: | |
| return ( | |
| session, | |
| "β No more messages", | |
| "", "", "", "", | |
| "", | |
| "β Correct: 0", | |
| "β Incorrect: 0", | |
| "π Accuracy: 0%", | |
| current_idx, | |
| records, | |
| ) | |
| next_idx = current_idx + 1 | |
| dataset = TestDatasetManager.load_dataset(dataset_id) | |
| next_message = next((m for m in dataset.messages if m.message_id == message_queue[next_idx]), None) | |
| if next_message: | |
| message_text, decision_badge, confidence, indicators = VerificationUIComponents.render_message_review( | |
| next_message, | |
| next_message.pre_classified_label, | |
| 0.85, | |
| ["Distress indicator 1", "Distress indicator 2"] | |
| ) | |
| progress = VerificationUIComponents.update_progress_display(next_idx, len(message_queue)) | |
| correct_str, incorrect_str, accuracy_str = VerificationUIComponents.update_statistics_display( | |
| session.correct_count, | |
| session.incorrect_count | |
| ) | |
| return ( | |
| session, | |
| "", | |
| message_text, | |
| decision_badge, | |
| confidence, | |
| indicators, | |
| progress, | |
| correct_str, | |
| incorrect_str, | |
| accuracy_str, | |
| next_idx, | |
| records, | |
| ) | |
| return ( | |
| session, | |
| "β Error loading next message", | |
| "", "", "", "", | |
| "", | |
| "β Correct: 0", | |
| "β Incorrect: 0", | |
| "π Accuracy: 0%", | |
| current_idx, | |
| records, | |
| ) | |
| def handle_previous_message(session: VerificationSession, current_idx: int, dataset_id: str, message_queue: List[str], records: List[dict]): | |
| """Move to previous message.""" | |
| if not session or current_idx <= 0: | |
| return ( | |
| session, | |
| "β No previous messages", | |
| "", "", "", "", | |
| "", | |
| "β Correct: 0", | |
| "β Incorrect: 0", | |
| "π Accuracy: 0%", | |
| current_idx, | |
| records, | |
| ) | |
| prev_idx = current_idx - 1 | |
| dataset = TestDatasetManager.load_dataset(dataset_id) | |
| prev_message = next((m for m in dataset.messages if m.message_id == message_queue[prev_idx]), None) | |
| if prev_message: | |
| message_text, decision_badge, confidence, indicators = VerificationUIComponents.render_message_review( | |
| prev_message, | |
| prev_message.pre_classified_label, | |
| 0.85, | |
| ["Distress indicator 1", "Distress indicator 2"] | |
| ) | |
| progress = VerificationUIComponents.update_progress_display(prev_idx, len(message_queue)) | |
| correct_str, incorrect_str, accuracy_str = VerificationUIComponents.update_statistics_display( | |
| session.correct_count, | |
| session.incorrect_count | |
| ) | |
| return ( | |
| session, | |
| "", | |
| message_text, | |
| decision_badge, | |
| confidence, | |
| indicators, | |
| progress, | |
| correct_str, | |
| incorrect_str, | |
| accuracy_str, | |
| prev_idx, | |
| records, | |
| ) | |
| return ( | |
| session, | |
| "β Error loading previous message", | |
| "", "", "", "", | |
| "", | |
| "β Correct: 0", | |
| "β Incorrect: 0", | |
| "π Accuracy: 0%", | |
| current_idx, | |
| records, | |
| ) | |
| def handle_skip_message(session: VerificationSession, current_idx: int, dataset_id: str, message_queue: List[str], records: List[dict]): | |
| """Skip current message and move to next.""" | |
| return handle_next_message(session, current_idx, dataset_id, message_queue, records) | |
| def handle_clear_session(): | |
| """Clear current verification session.""" | |
| return ( | |
| None, # verification_session | |
| "β Session cleared", # error_message | |
| "", "", "", "", # message components | |
| "", # progress | |
| "β Correct: 0", # correct count | |
| "β Incorrect: 0", # incorrect count | |
| "π Accuracy: 0%", # accuracy | |
| 0, # current index | |
| [], # records | |
| ) | |
| def show_chaplain_feedback_section(): | |
| """Show chaplain feedback section after message review.""" | |
| return gr.Row(visible=True) | |
| def handle_submit_feedback( | |
| classification_correct: bool, | |
| classification_subcategory: Optional[str], | |
| correct_classification: Optional[str], | |
| question_issues: List[str], | |
| question_comments: str, | |
| referral_issues: List[str], | |
| referral_comments: str, | |
| indicator_issues: str, | |
| indicator_comments: str, | |
| general_notes: str, | |
| session: VerificationSession, | |
| current_idx: int, | |
| message_queue: List[str], | |
| ): | |
| """Handle chaplain feedback submission.""" | |
| try: | |
| if not session or current_idx >= len(message_queue): | |
| return "β Error: Invalid session state", session, current_idx | |
| current_message_id = message_queue[current_idx] | |
| tagging_record = TaggingRecord( | |
| record_id=str(uuid.uuid4()), | |
| message_id=current_message_id, | |
| is_classification_correct=classification_correct, | |
| classification_subcategory=classification_subcategory, | |
| correct_classification=correct_classification, | |
| question_issues=question_issues or [], | |
| question_comments=question_comments, | |
| referral_issues=referral_issues or [], | |
| referral_comments=referral_comments, | |
| indicator_issues=[i.strip() for i in indicator_issues.split(",") if i.strip()], | |
| indicator_comments=indicator_comments, | |
| general_notes=general_notes, | |
| ) | |
| # Store tagging record in session (would need to extend VerificationSession) | |
| # For now, just confirm submission | |
| success_msg = f"β Feedback submitted for message {current_idx + 1}" | |
| return success_msg, session, current_idx | |
| except Exception as e: | |
| return f"β Error: {str(e)}", session, current_idx | |
| def display_classification_flow(flow_result: Optional[ClassificationFlowResult]): | |
| """Display classification flow result.""" | |
| if not flow_result: | |
| return "", "", "", "" | |
| badge, explanation, content, indicators = ChaplainFeedbackUIComponents.render_classification_flow(flow_result) | |
| return badge, explanation, content, indicators | |
| def _download_latest_verification_json(session: SimplifiedSessionData): | |
| """Return the most recently exported verification session JSON path (if present).""" | |
| # open_verification_window exports into ./verification_sessions | |
| import glob | |
| export_dir = os.path.join(os.getcwd(), "verification_sessions") | |
| if not os.path.isdir(export_dir): | |
| return None | |
| candidates = sorted( | |
| glob.glob(os.path.join(export_dir, "verification_session_*.json")), | |
| key=lambda p: os.path.getmtime(p), | |
| reverse=True, | |
| ) | |
| return candidates[0] if candidates else None | |
| def _render_conv_exchange(records: list, index: int): | |
| if not records: | |
| return "", "", "" | |
| index = max(0, min(index, len(records) - 1)) | |
| r = records[index] | |
| # Check if this is a Provider Summary exchange (Or_4.txt requirement) | |
| if isinstance(r, dict) and r.get("original_classification") == "PROVIDER_SUMMARY": | |
| # Render Provider Summary as final exchange | |
| provider_summary_html = r.get("provider_summary_html", "") | |
| if not provider_summary_html: | |
| # Fallback rendering if HTML not provided | |
| provider_summary_text = r.get("provider_summary", "") | |
| provider_summary_html = f""" | |
| <div style="background-color: #fff3cd; border-left: 4px solid #ffc107; padding: 1em; margin: 1em 0; border-radius: 4px;"> | |
| <h3 style="margin-top: 0; color: #856404;">π Provider Summary (Final Review)</h3> | |
| <div style="background-color: white; padding: 1em; border-radius: 4px; margin-top: 0.5em;"> | |
| <pre style="white-space: pre-wrap; font-family: system-ui; margin: 0;">{provider_summary_text}</pre> | |
| </div> | |
| <p style="margin-bottom: 0; margin-top: 0.5em; font-size: 0.9em; color: #856404;"> | |
| <strong>Please review this summary and provide feedback if incorrect or incomplete.</strong> | |
| </p> | |
| </div> | |
| """ | |
| html = provider_summary_html | |
| else: | |
| # Regular exchange rendering | |
| # Reuse renderer from conversation_verification_ui to keep style consistent | |
| from src.interface.conversation_verification_ui import VerificationInterface | |
| vi = VerificationInterface(ConversationVerificationManager()) | |
| # If we already have dicts, build a lightweight VerificationRecord | |
| if isinstance(r, dict): | |
| rec = ConvVerificationRecord( | |
| exchange_id=r.get("exchange_id") or r.get("record_id", ""), | |
| exchange_number=r.get("exchange_number", 0), | |
| user_message=r.get("user_message", ""), | |
| assistant_response=r.get("assistant_response", ""), | |
| original_classification=r.get("original_classification", ""), | |
| original_confidence=r.get("original_confidence", 0.0), | |
| original_indicators=r.get("original_indicators", []), | |
| original_reasoning=r.get("original_reasoning", ""), | |
| timestamp=r.get("timestamp"), | |
| is_correct=r.get("is_correct"), | |
| correct_classification=r.get("correct_classification"), | |
| correction_reason=r.get("correction_reason"), | |
| verifier_notes=r.get("verifier_notes"), | |
| ) | |
| else: | |
| rec = r | |
| html = vi._render_exchange_review(rec) | |
| # status badge | |
| cur_is_correct = (r.get("is_correct") if isinstance(r, dict) else getattr(r, "is_correct", None)) | |
| if cur_is_correct is True: | |
| badge = "β " | |
| elif cur_is_correct is False: | |
| badge = "β" | |
| else: | |
| badge = "β³" | |
| pos = f"### {badge} Exchange {index + 1} of {len(records)}" | |
| # richer stats | |
| reviewed = 0 | |
| correct = 0 | |
| incorrect = 0 | |
| incorrect_with_comment = 0 | |
| corrections = {} # Track classification corrections | |
| for x in records: | |
| v = (x.get("is_correct") if isinstance(x, dict) else getattr(x, "is_correct", None)) | |
| if v is None: | |
| continue | |
| reviewed += 1 | |
| if v is True: | |
| correct += 1 | |
| else: | |
| incorrect += 1 | |
| note = (x.get("verifier_notes") if isinstance(x, dict) else getattr(x, "verifier_notes", None)) | |
| if note and str(note).strip(): | |
| incorrect_with_comment += 1 | |
| # Track classification corrections | |
| original_class = (x.get("original_classification") if isinstance(x, dict) else getattr(x, "original_classification", "")) | |
| correct_class = (x.get("correct_classification") if isinstance(x, dict) else getattr(x, "correct_classification", None)) | |
| if original_class and correct_class: | |
| correction_key = f"{original_class}β{correct_class}" | |
| corrections[correction_key] = corrections.get(correction_key, 0) + 1 | |
| stats_parts = [ | |
| f"<div><strong>Reviewed:</strong> {reviewed}/{len(records)}</div>", | |
| f"<div><strong>β Correct:</strong> {correct}</div>", | |
| f"<div><strong>β Incorrect:</strong> {incorrect}</div>", | |
| f"<div><strong>π Incorrect w/ comment:</strong> {incorrect_with_comment}</div>" | |
| ] | |
| # Add correction breakdown if any corrections exist | |
| if corrections: | |
| correction_text = ", ".join([f"{k}: {v}" for k, v in corrections.items()]) | |
| stats_parts.append(f"<div><strong>π Corrections:</strong> {correction_text}</div>") | |
| stats = ( | |
| "<div style='display:flex; gap:12px; flex-wrap:wrap;'>" | |
| + "".join(stats_parts) + | |
| "</div>" | |
| ) | |
| return html, pos, stats | |
| def _comment_ui_state(records: list, idx: int): | |
| """Return (row_update, note_value) based on current record state.""" | |
| if not records: | |
| return gr.update(visible=False), "" | |
| idx = max(0, min(idx, len(records) - 1)) | |
| r = records[idx] | |
| is_incorrect = (r.get("is_correct") is False) if isinstance(r, dict) else (getattr(r, "is_correct", None) is False) | |
| if not is_incorrect: | |
| return gr.update(visible=False), "" | |
| note = (r.get("verifier_notes") or "") if isinstance(r, dict) else (getattr(r, "verifier_notes", "") or "") | |
| return gr.update(visible=True), str(note) | |
| def _export_conv_records_to_json(meta: dict, records: list): | |
| """Write reviewed conversation verification results to a JSON file and return its path.""" | |
| import json | |
| import os | |
| from datetime import datetime | |
| export_dir = os.path.join(os.getcwd(), "verification_sessions") | |
| os.makedirs(export_dir, exist_ok=True) | |
| session_id = (meta or {}).get("session_id") or "conversation_verification" | |
| export_filename = f"conversation_verification_reviewed_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{session_id}.json" | |
| export_path = os.path.join(export_dir, export_filename) | |
| payload = { | |
| **(meta or {}), | |
| "verification_records": records or [], | |
| } | |
| with open(export_path, "w", encoding="utf-8") as f: | |
| json.dump(payload, f, ensure_ascii=False, indent=2, default=str) | |
| return export_path | |
| def _export_conv_records_to_csv(meta: dict, records: list): | |
| """Write reviewed conversation verification results to a CSV file and return its path.""" | |
| import csv | |
| import os | |
| from datetime import datetime | |
| export_dir = os.path.join(os.getcwd(), "verification_exports") | |
| os.makedirs(export_dir, exist_ok=True) | |
| session_id = (meta or {}).get("session_id") or "conversation_verification" | |
| export_filename = f"conversation_verification_reviewed_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{session_id}.csv" | |
| export_path = os.path.join(export_dir, export_filename) | |
| fieldnames = [ | |
| "session_id", | |
| "patient_name", | |
| "patient_phone", | |
| "verifier_name", | |
| "start_time", | |
| "exchange_number", | |
| "exchange_id", | |
| "original_classification", | |
| "original_confidence", | |
| "is_correct", | |
| "correct_classification", | |
| "verifier_notes", | |
| "user_message", | |
| "assistant_response", | |
| "provider_summary", | |
| ] | |
| with open(export_path, "w", encoding="utf-8", newline="") as f: | |
| w = csv.DictWriter(f, fieldnames=fieldnames) | |
| w.writeheader() | |
| for r in records or []: | |
| # Include provider_summary only for RED cases | |
| provider_summary = "" | |
| if r.get("original_classification", "").upper() == "RED": | |
| provider_summary = r.get("provider_summary") or "" | |
| row = { | |
| "session_id": (meta or {}).get("session_id"), | |
| "patient_name": (meta or {}).get("patient_name"), | |
| "patient_phone": (meta or {}).get("patient_phone") or "", | |
| "verifier_name": (meta or {}).get("verifier_name"), | |
| "start_time": (meta or {}).get("start_time"), | |
| "exchange_number": r.get("exchange_number"), | |
| "exchange_id": r.get("exchange_id") or r.get("record_id"), | |
| "original_classification": r.get("original_classification"), | |
| "original_confidence": r.get("original_confidence"), | |
| "is_correct": r.get("is_correct"), | |
| "correct_classification": r.get("correct_classification") or "", | |
| "verifier_notes": r.get("verifier_notes") or "", | |
| "user_message": r.get("user_message"), | |
| "assistant_response": r.get("assistant_response"), | |
| "provider_summary": provider_summary, | |
| } | |
| w.writerow(row) | |
| return export_path | |
| def _generate_conv_verification(session: SimplifiedSessionData): | |
| if session is None or not hasattr(session.app_instance, "conversation_logger"): | |
| return None, [], 0, "β No session/conversation found", "", "" | |
| if not session.app_instance.conversation_logger.entries: | |
| return None, [], 0, "β οΈ No exchanges to verify yet", "", "" | |
| manager = ConversationVerificationManager() | |
| vs = manager.create_verification_session(session.app_instance.conversation_logger, "Medical Professional") | |
| # Get patient phone from app if available | |
| patient_phone = "" | |
| if hasattr(session.app_instance, 'patient_info'): | |
| patient_phone = session.app_instance.patient_info.get("phone") or "" | |
| meta = { | |
| "session_id": vs.session_id, | |
| "patient_name": vs.patient_name, | |
| "patient_phone": patient_phone, | |
| "verifier_name": vs.verifier_name, | |
| "start_time": vs.start_time.isoformat() if hasattr(vs, "start_time") else None, | |
| } | |
| # Get provider summary if available (for RED cases) | |
| provider_summary_text = "" | |
| if hasattr(session.app_instance, 'get_last_provider_summary'): | |
| summary = session.app_instance.get_last_provider_summary() | |
| if summary and hasattr(session.app_instance, 'provider_summary_generator'): | |
| provider_summary_text = session.app_instance.provider_summary_generator.format_for_export(summary) | |
| records_as_dicts = [ | |
| { | |
| "exchange_id": r.exchange_id, | |
| "exchange_number": r.exchange_number, | |
| "record_id": r.exchange_id, | |
| "timestamp": r.timestamp, | |
| "user_message": r.user_message, | |
| "assistant_response": r.assistant_response, | |
| "original_classification": r.original_classification, | |
| "original_confidence": r.original_confidence, | |
| "original_indicators": r.original_indicators, | |
| "original_reasoning": r.original_reasoning, | |
| "is_correct": r.is_correct, | |
| "correct_classification": r.correct_classification, | |
| "correction_reason": r.correction_reason, | |
| "verifier_notes": r.verifier_notes, | |
| "provider_summary": provider_summary_text if r.original_classification.upper() == "RED" else "", | |
| } | |
| for r in vs.verification_records | |
| ] | |
| html, pos, stats = _render_conv_exchange(records_as_dicts, 0) | |
| return meta, records_as_dicts, 0, f"β Generated session `{vs.session_id}`", html, pos, stats | |
| def _mark_conv_correct(records: list, idx: int): | |
| if not records: | |
| return records, idx, "", "", "", gr.update(visible=False), "", "" | |
| idx = max(0, min(idx, len(records) - 1)) | |
| if isinstance(records[idx], dict): | |
| records[idx]["is_correct"] = True | |
| # clear comment and correct_classification when marked correct (avoid stale data) | |
| records[idx]["verifier_notes"] = "" | |
| records[idx]["correct_classification"] = None | |
| html, pos, stats = _render_conv_exchange(records, idx) | |
| row_upd, note_val = _comment_ui_state(records, idx) | |
| return records, idx, "β Marked correct", html, pos, stats, row_upd, note_val, "" | |
| def _mark_conv_incorrect(records: list, idx: int): | |
| if not records: | |
| return records, idx, "", "", "", gr.update(visible=False), "", "" | |
| idx = max(0, min(idx, len(records) - 1)) | |
| if isinstance(records[idx], dict): | |
| records[idx]["is_correct"] = False | |
| html, pos, stats = _render_conv_exchange(records, idx) | |
| row_upd, note_val = _comment_ui_state(records, idx) | |
| # Get existing correct_classification if any | |
| existing_classification = "" | |
| if isinstance(records[idx], dict): | |
| correct_class = records[idx].get("correct_classification") | |
| if correct_class: | |
| # Map back to display text | |
| reverse_map = { | |
| "GREEN": "π’ Should be GREEN - No distress", | |
| "YELLOW": "π‘ Should be YELLOW - Needs clarification", | |
| "RED": "π΄ Should be RED - Spiritual distress" | |
| } | |
| existing_classification = reverse_map.get(correct_class, "") | |
| return records, idx, "β Marked incorrect", html, pos, stats, row_upd, note_val, existing_classification | |
| def _show_incorrect_comment_ui(records: list, idx: int): | |
| """Mark incorrect and open the comment row, pre-filling any existing note.""" | |
| records, idx, status, html, pos, stats, _row, note, existing_classification = _mark_conv_incorrect(records, idx) | |
| return records, idx, status, html, pos, stats, gr.update(visible=True), note, existing_classification | |
| def _save_incorrect_comment(records: list, idx: int, note: str, correct_classification: str): | |
| if not records: | |
| return records, idx, "", "", "", "", gr.update(visible=False), "", "" | |
| idx = max(0, min(idx, len(records) - 1)) | |
| if isinstance(records[idx], dict): | |
| records[idx]["verifier_notes"] = (note or "").strip() | |
| # Map display text to classification code | |
| classification_map = { | |
| "π’ Should be GREEN - No distress": "GREEN", | |
| "π‘ Should be YELLOW - Needs clarification": "YELLOW", | |
| "π΄ Should be RED - Spiritual distress": "RED" | |
| } | |
| if correct_classification and correct_classification in classification_map: | |
| records[idx]["correct_classification"] = classification_map[correct_classification] | |
| html, pos, stats = _render_conv_exchange(records, idx) | |
| row_upd, note_val = _comment_ui_state(records, idx) | |
| # keep row visible after save (since still incorrect) | |
| return records, idx, "πΎ Comment saved", html, pos, stats, row_upd, note_val, "" | |
| def _download_reviewed_json(meta: dict, records: list): | |
| return _export_conv_records_to_json(meta, records) | |
| def _download_reviewed_csv(meta: dict, records: list): | |
| return _export_conv_records_to_csv(meta, records) | |
| def _nav_conv(records: list, idx: int, delta: int): | |
| if not records: | |
| return idx, "", "", "", gr.update(visible=False), "", "" | |
| idx = max(0, min(idx + delta, len(records) - 1)) | |
| html, pos, stats = _render_conv_exchange(records, idx) | |
| row_upd, note_val = _comment_ui_state(records, idx) | |
| # Get existing correct_classification if any | |
| existing_classification = "" | |
| if isinstance(records[idx], dict): | |
| correct_class = records[idx].get("correct_classification") | |
| if correct_class: | |
| reverse_map = { | |
| "GREEN": "π’ Should be GREEN - No distress", | |
| "YELLOW": "π‘ Should be YELLOW - Needs clarification", | |
| "RED": "π΄ Should be RED - Spiritual distress" | |
| } | |
| existing_classification = reverse_map.get(correct_class, "") | |
| return idx, html, pos, stats, row_upd, note_val, existing_classification | |
| # ============================================================================ | |
| # NEW FUNCTIONS FOR SIMPLIFIED INTERFACE (Or_4.txt requirements) | |
| # ============================================================================ | |
| def _generate_conv_verification_with_summary(session: SimplifiedSessionData): | |
| """ | |
| Generate conversation verification with Provider Summary as the FINAL exchange. | |
| This addresses the customer requirement from Or_4.txt: | |
| "Provider Summary to be the final exchange presented in that tab" | |
| """ | |
| if session is None or not hasattr(session.app_instance, "conversation_logger"): | |
| return None, [], 0, "β No session/conversation found", "", "", "" | |
| if not session.app_instance.conversation_logger.entries: | |
| return None, [], 0, "β οΈ No exchanges to verify yet", "", "", "" | |
| manager = ConversationVerificationManager() | |
| vs = manager.create_verification_session(session.app_instance.conversation_logger, "Medical Professional") | |
| # Get patient phone from app if available | |
| patient_phone = "" | |
| if hasattr(session.app_instance, 'patient_info'): | |
| patient_phone = session.app_instance.patient_info.get("phone") or "" | |
| meta = { | |
| "session_id": vs.session_id, | |
| "patient_name": vs.patient_name, | |
| "patient_phone": patient_phone, | |
| "verifier_name": vs.verifier_name, | |
| "start_time": vs.start_time.isoformat() if hasattr(vs, "start_time") else None, | |
| } | |
| # Get provider summary if available (for RED cases) | |
| provider_summary_text = "" | |
| provider_summary_html = "" | |
| has_red_flag = False | |
| if hasattr(session.app_instance, 'get_last_provider_summary'): | |
| summary = session.app_instance.get_last_provider_summary() | |
| if summary: | |
| has_red_flag = True | |
| if hasattr(session.app_instance, 'provider_summary_generator'): | |
| # Use COHERENT NARRATIVE format (LLM-generated) instead of structured format | |
| try: | |
| provider_summary_text = session.app_instance.provider_summary_generator.format_coherent_paragraph(summary) | |
| if not provider_summary_text: | |
| # Fallback to structured format | |
| provider_summary_text = session.app_instance.provider_summary_generator.format_for_export(summary) | |
| except Exception as e: | |
| print(f"ERROR: Failed to generate coherent summary: {e}") | |
| provider_summary_text = session.app_instance.provider_summary_generator.format_for_export(summary) | |
| # Create HTML version for display | |
| provider_summary_html = f""" | |
| <div style="background-color: #fff3cd; border-left: 4px solid #ffc107; padding: 1em; margin: 1em 0; border-radius: 4px;"> | |
| <h3 style="margin-top: 0; color: #856404;">π Provider Summary (Final Review)</h3> | |
| <div style="background-color: white; padding: 1em; border-radius: 4px; margin-top: 0.5em;"> | |
| <pre style="white-space: pre-wrap; font-family: system-ui; margin: 0;">{provider_summary_text}</pre> | |
| </div> | |
| <p style="margin-bottom: 0; margin-top: 0.5em; font-size: 0.9em; color: #856404;"> | |
| <strong>Please review this summary and provide feedback if incorrect or incomplete.</strong> | |
| </p> | |
| </div> | |
| """ | |
| records_as_dicts = [ | |
| { | |
| "exchange_id": r.exchange_id, | |
| "exchange_number": r.exchange_number, | |
| "record_id": r.exchange_id, | |
| "timestamp": r.timestamp, | |
| "user_message": r.user_message, | |
| "assistant_response": r.assistant_response, | |
| "original_classification": r.original_classification, | |
| "original_confidence": r.original_confidence, | |
| "original_indicators": r.original_indicators, | |
| "original_reasoning": r.original_reasoning, | |
| "is_correct": r.is_correct, | |
| "correct_classification": r.correct_classification, | |
| "correction_reason": r.correction_reason, | |
| "verifier_notes": r.verifier_notes, | |
| "provider_summary": "", # Not shown in regular exchanges | |
| } | |
| for r in vs.verification_records | |
| ] | |
| # ADD PROVIDER SUMMARY AS FINAL EXCHANGE (Or_4.txt requirement) | |
| if has_red_flag and provider_summary_html: | |
| final_exchange = { | |
| "exchange_id": f"{vs.session_id}_provider_summary", | |
| "exchange_number": len(records_as_dicts) + 1, | |
| "record_id": f"{vs.session_id}_provider_summary", | |
| "timestamp": datetime.now().isoformat(), | |
| "user_message": "", | |
| "assistant_response": "", | |
| "original_classification": "PROVIDER_SUMMARY", | |
| "original_confidence": 1.0, | |
| "original_indicators": [], | |
| "original_reasoning": "Provider Summary for Spiritual Care Team", | |
| "is_correct": None, # Needs review | |
| "correct_classification": None, | |
| "correction_reason": "", | |
| "verifier_notes": "", | |
| "provider_summary": provider_summary_text, | |
| "provider_summary_html": provider_summary_html, | |
| } | |
| records_as_dicts.append(final_exchange) | |
| html, pos, stats = _render_conv_exchange(records_as_dicts, 0) | |
| return meta, records_as_dicts, 0, f"β Generated session with {len(records_as_dicts)} exchanges (Provider Summary as final step)", html, pos, stats | |
| def _auto_save_verification_report(meta: dict, records: list, session: SimplifiedSessionData): | |
| """ | |
| Auto-save verification report to a predefined location. | |
| This addresses the customer requirement from Or_4.txt: | |
| "I would prefer a single button for automatically saving the report" | |
| Saves both JSON and CSV formats to a standard location. | |
| """ | |
| try: | |
| if not records: | |
| return "β οΈ No verification data to save" | |
| # Create auto-save directory | |
| auto_save_dir = os.path.join(os.getcwd(), "verification_reports") | |
| os.makedirs(auto_save_dir, exist_ok=True) | |
| session_id = (meta or {}).get("session_id") or "unknown" | |
| timestamp = datetime.utcnow().strftime('%Y%m%d_%H%M%S') | |
| # Save JSON | |
| json_filename = f"report_{timestamp}_{session_id}.json" | |
| json_path = os.path.join(auto_save_dir, json_filename) | |
| payload = { | |
| **(meta or {}), | |
| "verification_records": records or [], | |
| "auto_saved_at": datetime.utcnow().isoformat(), | |
| } | |
| with open(json_path, "w", encoding="utf-8") as f: | |
| json.dump(payload, f, ensure_ascii=False, indent=2, default=str) | |
| # Save CSV | |
| csv_filename = f"report_{timestamp}_{session_id}.csv" | |
| csv_path = os.path.join(auto_save_dir, csv_filename) | |
| fieldnames = [ | |
| "session_id", | |
| "patient_name", | |
| "patient_phone", | |
| "verifier_name", | |
| "start_time", | |
| "exchange_number", | |
| "exchange_id", | |
| "original_classification", | |
| "original_confidence", | |
| "is_correct", | |
| "correct_classification", | |
| "verifier_notes", | |
| "user_message", | |
| "assistant_response", | |
| "provider_summary", | |
| ] | |
| with open(csv_path, "w", encoding="utf-8", newline="") as f: | |
| w = csv.DictWriter(f, fieldnames=fieldnames) | |
| w.writeheader() | |
| for r in records or []: | |
| # Include provider_summary for all records (especially the final one) | |
| provider_summary = r.get("provider_summary") or "" | |
| row = { | |
| "session_id": (meta or {}).get("session_id"), | |
| "patient_name": (meta or {}).get("patient_name"), | |
| "patient_phone": (meta or {}).get("patient_phone") or "", | |
| "verifier_name": (meta or {}).get("verifier_name"), | |
| "start_time": (meta or {}).get("start_time"), | |
| "exchange_number": r.get("exchange_number"), | |
| "exchange_id": r.get("exchange_id") or r.get("record_id"), | |
| "original_classification": r.get("original_classification"), | |
| "original_confidence": r.get("original_confidence"), | |
| "is_correct": r.get("is_correct"), | |
| "correct_classification": r.get("correct_classification") or "", | |
| "verifier_notes": r.get("verifier_notes") or "", | |
| "user_message": r.get("user_message"), | |
| "assistant_response": r.get("assistant_response"), | |
| "provider_summary": provider_summary, | |
| } | |
| w.writerow(row) | |
| return f"""β **Report Auto-Saved Successfully!** | |
| π **Location:** `{auto_save_dir}` | |
| π **Files:** | |
| - JSON: `{json_filename}` | |
| - CSV: `{csv_filename}` | |
| π **Summary:** | |
| - Total exchanges: {len(records)} | |
| - Reviewed: {sum(1 for r in records if r.get('is_correct') is not None)} | |
| - Correct: {sum(1 for r in records if r.get('is_correct') is True)} | |
| - Incorrect: {sum(1 for r in records if r.get('is_correct') is False)} | |
| """ | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"β Auto-save error: {error_details}") | |
| return f"β **Auto-save failed:** {str(e)}" | |