Spaces:
Sleeping
Sleeping
| # conftest.py | |
| """ | |
| Pytest fixtures for verification mode tests. | |
| Provides comprehensive fixtures for test datasets, sessions, records, and utility functions | |
| for generating test data and making assertions. | |
| """ | |
| import pytest | |
| from datetime import datetime | |
| from src.core.verification_models import ( | |
| VerificationRecord, | |
| VerificationSession, | |
| TestMessage, | |
| TestDataset, | |
| ) | |
| from src.core.verification_store import JSONVerificationStore | |
| from src.core.test_datasets import TestDatasetManager | |
| from src.core.message_queue_manager import MessageQueueManager | |
| from src.core.verification_feedback_handler import VerificationFeedbackHandler | |
| from src.core.verification_metrics import VerificationMetricsCalculator | |
| from src.core.verification_csv_exporter import VerificationCSVExporter | |
| import tempfile | |
| import shutil | |
| from typing import List, Dict, Any | |
| # ============================================================================ | |
| # STORAGE AND STORE FIXTURES | |
| # ============================================================================ | |
| def temp_storage_dir(): | |
| """Create a temporary directory for test storage.""" | |
| temp_dir = tempfile.mkdtemp() | |
| yield temp_dir | |
| shutil.rmtree(temp_dir) | |
| def verification_store(temp_storage_dir): | |
| """Create a verification store with temporary storage.""" | |
| return JSONVerificationStore(storage_dir=temp_storage_dir) | |
| # ============================================================================ | |
| # BASIC DATA MODEL FIXTURES | |
| # ============================================================================ | |
| def sample_verification_record(): | |
| """Create a sample verification record.""" | |
| return VerificationRecord( | |
| message_id="msg_001", | |
| original_message="I'm feeling very anxious about my health", | |
| classifier_decision="yellow", | |
| classifier_confidence=0.85, | |
| classifier_indicators=["anxiety", "health concern"], | |
| ground_truth_label="yellow", | |
| verifier_notes="Correctly identified anxiety", | |
| is_correct=True, | |
| timestamp=datetime.now(), | |
| ) | |
| def sample_verification_session(): | |
| """Create a sample verification session.""" | |
| return VerificationSession( | |
| session_id="session_001", | |
| verifier_name="Dr. Smith", | |
| dataset_id="dataset_001", | |
| dataset_name="Anxiety Messages", | |
| created_at=datetime.now(), | |
| total_messages=10, | |
| verified_count=0, | |
| correct_count=0, | |
| incorrect_count=0, | |
| verifications=[], | |
| is_complete=False, | |
| ) | |
| def sample_test_dataset(): | |
| """Create a sample test dataset.""" | |
| messages = [ | |
| TestMessage( | |
| message_id="msg_001", | |
| text="I'm feeling fine today", | |
| pre_classified_label="green", | |
| ), | |
| TestMessage( | |
| message_id="msg_002", | |
| text="I'm a bit worried about my symptoms", | |
| pre_classified_label="yellow", | |
| ), | |
| TestMessage( | |
| message_id="msg_003", | |
| text="I'm having severe thoughts of harming myself", | |
| pre_classified_label="red", | |
| ), | |
| ] | |
| return TestDataset( | |
| dataset_id="dataset_001", | |
| name="Test Dataset", | |
| description="A test dataset with sample messages", | |
| messages=messages, | |
| ) | |
| # ============================================================================ | |
| # DATASET FIXTURES | |
| # ============================================================================ | |
| def all_test_datasets(): | |
| """Get all predefined test datasets.""" | |
| return TestDatasetManager.get_all_datasets() | |
| def suicidal_ideation_dataset(): | |
| """Get the suicidal ideation test dataset.""" | |
| return TestDatasetManager.SUICIDAL_IDEATION_DATASET | |
| def anxiety_worry_dataset(): | |
| """Get the anxiety and worry test dataset.""" | |
| return TestDatasetManager.ANXIETY_WORRY_DATASET | |
| def healthy_positive_dataset(): | |
| """Get the healthy and positive test dataset.""" | |
| return TestDatasetManager.HEALTHY_POSITIVE_DATASET | |
| def mixed_scenarios_dataset(): | |
| """Get the mixed scenarios test dataset.""" | |
| return TestDatasetManager.MIXED_SCENARIOS_DATASET | |
| # ============================================================================ | |
| # COMPONENT FIXTURES | |
| # ============================================================================ | |
| def message_queue_manager(sample_verification_session): | |
| """Create a message queue manager.""" | |
| return MessageQueueManager(sample_verification_session) | |
| def verification_feedback_handler(sample_verification_session, verification_store, message_queue_manager): | |
| """Create a verification feedback handler.""" | |
| return VerificationFeedbackHandler( | |
| sample_verification_session, | |
| verification_store, | |
| message_queue_manager | |
| ) | |
| def metrics_calculator(): | |
| """Create a metrics calculator.""" | |
| return VerificationMetricsCalculator() | |
| def csv_exporter(): | |
| """Create a CSV exporter.""" | |
| return VerificationCSVExporter() | |
| # ============================================================================ | |
| # TEST DATA GENERATION UTILITIES | |
| # ============================================================================ | |
| class TestDataGenerator: | |
| """Utility class for generating test data.""" | |
| def create_verification_record( | |
| message_id: str = "msg_001", | |
| original_message: str = "Test message", | |
| classifier_decision: str = "yellow", | |
| classifier_confidence: float = 0.85, | |
| classifier_indicators: List[str] = None, | |
| ground_truth_label: str = "yellow", | |
| verifier_notes: str = "", | |
| is_correct: bool = True, | |
| timestamp: datetime = None, | |
| ) -> VerificationRecord: | |
| """Create a verification record with custom parameters.""" | |
| if classifier_indicators is None: | |
| classifier_indicators = ["test_indicator"] | |
| if timestamp is None: | |
| timestamp = datetime.now() | |
| return VerificationRecord( | |
| message_id=message_id, | |
| original_message=original_message, | |
| classifier_decision=classifier_decision, | |
| classifier_confidence=classifier_confidence, | |
| classifier_indicators=classifier_indicators, | |
| ground_truth_label=ground_truth_label, | |
| verifier_notes=verifier_notes, | |
| is_correct=is_correct, | |
| timestamp=timestamp, | |
| ) | |
| def create_verification_session( | |
| session_id: str = "session_001", | |
| verifier_name: str = "Test Verifier", | |
| dataset_id: str = "dataset_001", | |
| dataset_name: str = "Test Dataset", | |
| total_messages: int = 10, | |
| verified_count: int = 0, | |
| correct_count: int = 0, | |
| incorrect_count: int = 0, | |
| is_complete: bool = False, | |
| ) -> VerificationSession: | |
| """Create a verification session with custom parameters.""" | |
| return VerificationSession( | |
| session_id=session_id, | |
| verifier_name=verifier_name, | |
| dataset_id=dataset_id, | |
| dataset_name=dataset_name, | |
| created_at=datetime.now(), | |
| total_messages=total_messages, | |
| verified_count=verified_count, | |
| correct_count=correct_count, | |
| incorrect_count=incorrect_count, | |
| verifications=[], | |
| is_complete=is_complete, | |
| ) | |
| def create_test_messages( | |
| count: int = 5, | |
| classification_type: str = "mixed", | |
| ) -> List[TestMessage]: | |
| """Create test messages with specified classification types.""" | |
| messages = [] | |
| if classification_type == "green": | |
| for i in range(count): | |
| messages.append(TestMessage( | |
| message_id=f"green_{i}", | |
| text=f"I'm feeling great and positive. {i}", | |
| pre_classified_label="green", | |
| )) | |
| elif classification_type == "yellow": | |
| for i in range(count): | |
| messages.append(TestMessage( | |
| message_id=f"yellow_{i}", | |
| text=f"I'm feeling worried and anxious. {i}", | |
| pre_classified_label="yellow", | |
| )) | |
| elif classification_type == "red": | |
| for i in range(count): | |
| messages.append(TestMessage( | |
| message_id=f"red_{i}", | |
| text=f"I'm having severe thoughts of harming myself. {i}", | |
| pre_classified_label="red", | |
| )) | |
| else: # mixed | |
| for i in range(count): | |
| classification = ["green", "yellow", "red"][i % 3] | |
| if classification == "green": | |
| text = f"I'm feeling great. {i}" | |
| elif classification == "yellow": | |
| text = f"I'm feeling worried. {i}" | |
| else: | |
| text = f"I'm having severe thoughts. {i}" | |
| messages.append(TestMessage( | |
| message_id=f"msg_{i}", | |
| text=text, | |
| pre_classified_label=classification, | |
| )) | |
| return messages | |
| def create_test_dataset( | |
| dataset_id: str = "test_dataset", | |
| name: str = "Test Dataset", | |
| description: str = "A test dataset", | |
| message_count: int = 5, | |
| classification_type: str = "mixed", | |
| ) -> TestDataset: | |
| """Create a test dataset with specified parameters.""" | |
| messages = TestDataGenerator.create_test_messages( | |
| count=message_count, | |
| classification_type=classification_type, | |
| ) | |
| return TestDataset( | |
| dataset_id=dataset_id, | |
| name=name, | |
| description=description, | |
| messages=messages, | |
| ) | |
| def create_verification_records_batch( | |
| count: int = 5, | |
| correct_ratio: float = 0.8, | |
| classification_types: List[str] = None, | |
| ) -> List[VerificationRecord]: | |
| """Create a batch of verification records.""" | |
| if classification_types is None: | |
| classification_types = ["green", "yellow", "red"] | |
| records = [] | |
| correct_count = int(count * correct_ratio) | |
| for i in range(count): | |
| classification_type = classification_types[i % len(classification_types)] | |
| is_correct = i < correct_count | |
| record = TestDataGenerator.create_verification_record( | |
| message_id=f"msg_{i}", | |
| original_message=f"Test message {i}", | |
| classifier_decision=classification_type, | |
| classifier_confidence=0.85 + (i * 0.01), | |
| ground_truth_label=classification_type if is_correct else classification_types[(i + 1) % len(classification_types)], | |
| is_correct=is_correct, | |
| ) | |
| records.append(record) | |
| return records | |
| def test_data_generator(): | |
| """Provide the test data generator utility.""" | |
| return TestDataGenerator | |
| # ============================================================================ | |
| # ASSERTION HELPER UTILITIES | |
| # ============================================================================ | |
| class AssertionHelpers: | |
| """Utility class for common assertions.""" | |
| def assert_record_fields_match( | |
| record1: VerificationRecord, | |
| record2: VerificationRecord, | |
| exclude_fields: List[str] = None, | |
| ) -> None: | |
| """Assert that two verification records have matching fields.""" | |
| if exclude_fields is None: | |
| exclude_fields = [] | |
| if "message_id" not in exclude_fields: | |
| assert record1.message_id == record2.message_id | |
| if "original_message" not in exclude_fields: | |
| assert record1.original_message == record2.original_message | |
| if "classifier_decision" not in exclude_fields: | |
| assert record1.classifier_decision == record2.classifier_decision | |
| if "classifier_confidence" not in exclude_fields: | |
| assert record1.classifier_confidence == record2.classifier_confidence | |
| if "classifier_indicators" not in exclude_fields: | |
| assert record1.classifier_indicators == record2.classifier_indicators | |
| if "ground_truth_label" not in exclude_fields: | |
| assert record1.ground_truth_label == record2.ground_truth_label | |
| if "verifier_notes" not in exclude_fields: | |
| assert record1.verifier_notes == record2.verifier_notes | |
| if "is_correct" not in exclude_fields: | |
| assert record1.is_correct == record2.is_correct | |
| def assert_session_fields_match( | |
| session1: VerificationSession, | |
| session2: VerificationSession, | |
| exclude_fields: List[str] = None, | |
| ) -> None: | |
| """Assert that two verification sessions have matching fields.""" | |
| if exclude_fields is None: | |
| exclude_fields = [] | |
| if "session_id" not in exclude_fields: | |
| assert session1.session_id == session2.session_id | |
| if "verifier_name" not in exclude_fields: | |
| assert session1.verifier_name == session2.verifier_name | |
| if "dataset_id" not in exclude_fields: | |
| assert session1.dataset_id == session2.dataset_id | |
| if "dataset_name" not in exclude_fields: | |
| assert session1.dataset_name == session2.dataset_name | |
| if "total_messages" not in exclude_fields: | |
| assert session1.total_messages == session2.total_messages | |
| if "verified_count" not in exclude_fields: | |
| assert session1.verified_count == session2.verified_count | |
| if "correct_count" not in exclude_fields: | |
| assert session1.correct_count == session2.correct_count | |
| if "incorrect_count" not in exclude_fields: | |
| assert session1.incorrect_count == session2.incorrect_count | |
| if "is_complete" not in exclude_fields: | |
| assert session1.is_complete == session2.is_complete | |
| def assert_csv_contains_columns(csv_content: str, required_columns: List[str]) -> None: | |
| """Assert that CSV content contains all required columns.""" | |
| for column in required_columns: | |
| assert column in csv_content, f"Column '{column}' not found in CSV" | |
| def assert_csv_has_summary_section(csv_content: str) -> None: | |
| """Assert that CSV has a summary section.""" | |
| assert "VERIFICATION SUMMARY" in csv_content | |
| assert "Total Messages" in csv_content | |
| assert "Correct" in csv_content | |
| assert "Incorrect" in csv_content | |
| assert "Accuracy %" in csv_content | |
| def assert_accuracy_calculation( | |
| correct_count: int, | |
| total_count: int, | |
| calculated_accuracy: float, | |
| tolerance: float = 0.01, | |
| ) -> None: | |
| """Assert that accuracy calculation is correct.""" | |
| if total_count == 0: | |
| assert calculated_accuracy == 0.0 | |
| else: | |
| expected_accuracy = (correct_count / total_count) * 100 | |
| assert abs(calculated_accuracy - expected_accuracy) < tolerance | |
| def assertion_helpers(): | |
| """Provide assertion helper utilities.""" | |
| return AssertionHelpers | |