Spiritual_Health_Project / src /interface /enhanced_dataset_interface.py
DocUA's picture
Fix VerificationRecord validation errors for classification values
3e1ef65
# enhanced_dataset_interface.py
"""
Enhanced Dataset Interface Controller.
Provides the complete interface logic for enhanced dataset mode including
dataset selection, editing, creation, and verification workflows.
Requirements: 2.1, 2.2, 2.7
"""
import gradio as gr
from typing import List, Dict, Tuple, Optional, Any, Union
from datetime import datetime
import uuid
from src.core.verification_models import (
EnhancedVerificationSession,
VerificationRecord,
TestMessage,
TestDataset,
)
from src.core.enhanced_dataset_manager import EnhancedDatasetManager
from src.core.verification_store import JSONVerificationStore
from src.core.test_datasets import TestDatasetManager
from src.interface.verification_ui import VerificationUIComponents
from src.core.spiritual_monitor import SpiritualMonitor
from src.core.ai_client import AIClientManager
from src.core.enhanced_progress_tracker import EnhancedProgressTracker, VerificationMode
from src.interface.enhanced_progress_components import ProgressTrackingMixin
class EnhancedDatasetInterfaceController(ProgressTrackingMixin):
"""Controller for enhanced dataset mode interface."""
def __init__(self, store: JSONVerificationStore = None):
"""Initialize the enhanced dataset interface controller."""
super().__init__(VerificationMode.ENHANCED_DATASET)
self.store = store or JSONVerificationStore()
self.dataset_manager = EnhancedDatasetManager()
self.ai_client_manager = AIClientManager()
self.spiritual_monitor = SpiritualMonitor(self.ai_client_manager)
self.current_session = None
self.current_dataset = None
self.current_message_index = 0
self.verification_start_time = None
def initialize_interface(self) -> Tuple[List[str], str, str]:
"""
Initialize the enhanced dataset interface.
Returns:
Tuple of (dataset_choices, dataset_info, status_message)
"""
try:
# Get all available datasets
datasets = self.dataset_manager.list_datasets()
# Create dropdown choices
dataset_choices = [
f"{dataset.name} ({dataset.message_count} messages)"
for dataset in datasets
]
# Get templates for creation
templates = self.dataset_manager.get_available_templates()
return (
dataset_choices,
"Select a dataset to view details and start verification or editing.",
"✨ Enhanced Dataset Mode initialized. Select a dataset to get started.",
templates
)
except Exception as e:
return (
[],
f"❌ Error loading datasets: {str(e)}",
f"❌ Failed to initialize interface: {str(e)}",
[]
)
def get_dataset_info(self, dataset_selection: str) -> Tuple[str, Optional[TestDataset]]:
"""
Get dataset information for display.
Args:
dataset_selection: Selected dataset string from dropdown
Returns:
Tuple of (dataset_info_markdown, dataset_object)
"""
try:
if not dataset_selection:
return "Select a dataset to view details", None
# Parse dataset name from selection
dataset_name = dataset_selection.split(" (")[0]
# Find matching dataset
datasets = self.dataset_manager.list_datasets()
selected_dataset = None
for dataset in datasets:
if dataset.name == dataset_name:
selected_dataset = dataset
break
if not selected_dataset:
return "❌ Dataset not found", None
# Create info display
info_markdown = f"""### {selected_dataset.name}
**Description:** {selected_dataset.description}
**Message Count:** {selected_dataset.message_count} messages
**Dataset ID:** `{selected_dataset.dataset_id}`
**Classification Breakdown:**
"""
# Add classification breakdown
green_count = sum(1 for msg in selected_dataset.messages if msg.pre_classified_label.lower() == "green")
yellow_count = sum(1 for msg in selected_dataset.messages if msg.pre_classified_label.lower() == "yellow")
red_count = sum(1 for msg in selected_dataset.messages if msg.pre_classified_label.lower() == "red")
info_markdown += f"""
- 🟒 GREEN: {green_count} messages
- 🟑 YELLOW: {yellow_count} messages
- πŸ”΄ RED: {red_count} messages
"""
return info_markdown, selected_dataset
except Exception as e:
return f"❌ Error loading dataset info: {str(e)}", None
def render_test_cases_display(self, dataset: TestDataset) -> str:
"""
Render test cases for editing display.
Args:
dataset: Dataset to display test cases for
Returns:
HTML string for test cases display
"""
if not dataset or not dataset.messages:
return "<p>No test cases in this dataset.</p>"
html = """
<div style="font-family: system-ui; max-height: 400px; overflow-y: auto;">
"""
for i, message in enumerate(dataset.messages):
# Get classification badge
badge_colors = {"green": "🟒", "yellow": "🟑", "red": "πŸ”΄"}
badge = badge_colors.get(message.pre_classified_label.lower(), "❓")
# Truncate message text for display
display_text = message.text[:100] + "..." if len(message.text) > 100 else message.text
html += f"""
<div style="margin-bottom: 1em; padding: 1em; background-color: #f9fafb; border-radius: 6px; border: 1px solid #e5e7eb;">
<div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 0.5em;">
<h4 style="margin: 0; color: #1f2937;">
{badge} Test Case {i+1}
</h4>
<div>
<button onclick="editTestCase('{message.message_id}')"
style="background: #3b82f6; color: white; border: none; padding: 0.25em 0.5em; border-radius: 4px; cursor: pointer; margin-right: 0.5em;">
✏️ Edit
</button>
<button onclick="deleteTestCase('{message.message_id}')"
style="background: #dc2626; color: white; border: none; padding: 0.25em 0.5em; border-radius: 4px; cursor: pointer;">
πŸ—‘οΈ Delete
</button>
</div>
</div>
<div style="margin-bottom: 0.5em;">
<strong>Message:</strong> {display_text}
</div>
<div style="font-size: 0.875em; color: #6b7280;">
<strong>Expected Classification:</strong> {message.pre_classified_label.upper()}
</div>
<div style="font-size: 0.75em; color: #9ca3af; margin-top: 0.5em;">
ID: {message.message_id}
</div>
</div>
"""
html += """
</div>
<script>
function editTestCase(messageId) {
// This would trigger the edit modal
console.log('Edit test case:', messageId);
}
function deleteTestCase(messageId) {
if (confirm('Are you sure you want to delete this test case?')) {
console.log('Delete test case:', messageId);
}
}
</script>
"""
return html
def create_new_dataset(
self,
name: str,
description: str,
template_type: Optional[str] = None
) -> Tuple[bool, str, Optional[TestDataset]]:
"""
Create a new dataset.
Args:
name: Dataset name
description: Dataset description
template_type: Optional template type
Returns:
Tuple of (success, message, dataset)
"""
try:
if not name or not name.strip():
return False, "❌ Dataset name is required", None
if not description or not description.strip():
return False, "❌ Dataset description is required", None
# Create dataset
if template_type and template_type != "":
dataset = self.dataset_manager.create_template_dataset(template_type)
dataset.name = name.strip()
dataset.description = description.strip()
self.dataset_manager.update_dataset(dataset.dataset_id, dataset)
else:
dataset = self.dataset_manager.create_dataset(name.strip(), description.strip())
return True, f"βœ… Dataset '{name}' created successfully", dataset
except Exception as e:
return False, f"❌ Error creating dataset: {str(e)}", None
def add_test_case(
self,
dataset: TestDataset,
message_text: str,
classification: str
) -> Tuple[bool, str, TestDataset]:
"""
Add a new test case to the dataset.
Args:
dataset: Dataset to add test case to
message_text: Message text
classification: Expected classification
Returns:
Tuple of (success, message, updated_dataset)
"""
try:
if not message_text or not message_text.strip():
return False, "❌ Message text is required", dataset
if not classification:
return False, "❌ Classification is required", dataset
# Create new test message
test_message = TestMessage(
message_id=f"{dataset.dataset_id}_{uuid.uuid4().hex[:8]}",
text=message_text.strip(),
pre_classified_label=classification.lower()
)
# Add to dataset
self.dataset_manager.add_test_case(dataset.dataset_id, test_message)
# Get updated dataset
updated_dataset = self.dataset_manager.get_dataset(dataset.dataset_id)
return True, f"βœ… Test case added successfully", updated_dataset
except Exception as e:
return False, f"❌ Error adding test case: {str(e)}", dataset
def save_dataset(self, dataset: TestDataset) -> Tuple[bool, str]:
"""
Save dataset changes.
Args:
dataset: Dataset to save
Returns:
Tuple of (success, message)
"""
try:
# Validate dataset
validation_errors = self.dataset_manager.validate_dataset(dataset)
if validation_errors:
error_list = "\n".join([f"β€’ {error}" for error in validation_errors])
return False, f"❌ Validation errors:\n{error_list}"
# Save dataset
self.dataset_manager.update_dataset(dataset.dataset_id, dataset)
return True, f"βœ… Dataset '{dataset.name}' saved successfully"
except Exception as e:
return False, f"❌ Error saving dataset: {str(e)}"
def start_verification_session(
self,
dataset: TestDataset,
verifier_name: str
) -> Tuple[bool, str, Optional[EnhancedVerificationSession]]:
"""
Start a new verification session.
Args:
dataset: Dataset to verify
verifier_name: Name of the verifier
Returns:
Tuple of (success, message, session)
"""
try:
if not verifier_name or not verifier_name.strip():
return False, "❌ Verifier name is required", None
if not dataset or not dataset.messages:
return False, "❌ Dataset is empty or invalid", None
# Create enhanced verification session
session = EnhancedVerificationSession(
session_id=f"enhanced_{uuid.uuid4().hex}",
verifier_name=verifier_name.strip(),
dataset_id=dataset.dataset_id,
dataset_name=dataset.name,
mode_type="enhanced_dataset",
total_messages=len(dataset.messages),
message_queue=[msg.message_id for msg in dataset.messages],
mode_metadata={
"dataset_version": datetime.now().isoformat(),
"original_message_count": len(dataset.messages)
}
)
# Save session
self.store.save_session(session)
self.current_session = session
self.current_dataset = dataset
self.current_message_index = 0
# Setup progress tracking
self.setup_progress_tracking(len(dataset.messages))
return True, f"βœ… Verification session started for '{dataset.name}'", session
except Exception as e:
return False, f"❌ Error starting verification: {str(e)}", None
def get_current_message_for_verification(self) -> Tuple[Optional[TestMessage], Dict[str, Any]]:
"""
Get the current message for verification.
Returns:
Tuple of (test_message, classification_results)
"""
try:
if not self.current_session or not self.current_dataset:
return None, {}
if self.current_message_index >= len(self.current_dataset.messages):
return None, {}
# Get current message
current_message = self.current_dataset.messages[self.current_message_index]
# Record verification start time for progress tracking
self.verification_start_time = datetime.now()
# Get spiritual distress classification
assessment = self.spiritual_monitor.classify(current_message.text)
# Convert to expected format
classification_result = {
"decision": assessment.state.value,
"confidence": assessment.confidence,
"indicators": assessment.indicators
}
return current_message, classification_result
except Exception as e:
return None, {"error": str(e)}
def submit_verification_feedback(
self,
is_correct: bool,
correction: Optional[str] = None,
notes: str = ""
) -> Tuple[bool, str, Dict[str, Any]]:
"""
Submit verification feedback for current message.
Args:
is_correct: Whether the classification is correct
correction: Correct classification if incorrect
notes: Optional notes
Returns:
Tuple of (success, message, session_stats)
"""
try:
if not self.current_session or not self.current_dataset:
return False, "❌ No active verification session", {}
current_message = self.current_dataset.messages[self.current_message_index]
# Get classification result
_, classification_result = self.get_current_message_for_verification()
# Create verification record
# Ensure valid classification values (green, yellow, red only)
classifier_decision = classification_result.get("decision", "green")
if classifier_decision not in ["green", "yellow", "red"]:
classifier_decision = "green" # Safe fallback
ground_truth = correction.lower() if correction else current_message.pre_classified_label
if ground_truth not in ["green", "yellow", "red"]:
ground_truth = "green" # Safe fallback
record = VerificationRecord(
message_id=current_message.message_id,
original_message=current_message.text,
classifier_decision=classifier_decision,
classifier_confidence=classification_result.get("confidence", 0.0),
classifier_indicators=classification_result.get("indicators", []),
ground_truth_label=ground_truth,
verifier_notes=notes,
is_correct=is_correct
)
# Add to session
self.current_session.verifications.append(record)
self.current_session.verified_count += 1
self.current_session.verified_message_ids.append(current_message.message_id)
if is_correct:
self.current_session.correct_count += 1
else:
self.current_session.incorrect_count += 1
# Record verification with timing for progress tracking
self.record_verification_with_timing(is_correct, self.verification_start_time)
# Move to next message
self.current_message_index += 1
self.current_session.current_queue_index = self.current_message_index
# Check if session is complete
if self.current_message_index >= len(self.current_dataset.messages):
self.current_session.is_complete = True
self.current_session.completed_at = datetime.now()
# Save session
self.store.save_session(self.current_session)
# Calculate session stats
session_stats = {
"processed": self.current_session.verified_count,
"total": self.current_session.total_messages,
"correct": self.current_session.correct_count,
"incorrect": self.current_session.incorrect_count,
"accuracy": (self.current_session.correct_count / self.current_session.verified_count * 100) if self.current_session.verified_count > 0 else 0,
"is_complete": self.current_session.is_complete
}
success_msg = "βœ… Feedback recorded"
if self.current_session.is_complete:
success_msg += f" - Session complete! Final accuracy: {session_stats['accuracy']:.1f}%"
return True, success_msg, session_stats
except Exception as e:
return False, f"❌ Error submitting feedback: {str(e)}", {}
def export_session_results(self, format_type: str) -> Tuple[bool, str, Optional[str]]:
"""
Export session results in specified format.
Args:
format_type: Export format ("csv", "json", "xlsx")
Returns:
Tuple of (success, message, file_path)
"""
try:
if not self.current_session:
return False, "❌ No active session to export", None
if format_type == "csv":
file_content = self.store.export_to_csv(self.current_session.session_id)
file_path = f"session_{self.current_session.session_id}.csv"
elif format_type == "json":
file_content = self.store.export_to_json(self.current_session.session_id)
file_path = f"session_{self.current_session.session_id}.json"
elif format_type == "xlsx":
file_content = self.store.export_to_xlsx(self.current_session.session_id)
file_path = f"session_{self.current_session.session_id}.xlsx"
else:
return False, f"❌ Unsupported export format: {format_type}", None
return True, f"βœ… Results exported to {format_type.upper()}", file_path
except Exception as e:
return False, f"❌ Error exporting results: {str(e)}", None
def get_enhanced_progress_info(self) -> Dict[str, Any]:
"""
Get enhanced progress information for display.
Returns:
Dictionary containing progress information
"""
if not hasattr(self, 'progress_tracker') or not self.progress_tracker:
return {
"progress_display": "πŸ“Š Progress: Ready to start",
"accuracy_display": "🎯 Current Accuracy: No verifications yet",
"time_display": "⏱️ Time: Not started",
"error_display": "",
"stats_summary": "No active session"
}
return {
"progress_display": self.progress_tracker.get_progress_display(),
"accuracy_display": self.progress_tracker.get_accuracy_display(),
"time_display": self.progress_tracker.get_time_tracking_display(),
"error_display": self.progress_tracker.get_error_display(),
"stats_summary": self._get_session_stats_summary()
}
def record_verification_error(self, error_message: str, can_continue: bool = True) -> None:
"""
Record a verification error.
Args:
error_message: Description of the error
can_continue: Whether processing can continue
"""
if hasattr(self, 'progress_tracker') and self.progress_tracker:
self.progress_tracker.record_error(error_message, can_continue)
def pause_verification_session(self) -> Tuple[bool, bool, bool]:
"""
Pause the current verification session.
Returns:
Tuple of control button visibility states
"""
if hasattr(self, 'progress_tracker') and self.progress_tracker:
return self.handle_session_pause()
return False, False, True
def resume_verification_session(self) -> Tuple[bool, bool, bool]:
"""
Resume the current verification session.
Returns:
Tuple of control button visibility states
"""
if hasattr(self, 'progress_tracker') and self.progress_tracker:
return self.handle_session_resume()
return True, False, True
def _get_session_stats_summary(self) -> str:
"""Get formatted session statistics summary."""
if not self.current_session:
return "No active session"
accuracy = (self.current_session.correct_count / self.current_session.verified_count * 100) if self.current_session.verified_count > 0 else 0
return f"""
**Session Progress:**
- Dataset: {self.current_session.dataset_name}
- Processed: {self.current_session.verified_count}/{self.current_session.total_messages}
- Accuracy: {accuracy:.1f}%
- Correct: {self.current_session.correct_count}
- Incorrect: {self.current_session.incorrect_count}
"""