File size: 9,279 Bytes
88b8fd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
"""
Test script for the trained CBT binary classifier.
"""
import argparse
from binary_classifier import CBTBinaryClassifier
def main():
parser = argparse.ArgumentParser(description='Test CBT Binary Classifier')
parser.add_argument('--model_path', default='./cbt_classifier',
help='Path to the trained model')
parser.add_argument('--threshold', type=float, default=0.7,
help='Confidence threshold for CBT trigger detection')
args = parser.parse_args()
# Load the trained model
classifier = CBTBinaryClassifier()
classifier.load_model(args.model_path)
# Test examples
test_texts = [
# Normal conversation examples
"How was your weekend?",
"Nice weather today!",
"Did you see that movie last night?",
"I had a great lunch at that new restaurant",
"What are your plans for tonight?",
# CBT trigger examples
"I'm such a failure at everything",
"I always mess things up",
"Everyone probably thinks I'm stupid",
"I'm not good enough for this job",
"I'll never be successful",
"It's all my fault that this happened"
]
print(f"Testing classifier with threshold: {args.threshold}")
print("=" * 60)
for text in test_texts:
result = classifier.predict(text, threshold=args.threshold)
status = "🚨 CBT TRIGGER" if result['is_cbt_trigger'] else "✅ NORMAL"
confidence = result['confidence']
print(f"{status} (confidence: {confidence:.3f})")
print(f"Text: '{text}'")
print("-" * 60)
# Interactive testing
print("\nInteractive testing (type 'quit' to exit):")
while True:
user_input = input("\nEnter text to classify: ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
break
if not user_input:
continue
result = classifier.predict(user_input, threshold=args.threshold)
status = "🚨 CBT TRIGGER" if result['is_cbt_trigger'] else "✅ NORMAL"
confidence = result['confidence']
print(f"{status} (confidence: {confidence:.3f})")
if __name__ == "__main__":
main() |