Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Conversation Verification System - Core Models and Manager. | |
| Provides data models and management functionality for verifying AI classifier | |
| decisions made during patient conversations. | |
| """ | |
| import json | |
| import os | |
| import uuid | |
| from datetime import datetime | |
| from typing import Dict, List, Any, Optional, Tuple | |
| from dataclasses import dataclass, asdict, field | |
| from src.core.conversation_logger import ConversationLogger, ConversationEntry | |
| from src.interface.enhanced_results_display_manager import EnhancedResultsDisplayManager, EnhancedDisplayConfig | |
| from src.core.provider_summary_generator import ProviderSummary | |
| class VerificationFeedback: | |
| """Feedback provided by verifier for a conversation exchange.""" | |
| exchange_id: str | |
| is_correct: bool | |
| correct_classification: Optional[str] = None # Required if is_correct=False | |
| correction_reason: Optional[str] = None | |
| notes: Optional[str] = None | |
| class EnhancedVerificationRecord: | |
| """Enhanced verification record with new format support.""" | |
| exchange_id: str | |
| exchange_number: int | |
| timestamp: datetime | |
| user_message: str | |
| assistant_response: str | |
| original_classification: str # GREEN/YELLOW/RED | |
| original_confidence: float | |
| original_indicators: List[str] | |
| original_reasoning: str | |
| # Enhanced fields for new formats | |
| enhanced_display_format: Optional[str] = None # HTML formatted display | |
| provider_summary: Optional[Dict[str, Any]] = None # Provider summary data | |
| coherent_summary_paragraph: Optional[str] = None # Coherent paragraph format | |
| visual_sections: Optional[List[Dict[str, str]]] = None # Visual section data | |
| # Verification fields | |
| is_correct: Optional[bool] = None | |
| correct_classification: Optional[str] = None | |
| correction_reason: Optional[str] = None | |
| verifier_notes: Optional[str] = None | |
| verification_timestamp: Optional[datetime] = None | |
| def from_conversation_entry(cls, entry: ConversationEntry, exchange_number: int) -> 'EnhancedVerificationRecord': | |
| """Create EnhancedVerificationRecord from ConversationEntry.""" | |
| return cls( | |
| exchange_id=f"{entry.session_id}_{entry.message_index}", | |
| exchange_number=exchange_number, | |
| timestamp=datetime.fromisoformat(entry.timestamp), | |
| user_message=entry.user_message, | |
| assistant_response=entry.assistant_response, | |
| original_classification=entry.spiritual_classification, | |
| original_confidence=entry.classification_confidence, | |
| original_indicators=entry.classification_indicators.copy(), | |
| original_reasoning=entry.classification_reasoning | |
| ) | |
| def apply_feedback(self, feedback: VerificationFeedback) -> None: | |
| """Apply verification feedback to this record.""" | |
| self.is_correct = feedback.is_correct | |
| self.correct_classification = feedback.correct_classification | |
| self.correction_reason = feedback.correction_reason | |
| self.verifier_notes = feedback.notes | |
| self.verification_timestamp = datetime.now() | |
| def set_enhanced_formats( | |
| self, | |
| enhanced_display: Optional[str] = None, | |
| provider_summary: Optional[ProviderSummary] = None, | |
| coherent_paragraph: Optional[str] = None, | |
| visual_sections: Optional[List[Dict[str, str]]] = None | |
| ) -> None: | |
| """Set enhanced format data for this record.""" | |
| self.enhanced_display_format = enhanced_display | |
| if provider_summary: | |
| self.provider_summary = provider_summary.to_dict() | |
| self.coherent_summary_paragraph = coherent_paragraph | |
| self.visual_sections = visual_sections or [] | |
| class VerificationProgress: | |
| """Progress tracking for verification session.""" | |
| total_exchanges: int | |
| verified_exchanges: int | |
| accuracy_overall: float = 0.0 | |
| accuracy_by_type: Dict[str, float] = field(default_factory=dict) | |
| common_errors: List[Tuple[str, str, int]] = field(default_factory=list) | |
| def calculate_progress_percentage(self) -> float: | |
| """Calculate verification progress as percentage.""" | |
| if self.total_exchanges == 0: | |
| return 0.0 | |
| return (self.verified_exchanges / self.total_exchanges) * 100 | |
| def is_complete(self) -> bool: | |
| """Check if verification is complete.""" | |
| return self.verified_exchanges == self.total_exchanges | |
| class EnhancedVerificationSession: | |
| """Enhanced verification session with new format support.""" | |
| session_id: str | |
| conversation_session_id: str # Links to ConversationLogger session | |
| patient_name: str | |
| verifier_name: str | |
| start_time: datetime | |
| end_time: Optional[datetime] = None | |
| total_exchanges: int = 0 | |
| verified_exchanges: int = 0 | |
| verification_records: List[EnhancedVerificationRecord] = field(default_factory=list) | |
| is_complete: bool = False | |
| # Enhanced format support | |
| display_manager: Optional[EnhancedResultsDisplayManager] = None | |
| enhanced_format_enabled: bool = True | |
| def get_progress(self) -> VerificationProgress: | |
| """Get current verification progress.""" | |
| # Calculate overall accuracy | |
| verified_records = [r for r in self.verification_records if r.is_correct is not None] | |
| correct_count = sum(1 for r in verified_records if r.is_correct) | |
| accuracy_overall = (correct_count / len(verified_records)) if verified_records else 0.0 | |
| # Calculate accuracy by classification type | |
| accuracy_by_type = {} | |
| for classification in ['GREEN', 'YELLOW', 'RED']: | |
| type_records = [r for r in verified_records if r.original_classification == classification] | |
| if type_records: | |
| type_correct = sum(1 for r in type_records if r.is_correct) | |
| accuracy_by_type[classification] = type_correct / len(type_records) | |
| else: | |
| accuracy_by_type[classification] = 0.0 | |
| # Find common errors | |
| error_patterns = {} | |
| for record in verified_records: | |
| if not record.is_correct and record.correct_classification: | |
| error_key = (record.original_classification, record.correct_classification) | |
| error_patterns[error_key] = error_patterns.get(error_key, 0) + 1 | |
| common_errors = [(from_class, to_class, count) | |
| for (from_class, to_class), count in | |
| sorted(error_patterns.items(), key=lambda x: x[1], reverse=True)[:5]] | |
| return VerificationProgress( | |
| total_exchanges=self.total_exchanges, | |
| verified_exchanges=len(verified_records), | |
| accuracy_overall=accuracy_overall, | |
| accuracy_by_type=accuracy_by_type, | |
| common_errors=common_errors | |
| ) | |
| def add_verification_record(self, record: EnhancedVerificationRecord) -> None: | |
| """Add verification record to session.""" | |
| self.verification_records.append(record) | |
| def apply_feedback(self, exchange_id: str, feedback: VerificationFeedback) -> bool: | |
| """Apply feedback to specific exchange.""" | |
| for record in self.verification_records: | |
| if record.exchange_id == exchange_id: | |
| record.apply_feedback(feedback) | |
| self.verified_exchanges = len([r for r in self.verification_records if r.is_correct is not None]) | |
| # Check if session is complete | |
| if self.verified_exchanges == self.total_exchanges: | |
| self.is_complete = True | |
| self.end_time = datetime.now() | |
| return True | |
| return False | |
| def get_unverified_records(self) -> List[EnhancedVerificationRecord]: | |
| """Get list of unverified records.""" | |
| return [r for r in self.verification_records if r.is_correct is None] | |
| def get_next_unverified_record(self) -> Optional[EnhancedVerificationRecord]: | |
| """Get next unverified record.""" | |
| unverified = self.get_unverified_records() | |
| return unverified[0] if unverified else None | |
| def generate_enhanced_display_for_record(self, record: EnhancedVerificationRecord) -> str: | |
| """Generate enhanced display format for a verification record.""" | |
| if not self.display_manager: | |
| self.display_manager = EnhancedResultsDisplayManager() | |
| # Create AI analysis data | |
| ai_analysis = { | |
| 'classification': record.original_classification, | |
| 'indicators': record.original_indicators, | |
| 'reasoning': record.original_reasoning, | |
| 'confidence': record.original_confidence | |
| } | |
| # Generate enhanced display | |
| enhanced_display = self.display_manager.format_combined_results( | |
| ai_analysis=ai_analysis, | |
| patient_message=record.user_message, | |
| provider_summary=None # Would need to be generated separately | |
| ) | |
| # Store enhanced display in record | |
| record.enhanced_display_format = enhanced_display | |
| return enhanced_display | |
| class EnhancedConversationVerificationManager: | |
| """Enhanced manager for conversation verification sessions with new format support.""" | |
| def __init__(self, storage_dir: str = "verification_sessions"): | |
| """Initialize enhanced verification manager.""" | |
| from src.core.verification_store import JSONVerificationStore | |
| self.store = JSONVerificationStore(storage_dir) | |
| self.display_manager = EnhancedResultsDisplayManager() | |
| def create_verification_session( | |
| self, | |
| conversation_logger: ConversationLogger, | |
| verifier_name: str = "Medical Professional", | |
| enable_enhanced_formats: bool = True | |
| ) -> EnhancedVerificationSession: | |
| """ | |
| Create new enhanced verification session from conversation logger. | |
| Args: | |
| conversation_logger: Source conversation to verify | |
| verifier_name: Name of person doing verification | |
| enable_enhanced_formats: Whether to enable enhanced display formats | |
| Returns: | |
| New EnhancedVerificationSession ready for verification | |
| """ | |
| session_id = f"verification_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{str(uuid.uuid4())[:8]}" | |
| # Create enhanced verification session | |
| session = EnhancedVerificationSession( | |
| session_id=session_id, | |
| conversation_session_id=conversation_logger.session_id, | |
| patient_name=conversation_logger.patient_name, | |
| verifier_name=verifier_name, | |
| start_time=datetime.now(), | |
| total_exchanges=len(conversation_logger.entries), | |
| display_manager=self.display_manager if enable_enhanced_formats else None, | |
| enhanced_format_enabled=enable_enhanced_formats | |
| ) | |
| # Convert conversation entries to enhanced verification records | |
| for i, entry in enumerate(conversation_logger.entries, 1): | |
| record = EnhancedVerificationRecord.from_conversation_entry(entry, i) | |
| # Generate enhanced formats if enabled | |
| if enable_enhanced_formats: | |
| self._generate_enhanced_formats_for_record(record, entry) | |
| session.add_verification_record(record) | |
| # Save initial session | |
| self.store.save_session(session) | |
| return session | |
| def _generate_enhanced_formats_for_record( | |
| self, | |
| record: EnhancedVerificationRecord, | |
| entry: ConversationEntry | |
| ) -> None: | |
| """Generate enhanced formats for a verification record.""" | |
| try: | |
| # Generate enhanced display format | |
| ai_analysis = { | |
| 'classification': entry.spiritual_classification, | |
| 'indicators': entry.classification_indicators, | |
| 'reasoning': entry.classification_reasoning, | |
| 'confidence': entry.classification_confidence | |
| } | |
| enhanced_display = self.display_manager.format_combined_results( | |
| ai_analysis=ai_analysis, | |
| patient_message=entry.user_message, | |
| provider_summary=None # Would need provider summary data | |
| ) | |
| # Generate visual sections data | |
| visual_sections = [ | |
| { | |
| 'type': 'ai_analysis', | |
| 'classification': entry.spiritual_classification, | |
| 'confidence': str(entry.classification_confidence), | |
| 'indicators': '; '.join(entry.classification_indicators), | |
| 'reasoning': entry.classification_reasoning | |
| }, | |
| { | |
| 'type': 'patient_message', | |
| 'content': entry.user_message | |
| }, | |
| { | |
| 'type': 'assistant_response', | |
| 'content': entry.assistant_response | |
| } | |
| ] | |
| # Set enhanced formats in record | |
| record.set_enhanced_formats( | |
| enhanced_display=enhanced_display, | |
| visual_sections=visual_sections | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Could not generate enhanced formats for record {record.exchange_id}: {e}") | |
| def get_verification_progress(self, session_id: str) -> Optional[VerificationProgress]: | |
| """Get verification progress for session.""" | |
| session = self.store.load_session(session_id) | |
| return session.get_progress() if session else None | |
| def submit_exchange_verification( | |
| self, | |
| session_id: str, | |
| exchange_id: str, | |
| feedback: VerificationFeedback | |
| ) -> bool: | |
| """ | |
| Submit verification feedback for an exchange. | |
| Args: | |
| session_id: Verification session ID | |
| exchange_id: Exchange being verified | |
| feedback: Verification feedback | |
| Returns: | |
| True if feedback was applied successfully | |
| """ | |
| session = self.store.load_session(session_id) | |
| if not session: | |
| return False | |
| # Validate feedback | |
| if not feedback.is_correct and not feedback.correct_classification: | |
| raise ValueError("correct_classification required when is_correct=False") | |
| if (not feedback.is_correct and | |
| feedback.correct_classification and | |
| feedback.correct_classification not in ['GREEN', 'YELLOW', 'RED']): | |
| raise ValueError("correct_classification must be GREEN, YELLOW, or RED") | |
| # Apply feedback | |
| success = session.apply_feedback(exchange_id, feedback) | |
| if success: | |
| self.store.save_session(session) | |
| return success | |
| def get_session_statistics(self, session_id: str) -> Optional[Dict[str, Any]]: | |
| """Get detailed statistics for verification session.""" | |
| session = self.store.load_session(session_id) | |
| if not session: | |
| return None | |
| progress = session.get_progress() | |
| # Enhanced statistics including format information | |
| enhanced_stats = { | |
| "session_id": session.session_id, | |
| "patient_name": session.patient_name, | |
| "verifier_name": session.verifier_name, | |
| "start_time": session.start_time.isoformat(), | |
| "end_time": session.end_time.isoformat() if session.end_time else None, | |
| "is_complete": session.is_complete, | |
| "progress": asdict(progress), | |
| "total_exchanges": session.total_exchanges, | |
| "verified_exchanges": session.verified_exchanges, | |
| "enhanced_format_enabled": getattr(session, 'enhanced_format_enabled', False), | |
| "records_with_enhanced_display": sum(1 for r in session.verification_records if r.enhanced_display_format), | |
| "records_with_provider_summary": sum(1 for r in session.verification_records if r.provider_summary), | |
| "records_with_coherent_paragraph": sum(1 for r in session.verification_records if r.coherent_summary_paragraph) | |
| } | |
| return enhanced_stats | |
| def load_session(self, session_id: str) -> Optional[EnhancedVerificationSession]: | |
| """Load enhanced verification session by ID.""" | |
| return self.store.load_session(session_id) | |
| def save_session(self, session: EnhancedVerificationSession) -> None: | |
| """Save enhanced verification session.""" | |
| self.store.save_session(session) | |
| def list_sessions(self) -> List[Dict[str, Any]]: | |
| """List all verification sessions.""" | |
| return self.store.list_sessions() | |
| def get_incomplete_sessions(self) -> List[Dict[str, Any]]: | |
| """Get incomplete verification sessions.""" | |
| return self.store.get_incomplete_sessions() | |
| def export_session_with_enhanced_data(self, session_id: str) -> Optional[Dict[str, Any]]: | |
| """Export session with enhanced format data included.""" | |
| session = self.load_session(session_id) | |
| if not session: | |
| return None | |
| export_data = { | |
| 'session_info': { | |
| 'session_id': session.session_id, | |
| 'conversation_session_id': session.conversation_session_id, | |
| 'patient_name': session.patient_name, | |
| 'verifier_name': session.verifier_name, | |
| 'start_time': session.start_time.isoformat(), | |
| 'end_time': session.end_time.isoformat() if session.end_time else None, | |
| 'enhanced_format_enabled': getattr(session, 'enhanced_format_enabled', False) | |
| }, | |
| 'records': [] | |
| } | |
| for record in session.verification_records: | |
| record_data = { | |
| 'exchange_id': record.exchange_id, | |
| 'exchange_number': record.exchange_number, | |
| 'timestamp': record.timestamp.isoformat(), | |
| 'user_message': record.user_message, | |
| 'assistant_response': record.assistant_response, | |
| 'original_classification': record.original_classification, | |
| 'original_confidence': record.original_confidence, | |
| 'original_indicators': record.original_indicators, | |
| 'original_reasoning': record.original_reasoning, | |
| 'is_correct': record.is_correct, | |
| 'correct_classification': record.correct_classification, | |
| 'correction_reason': record.correction_reason, | |
| 'verifier_notes': record.verifier_notes, | |
| 'verification_timestamp': record.verification_timestamp.isoformat() if record.verification_timestamp else None, | |
| # Enhanced format data | |
| 'enhanced_display_format': record.enhanced_display_format, | |
| 'provider_summary': record.provider_summary, | |
| 'coherent_summary_paragraph': record.coherent_summary_paragraph, | |
| 'visual_sections': record.visual_sections | |
| } | |
| export_data['records'].append(record_data) | |
| return export_data | |
| # Legacy aliases for backward compatibility | |
| ConversationVerificationManager = EnhancedConversationVerificationManager | |
| VerificationSession = EnhancedVerificationSession | |
| VerificationRecord = EnhancedVerificationRecord | |