Spaces:
Sleeping
Sleeping
| # verification_models.py | |
| """ | |
| Data models for Verification Mode. | |
| Defines core data structures for verification sessions, records, and test datasets. | |
| Includes enhanced models for multi-mode verification support. | |
| """ | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional, Dict, Any | |
| from datetime import datetime | |
| class VerificationRecord: | |
| """Single verification record for a message.""" | |
| message_id: str | |
| original_message: str | |
| classifier_decision: str # "green", "yellow", "red" | |
| classifier_confidence: float # 0.0-1.0 | |
| classifier_indicators: List[str] | |
| ground_truth_label: str # "green", "yellow", "red" | |
| verifier_notes: str = "" | |
| is_correct: bool = False | |
| timestamp: datetime = field(default_factory=datetime.now) | |
| def to_dict(self) -> dict: | |
| """Convert record to dictionary for serialization.""" | |
| return { | |
| "message_id": self.message_id, | |
| "original_message": self.original_message, | |
| "classifier_decision": self.classifier_decision, | |
| "classifier_confidence": self.classifier_confidence, | |
| "classifier_indicators": self.classifier_indicators, | |
| "ground_truth_label": self.ground_truth_label, | |
| "verifier_notes": self.verifier_notes, | |
| "is_correct": self.is_correct, | |
| "timestamp": self.timestamp.isoformat(), | |
| } | |
| def from_dict(cls, data: dict) -> "VerificationRecord": | |
| """Create record from dictionary.""" | |
| data_copy = data.copy() | |
| if isinstance(data_copy.get("timestamp"), str): | |
| data_copy["timestamp"] = datetime.fromisoformat(data_copy["timestamp"]) | |
| return cls(**data_copy) | |
| class VerificationSession: | |
| """Tracks a complete verification session.""" | |
| session_id: str | |
| verifier_name: str | |
| dataset_id: str | |
| dataset_name: str | |
| created_at: datetime = field(default_factory=datetime.now) | |
| completed_at: Optional[datetime] = None | |
| total_messages: int = 0 | |
| verified_count: int = 0 | |
| correct_count: int = 0 | |
| incorrect_count: int = 0 | |
| verifications: List[VerificationRecord] = field(default_factory=list) | |
| is_complete: bool = False | |
| message_queue: List[str] = field(default_factory=list) # List of message IDs | |
| current_queue_index: int = 0 # Current position in queue | |
| verified_message_ids: List[str] = field(default_factory=list) # Verified message IDs | |
| def to_dict(self) -> dict: | |
| """Convert session to dictionary for serialization.""" | |
| return { | |
| "session_id": self.session_id, | |
| "verifier_name": self.verifier_name, | |
| "dataset_id": self.dataset_id, | |
| "dataset_name": self.dataset_name, | |
| "created_at": self.created_at.isoformat(), | |
| "completed_at": self.completed_at.isoformat() if self.completed_at else None, | |
| "total_messages": self.total_messages, | |
| "verified_count": self.verified_count, | |
| "correct_count": self.correct_count, | |
| "incorrect_count": self.incorrect_count, | |
| "verifications": [v.to_dict() for v in self.verifications], | |
| "is_complete": self.is_complete, | |
| "message_queue": self.message_queue, | |
| "current_queue_index": self.current_queue_index, | |
| "verified_message_ids": self.verified_message_ids, | |
| } | |
| def from_dict(cls, data: dict) -> "VerificationSession": | |
| """Create session from dictionary.""" | |
| data_copy = data.copy() | |
| if isinstance(data_copy.get("created_at"), str): | |
| data_copy["created_at"] = datetime.fromisoformat(data_copy["created_at"]) | |
| if isinstance(data_copy.get("completed_at"), str): | |
| data_copy["completed_at"] = datetime.fromisoformat(data_copy["completed_at"]) | |
| verifications = data_copy.pop("verifications", []) | |
| # Ensure queue fields exist for backward compatibility | |
| if "message_queue" not in data_copy: | |
| data_copy["message_queue"] = [] | |
| if "current_queue_index" not in data_copy: | |
| data_copy["current_queue_index"] = 0 | |
| if "verified_message_ids" not in data_copy: | |
| data_copy["verified_message_ids"] = [] | |
| session = cls(**data_copy) | |
| session.verifications = [VerificationRecord.from_dict(v) for v in verifications] | |
| return session | |
| class TestMessage: | |
| """A single test message with pre-classified label.""" | |
| message_id: str | |
| text: str | |
| pre_classified_label: str # "green", "yellow", "red" | |
| class TestDataset: | |
| """A test dataset for verification.""" | |
| dataset_id: str | |
| name: str | |
| description: str | |
| messages: List[TestMessage] = field(default_factory=list) | |
| def message_count(self) -> int: | |
| """Get total number of messages in dataset.""" | |
| return len(self.messages) | |
| def to_dict(self) -> dict: | |
| """Convert dataset to dictionary for serialization.""" | |
| return { | |
| "dataset_id": self.dataset_id, | |
| "name": self.name, | |
| "description": self.description, | |
| "messages": [ | |
| { | |
| "message_id": m.message_id, | |
| "text": m.text, | |
| "pre_classified_label": m.pre_classified_label, | |
| } | |
| for m in self.messages | |
| ], | |
| } | |
| def from_dict(cls, data: dict) -> "TestDataset": | |
| """Create dataset from dictionary.""" | |
| data_copy = data.copy() | |
| messages_data = data_copy.pop("messages", []) | |
| dataset = cls(**data_copy) | |
| dataset.messages = [TestMessage(**m) for m in messages_data] | |
| return dataset | |
| class TestCaseEdit: | |
| """Represents an edit operation on a test case.""" | |
| edit_id: str | |
| test_case_id: str | |
| operation: str # "add", "modify", "delete" | |
| old_values: Optional[Dict[str, Any]] | |
| new_values: Optional[Dict[str, Any]] | |
| timestamp: datetime | |
| editor_name: str | |
| def to_dict(self) -> dict: | |
| """Convert edit to dictionary for serialization.""" | |
| return { | |
| "edit_id": self.edit_id, | |
| "test_case_id": self.test_case_id, | |
| "operation": self.operation, | |
| "old_values": self.old_values, | |
| "new_values": self.new_values, | |
| "timestamp": self.timestamp.isoformat(), | |
| "editor_name": self.editor_name, | |
| } | |
| def from_dict(cls, data: dict) -> "TestCaseEdit": | |
| """Create edit from dictionary.""" | |
| data_copy = data.copy() | |
| if isinstance(data_copy.get("timestamp"), str): | |
| data_copy["timestamp"] = datetime.fromisoformat(data_copy["timestamp"]) | |
| return cls(**data_copy) | |
| class FileUploadResult: | |
| """Result of file upload processing.""" | |
| file_id: str | |
| original_filename: str | |
| file_format: str # "csv", "xlsx" | |
| total_rows: int | |
| valid_rows: int | |
| validation_errors: List[str] | |
| parsed_test_cases: List[TestMessage] | |
| upload_timestamp: datetime | |
| def to_dict(self) -> dict: | |
| """Convert file upload result to dictionary for serialization.""" | |
| return { | |
| "file_id": self.file_id, | |
| "original_filename": self.original_filename, | |
| "file_format": self.file_format, | |
| "total_rows": self.total_rows, | |
| "valid_rows": self.valid_rows, | |
| "validation_errors": self.validation_errors, | |
| "parsed_test_cases": [ | |
| { | |
| "message_id": tc.message_id, | |
| "text": tc.text, | |
| "pre_classified_label": tc.pre_classified_label, | |
| } | |
| for tc in self.parsed_test_cases | |
| ], | |
| "upload_timestamp": self.upload_timestamp.isoformat(), | |
| } | |
| def from_dict(cls, data: dict) -> "FileUploadResult": | |
| """Create file upload result from dictionary.""" | |
| data_copy = data.copy() | |
| if isinstance(data_copy.get("upload_timestamp"), str): | |
| data_copy["upload_timestamp"] = datetime.fromisoformat(data_copy["upload_timestamp"]) | |
| test_cases_data = data_copy.pop("parsed_test_cases", []) | |
| parsed_test_cases = [TestMessage(**tc) for tc in test_cases_data] | |
| data_copy["parsed_test_cases"] = parsed_test_cases | |
| return cls(**data_copy) | |
| class EnhancedVerificationSession(VerificationSession): | |
| """Extended verification session with mode support.""" | |
| mode_type: str = "enhanced_dataset" # "enhanced_dataset", "manual_input", "file_upload" | |
| mode_metadata: Dict[str, Any] = field(default_factory=dict) # Mode-specific metadata | |
| file_source: Optional[str] = None # Original filename for file upload mode | |
| dataset_version: Optional[str] = None # Dataset version for enhanced dataset mode | |
| manual_input_count: int = 0 # Number of manual inputs in session | |
| def to_dict(self) -> dict: | |
| """Convert enhanced session to dictionary for serialization.""" | |
| base_dict = super().to_dict() | |
| base_dict.update({ | |
| "mode_type": self.mode_type, | |
| "mode_metadata": self.mode_metadata, | |
| "file_source": self.file_source, | |
| "dataset_version": self.dataset_version, | |
| "manual_input_count": self.manual_input_count, | |
| }) | |
| return base_dict | |
| def from_dict(cls, data: dict) -> "EnhancedVerificationSession": | |
| """Create enhanced session from dictionary.""" | |
| data_copy = data.copy() | |
| # Handle datetime fields | |
| if isinstance(data_copy.get("created_at"), str): | |
| data_copy["created_at"] = datetime.fromisoformat(data_copy["created_at"]) | |
| if isinstance(data_copy.get("completed_at"), str): | |
| data_copy["completed_at"] = datetime.fromisoformat(data_copy["completed_at"]) | |
| # Extract verifications for separate processing | |
| verifications = data_copy.pop("verifications", []) | |
| # Ensure backward compatibility for queue fields | |
| if "message_queue" not in data_copy: | |
| data_copy["message_queue"] = [] | |
| if "current_queue_index" not in data_copy: | |
| data_copy["current_queue_index"] = 0 | |
| if "verified_message_ids" not in data_copy: | |
| data_copy["verified_message_ids"] = [] | |
| # Ensure enhanced fields have defaults | |
| if "mode_type" not in data_copy: | |
| data_copy["mode_type"] = "enhanced_dataset" | |
| if "mode_metadata" not in data_copy: | |
| data_copy["mode_metadata"] = {} | |
| if "file_source" not in data_copy: | |
| data_copy["file_source"] = None | |
| if "dataset_version" not in data_copy: | |
| data_copy["dataset_version"] = None | |
| if "manual_input_count" not in data_copy: | |
| data_copy["manual_input_count"] = 0 | |
| session = cls(**data_copy) | |
| session.verifications = [VerificationRecord.from_dict(v) for v in verifications] | |
| return session | |