Spaces:
Paused
Paused
Nada commited on
Commit ·
a28b27c
1
Parent(s): 66ed960
yarab
Browse files- chatbot.py +8 -4
- conversation_flow.py +42 -43
chatbot.py
CHANGED
|
@@ -242,7 +242,8 @@ class MentalHealthChatbot:
|
|
| 242 |
self.summary_model = pipeline(
|
| 243 |
"summarization",
|
| 244 |
model="philschmid/bart-large-cnn-samsum",
|
| 245 |
-
device=0 if self.device == "cuda" else -1
|
|
|
|
| 246 |
)
|
| 247 |
logger.info("Summary model loaded successfully")
|
| 248 |
|
|
@@ -295,7 +296,10 @@ Response:"""
|
|
| 295 |
# Setup embeddings for vector search
|
| 296 |
self.embeddings = HuggingFaceEmbeddings(
|
| 297 |
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 298 |
-
model_kwargs={
|
|
|
|
|
|
|
|
|
|
| 299 |
)
|
| 300 |
|
| 301 |
# Setup vector database for retrieving relevant past conversations
|
|
@@ -321,7 +325,7 @@ Response:"""
|
|
| 321 |
model="SamLowe/roberta-base-go_emotions",
|
| 322 |
top_k=None,
|
| 323 |
device_map="auto" if torch.cuda.is_available() else None,
|
| 324 |
-
|
| 325 |
local_files_only=False # Ensure we download from Hugging Face
|
| 326 |
)
|
| 327 |
except Exception as e:
|
|
@@ -333,7 +337,7 @@ Response:"""
|
|
| 333 |
model="j-hartmann/emotion-english-distilroberta-base",
|
| 334 |
return_all_scores=True,
|
| 335 |
device_map="auto" if torch.cuda.is_available() else None,
|
| 336 |
-
|
| 337 |
local_files_only=False # Ensure we download from Hugging Face
|
| 338 |
)
|
| 339 |
except Exception as e:
|
|
|
|
| 242 |
self.summary_model = pipeline(
|
| 243 |
"summarization",
|
| 244 |
model="philschmid/bart-large-cnn-samsum",
|
| 245 |
+
device=0 if self.device == "cuda" else -1,
|
| 246 |
+
model_kwargs={"cache_dir": CACHE_DIR}
|
| 247 |
)
|
| 248 |
logger.info("Summary model loaded successfully")
|
| 249 |
|
|
|
|
| 296 |
# Setup embeddings for vector search
|
| 297 |
self.embeddings = HuggingFaceEmbeddings(
|
| 298 |
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 299 |
+
model_kwargs={
|
| 300 |
+
"device": self.device,
|
| 301 |
+
"cache_dir": CACHE_DIR
|
| 302 |
+
}
|
| 303 |
)
|
| 304 |
|
| 305 |
# Setup vector database for retrieving relevant past conversations
|
|
|
|
| 325 |
model="SamLowe/roberta-base-go_emotions",
|
| 326 |
top_k=None,
|
| 327 |
device_map="auto" if torch.cuda.is_available() else None,
|
| 328 |
+
model_kwargs={"cache_dir": CACHE_DIR},
|
| 329 |
local_files_only=False # Ensure we download from Hugging Face
|
| 330 |
)
|
| 331 |
except Exception as e:
|
|
|
|
| 337 |
model="j-hartmann/emotion-english-distilroberta-base",
|
| 338 |
return_all_scores=True,
|
| 339 |
device_map="auto" if torch.cuda.is_available() else None,
|
| 340 |
+
model_kwargs={"cache_dir": CACHE_DIR},
|
| 341 |
local_files_only=False # Ensure we download from Hugging Face
|
| 342 |
)
|
| 343 |
except Exception as e:
|
conversation_flow.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import logging
|
| 2 |
import json
|
|
|
|
| 3 |
import time
|
| 4 |
from datetime import datetime
|
| 5 |
from typing import List, Dict, Any, Optional
|
|
@@ -8,6 +9,21 @@ from pydantic import BaseModel, Field
|
|
| 8 |
# Configure logging
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
class ConversationPhase(BaseModel):
|
| 12 |
name: str
|
| 13 |
description: str
|
|
@@ -15,7 +31,7 @@ class ConversationPhase(BaseModel):
|
|
| 15 |
typical_duration: int # in minutes
|
| 16 |
started_at: Optional[str] = None # ISO timestamp
|
| 17 |
ended_at: Optional[str] = None # ISO timestamp
|
| 18 |
-
completion_metrics: Dict[str, float] =
|
| 19 |
|
| 20 |
class FlowManager:
|
| 21 |
|
|
@@ -230,27 +246,22 @@ class FlowManager:
|
|
| 230 |
|
| 231 |
response = self.llm.invoke(prompt)
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
if
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
if next_phase_name in self.PHASES:
|
| 250 |
-
self._transition_to_phase(user_id, next_phase_name, evaluation.get('reasoning', ''))
|
| 251 |
-
except json.JSONDecodeError:
|
| 252 |
-
self._check_time_based_transition(user_id)
|
| 253 |
-
else:
|
| 254 |
self._check_time_based_transition(user_id)
|
| 255 |
|
| 256 |
def _check_time_based_transition(self, user_id: str):
|
|
@@ -366,27 +377,15 @@ class FlowManager:
|
|
| 366 |
|
| 367 |
response = self.llm.invoke(prompt)
|
| 368 |
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
'cognitive_pattern', 'coping_mechanisms', 'progress_quality',
|
| 379 |
-
'recommended_focus'
|
| 380 |
-
]
|
| 381 |
-
if all(field in characteristics for field in required_fields):
|
| 382 |
-
session['llm_context']['session_characteristics'] = characteristics
|
| 383 |
-
logger.info(f"Updated session characteristics for user {user_id}")
|
| 384 |
-
else:
|
| 385 |
-
logger.warning("Missing required fields in session characteristics")
|
| 386 |
-
except json.JSONDecodeError:
|
| 387 |
-
logger.warning("Failed to parse session characteristics from LLM")
|
| 388 |
-
else:
|
| 389 |
-
logger.warning("No JSON object found in LLM response")
|
| 390 |
|
| 391 |
def _create_flow_context(self, user_id: str) -> Dict[str, Any]:
|
| 392 |
|
|
|
|
| 1 |
import logging
|
| 2 |
import json
|
| 3 |
+
import json5
|
| 4 |
import time
|
| 5 |
from datetime import datetime
|
| 6 |
from typing import List, Dict, Any, Optional
|
|
|
|
| 9 |
# Configure logging
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
+
class PhaseTransitionResponse(BaseModel):
|
| 13 |
+
goals_progress: Dict[str, float]
|
| 14 |
+
should_transition: bool
|
| 15 |
+
next_phase: str
|
| 16 |
+
reasoning: str
|
| 17 |
+
|
| 18 |
+
class SessionCharacteristics(BaseModel):
|
| 19 |
+
alliance_strength: float = Field(ge=0.0, le=1.0)
|
| 20 |
+
engagement_level: float = Field(ge=0.0, le=1.0)
|
| 21 |
+
emotional_pattern: str
|
| 22 |
+
cognitive_pattern: str
|
| 23 |
+
coping_mechanisms: List[str] = Field(min_items=2)
|
| 24 |
+
progress_quality: float = Field(ge=0.0, le=1.0)
|
| 25 |
+
recommended_focus: str
|
| 26 |
+
|
| 27 |
class ConversationPhase(BaseModel):
|
| 28 |
name: str
|
| 29 |
description: str
|
|
|
|
| 31 |
typical_duration: int # in minutes
|
| 32 |
started_at: Optional[str] = None # ISO timestamp
|
| 33 |
ended_at: Optional[str] = None # ISO timestamp
|
| 34 |
+
completion_metrics: Dict[str, float] = Field(default_factory=dict) # e.g., {'goal_progress': 0.8}
|
| 35 |
|
| 36 |
class FlowManager:
|
| 37 |
|
|
|
|
| 246 |
|
| 247 |
response = self.llm.invoke(prompt)
|
| 248 |
|
| 249 |
+
try:
|
| 250 |
+
# Parse with json5 for more tolerant parsing
|
| 251 |
+
evaluation = json5.loads(response)
|
| 252 |
+
# Validate with Pydantic
|
| 253 |
+
phase_transition = PhaseTransitionResponse.parse_obj(evaluation)
|
| 254 |
+
|
| 255 |
+
# Update goal progress metrics
|
| 256 |
+
for goal, score in phase_transition.goals_progress.items():
|
| 257 |
+
if goal in current_phase.goals:
|
| 258 |
+
current_phase.completion_metrics[goal] = score
|
| 259 |
+
|
| 260 |
+
# Check if we should transition
|
| 261 |
+
if phase_transition.should_transition:
|
| 262 |
+
if phase_transition.next_phase in self.PHASES:
|
| 263 |
+
self._transition_to_phase(user_id, phase_transition.next_phase, phase_transition.reasoning)
|
| 264 |
+
except (json5.Json5DecodeError, ValueError):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
self._check_time_based_transition(user_id)
|
| 266 |
|
| 267 |
def _check_time_based_transition(self, user_id: str):
|
|
|
|
| 377 |
|
| 378 |
response = self.llm.invoke(prompt)
|
| 379 |
|
| 380 |
+
try:
|
| 381 |
+
# Parse with json5 for more tolerant parsing
|
| 382 |
+
characteristics = json5.loads(response)
|
| 383 |
+
# Validate with Pydantic
|
| 384 |
+
session_chars = SessionCharacteristics.parse_obj(characteristics)
|
| 385 |
+
session['llm_context']['session_characteristics'] = session_chars.dict()
|
| 386 |
+
logger.info(f"Updated session characteristics for user {user_id}")
|
| 387 |
+
except (json5.Json5DecodeError, ValueError) as e:
|
| 388 |
+
logger.warning(f"Failed to parse session characteristics: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
def _create_flow_context(self, user_id: str) -> Dict[str, Any]:
|
| 391 |
|