File size: 9,279 Bytes
88b8fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""                                                                                                                                
Test script for the trained CBT binary classifier.                                                                                 
"""                                                                                                                                
                                                                                                                                   
import argparse                                                                                                                    
from binary_classifier import CBTBinaryClassifier                                                                                  
                                                                                                                                   
def main():                                                                                                                        
    parser = argparse.ArgumentParser(description='Test CBT Binary Classifier')                                                     
    parser.add_argument('--model_path', default='./cbt_classifier',                                                                
                       help='Path to the trained model')                                                                           
    parser.add_argument('--threshold', type=float, default=0.7,                                                                    
                       help='Confidence threshold for CBT trigger detection')                                                      
                                                                                                                                   
    args = parser.parse_args()                                                                                                     
                                                                                                                                   
    # Load the trained model                                                                                                       
    classifier = CBTBinaryClassifier()                                                                                             
    classifier.load_model(args.model_path)                                                                                         
                                                                                                                                   
    # Test examples                                                                                                                
    test_texts = [                                                                                                                 
        # Normal conversation examples                                                                                             
        "How was your weekend?",                                                                                                   
        "Nice weather today!",                                                                                                     
        "Did you see that movie last night?",                                                                                      
        "I had a great lunch at that new restaurant",                                                                              
        "What are your plans for tonight?",                                                                                        
                                                                                                                                   
        # CBT trigger examples                                                                                                     
        "I'm such a failure at everything",                                                                                        
        "I always mess things up",                                                                                                 
        "Everyone probably thinks I'm stupid",                                                                                     
        "I'm not good enough for this job",                                                                                        
        "I'll never be successful",                                                                                                
        "It's all my fault that this happened"                                                                                     
    ]                                                                                                                              
                                                                                                                                   
    print(f"Testing classifier with threshold: {args.threshold}")                                                                  
    print("=" * 60)                                                                                                                
                                                                                                                                   
    for text in test_texts:                                                                                                        
        result = classifier.predict(text, threshold=args.threshold)                                                                
                                                                                                                                   
        status = "🚨 CBT TRIGGER" if result['is_cbt_trigger'] else "✅ NORMAL"                                                     
        confidence = result['confidence']                                                                                          
                                                                                                                                   
        print(f"{status} (confidence: {confidence:.3f})")                                                                          
        print(f"Text: '{text}'")                                                                                                   
        print("-" * 60)                                                                                                            
                                                                                                                                   
    # Interactive testing                                                                                                          
    print("\nInteractive testing (type 'quit' to exit):")                                                                          
    while True:                                                                                                                    
        user_input = input("\nEnter text to classify: ").strip()                                                                   
                                                                                                                                   
        if user_input.lower() in ['quit', 'exit', 'q']:                                                                            
            break                                                                                                                  
                                                                                                                                   
        if not user_input:                                                                                                         
            continue                                                                                                               
                                                                                                                                   
        result = classifier.predict(user_input, threshold=args.threshold)                                                          
                                                                                                                                   
        status = "🚨 CBT TRIGGER" if result['is_cbt_trigger'] else "✅ NORMAL"                                                     
        confidence = result['confidence']                                                                                          
                                                                                                                                   
        print(f"{status} (confidence: {confidence:.3f})")                                                                          
                                                                                                                                   
if __name__ == "__main__":                                                                                                         
    main()