Spaces:
Running
Running
stanlee47 commited on
Commit Β·
1dbb69b
1
Parent(s): 21f026d
changes
Browse files- README.md +1 -1
- __pycache__/groq_client.cpython-312.pyc +0 -0
- app.py +2 -4
- classifier.py +0 -142
- groq_client.py +112 -0
- requirements.txt +0 -2
- test_beck_protocol.py +11 -6
README.md
CHANGED
|
@@ -14,7 +14,7 @@ Backend API for the CBT Companion mental health support app.
|
|
| 14 |
## Features
|
| 15 |
|
| 16 |
- π User authentication (register/login)
|
| 17 |
-
- π§ Cognitive distortion classification (
|
| 18 |
- π¬ Therapeutic conversations (Groq LLaMA 3.3)
|
| 19 |
- π¨ Crisis detection and flagging
|
| 20 |
- π User statistics tracking
|
|
|
|
| 14 |
## Features
|
| 15 |
|
| 16 |
- π User authentication (register/login)
|
| 17 |
+
- π§ Cognitive distortion classification (LLM-based)
|
| 18 |
- π¬ Therapeutic conversations (Groq LLaMA 3.3)
|
| 19 |
- π¨ Crisis detection and flagging
|
| 20 |
- π User statistics tracking
|
__pycache__/groq_client.cpython-312.pyc
CHANGED
|
Binary files a/__pycache__/groq_client.cpython-312.pyc and b/__pycache__/groq_client.cpython-312.pyc differ
|
|
|
app.py
CHANGED
|
@@ -6,7 +6,6 @@ Hosted on HuggingFace Spaces
|
|
| 6 |
|
| 7 |
from flask import Flask, request, jsonify
|
| 8 |
from flask_cors import CORS
|
| 9 |
-
from classifier import DistortionClassifier
|
| 10 |
from groq_client import GroqClient
|
| 11 |
from database import get_db
|
| 12 |
from auth import register_user, login_user, token_required
|
|
@@ -27,7 +26,6 @@ app.register_blueprint(wearable_bp)
|
|
| 27 |
app.register_blueprint(admin_bp)
|
| 28 |
|
| 29 |
# Initialize components
|
| 30 |
-
classifier = DistortionClassifier()
|
| 31 |
groq_client = GroqClient(api_key=os.environ.get("GROQ_API_KEY"))
|
| 32 |
|
| 33 |
|
|
@@ -198,8 +196,8 @@ def chat():
|
|
| 198 |
beck_data = db.get_beck_session(session_id)
|
| 199 |
|
| 200 |
if not beck_data:
|
| 201 |
-
# First message - check if distorted
|
| 202 |
-
classification =
|
| 203 |
|
| 204 |
if classification["group"] == "G0":
|
| 205 |
# No distortion - supportive listening
|
|
|
|
| 6 |
|
| 7 |
from flask import Flask, request, jsonify
|
| 8 |
from flask_cors import CORS
|
|
|
|
| 9 |
from groq_client import GroqClient
|
| 10 |
from database import get_db
|
| 11 |
from auth import register_user, login_user, token_required
|
|
|
|
| 26 |
app.register_blueprint(admin_bp)
|
| 27 |
|
| 28 |
# Initialize components
|
|
|
|
| 29 |
groq_client = GroqClient(api_key=os.environ.get("GROQ_API_KEY"))
|
| 30 |
|
| 31 |
|
|
|
|
| 196 |
beck_data = db.get_beck_session(session_id)
|
| 197 |
|
| 198 |
if not beck_data:
|
| 199 |
+
# First message - check if distorted using LLM
|
| 200 |
+
classification = groq_client.classify_distortion(user_message)
|
| 201 |
|
| 202 |
if classification["group"] == "G0":
|
| 203 |
# No distortion - supportive listening
|
classifier.py
DELETED
|
@@ -1,142 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Cognitive Distortion Classifier
|
| 3 |
-
Uses TinyBERT model from HuggingFace
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class DistortionClassifier:
|
| 12 |
-
"""
|
| 13 |
-
Classifies text into cognitive distortion groups.
|
| 14 |
-
|
| 15 |
-
Groups:
|
| 16 |
-
- G0: No Distortion
|
| 17 |
-
- G1: Binary & Global Evaluation (All-or-nothing, Labeling)
|
| 18 |
-
- G2: Overgeneralized Beliefs (Overgeneralization, Mind Reading, Fortune-telling)
|
| 19 |
-
- G3: Attentional Bias (Mental Filter, Magnification)
|
| 20 |
-
- G4: Self-Referential Reasoning (Emotional Reasoning, Personalization, Should statements)
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
MODEL_NAME = "santa47/cbt-distortion-classifier-bert"
|
| 24 |
-
|
| 25 |
-
def __init__(self):
|
| 26 |
-
print(f"Loading classifier from {self.MODEL_NAME}...")
|
| 27 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 28 |
-
print(f"Using device: {self.device}")
|
| 29 |
-
|
| 30 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
|
| 31 |
-
self.model = AutoModelForSequenceClassification.from_pretrained(self.MODEL_NAME)
|
| 32 |
-
self.model.to(self.device)
|
| 33 |
-
self.model.eval()
|
| 34 |
-
|
| 35 |
-
# Label mapping
|
| 36 |
-
self.id_to_label = {
|
| 37 |
-
0: "G0",
|
| 38 |
-
1: "G1",
|
| 39 |
-
2: "G2",
|
| 40 |
-
3: "G3",
|
| 41 |
-
4: "G4"
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
self.label_info = {
|
| 45 |
-
"G0": {
|
| 46 |
-
"name": "No Distortion Detected",
|
| 47 |
-
"description": "Healthy, balanced thinking",
|
| 48 |
-
"distortions": []
|
| 49 |
-
},
|
| 50 |
-
"G1": {
|
| 51 |
-
"name": "Binary & Global Evaluation",
|
| 52 |
-
"description": "All-or-nothing thinking patterns",
|
| 53 |
-
"distortions": ["All-or-nothing thinking", "Labeling"]
|
| 54 |
-
},
|
| 55 |
-
"G2": {
|
| 56 |
-
"name": "Overgeneralized Beliefs",
|
| 57 |
-
"description": "Making broad conclusions from limited evidence",
|
| 58 |
-
"distortions": ["Overgeneralization", "Mind Reading", "Fortune-telling"]
|
| 59 |
-
},
|
| 60 |
-
"G3": {
|
| 61 |
-
"name": "Attentional & Salience Bias",
|
| 62 |
-
"description": "Focusing on negatives, ignoring positives",
|
| 63 |
-
"distortions": ["Mental Filter", "Magnification"]
|
| 64 |
-
},
|
| 65 |
-
"G4": {
|
| 66 |
-
"name": "Self-Referential & Emotion-Driven",
|
| 67 |
-
"description": "Letting emotions drive conclusions",
|
| 68 |
-
"distortions": ["Emotional Reasoning", "Personalization", "Should statements"]
|
| 69 |
-
}
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
print("β
Classifier loaded successfully!")
|
| 73 |
-
|
| 74 |
-
def classify(self, text: str) -> dict:
|
| 75 |
-
"""
|
| 76 |
-
Classify text into a distortion group.
|
| 77 |
-
|
| 78 |
-
Args:
|
| 79 |
-
text: Input text to classify
|
| 80 |
-
|
| 81 |
-
Returns:
|
| 82 |
-
dict with group, confidence, and group info
|
| 83 |
-
"""
|
| 84 |
-
# Tokenize
|
| 85 |
-
inputs = self.tokenizer(
|
| 86 |
-
text,
|
| 87 |
-
return_tensors="pt",
|
| 88 |
-
truncation=True,
|
| 89 |
-
max_length=512,
|
| 90 |
-
padding=True
|
| 91 |
-
)
|
| 92 |
-
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 93 |
-
|
| 94 |
-
# Predict
|
| 95 |
-
with torch.no_grad():
|
| 96 |
-
outputs = self.model(**inputs)
|
| 97 |
-
logits = outputs.logits
|
| 98 |
-
probabilities = F.softmax(logits, dim=-1)
|
| 99 |
-
|
| 100 |
-
predicted_class = torch.argmax(probabilities, dim=-1).item()
|
| 101 |
-
confidence = probabilities[0][predicted_class].item()
|
| 102 |
-
|
| 103 |
-
# Get group label
|
| 104 |
-
group = self.id_to_label.get(predicted_class, "G0")
|
| 105 |
-
group_info = self.label_info.get(group, {})
|
| 106 |
-
|
| 107 |
-
return {
|
| 108 |
-
"group": group,
|
| 109 |
-
"confidence": round(confidence, 4),
|
| 110 |
-
"group_name": group_info.get("name", "Unknown"),
|
| 111 |
-
"description": group_info.get("description", ""),
|
| 112 |
-
"distortions": group_info.get("distortions", []),
|
| 113 |
-
"all_probabilities": {
|
| 114 |
-
self.id_to_label[i]: round(probabilities[0][i].item(), 4)
|
| 115 |
-
for i in range(len(self.id_to_label))
|
| 116 |
-
}
|
| 117 |
-
}
|
| 118 |
-
|
| 119 |
-
def get_group_info(self, group: str) -> dict:
|
| 120 |
-
"""Get detailed info about a distortion group."""
|
| 121 |
-
return self.label_info.get(group, {})
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
# Test if run directly
|
| 125 |
-
if __name__ == "__main__":
|
| 126 |
-
classifier = DistortionClassifier()
|
| 127 |
-
|
| 128 |
-
test_texts = [
|
| 129 |
-
"I failed my exam. I'll never succeed at anything.",
|
| 130 |
-
"My friend didn't text back. She must hate me.",
|
| 131 |
-
"I made one mistake so the whole project is ruined.",
|
| 132 |
-
"I feel anxious so something bad must be happening.",
|
| 133 |
-
"I had a nice day today and enjoyed my lunch."
|
| 134 |
-
]
|
| 135 |
-
|
| 136 |
-
print("\nπ§ͺ Testing classifier:\n")
|
| 137 |
-
for text in test_texts:
|
| 138 |
-
result = classifier.classify(text)
|
| 139 |
-
print(f"Text: {text[:50]}...")
|
| 140 |
-
print(f" β {result['group']}: {result['group_name']}")
|
| 141 |
-
print(f" β Confidence: {result['confidence']}")
|
| 142 |
-
print()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
groq_client.py
CHANGED
|
@@ -279,6 +279,118 @@ Respond as Aria for the {current_state} state."""
|
|
| 279 |
print(f"Agent 3 error: {e}")
|
| 280 |
return f"You've done really good work here, {user_name}. π"
|
| 281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
# ==================== SUPPORTIVE RESPONSE (G0 - No Distortion) ====================
|
| 283 |
|
| 284 |
def generate_supportive_response(
|
|
|
|
| 279 |
print(f"Agent 3 error: {e}")
|
| 280 |
return f"You've done really good work here, {user_name}. π"
|
| 281 |
|
| 282 |
+
# ==================== COGNITIVE DISTORTION CLASSIFIER ====================
|
| 283 |
+
|
| 284 |
+
def classify_distortion(self, text: str) -> dict:
|
| 285 |
+
"""
|
| 286 |
+
Classify text into cognitive distortion groups using LLM.
|
| 287 |
+
|
| 288 |
+
Groups:
|
| 289 |
+
- G0: No Distortion
|
| 290 |
+
- G1: Binary & Global Evaluation (All-or-nothing, Labeling)
|
| 291 |
+
- G2: Overgeneralized Beliefs (Overgeneralization, Mind Reading, Fortune-telling)
|
| 292 |
+
- G3: Attentional Bias (Mental Filter, Magnification)
|
| 293 |
+
- G4: Self-Referential Reasoning (Emotional Reasoning, Personalization, Should statements)
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
text: User's thought to classify
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
dict with group, confidence, group_name, description, and distortions
|
| 300 |
+
"""
|
| 301 |
+
system_prompt = """You are a cognitive distortion classifier for CBT (Cognitive Behavioral Therapy).
|
| 302 |
+
|
| 303 |
+
Analyze the user's thought and classify it into ONE of these groups:
|
| 304 |
+
|
| 305 |
+
G0: No Distortion - Healthy, balanced thinking
|
| 306 |
+
G1: Binary & Global Evaluation - All-or-nothing thinking, Labeling
|
| 307 |
+
G2: Overgeneralized Beliefs - Overgeneralization, Mind Reading, Fortune-telling
|
| 308 |
+
G3: Attentional & Salience Bias - Mental Filter, Magnification/Minimization
|
| 309 |
+
G4: Self-Referential & Emotion-Driven - Emotional Reasoning, Personalization, Should statements
|
| 310 |
+
|
| 311 |
+
Examples:
|
| 312 |
+
- "I failed my exam. I'll never succeed at anything." β G1 (all-or-nothing)
|
| 313 |
+
- "My friend didn't text back. She must hate me." β G2 (mind reading)
|
| 314 |
+
- "Everything in my life is terrible." β G3 (mental filter)
|
| 315 |
+
- "I feel anxious so something bad must be happening." β G4 (emotional reasoning)
|
| 316 |
+
- "I had a nice day today and enjoyed my lunch." β G0 (no distortion)
|
| 317 |
+
|
| 318 |
+
Respond ONLY in JSON format:
|
| 319 |
+
{
|
| 320 |
+
"group": "G0/G1/G2/G3/G4",
|
| 321 |
+
"confidence": 0.85,
|
| 322 |
+
"reasoning": "Brief explanation of why this classification"
|
| 323 |
+
}"""
|
| 324 |
+
|
| 325 |
+
user_prompt = f'Classify this thought: "{text}"'
|
| 326 |
+
|
| 327 |
+
try:
|
| 328 |
+
response = self.client.chat.completions.create(
|
| 329 |
+
model=self.MODEL,
|
| 330 |
+
messages=[
|
| 331 |
+
{"role": "system", "content": system_prompt},
|
| 332 |
+
{"role": "user", "content": user_prompt}
|
| 333 |
+
],
|
| 334 |
+
temperature=0.3,
|
| 335 |
+
max_tokens=200,
|
| 336 |
+
response_format={"type": "json_object"}
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
result = json.loads(response.choices[0].message.content)
|
| 340 |
+
group = result.get("group", "G0")
|
| 341 |
+
|
| 342 |
+
# Group information
|
| 343 |
+
label_info = {
|
| 344 |
+
"G0": {
|
| 345 |
+
"name": "No Distortion Detected",
|
| 346 |
+
"description": "Healthy, balanced thinking",
|
| 347 |
+
"distortions": []
|
| 348 |
+
},
|
| 349 |
+
"G1": {
|
| 350 |
+
"name": "Binary & Global Evaluation",
|
| 351 |
+
"description": "All-or-nothing thinking patterns",
|
| 352 |
+
"distortions": ["All-or-nothing thinking", "Labeling"]
|
| 353 |
+
},
|
| 354 |
+
"G2": {
|
| 355 |
+
"name": "Overgeneralized Beliefs",
|
| 356 |
+
"description": "Making broad conclusions from limited evidence",
|
| 357 |
+
"distortions": ["Overgeneralization", "Mind Reading", "Fortune-telling"]
|
| 358 |
+
},
|
| 359 |
+
"G3": {
|
| 360 |
+
"name": "Attentional & Salience Bias",
|
| 361 |
+
"description": "Focusing on negatives, ignoring positives",
|
| 362 |
+
"distortions": ["Mental Filter", "Magnification"]
|
| 363 |
+
},
|
| 364 |
+
"G4": {
|
| 365 |
+
"name": "Self-Referential & Emotion-Driven",
|
| 366 |
+
"description": "Letting emotions drive conclusions",
|
| 367 |
+
"distortions": ["Emotional Reasoning", "Personalization", "Should statements"]
|
| 368 |
+
}
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
group_info = label_info.get(group, label_info["G0"])
|
| 372 |
+
|
| 373 |
+
return {
|
| 374 |
+
"group": group,
|
| 375 |
+
"confidence": round(result.get("confidence", 0.8), 4),
|
| 376 |
+
"group_name": group_info["name"],
|
| 377 |
+
"description": group_info["description"],
|
| 378 |
+
"distortions": group_info["distortions"],
|
| 379 |
+
"reasoning": result.get("reasoning", "")
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
except Exception as e:
|
| 383 |
+
print(f"Classification error: {e}")
|
| 384 |
+
# Default to G0 (no distortion) on error
|
| 385 |
+
return {
|
| 386 |
+
"group": "G0",
|
| 387 |
+
"confidence": 0.5,
|
| 388 |
+
"group_name": "No Distortion Detected",
|
| 389 |
+
"description": "Healthy, balanced thinking",
|
| 390 |
+
"distortions": [],
|
| 391 |
+
"reasoning": "Classification failed, defaulting to G0"
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
# ==================== SUPPORTIVE RESPONSE (G0 - No Distortion) ====================
|
| 395 |
|
| 396 |
def generate_supportive_response(
|
requirements.txt
CHANGED
|
@@ -5,8 +5,6 @@ gunicorn==21.2.0
|
|
| 5 |
|
| 6 |
# AI/ML
|
| 7 |
numpy<2
|
| 8 |
-
transformers==4.36.0
|
| 9 |
-
torch==2.1.0
|
| 10 |
httpx>=0.27.0,<0.28.0
|
| 11 |
groq>=0.11.0
|
| 12 |
|
|
|
|
| 5 |
|
| 6 |
# AI/ML
|
| 7 |
numpy<2
|
|
|
|
|
|
|
| 8 |
httpx>=0.27.0,<0.28.0
|
| 9 |
groq>=0.11.0
|
| 10 |
|
test_beck_protocol.py
CHANGED
|
@@ -5,16 +5,20 @@ Run this to verify the 3-agent system works
|
|
| 5 |
|
| 6 |
import os
|
| 7 |
from groq_client import GroqClient
|
| 8 |
-
from classifier import DistortionClassifier
|
| 9 |
|
| 10 |
# Test the components
|
| 11 |
def test_classifier():
|
| 12 |
-
"""Test
|
| 13 |
print("=" * 60)
|
| 14 |
-
print("TEST 1: CLASSIFIER (
|
| 15 |
print("=" * 60)
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
test_cases = [
|
| 20 |
("I had a nice day today", "G0"),
|
|
@@ -23,12 +27,13 @@ def test_classifier():
|
|
| 23 |
]
|
| 24 |
|
| 25 |
for text, expected in test_cases:
|
| 26 |
-
result =
|
| 27 |
print(f"\nText: {text[:50]}...")
|
| 28 |
print(f" Predicted: {result['group']} ({result['confidence']:.2%})")
|
|
|
|
| 29 |
print(f" Expected: {expected}")
|
| 30 |
|
| 31 |
-
print("\nβ
Classifier test complete\n")
|
| 32 |
|
| 33 |
|
| 34 |
def test_agent1():
|
|
|
|
| 5 |
|
| 6 |
import os
|
| 7 |
from groq_client import GroqClient
|
|
|
|
| 8 |
|
| 9 |
# Test the components
|
| 10 |
def test_classifier():
|
| 11 |
+
"""Test LLM-based cognitive distortion classification"""
|
| 12 |
print("=" * 60)
|
| 13 |
+
print("TEST 1: LLM CLASSIFIER (Cognitive Distortion Detection)")
|
| 14 |
print("=" * 60)
|
| 15 |
|
| 16 |
+
api_key = os.environ.get("GROQ_API_KEY")
|
| 17 |
+
if not api_key:
|
| 18 |
+
print("β GROQ_API_KEY not set - skipping classifier test")
|
| 19 |
+
return
|
| 20 |
+
|
| 21 |
+
client = GroqClient(api_key)
|
| 22 |
|
| 23 |
test_cases = [
|
| 24 |
("I had a nice day today", "G0"),
|
|
|
|
| 27 |
]
|
| 28 |
|
| 29 |
for text, expected in test_cases:
|
| 30 |
+
result = client.classify_distortion(text)
|
| 31 |
print(f"\nText: {text[:50]}...")
|
| 32 |
print(f" Predicted: {result['group']} ({result['confidence']:.2%})")
|
| 33 |
+
print(f" Reasoning: {result.get('reasoning', 'N/A')}")
|
| 34 |
print(f" Expected: {expected}")
|
| 35 |
|
| 36 |
+
print("\nβ
LLM Classifier test complete\n")
|
| 37 |
|
| 38 |
|
| 39 |
def test_agent1():
|