# enhanced_dataset_interface.py """ Enhanced Dataset Interface Controller. Provides the complete interface logic for enhanced dataset mode including dataset selection, editing, creation, and verification workflows. Requirements: 2.1, 2.2, 2.7 """ import gradio as gr from typing import List, Dict, Tuple, Optional, Any, Union from datetime import datetime import uuid from src.core.verification_models import ( EnhancedVerificationSession, VerificationRecord, TestMessage, TestDataset, ) from src.core.enhanced_dataset_manager import EnhancedDatasetManager from src.core.verification_store import JSONVerificationStore from src.core.test_datasets import TestDatasetManager from src.interface.verification_ui import VerificationUIComponents from src.core.spiritual_monitor import SpiritualMonitor from src.core.ai_client import AIClientManager from src.core.enhanced_progress_tracker import EnhancedProgressTracker, VerificationMode from src.interface.enhanced_progress_components import ProgressTrackingMixin class EnhancedDatasetInterfaceController(ProgressTrackingMixin): """Controller for enhanced dataset mode interface.""" def __init__(self, store: JSONVerificationStore = None): """Initialize the enhanced dataset interface controller.""" super().__init__(VerificationMode.ENHANCED_DATASET) self.store = store or JSONVerificationStore() self.dataset_manager = EnhancedDatasetManager() self.ai_client_manager = AIClientManager() self.spiritual_monitor = SpiritualMonitor(self.ai_client_manager) self.current_session = None self.current_dataset = None self.current_message_index = 0 self.verification_start_time = None def initialize_interface(self) -> Tuple[List[str], str, str]: """ Initialize the enhanced dataset interface. Returns: Tuple of (dataset_choices, dataset_info, status_message) """ try: # Get all available datasets datasets = self.dataset_manager.list_datasets() # Create dropdown choices dataset_choices = [ f"{dataset.name} ({dataset.message_count} messages)" for dataset in datasets ] # Get templates for creation templates = self.dataset_manager.get_available_templates() return ( dataset_choices, "Select a dataset to view details and start verification or editing.", "✨ Enhanced Dataset Mode initialized. Select a dataset to get started.", templates ) except Exception as e: return ( [], f"❌ Error loading datasets: {str(e)}", f"❌ Failed to initialize interface: {str(e)}", [] ) def get_dataset_info(self, dataset_selection: str) -> Tuple[str, Optional[TestDataset]]: """ Get dataset information for display. Args: dataset_selection: Selected dataset string from dropdown Returns: Tuple of (dataset_info_markdown, dataset_object) """ try: if not dataset_selection: return "Select a dataset to view details", None # Parse dataset name from selection dataset_name = dataset_selection.split(" (")[0] # Find matching dataset datasets = self.dataset_manager.list_datasets() selected_dataset = None for dataset in datasets: if dataset.name == dataset_name: selected_dataset = dataset break if not selected_dataset: return "❌ Dataset not found", None # Create info display info_markdown = f"""### {selected_dataset.name} **Description:** {selected_dataset.description} **Message Count:** {selected_dataset.message_count} messages **Dataset ID:** `{selected_dataset.dataset_id}` **Classification Breakdown:** """ # Add classification breakdown green_count = sum(1 for msg in selected_dataset.messages if msg.pre_classified_label.lower() == "green") yellow_count = sum(1 for msg in selected_dataset.messages if msg.pre_classified_label.lower() == "yellow") red_count = sum(1 for msg in selected_dataset.messages if msg.pre_classified_label.lower() == "red") info_markdown += f""" - 🟢 GREEN: {green_count} messages - 🟡 YELLOW: {yellow_count} messages - 🔴 RED: {red_count} messages """ return info_markdown, selected_dataset except Exception as e: return f"❌ Error loading dataset info: {str(e)}", None def render_test_cases_display(self, dataset: TestDataset) -> str: """ Render test cases for editing display. Args: dataset: Dataset to display test cases for Returns: HTML string for test cases display """ if not dataset or not dataset.messages: return "

No test cases in this dataset.

" html = """
""" for i, message in enumerate(dataset.messages): # Get classification badge badge_colors = {"green": "🟢", "yellow": "🟡", "red": "🔴"} badge = badge_colors.get(message.pre_classified_label.lower(), "❓") # Truncate message text for display display_text = message.text[:100] + "..." if len(message.text) > 100 else message.text html += f"""

{badge} Test Case {i+1}

Message: {display_text}
Expected Classification: {message.pre_classified_label.upper()}
ID: {message.message_id}
""" html += """
""" return html def create_new_dataset( self, name: str, description: str, template_type: Optional[str] = None ) -> Tuple[bool, str, Optional[TestDataset]]: """ Create a new dataset. Args: name: Dataset name description: Dataset description template_type: Optional template type Returns: Tuple of (success, message, dataset) """ try: if not name or not name.strip(): return False, "❌ Dataset name is required", None if not description or not description.strip(): return False, "❌ Dataset description is required", None # Create dataset if template_type and template_type != "": dataset = self.dataset_manager.create_template_dataset(template_type) dataset.name = name.strip() dataset.description = description.strip() self.dataset_manager.update_dataset(dataset.dataset_id, dataset) else: dataset = self.dataset_manager.create_dataset(name.strip(), description.strip()) return True, f"✅ Dataset '{name}' created successfully", dataset except Exception as e: return False, f"❌ Error creating dataset: {str(e)}", None def add_test_case( self, dataset: TestDataset, message_text: str, classification: str ) -> Tuple[bool, str, TestDataset]: """ Add a new test case to the dataset. Args: dataset: Dataset to add test case to message_text: Message text classification: Expected classification Returns: Tuple of (success, message, updated_dataset) """ try: if not message_text or not message_text.strip(): return False, "❌ Message text is required", dataset if not classification: return False, "❌ Classification is required", dataset # Create new test message test_message = TestMessage( message_id=f"{dataset.dataset_id}_{uuid.uuid4().hex[:8]}", text=message_text.strip(), pre_classified_label=classification.lower() ) # Add to dataset self.dataset_manager.add_test_case(dataset.dataset_id, test_message) # Get updated dataset updated_dataset = self.dataset_manager.get_dataset(dataset.dataset_id) return True, f"✅ Test case added successfully", updated_dataset except Exception as e: return False, f"❌ Error adding test case: {str(e)}", dataset def save_dataset(self, dataset: TestDataset) -> Tuple[bool, str]: """ Save dataset changes. Args: dataset: Dataset to save Returns: Tuple of (success, message) """ try: # Validate dataset validation_errors = self.dataset_manager.validate_dataset(dataset) if validation_errors: error_list = "\n".join([f"• {error}" for error in validation_errors]) return False, f"❌ Validation errors:\n{error_list}" # Save dataset self.dataset_manager.update_dataset(dataset.dataset_id, dataset) return True, f"✅ Dataset '{dataset.name}' saved successfully" except Exception as e: return False, f"❌ Error saving dataset: {str(e)}" def start_verification_session( self, dataset: TestDataset, verifier_name: str ) -> Tuple[bool, str, Optional[EnhancedVerificationSession]]: """ Start a new verification session. Args: dataset: Dataset to verify verifier_name: Name of the verifier Returns: Tuple of (success, message, session) """ try: if not verifier_name or not verifier_name.strip(): return False, "❌ Verifier name is required", None if not dataset or not dataset.messages: return False, "❌ Dataset is empty or invalid", None # Create enhanced verification session session = EnhancedVerificationSession( session_id=f"enhanced_{uuid.uuid4().hex}", verifier_name=verifier_name.strip(), dataset_id=dataset.dataset_id, dataset_name=dataset.name, mode_type="enhanced_dataset", total_messages=len(dataset.messages), message_queue=[msg.message_id for msg in dataset.messages], mode_metadata={ "dataset_version": datetime.now().isoformat(), "original_message_count": len(dataset.messages) } ) # Save session self.store.save_session(session) self.current_session = session self.current_dataset = dataset self.current_message_index = 0 # Setup progress tracking self.setup_progress_tracking(len(dataset.messages)) return True, f"✅ Verification session started for '{dataset.name}'", session except Exception as e: return False, f"❌ Error starting verification: {str(e)}", None def get_current_message_for_verification(self) -> Tuple[Optional[TestMessage], Dict[str, Any]]: """ Get the current message for verification. Returns: Tuple of (test_message, classification_results) """ try: if not self.current_session or not self.current_dataset: return None, {} if self.current_message_index >= len(self.current_dataset.messages): return None, {} # Get current message current_message = self.current_dataset.messages[self.current_message_index] # Record verification start time for progress tracking self.verification_start_time = datetime.now() # Get spiritual distress classification assessment = self.spiritual_monitor.classify(current_message.text) # Convert to expected format classification_result = { "decision": assessment.state.value, "confidence": assessment.confidence, "indicators": assessment.indicators } return current_message, classification_result except Exception as e: return None, {"error": str(e)} def submit_verification_feedback( self, is_correct: bool, correction: Optional[str] = None, notes: str = "" ) -> Tuple[bool, str, Dict[str, Any]]: """ Submit verification feedback for current message. Args: is_correct: Whether the classification is correct correction: Correct classification if incorrect notes: Optional notes Returns: Tuple of (success, message, session_stats) """ try: if not self.current_session or not self.current_dataset: return False, "❌ No active verification session", {} current_message = self.current_dataset.messages[self.current_message_index] # Get classification result _, classification_result = self.get_current_message_for_verification() # Create verification record # Ensure valid classification values (green, yellow, red only) classifier_decision = classification_result.get("decision", "green") if classifier_decision not in ["green", "yellow", "red"]: classifier_decision = "green" # Safe fallback ground_truth = correction.lower() if correction else current_message.pre_classified_label if ground_truth not in ["green", "yellow", "red"]: ground_truth = "green" # Safe fallback record = VerificationRecord( message_id=current_message.message_id, original_message=current_message.text, classifier_decision=classifier_decision, classifier_confidence=classification_result.get("confidence", 0.0), classifier_indicators=classification_result.get("indicators", []), ground_truth_label=ground_truth, verifier_notes=notes, is_correct=is_correct ) # Add to session self.current_session.verifications.append(record) self.current_session.verified_count += 1 self.current_session.verified_message_ids.append(current_message.message_id) if is_correct: self.current_session.correct_count += 1 else: self.current_session.incorrect_count += 1 # Record verification with timing for progress tracking self.record_verification_with_timing(is_correct, self.verification_start_time) # Move to next message self.current_message_index += 1 self.current_session.current_queue_index = self.current_message_index # Check if session is complete if self.current_message_index >= len(self.current_dataset.messages): self.current_session.is_complete = True self.current_session.completed_at = datetime.now() # Save session self.store.save_session(self.current_session) # Calculate session stats session_stats = { "processed": self.current_session.verified_count, "total": self.current_session.total_messages, "correct": self.current_session.correct_count, "incorrect": self.current_session.incorrect_count, "accuracy": (self.current_session.correct_count / self.current_session.verified_count * 100) if self.current_session.verified_count > 0 else 0, "is_complete": self.current_session.is_complete } success_msg = "✅ Feedback recorded" if self.current_session.is_complete: success_msg += f" - Session complete! Final accuracy: {session_stats['accuracy']:.1f}%" return True, success_msg, session_stats except Exception as e: return False, f"❌ Error submitting feedback: {str(e)}", {} def export_session_results(self, format_type: str) -> Tuple[bool, str, Optional[str]]: """ Export session results in specified format. Args: format_type: Export format ("csv", "json", "xlsx") Returns: Tuple of (success, message, file_path) """ try: if not self.current_session: return False, "❌ No active session to export", None if format_type == "csv": file_content = self.store.export_to_csv(self.current_session.session_id) file_path = f"session_{self.current_session.session_id}.csv" elif format_type == "json": file_content = self.store.export_to_json(self.current_session.session_id) file_path = f"session_{self.current_session.session_id}.json" elif format_type == "xlsx": file_content = self.store.export_to_xlsx(self.current_session.session_id) file_path = f"session_{self.current_session.session_id}.xlsx" else: return False, f"❌ Unsupported export format: {format_type}", None return True, f"✅ Results exported to {format_type.upper()}", file_path except Exception as e: return False, f"❌ Error exporting results: {str(e)}", None def get_enhanced_progress_info(self) -> Dict[str, Any]: """ Get enhanced progress information for display. Returns: Dictionary containing progress information """ if not hasattr(self, 'progress_tracker') or not self.progress_tracker: return { "progress_display": "📊 Progress: Ready to start", "accuracy_display": "🎯 Current Accuracy: No verifications yet", "time_display": "⏱️ Time: Not started", "error_display": "", "stats_summary": "No active session" } return { "progress_display": self.progress_tracker.get_progress_display(), "accuracy_display": self.progress_tracker.get_accuracy_display(), "time_display": self.progress_tracker.get_time_tracking_display(), "error_display": self.progress_tracker.get_error_display(), "stats_summary": self._get_session_stats_summary() } def record_verification_error(self, error_message: str, can_continue: bool = True) -> None: """ Record a verification error. Args: error_message: Description of the error can_continue: Whether processing can continue """ if hasattr(self, 'progress_tracker') and self.progress_tracker: self.progress_tracker.record_error(error_message, can_continue) def pause_verification_session(self) -> Tuple[bool, bool, bool]: """ Pause the current verification session. Returns: Tuple of control button visibility states """ if hasattr(self, 'progress_tracker') and self.progress_tracker: return self.handle_session_pause() return False, False, True def resume_verification_session(self) -> Tuple[bool, bool, bool]: """ Resume the current verification session. Returns: Tuple of control button visibility states """ if hasattr(self, 'progress_tracker') and self.progress_tracker: return self.handle_session_resume() return True, False, True def _get_session_stats_summary(self) -> str: """Get formatted session statistics summary.""" if not self.current_session: return "No active session" accuracy = (self.current_session.correct_count / self.current_session.verified_count * 100) if self.current_session.verified_count > 0 else 0 return f""" **Session Progress:** - Dataset: {self.current_session.dataset_name} - Processed: {self.current_session.verified_count}/{self.current_session.total_messages} - Accuracy: {accuracy:.1f}% - Correct: {self.current_session.correct_count} - Incorrect: {self.current_session.incorrect_count} """