stanlee47 commited on
Commit
1dbb69b
Β·
1 Parent(s): 21f026d
README.md CHANGED
@@ -14,7 +14,7 @@ Backend API for the CBT Companion mental health support app.
14
  ## Features
15
 
16
  - πŸ” User authentication (register/login)
17
- - 🧠 Cognitive distortion classification (TinyBERT)
18
  - πŸ’¬ Therapeutic conversations (Groq LLaMA 3.3)
19
  - 🚨 Crisis detection and flagging
20
  - πŸ“Š User statistics tracking
 
14
  ## Features
15
 
16
  - πŸ” User authentication (register/login)
17
+ - 🧠 Cognitive distortion classification (LLM-based)
18
  - πŸ’¬ Therapeutic conversations (Groq LLaMA 3.3)
19
  - 🚨 Crisis detection and flagging
20
  - πŸ“Š User statistics tracking
__pycache__/groq_client.cpython-312.pyc CHANGED
Binary files a/__pycache__/groq_client.cpython-312.pyc and b/__pycache__/groq_client.cpython-312.pyc differ
 
app.py CHANGED
@@ -6,7 +6,6 @@ Hosted on HuggingFace Spaces
6
 
7
  from flask import Flask, request, jsonify
8
  from flask_cors import CORS
9
- from classifier import DistortionClassifier
10
  from groq_client import GroqClient
11
  from database import get_db
12
  from auth import register_user, login_user, token_required
@@ -27,7 +26,6 @@ app.register_blueprint(wearable_bp)
27
  app.register_blueprint(admin_bp)
28
 
29
  # Initialize components
30
- classifier = DistortionClassifier()
31
  groq_client = GroqClient(api_key=os.environ.get("GROQ_API_KEY"))
32
 
33
 
@@ -198,8 +196,8 @@ def chat():
198
  beck_data = db.get_beck_session(session_id)
199
 
200
  if not beck_data:
201
- # First message - check if distorted
202
- classification = classifier.classify(user_message)
203
 
204
  if classification["group"] == "G0":
205
  # No distortion - supportive listening
 
6
 
7
  from flask import Flask, request, jsonify
8
  from flask_cors import CORS
 
9
  from groq_client import GroqClient
10
  from database import get_db
11
  from auth import register_user, login_user, token_required
 
26
  app.register_blueprint(admin_bp)
27
 
28
  # Initialize components
 
29
  groq_client = GroqClient(api_key=os.environ.get("GROQ_API_KEY"))
30
 
31
 
 
196
  beck_data = db.get_beck_session(session_id)
197
 
198
  if not beck_data:
199
+ # First message - check if distorted using LLM
200
+ classification = groq_client.classify_distortion(user_message)
201
 
202
  if classification["group"] == "G0":
203
  # No distortion - supportive listening
classifier.py DELETED
@@ -1,142 +0,0 @@
1
- """
2
- Cognitive Distortion Classifier
3
- Uses TinyBERT model from HuggingFace
4
- """
5
-
6
- import torch
7
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
- import torch.nn.functional as F
9
-
10
-
11
- class DistortionClassifier:
12
- """
13
- Classifies text into cognitive distortion groups.
14
-
15
- Groups:
16
- - G0: No Distortion
17
- - G1: Binary & Global Evaluation (All-or-nothing, Labeling)
18
- - G2: Overgeneralized Beliefs (Overgeneralization, Mind Reading, Fortune-telling)
19
- - G3: Attentional Bias (Mental Filter, Magnification)
20
- - G4: Self-Referential Reasoning (Emotional Reasoning, Personalization, Should statements)
21
- """
22
-
23
- MODEL_NAME = "santa47/cbt-distortion-classifier-bert"
24
-
25
- def __init__(self):
26
- print(f"Loading classifier from {self.MODEL_NAME}...")
27
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- print(f"Using device: {self.device}")
29
-
30
- self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
31
- self.model = AutoModelForSequenceClassification.from_pretrained(self.MODEL_NAME)
32
- self.model.to(self.device)
33
- self.model.eval()
34
-
35
- # Label mapping
36
- self.id_to_label = {
37
- 0: "G0",
38
- 1: "G1",
39
- 2: "G2",
40
- 3: "G3",
41
- 4: "G4"
42
- }
43
-
44
- self.label_info = {
45
- "G0": {
46
- "name": "No Distortion Detected",
47
- "description": "Healthy, balanced thinking",
48
- "distortions": []
49
- },
50
- "G1": {
51
- "name": "Binary & Global Evaluation",
52
- "description": "All-or-nothing thinking patterns",
53
- "distortions": ["All-or-nothing thinking", "Labeling"]
54
- },
55
- "G2": {
56
- "name": "Overgeneralized Beliefs",
57
- "description": "Making broad conclusions from limited evidence",
58
- "distortions": ["Overgeneralization", "Mind Reading", "Fortune-telling"]
59
- },
60
- "G3": {
61
- "name": "Attentional & Salience Bias",
62
- "description": "Focusing on negatives, ignoring positives",
63
- "distortions": ["Mental Filter", "Magnification"]
64
- },
65
- "G4": {
66
- "name": "Self-Referential & Emotion-Driven",
67
- "description": "Letting emotions drive conclusions",
68
- "distortions": ["Emotional Reasoning", "Personalization", "Should statements"]
69
- }
70
- }
71
-
72
- print("βœ… Classifier loaded successfully!")
73
-
74
- def classify(self, text: str) -> dict:
75
- """
76
- Classify text into a distortion group.
77
-
78
- Args:
79
- text: Input text to classify
80
-
81
- Returns:
82
- dict with group, confidence, and group info
83
- """
84
- # Tokenize
85
- inputs = self.tokenizer(
86
- text,
87
- return_tensors="pt",
88
- truncation=True,
89
- max_length=512,
90
- padding=True
91
- )
92
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
93
-
94
- # Predict
95
- with torch.no_grad():
96
- outputs = self.model(**inputs)
97
- logits = outputs.logits
98
- probabilities = F.softmax(logits, dim=-1)
99
-
100
- predicted_class = torch.argmax(probabilities, dim=-1).item()
101
- confidence = probabilities[0][predicted_class].item()
102
-
103
- # Get group label
104
- group = self.id_to_label.get(predicted_class, "G0")
105
- group_info = self.label_info.get(group, {})
106
-
107
- return {
108
- "group": group,
109
- "confidence": round(confidence, 4),
110
- "group_name": group_info.get("name", "Unknown"),
111
- "description": group_info.get("description", ""),
112
- "distortions": group_info.get("distortions", []),
113
- "all_probabilities": {
114
- self.id_to_label[i]: round(probabilities[0][i].item(), 4)
115
- for i in range(len(self.id_to_label))
116
- }
117
- }
118
-
119
- def get_group_info(self, group: str) -> dict:
120
- """Get detailed info about a distortion group."""
121
- return self.label_info.get(group, {})
122
-
123
-
124
- # Test if run directly
125
- if __name__ == "__main__":
126
- classifier = DistortionClassifier()
127
-
128
- test_texts = [
129
- "I failed my exam. I'll never succeed at anything.",
130
- "My friend didn't text back. She must hate me.",
131
- "I made one mistake so the whole project is ruined.",
132
- "I feel anxious so something bad must be happening.",
133
- "I had a nice day today and enjoyed my lunch."
134
- ]
135
-
136
- print("\nπŸ§ͺ Testing classifier:\n")
137
- for text in test_texts:
138
- result = classifier.classify(text)
139
- print(f"Text: {text[:50]}...")
140
- print(f" β†’ {result['group']}: {result['group_name']}")
141
- print(f" β†’ Confidence: {result['confidence']}")
142
- print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
groq_client.py CHANGED
@@ -279,6 +279,118 @@ Respond as Aria for the {current_state} state."""
279
  print(f"Agent 3 error: {e}")
280
  return f"You've done really good work here, {user_name}. πŸ’™"
281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  # ==================== SUPPORTIVE RESPONSE (G0 - No Distortion) ====================
283
 
284
  def generate_supportive_response(
 
279
  print(f"Agent 3 error: {e}")
280
  return f"You've done really good work here, {user_name}. πŸ’™"
281
 
282
+ # ==================== COGNITIVE DISTORTION CLASSIFIER ====================
283
+
284
+ def classify_distortion(self, text: str) -> dict:
285
+ """
286
+ Classify text into cognitive distortion groups using LLM.
287
+
288
+ Groups:
289
+ - G0: No Distortion
290
+ - G1: Binary & Global Evaluation (All-or-nothing, Labeling)
291
+ - G2: Overgeneralized Beliefs (Overgeneralization, Mind Reading, Fortune-telling)
292
+ - G3: Attentional Bias (Mental Filter, Magnification)
293
+ - G4: Self-Referential Reasoning (Emotional Reasoning, Personalization, Should statements)
294
+
295
+ Args:
296
+ text: User's thought to classify
297
+
298
+ Returns:
299
+ dict with group, confidence, group_name, description, and distortions
300
+ """
301
+ system_prompt = """You are a cognitive distortion classifier for CBT (Cognitive Behavioral Therapy).
302
+
303
+ Analyze the user's thought and classify it into ONE of these groups:
304
+
305
+ G0: No Distortion - Healthy, balanced thinking
306
+ G1: Binary & Global Evaluation - All-or-nothing thinking, Labeling
307
+ G2: Overgeneralized Beliefs - Overgeneralization, Mind Reading, Fortune-telling
308
+ G3: Attentional & Salience Bias - Mental Filter, Magnification/Minimization
309
+ G4: Self-Referential & Emotion-Driven - Emotional Reasoning, Personalization, Should statements
310
+
311
+ Examples:
312
+ - "I failed my exam. I'll never succeed at anything." β†’ G1 (all-or-nothing)
313
+ - "My friend didn't text back. She must hate me." β†’ G2 (mind reading)
314
+ - "Everything in my life is terrible." β†’ G3 (mental filter)
315
+ - "I feel anxious so something bad must be happening." β†’ G4 (emotional reasoning)
316
+ - "I had a nice day today and enjoyed my lunch." β†’ G0 (no distortion)
317
+
318
+ Respond ONLY in JSON format:
319
+ {
320
+ "group": "G0/G1/G2/G3/G4",
321
+ "confidence": 0.85,
322
+ "reasoning": "Brief explanation of why this classification"
323
+ }"""
324
+
325
+ user_prompt = f'Classify this thought: "{text}"'
326
+
327
+ try:
328
+ response = self.client.chat.completions.create(
329
+ model=self.MODEL,
330
+ messages=[
331
+ {"role": "system", "content": system_prompt},
332
+ {"role": "user", "content": user_prompt}
333
+ ],
334
+ temperature=0.3,
335
+ max_tokens=200,
336
+ response_format={"type": "json_object"}
337
+ )
338
+
339
+ result = json.loads(response.choices[0].message.content)
340
+ group = result.get("group", "G0")
341
+
342
+ # Group information
343
+ label_info = {
344
+ "G0": {
345
+ "name": "No Distortion Detected",
346
+ "description": "Healthy, balanced thinking",
347
+ "distortions": []
348
+ },
349
+ "G1": {
350
+ "name": "Binary & Global Evaluation",
351
+ "description": "All-or-nothing thinking patterns",
352
+ "distortions": ["All-or-nothing thinking", "Labeling"]
353
+ },
354
+ "G2": {
355
+ "name": "Overgeneralized Beliefs",
356
+ "description": "Making broad conclusions from limited evidence",
357
+ "distortions": ["Overgeneralization", "Mind Reading", "Fortune-telling"]
358
+ },
359
+ "G3": {
360
+ "name": "Attentional & Salience Bias",
361
+ "description": "Focusing on negatives, ignoring positives",
362
+ "distortions": ["Mental Filter", "Magnification"]
363
+ },
364
+ "G4": {
365
+ "name": "Self-Referential & Emotion-Driven",
366
+ "description": "Letting emotions drive conclusions",
367
+ "distortions": ["Emotional Reasoning", "Personalization", "Should statements"]
368
+ }
369
+ }
370
+
371
+ group_info = label_info.get(group, label_info["G0"])
372
+
373
+ return {
374
+ "group": group,
375
+ "confidence": round(result.get("confidence", 0.8), 4),
376
+ "group_name": group_info["name"],
377
+ "description": group_info["description"],
378
+ "distortions": group_info["distortions"],
379
+ "reasoning": result.get("reasoning", "")
380
+ }
381
+
382
+ except Exception as e:
383
+ print(f"Classification error: {e}")
384
+ # Default to G0 (no distortion) on error
385
+ return {
386
+ "group": "G0",
387
+ "confidence": 0.5,
388
+ "group_name": "No Distortion Detected",
389
+ "description": "Healthy, balanced thinking",
390
+ "distortions": [],
391
+ "reasoning": "Classification failed, defaulting to G0"
392
+ }
393
+
394
  # ==================== SUPPORTIVE RESPONSE (G0 - No Distortion) ====================
395
 
396
  def generate_supportive_response(
requirements.txt CHANGED
@@ -5,8 +5,6 @@ gunicorn==21.2.0
5
 
6
  # AI/ML
7
  numpy<2
8
- transformers==4.36.0
9
- torch==2.1.0
10
  httpx>=0.27.0,<0.28.0
11
  groq>=0.11.0
12
 
 
5
 
6
  # AI/ML
7
  numpy<2
 
 
8
  httpx>=0.27.0,<0.28.0
9
  groq>=0.11.0
10
 
test_beck_protocol.py CHANGED
@@ -5,16 +5,20 @@ Run this to verify the 3-agent system works
5
 
6
  import os
7
  from groq_client import GroqClient
8
- from classifier import DistortionClassifier
9
 
10
  # Test the components
11
  def test_classifier():
12
- """Test that classifier still works for binary detection"""
13
  print("=" * 60)
14
- print("TEST 1: CLASSIFIER (Binary Detection)")
15
  print("=" * 60)
16
 
17
- classifier = DistortionClassifier()
 
 
 
 
 
18
 
19
  test_cases = [
20
  ("I had a nice day today", "G0"),
@@ -23,12 +27,13 @@ def test_classifier():
23
  ]
24
 
25
  for text, expected in test_cases:
26
- result = classifier.classify(text)
27
  print(f"\nText: {text[:50]}...")
28
  print(f" Predicted: {result['group']} ({result['confidence']:.2%})")
 
29
  print(f" Expected: {expected}")
30
 
31
- print("\nβœ… Classifier test complete\n")
32
 
33
 
34
  def test_agent1():
 
5
 
6
  import os
7
  from groq_client import GroqClient
 
8
 
9
  # Test the components
10
  def test_classifier():
11
+ """Test LLM-based cognitive distortion classification"""
12
  print("=" * 60)
13
+ print("TEST 1: LLM CLASSIFIER (Cognitive Distortion Detection)")
14
  print("=" * 60)
15
 
16
+ api_key = os.environ.get("GROQ_API_KEY")
17
+ if not api_key:
18
+ print("❌ GROQ_API_KEY not set - skipping classifier test")
19
+ return
20
+
21
+ client = GroqClient(api_key)
22
 
23
  test_cases = [
24
  ("I had a nice day today", "G0"),
 
27
  ]
28
 
29
  for text, expected in test_cases:
30
+ result = client.classify_distortion(text)
31
  print(f"\nText: {text[:50]}...")
32
  print(f" Predicted: {result['group']} ({result['confidence']:.2%})")
33
+ print(f" Reasoning: {result.get('reasoning', 'N/A')}")
34
  print(f" Expected: {expected}")
35
 
36
+ print("\nβœ… LLM Classifier test complete\n")
37
 
38
 
39
  def test_agent1():