Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test script for Context-Aware Classifier implementation. | |
| This script validates the context-aware classification functionality including: | |
| - Context-aware classification with conversation history | |
| - Defensive response pattern detection | |
| - Contextual indicator weighting | |
| - Contextual follow-up question generation | |
| - Medical context integration | |
| """ | |
| import sys | |
| import os | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'src')) | |
| from datetime import datetime, timedelta | |
| from config.prompt_management.context_aware_classifier import ContextAwareClassifier | |
| from config.prompt_management.data_models import ConversationHistory, Message, Classification | |
| def test_context_aware_classifier(): | |
| """Test the ContextAwareClassifier implementation.""" | |
| print("Testing Context-Aware Classifier...") | |
| classifier = ContextAwareClassifier() | |
| # Test 1: Basic classification without context | |
| print("\n1. Testing basic classification...") | |
| message = "I'm feeling stressed about work" | |
| empty_history = ConversationHistory( | |
| messages=[], | |
| distress_indicators_found=[], | |
| context_flags=[] | |
| ) | |
| result = classifier.classify_with_context(message, empty_history) | |
| print(f" Message: '{message}'") | |
| print(f" Classification: {result.category} (confidence: {result.confidence:.2f})") | |
| print(f" Reasoning: {result.reasoning}") | |
| assert result.category in ['GREEN', 'YELLOW', 'RED'], "Invalid category" | |
| assert 0.0 <= result.confidence <= 1.0, "Invalid confidence" | |
| print(" ✓ Basic classification works") | |
| # Test 2: Historical distress with dismissive response | |
| print("\n2. Testing historical distress with dismissive response...") | |
| history_with_distress = ConversationHistory( | |
| messages=[ | |
| Message("I'm really struggling with anxiety", "YELLOW", datetime.now() - timedelta(hours=1)), | |
| Message("I feel overwhelmed and sad", "YELLOW", datetime.now() - timedelta(minutes=30)) | |
| ], | |
| distress_indicators_found=['anxiety', 'overwhelmed', 'sad'], | |
| context_flags=['distress_expressed'] | |
| ) | |
| dismissive_message = "I'm fine now, everything is okay" | |
| result = classifier.classify_with_context(dismissive_message, history_with_distress) | |
| print(f" Message: '{dismissive_message}'") | |
| print(f" Classification: {result.category} (confidence: {result.confidence:.2f})") | |
| print(f" Context factors: {result.context_factors}") | |
| print(f" Reasoning: {result.reasoning}") | |
| # Should be YELLOW due to historical context | |
| assert result.category in ['YELLOW', 'RED'], f"Expected YELLOW/RED with historical distress, got {result.category}" | |
| assert 'historical' in result.reasoning.lower() or 'previous' in result.reasoning.lower(), "Should mention historical context" | |
| print(" ✓ Historical context influences classification") | |
| # Test 3: Defensive response detection | |
| print("\n3. Testing defensive response detection...") | |
| defensive_responses = [ | |
| "I'm fine", | |
| "Everything is okay", | |
| "No problems here", | |
| "I don't need help" | |
| ] | |
| for response in defensive_responses: | |
| is_defensive = classifier.detect_defensive_responses(response, history_with_distress) | |
| print(f" '{response}' -> Defensive: {is_defensive}") | |
| assert is_defensive == True, f"Should detect '{response}' as defensive with distress history" | |
| print(" ✓ Defensive response detection works") | |
| # Test 4: Contextual indicator weighting | |
| print("\n4. Testing contextual indicator weighting...") | |
| context_scenarios = [ | |
| {'historical_mentions': 0, 'recent_mention': False, 'conversation_length': 1}, | |
| {'historical_mentions': 3, 'recent_mention': True, 'conversation_length': 5}, | |
| {'historical_mentions': 1, 'recent_mention': False, 'conversation_length': 2} | |
| ] | |
| for i, context in enumerate(context_scenarios): | |
| weight = classifier.evaluate_contextual_indicators(['stress'], context) | |
| print(f" Scenario {i+1}: {context} -> Weight: {weight:.2f}") | |
| assert 0.0 <= weight <= 1.0, "Weight should be between 0 and 1" | |
| # Higher historical mentions should generally increase weight | |
| if context['historical_mentions'] >= 2: | |
| assert weight >= 0.5, "High historical mentions should increase weight" | |
| print(" ✓ Contextual indicator weighting works") | |
| # Test 5: Contextual follow-up generation | |
| print("\n5. Testing contextual follow-up generation...") | |
| follow_up = classifier.generate_contextual_follow_up( | |
| "I'm not sure how I feel", | |
| history_with_distress, | |
| "YELLOW" | |
| ) | |
| print(f" Follow-up question: '{follow_up}'") | |
| assert len(follow_up.strip()) > 0, "Follow-up should not be empty" | |
| assert '?' in follow_up, "Follow-up should be a question" | |
| print(" ✓ Contextual follow-up generation works") | |
| # Test 6: Medical context integration | |
| print("\n6. Testing medical context integration...") | |
| medical_history = ConversationHistory( | |
| messages=[], | |
| distress_indicators_found=[], | |
| context_flags=[], | |
| medical_context={'conditions': ['anxiety disorder'], 'medications': ['SSRI']} | |
| ) | |
| medical_message = "I'm managing my anxiety with medication but still feel stressed" | |
| result = classifier.classify_with_context(medical_message, medical_history) | |
| print(f" Message: '{medical_message}'") | |
| print(f" Classification: {result.category} (confidence: {result.confidence:.2f})") | |
| print(f" Reasoning: {result.reasoning}") | |
| # Should consider medical context | |
| assert result.category in ['YELLOW', 'RED'], "Medical context with stress should be YELLOW/RED" | |
| print(" ✓ Medical context integration works") | |
| # Test 7: Classification consistency | |
| print("\n7. Testing classification consistency...") | |
| test_messages = [ | |
| ("I feel great today", "GREEN"), | |
| ("I'm worried about my job", "YELLOW"), | |
| ("I want to end it all", "RED") | |
| ] | |
| for message, expected_category in test_messages: | |
| result = classifier.classify_with_context(message, empty_history) | |
| print(f" '{message}' -> {result.category} (expected: {expected_category})") | |
| # Allow some flexibility in classification | |
| if expected_category == "RED": | |
| assert result.category == "RED", f"RED messages should be classified as RED" | |
| # Other categories can have some variation based on context | |
| print(" ✓ Classification consistency maintained") | |
| return True | |
| def test_data_model_integration(): | |
| """Test integration with data models.""" | |
| print("\nTesting data model integration...") | |
| # Test Message serialization | |
| message = Message( | |
| content="Test message", | |
| classification="YELLOW", | |
| timestamp=datetime.now(), | |
| confidence=0.8 | |
| ) | |
| message_dict = message.to_dict() | |
| restored_message = Message.from_dict(message_dict) | |
| assert restored_message.content == message.content, "Message content should match" | |
| assert restored_message.classification == message.classification, "Classification should match" | |
| print(" ✓ Message serialization works") | |
| # Test Classification serialization | |
| classification = Classification( | |
| category="YELLOW", | |
| confidence=0.7, | |
| reasoning="Test reasoning", | |
| indicators_found=['stress'], | |
| context_factors=['historical_distress'] | |
| ) | |
| class_dict = classification.to_dict() | |
| restored_class = Classification.from_dict(class_dict) | |
| assert restored_class.category == classification.category, "Category should match" | |
| assert restored_class.confidence == classification.confidence, "Confidence should match" | |
| print(" ✓ Classification serialization works") | |
| # Test ConversationHistory serialization | |
| history = ConversationHistory( | |
| messages=[message], | |
| distress_indicators_found=['stress', 'anxiety'], | |
| context_flags=['distress_expressed'], | |
| medical_context={'conditions': ['anxiety'], 'medications': []} | |
| ) | |
| history_dict = history.to_dict() | |
| restored_history = ConversationHistory.from_dict(history_dict) | |
| assert len(restored_history.messages) == 1, "Should have one message" | |
| assert restored_history.distress_indicators_found == history.distress_indicators_found, "Indicators should match" | |
| print(" ✓ ConversationHistory serialization works") | |
| return True | |
| def main(): | |
| """Run all tests.""" | |
| print("=" * 60) | |
| print("CONTEXT-AWARE CLASSIFIER TEST SUITE") | |
| print("=" * 60) | |
| try: | |
| # Run tests | |
| test_context_aware_classifier() | |
| test_data_model_integration() | |
| print("\n" + "=" * 60) | |
| print("✅ ALL TESTS PASSED!") | |
| print("Context-Aware Classifier implementation is working correctly.") | |
| print("=" * 60) | |
| return True | |
| except Exception as e: | |
| print(f"\n❌ TEST FAILED: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| if __name__ == "__main__": | |
| success = main() | |
| sys.exit(0 if success else 1) |