Nada commited on
Commit
a28b27c
·
1 Parent(s): 66ed960
Files changed (2) hide show
  1. chatbot.py +8 -4
  2. conversation_flow.py +42 -43
chatbot.py CHANGED
@@ -242,7 +242,8 @@ class MentalHealthChatbot:
242
  self.summary_model = pipeline(
243
  "summarization",
244
  model="philschmid/bart-large-cnn-samsum",
245
- device=0 if self.device == "cuda" else -1
 
246
  )
247
  logger.info("Summary model loaded successfully")
248
 
@@ -295,7 +296,10 @@ Response:"""
295
  # Setup embeddings for vector search
296
  self.embeddings = HuggingFaceEmbeddings(
297
  model_name="sentence-transformers/all-MiniLM-L6-v2",
298
- model_kwargs={"device": self.device}
 
 
 
299
  )
300
 
301
  # Setup vector database for retrieving relevant past conversations
@@ -321,7 +325,7 @@ Response:"""
321
  model="SamLowe/roberta-base-go_emotions",
322
  top_k=None,
323
  device_map="auto" if torch.cuda.is_available() else None,
324
- cache_dir=CACHE_DIR,
325
  local_files_only=False # Ensure we download from Hugging Face
326
  )
327
  except Exception as e:
@@ -333,7 +337,7 @@ Response:"""
333
  model="j-hartmann/emotion-english-distilroberta-base",
334
  return_all_scores=True,
335
  device_map="auto" if torch.cuda.is_available() else None,
336
- cache_dir=CACHE_DIR,
337
  local_files_only=False # Ensure we download from Hugging Face
338
  )
339
  except Exception as e:
 
242
  self.summary_model = pipeline(
243
  "summarization",
244
  model="philschmid/bart-large-cnn-samsum",
245
+ device=0 if self.device == "cuda" else -1,
246
+ model_kwargs={"cache_dir": CACHE_DIR}
247
  )
248
  logger.info("Summary model loaded successfully")
249
 
 
296
  # Setup embeddings for vector search
297
  self.embeddings = HuggingFaceEmbeddings(
298
  model_name="sentence-transformers/all-MiniLM-L6-v2",
299
+ model_kwargs={
300
+ "device": self.device,
301
+ "cache_dir": CACHE_DIR
302
+ }
303
  )
304
 
305
  # Setup vector database for retrieving relevant past conversations
 
325
  model="SamLowe/roberta-base-go_emotions",
326
  top_k=None,
327
  device_map="auto" if torch.cuda.is_available() else None,
328
+ model_kwargs={"cache_dir": CACHE_DIR},
329
  local_files_only=False # Ensure we download from Hugging Face
330
  )
331
  except Exception as e:
 
337
  model="j-hartmann/emotion-english-distilroberta-base",
338
  return_all_scores=True,
339
  device_map="auto" if torch.cuda.is_available() else None,
340
+ model_kwargs={"cache_dir": CACHE_DIR},
341
  local_files_only=False # Ensure we download from Hugging Face
342
  )
343
  except Exception as e:
conversation_flow.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
  import json
 
3
  import time
4
  from datetime import datetime
5
  from typing import List, Dict, Any, Optional
@@ -8,6 +9,21 @@ from pydantic import BaseModel, Field
8
  # Configure logging
9
  logger = logging.getLogger(__name__)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class ConversationPhase(BaseModel):
12
  name: str
13
  description: str
@@ -15,7 +31,7 @@ class ConversationPhase(BaseModel):
15
  typical_duration: int # in minutes
16
  started_at: Optional[str] = None # ISO timestamp
17
  ended_at: Optional[str] = None # ISO timestamp
18
- completion_metrics: Dict[str, float] = {} # e.g., {'goal_progress': 0.8}
19
 
20
  class FlowManager:
21
 
@@ -230,27 +246,22 @@ class FlowManager:
230
 
231
  response = self.llm.invoke(prompt)
232
 
233
- # Extract JSON from response
234
- import re
235
- json_match = re.search(r'\{.*\}', response, re.DOTALL)
236
- if json_match:
237
- try:
238
- evaluation = json.loads(json_match.group(0))
239
-
240
- # Update goal progress metrics
241
- if 'goals_progress' in evaluation:
242
- for goal, score in evaluation['goals_progress'].items():
243
- if goal in current_phase.goals:
244
- current_phase.completion_metrics[goal] = score
245
-
246
- # Check if we should transition
247
- if evaluation.get('should_transition', False):
248
- next_phase_name = evaluation.get('next_phase')
249
- if next_phase_name in self.PHASES:
250
- self._transition_to_phase(user_id, next_phase_name, evaluation.get('reasoning', ''))
251
- except json.JSONDecodeError:
252
- self._check_time_based_transition(user_id)
253
- else:
254
  self._check_time_based_transition(user_id)
255
 
256
  def _check_time_based_transition(self, user_id: str):
@@ -366,27 +377,15 @@ class FlowManager:
366
 
367
  response = self.llm.invoke(prompt)
368
 
369
- # Extract JSON from response
370
- import re
371
- json_match = re.search(r'\{.*\}', response, re.DOTALL)
372
- if json_match:
373
- try:
374
- characteristics = json.loads(json_match.group(0))
375
- # Validate required fields
376
- required_fields = [
377
- 'alliance_strength', 'engagement_level', 'emotional_pattern',
378
- 'cognitive_pattern', 'coping_mechanisms', 'progress_quality',
379
- 'recommended_focus'
380
- ]
381
- if all(field in characteristics for field in required_fields):
382
- session['llm_context']['session_characteristics'] = characteristics
383
- logger.info(f"Updated session characteristics for user {user_id}")
384
- else:
385
- logger.warning("Missing required fields in session characteristics")
386
- except json.JSONDecodeError:
387
- logger.warning("Failed to parse session characteristics from LLM")
388
- else:
389
- logger.warning("No JSON object found in LLM response")
390
 
391
  def _create_flow_context(self, user_id: str) -> Dict[str, Any]:
392
 
 
1
  import logging
2
  import json
3
+ import json5
4
  import time
5
  from datetime import datetime
6
  from typing import List, Dict, Any, Optional
 
9
  # Configure logging
10
  logger = logging.getLogger(__name__)
11
 
12
+ class PhaseTransitionResponse(BaseModel):
13
+ goals_progress: Dict[str, float]
14
+ should_transition: bool
15
+ next_phase: str
16
+ reasoning: str
17
+
18
+ class SessionCharacteristics(BaseModel):
19
+ alliance_strength: float = Field(ge=0.0, le=1.0)
20
+ engagement_level: float = Field(ge=0.0, le=1.0)
21
+ emotional_pattern: str
22
+ cognitive_pattern: str
23
+ coping_mechanisms: List[str] = Field(min_items=2)
24
+ progress_quality: float = Field(ge=0.0, le=1.0)
25
+ recommended_focus: str
26
+
27
  class ConversationPhase(BaseModel):
28
  name: str
29
  description: str
 
31
  typical_duration: int # in minutes
32
  started_at: Optional[str] = None # ISO timestamp
33
  ended_at: Optional[str] = None # ISO timestamp
34
+ completion_metrics: Dict[str, float] = Field(default_factory=dict) # e.g., {'goal_progress': 0.8}
35
 
36
  class FlowManager:
37
 
 
246
 
247
  response = self.llm.invoke(prompt)
248
 
249
+ try:
250
+ # Parse with json5 for more tolerant parsing
251
+ evaluation = json5.loads(response)
252
+ # Validate with Pydantic
253
+ phase_transition = PhaseTransitionResponse.parse_obj(evaluation)
254
+
255
+ # Update goal progress metrics
256
+ for goal, score in phase_transition.goals_progress.items():
257
+ if goal in current_phase.goals:
258
+ current_phase.completion_metrics[goal] = score
259
+
260
+ # Check if we should transition
261
+ if phase_transition.should_transition:
262
+ if phase_transition.next_phase in self.PHASES:
263
+ self._transition_to_phase(user_id, phase_transition.next_phase, phase_transition.reasoning)
264
+ except (json5.Json5DecodeError, ValueError):
 
 
 
 
 
265
  self._check_time_based_transition(user_id)
266
 
267
  def _check_time_based_transition(self, user_id: str):
 
377
 
378
  response = self.llm.invoke(prompt)
379
 
380
+ try:
381
+ # Parse with json5 for more tolerant parsing
382
+ characteristics = json5.loads(response)
383
+ # Validate with Pydantic
384
+ session_chars = SessionCharacteristics.parse_obj(characteristics)
385
+ session['llm_context']['session_characteristics'] = session_chars.dict()
386
+ logger.info(f"Updated session characteristics for user {user_id}")
387
+ except (json5.Json5DecodeError, ValueError) as e:
388
+ logger.warning(f"Failed to parse session characteristics: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
  def _create_flow_context(self, user_id: str) -> Dict[str, Any]:
391