tillu-AI commited on
Commit
7e59074
·
verified ·
1 Parent(s): 1c5638b

upload app/transformers/classifiers.py

Browse files
Files changed (1) hide show
  1. app/transformers/classifiers.py +326 -0
app/transformers/classifiers.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Classification models for intent, emotion, and stress
3
+ Uses Hugging Face Inference API
4
+ """
5
+ import httpx
6
+ from typing import Dict, Any, Optional
7
+ from app.config import settings
8
+ from app.utils.logging import get_logger
9
+
10
+ logger = get_logger("classifiers")
11
+
12
+
13
+ class IntentClassifier:
14
+ """
15
+ Intent Classification
16
+ Model: distilbert-base-uncased (fine-tuned)
17
+ Output: Intent class from 14 categories
18
+ """
19
+
20
+ INTENT_CLASSES = [
21
+ "small_talk",
22
+ "general_query",
23
+ "follow_up",
24
+ "research_request",
25
+ "deep_analysis",
26
+ "action_required",
27
+ "real_world_query",
28
+ "multi_step_task",
29
+ "pattern_query",
30
+ "data_analysis",
31
+ "structured_request",
32
+ "distress",
33
+ "sadness",
34
+ "high_stress",
35
+ ]
36
+
37
+ def __init__(self):
38
+ self.model = settings.hf_classifier_model
39
+ self.api_url = settings.hf_inference_api_url
40
+ self.token = settings.hf_token
41
+ self.logger = get_logger("intent_classifier")
42
+
43
+ async def classify(self, text: str) -> Dict[str, Any]:
44
+ """
45
+ Classify user intent
46
+
47
+ Args:
48
+ text: User input text
49
+
50
+ Returns:
51
+ Dict with intent_class, confidence, all_scores
52
+ """
53
+ if not text or not text.strip():
54
+ return {
55
+ "intent_class": "general_query",
56
+ "confidence": 1.0,
57
+ "all_scores": {}
58
+ }
59
+
60
+ try:
61
+ # Use HF Inference API
62
+ result = await self._classify_hf_api(text)
63
+ if result:
64
+ return result
65
+ except Exception as e:
66
+ self.logger.error(f"Intent classification error: {e}")
67
+
68
+ # Fallback: Simple keyword-based classification
69
+ return self._classify_heuristic(text)
70
+
71
+ async def _classify_hf_api(self, text: str) -> Optional[Dict[str, Any]]:
72
+ """Classify via HF Inference API"""
73
+ if not self.token:
74
+ return None
75
+
76
+ async with httpx.AsyncClient() as client:
77
+ response = await client.post(
78
+ f"{self.api_url}/models/{self.model}",
79
+ headers={"Authorization": f"Bearer {self.token}"},
80
+ json={"inputs": text},
81
+ timeout=10.0
82
+ )
83
+
84
+ if response.status_code == 200:
85
+ data = response.json()
86
+
87
+ # Parse HF zero-shot or classification output
88
+ if isinstance(data, list) and len(data) > 0:
89
+ predictions = data[0]
90
+
91
+ # Get top prediction
92
+ if isinstance(predictions, list):
93
+ top = max(predictions, key=lambda x: x.get('score', 0))
94
+ return {
95
+ "intent_class": top.get('label', 'general_query').lower().replace(' ', '_'),
96
+ "confidence": top.get('score', 0.5),
97
+ "all_scores": {p.get('label', '').lower().replace(' ', '_'): p.get('score', 0) for p in predictions}
98
+ }
99
+
100
+ return None
101
+
102
+ def _classify_heuristic(self, text: str) -> Dict[str, Any]:
103
+ """Fallback heuristic classification"""
104
+ text_lower = text.lower()
105
+
106
+ # Research indicators
107
+ if any(word in text_lower for word in ['research', 'analyze', 'study', 'investigate', 'deep dive']):
108
+ return {"intent_class": "research_request", "confidence": 0.7, "all_scores": {}}
109
+
110
+ # Action indicators
111
+ if any(word in text_lower for word in ['book', 'schedule', 'set up', 'create', 'buy', 'order']):
112
+ return {"intent_class": "action_required", "confidence": 0.7, "all_scores": {}}
113
+
114
+ # Distress indicators
115
+ if any(word in text_lower for word in ['stressed', 'worried', 'anxious', 'overwhelmed', 'help']):
116
+ return {"intent_class": "distress", "confidence": 0.6, "all_scores": {}}
117
+
118
+ # Question indicators
119
+ if '?' in text or any(word in text_lower for word in ['what', 'how', 'why', 'when', 'where']):
120
+ return {"intent_class": "general_query", "confidence": 0.8, "all_scores": {}}
121
+
122
+ # Default
123
+ return {"intent_class": "small_talk", "confidence": 0.6, "all_scores": {}}
124
+
125
+
126
+ class EmotionDetector:
127
+ """
128
+ Emotion Detection
129
+ Model: j-hartmann/emotion-english-distilroberta-base
130
+ Output: {joy, sadness, anger, fear, surprise, disgust, neutral} scores
131
+ """
132
+
133
+ EMOTIONS = ["joy", "sadness", "anger", "fear", "surprise", "disgust", "neutral"]
134
+
135
+ def __init__(self):
136
+ self.model = settings.hf_emotion_model
137
+ self.api_url = settings.hf_inference_api_url
138
+ self.token = settings.hf_token
139
+ self.logger = get_logger("emotion_detector")
140
+
141
+ async def detect(self, text: str) -> Dict[str, Any]:
142
+ """
143
+ Detect emotions in text
144
+
145
+ Args:
146
+ text: Input text
147
+
148
+ Returns:
149
+ Dict with emotion scores and dominant emotion
150
+ """
151
+ if not text or not text.strip():
152
+ return {
153
+ "dominant_emotion": "neutral",
154
+ "scores": {e: 0.0 for e in self.EMOTIONS},
155
+ "emotion_intensity": 0.0
156
+ }
157
+
158
+ try:
159
+ result = await self._detect_hf_api(text)
160
+ if result:
161
+ return result
162
+ except Exception as e:
163
+ self.logger.error(f"Emotion detection error: {e}")
164
+
165
+ # Fallback: neutral
166
+ return {
167
+ "dominant_emotion": "neutral",
168
+ "scores": {e: 0.0 for e in self.EMOTIONS},
169
+ "emotion_intensity": 0.0
170
+ }
171
+
172
+ async def _detect_hf_api(self, text: str) -> Optional[Dict[str, Any]]:
173
+ """Detect emotions via HF Inference API"""
174
+ if not self.token:
175
+ return None
176
+
177
+ async with httpx.AsyncClient() as client:
178
+ response = await client.post(
179
+ f"{self.api_url}/models/{self.model}",
180
+ headers={"Authorization": f"Bearer {self.token}"},
181
+ json={"inputs": text},
182
+ timeout=10.0
183
+ )
184
+
185
+ if response.status_code == 200:
186
+ data = response.json()
187
+
188
+ if isinstance(data, list) and len(data) > 0:
189
+ predictions = data[0]
190
+
191
+ # Build scores dict
192
+ scores = {}
193
+ for pred in predictions:
194
+ label = pred.get('label', '').lower()
195
+ score = pred.get('score', 0.0)
196
+ scores[label] = score
197
+
198
+ # Fill missing emotions with 0
199
+ for emotion in self.EMOTIONS:
200
+ if emotion not in scores:
201
+ scores[emotion] = 0.0
202
+
203
+ # Determine dominant
204
+ dominant = max(scores, key=scores.get)
205
+ intensity = scores[dominant]
206
+
207
+ return {
208
+ "dominant_emotion": dominant,
209
+ "scores": scores,
210
+ "emotion_intensity": intensity
211
+ }
212
+
213
+ return None
214
+
215
+
216
+ class StressDetector:
217
+ """
218
+ Stress/Toxicity Detection
219
+ Model: martin-ha/toxic-comment-model
220
+ Output: Stress/distress probability score
221
+ """
222
+
223
+ def __init__(self):
224
+ self.model = "martin-ha/toxic-comment-model"
225
+ self.api_url = settings.hf_inference_api_url
226
+ self.token = settings.hf_token
227
+ self.logger = get_logger("stress_detector")
228
+
229
+ async def detect(self, text: str) -> Dict[str, Any]:
230
+ """
231
+ Detect stress level
232
+
233
+ Args:
234
+ text: Input text
235
+
236
+ Returns:
237
+ Dict with stress_level, score, is_stressed
238
+ """
239
+ if not text or not text.strip():
240
+ return {
241
+ "stress_level": "low",
242
+ "score": 0.0,
243
+ "is_stressed": False
244
+ }
245
+
246
+ try:
247
+ result = await self._detect_hf_api(text)
248
+ if result:
249
+ return result
250
+ except Exception as e:
251
+ self.logger.error(f"Stress detection error: {e}")
252
+
253
+ # Fallback: Heuristic
254
+ return self._detect_heuristic(text)
255
+
256
+ async def _detect_hf_api(self, text: str) -> Optional[Dict[str, Any]]:
257
+ """Detect stress via HF Inference API"""
258
+ if not self.token:
259
+ return None
260
+
261
+ async with httpx.AsyncClient() as client:
262
+ response = await client.post(
263
+ f"{self.api_url}/models/{self.model}",
264
+ headers={"Authorization": f"Bearer {self.token}"},
265
+ json={"inputs": text},
266
+ timeout=10.0
267
+ )
268
+
269
+ if response.status_code == 200:
270
+ data = response.json()
271
+
272
+ if isinstance(data, list) and len(data) > 0:
273
+ predictions = data[0]
274
+
275
+ # Calculate toxic score
276
+ toxic_score = 0.0
277
+ for pred in predictions:
278
+ if pred.get('label') == 'toxic' or pred.get('label') == 'LABEL_1':
279
+ toxic_score = pred.get('score', 0.0)
280
+
281
+ # Map to stress levels
282
+ if toxic_score > 0.7:
283
+ level = "high"
284
+ elif toxic_score > 0.3:
285
+ level = "medium"
286
+ else:
287
+ level = "low"
288
+
289
+ return {
290
+ "stress_level": level,
291
+ "score": toxic_score,
292
+ "is_stressed": toxic_score > 0.5
293
+ }
294
+
295
+ return None
296
+
297
+ def _detect_heuristic(self, text: str) -> Dict[str, Any]:
298
+ """Heuristic stress detection"""
299
+ text_lower = text.lower()
300
+
301
+ stress_words = [
302
+ 'stressed', 'overwhelmed', 'anxious', 'worried', 'panic',
303
+ 'urgent', 'emergency', 'help', 'desperate', 'exhausted'
304
+ ]
305
+
306
+ count = sum(1 for word in stress_words if word in text_lower)
307
+ intensity = min(count / 3, 1.0) # Cap at 1.0
308
+
309
+ if intensity > 0.6:
310
+ level = "high"
311
+ elif intensity > 0.3:
312
+ level = "medium"
313
+ else:
314
+ level = "low"
315
+
316
+ return {
317
+ "stress_level": level,
318
+ "score": intensity,
319
+ "is_stressed": intensity > 0.5
320
+ }
321
+
322
+
323
+ # Global instances
324
+ intent_classifier = IntentClassifier()
325
+ emotion_detector = EmotionDetector()
326
+ stress_detector = StressDetector()