|
|
import os |
|
|
import logging |
|
|
import json |
|
|
import torch |
|
|
import re |
|
|
from typing import List, Dict, Any, Optional, Union |
|
|
from datetime import datetime |
|
|
from pydantic import BaseModel, Field |
|
|
import tempfile |
|
|
|
|
|
|
|
|
from transformers import ( |
|
|
pipeline, |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
BitsAndBytesConfig |
|
|
) |
|
|
from peft import PeftModel, PeftConfig |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
from langchain.llms import HuggingFacePipeline |
|
|
from langchain.chains import LLMChain |
|
|
from langchain.memory import ConversationBufferMemory |
|
|
from langchain.prompts import PromptTemplate |
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.document_loaders import TextLoader |
|
|
from langchain.vectorstores import FAISS |
|
|
|
|
|
|
|
|
from conversation_flow import FlowManager |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[logging.StreamHandler()] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
import warnings |
|
|
warnings.filterwarnings('ignore', category=UserWarning) |
|
|
|
|
|
|
|
|
def setup_cache_dirs(): |
|
|
|
|
|
is_spaces = os.environ.get('SPACE_ID') is not None |
|
|
|
|
|
if is_spaces: |
|
|
|
|
|
cache_dir = '/tmp/huggingface' |
|
|
os.environ.update({ |
|
|
'TRANSFORMERS_CACHE': cache_dir, |
|
|
'HF_HOME': cache_dir, |
|
|
'TOKENIZERS_PARALLELISM': 'false', |
|
|
'TRANSFORMERS_VERBOSITY': 'error', |
|
|
'BITSANDBYTES_NOWELCOME': '1', |
|
|
'HF_DATASETS_CACHE': cache_dir, |
|
|
'HF_METRICS_CACHE': cache_dir, |
|
|
'HF_MODULES_CACHE': cache_dir, |
|
|
'HUGGING_FACE_HUB_TOKEN': os.environ.get('HF_TOKEN', ''), |
|
|
'HF_TOKEN': os.environ.get('HF_TOKEN', '') |
|
|
}) |
|
|
else: |
|
|
|
|
|
cache_dir = os.path.expanduser('~/.cache/huggingface') |
|
|
os.environ.update({ |
|
|
'TOKENIZERS_PARALLELISM': 'false', |
|
|
'TRANSFORMERS_VERBOSITY': 'error', |
|
|
'BITSANDBYTES_NOWELCOME': '1' |
|
|
}) |
|
|
|
|
|
|
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
return cache_dir |
|
|
|
|
|
|
|
|
CACHE_DIR = setup_cache_dirs() |
|
|
|
|
|
|
|
|
BASE_DIR = os.path.abspath(os.path.dirname(__file__)) |
|
|
MODELS_DIR = os.path.join(BASE_DIR, "models") |
|
|
VECTOR_DB_PATH = os.path.join(BASE_DIR, "vector_db") |
|
|
SESSION_DATA_PATH = os.path.join(BASE_DIR, "session_data") |
|
|
SUMMARIES_DIR = os.path.join(BASE_DIR, "session_summaries") |
|
|
|
|
|
|
|
|
for directory in [MODELS_DIR, VECTOR_DB_PATH, SESSION_DATA_PATH, SUMMARIES_DIR]: |
|
|
os.makedirs(directory, exist_ok=True) |
|
|
|
|
|
|
|
|
class Message(BaseModel): |
|
|
text: str = Field(..., description="The content of the message") |
|
|
timestamp: str = Field(None, description="ISO format timestamp of the message") |
|
|
role: str = Field("user", description="The role of the message sender (user or assistant)") |
|
|
|
|
|
class SessionSummary(BaseModel): |
|
|
session_id: str = Field( |
|
|
..., |
|
|
description="Unique identifier for the session", |
|
|
examples=["user_789_session_20240314"] |
|
|
) |
|
|
user_id: str = Field( |
|
|
..., |
|
|
description="Identifier of the user", |
|
|
examples=["user_123"]) |
|
|
start_time: str = Field(..., description="ISO format start time of the session" |
|
|
) |
|
|
end_time: str = Field( |
|
|
..., |
|
|
description="ISO format end time of the session" |
|
|
) |
|
|
message_count: int = Field( |
|
|
..., |
|
|
description="Total number of messages in the session" |
|
|
) |
|
|
duration_minutes: float = Field( |
|
|
..., |
|
|
description="Duration of the session in minutes" |
|
|
) |
|
|
primary_emotions: List[str] = Field( |
|
|
..., |
|
|
min_items=1, |
|
|
description="List of primary emotions detected", |
|
|
examples=[ |
|
|
["anxiety", "stress"], |
|
|
["joy", "excitement"], |
|
|
["sadness", "loneliness"] |
|
|
] |
|
|
) |
|
|
emotion_progression: List[Dict[str, float]] = Field( |
|
|
..., |
|
|
description="Progression of emotions throughout the session", |
|
|
examples=[ |
|
|
[ |
|
|
{"anxiety": 0.8, "stress": 0.6}, |
|
|
{"calm": 0.7, "anxiety": 0.3}, |
|
|
{"joy": 0.9, "calm": 0.8} |
|
|
] |
|
|
] |
|
|
) |
|
|
summary_text: str = Field( |
|
|
..., |
|
|
description="Text summary of the session", |
|
|
examples=[ |
|
|
"The session focused on managing work-related stress and developing coping strategies. The client showed improvement in recognizing stress triggers and implementing relaxation techniques.", |
|
|
"Discussion centered around relationship challenges and self-esteem issues. The client expressed willingness to try new communication strategies." |
|
|
] |
|
|
) |
|
|
recommendations: Optional[List[str]] = Field( |
|
|
None, |
|
|
description="Optional recommendations based on the session" |
|
|
) |
|
|
|
|
|
class Conversation(BaseModel): |
|
|
user_id: str = Field( |
|
|
..., |
|
|
description="Identifier of the user", |
|
|
examples=["user_123"] |
|
|
) |
|
|
session_id: str = Field( |
|
|
"", |
|
|
description="Identifier of the current session" |
|
|
) |
|
|
start_time: str = Field( |
|
|
"", |
|
|
description="ISO format start time of the conversation" |
|
|
) |
|
|
messages: List[Message] = Field( |
|
|
[], |
|
|
description="List of messages in the conversation", |
|
|
examples=[ |
|
|
[ |
|
|
Message(text="I'm feeling anxious", role="user"), |
|
|
Message(text="I understand you're feeling anxious. Can you tell me more about what's causing this?", role="assistant") |
|
|
] |
|
|
] |
|
|
) |
|
|
emotion_history: List[Dict[str, float]] = Field( |
|
|
[], |
|
|
description="History of emotions detected", |
|
|
examples=[ |
|
|
[ |
|
|
{"anxiety": 0.8, "stress": 0.6}, |
|
|
{"calm": 0.7, "anxiety": 0.3} |
|
|
] |
|
|
] |
|
|
) |
|
|
context: Dict[str, Any] = Field( |
|
|
{}, |
|
|
description="Additional context for the conversation", |
|
|
examples=[ |
|
|
{ |
|
|
"last_emotion": "anxiety", |
|
|
"conversation_topic": "work stress", |
|
|
"previous_sessions": 3 |
|
|
} |
|
|
] |
|
|
) |
|
|
is_active: bool = Field( |
|
|
True, |
|
|
description="Whether the conversation is currently active", |
|
|
examples=[True, False] |
|
|
) |
|
|
|
|
|
class MentalHealthChatbot: |
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "meta-llama/Llama-3.2-3B-Instruct", |
|
|
peft_model_path: str = "nada013/mental-health-chatbot", |
|
|
therapy_guidelines_path: str = None, |
|
|
use_4bit: bool = True, |
|
|
device: str = None, |
|
|
max_response_length: int = 500, |
|
|
max_response_words: int = 100, |
|
|
min_response_words: int = 10, |
|
|
max_consecutive_responses: int = 3 |
|
|
): |
|
|
|
|
|
if device is None: |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
else: |
|
|
self.device = device |
|
|
|
|
|
|
|
|
if self.device == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.batch_size = 4 |
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" |
|
|
else: |
|
|
self.batch_size = 8 |
|
|
|
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
self.max_response_length = max_response_length |
|
|
self.max_response_words = max_response_words |
|
|
self.min_response_words = min_response_words |
|
|
self.max_consecutive_responses = max_consecutive_responses |
|
|
self.consecutive_response_count = 0 |
|
|
|
|
|
|
|
|
self.peft_model_path = peft_model_path |
|
|
|
|
|
|
|
|
logger.info("Loading emotion detection model") |
|
|
self.emotion_classifier = self._load_emotion_model() |
|
|
|
|
|
|
|
|
logger.info(f"Loading LLAMA model: {model_name}") |
|
|
self.llama_model, self.llama_tokenizer, self.llm = self._initialize_llm(model_name, use_4bit) |
|
|
|
|
|
|
|
|
logger.info("Loading summary model") |
|
|
self.summary_model = pipeline( |
|
|
"summarization", |
|
|
model="philschmid/bart-large-cnn-samsum", |
|
|
device=0 if self.device == "cuda" else -1, |
|
|
model_kwargs={ |
|
|
"cache_dir": CACHE_DIR, |
|
|
"torch_dtype": torch.float16, |
|
|
"max_memory": {0: "2GB"} if self.device == "cuda" else None |
|
|
} |
|
|
) |
|
|
logger.info("Summary model loaded successfully") |
|
|
|
|
|
|
|
|
logger.info("Initializing FlowManager") |
|
|
self.flow_manager = FlowManager(self.llm) |
|
|
|
|
|
|
|
|
self.memory = ConversationBufferMemory( |
|
|
return_messages=True, |
|
|
input_key="input" |
|
|
) |
|
|
|
|
|
|
|
|
self.prompt_template = PromptTemplate( |
|
|
input_variables=["history", "input", "past_context", "emotion_context", "guidelines"], |
|
|
template="""You are a supportive and empathetic mental health conversational AI. Your role is to provide therapeutic support while maintaining professional boundaries. |
|
|
|
|
|
Previous conversation: |
|
|
{history} |
|
|
|
|
|
EMOTIONAL CONTEXT: |
|
|
{emotion_context} |
|
|
|
|
|
Past context: {past_context} |
|
|
|
|
|
Relevant therapeutic guidelines: |
|
|
{guidelines} |
|
|
|
|
|
Current message: {input} |
|
|
|
|
|
Provide a supportive response that: |
|
|
1. Validates the user's feelings without using casual greetings |
|
|
2. Asks relevant follow-up questions |
|
|
3. Maintains a conversational tone , professional and empathetic tone |
|
|
4. Focuses on understanding and support |
|
|
5. Avoids repeating previous responses |
|
|
|
|
|
Response:""" |
|
|
) |
|
|
|
|
|
|
|
|
self.conversation = LLMChain( |
|
|
llm=self.llm, |
|
|
prompt=self.prompt_template, |
|
|
memory=self.memory, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
|
|
|
self.embeddings = HuggingFaceEmbeddings( |
|
|
model_name="sentence-transformers/all-MiniLM-L6-v2" |
|
|
) |
|
|
|
|
|
|
|
|
if therapy_guidelines_path and os.path.exists(therapy_guidelines_path): |
|
|
self.setup_vector_db(therapy_guidelines_path) |
|
|
else: |
|
|
self.setup_vector_db(None) |
|
|
|
|
|
|
|
|
self.conversations = {} |
|
|
|
|
|
|
|
|
self.session_summaries = {} |
|
|
self._load_existing_summaries() |
|
|
|
|
|
logger.info("All models and components initialized successfully") |
|
|
|
|
|
def _load_emotion_model(self): |
|
|
try: |
|
|
|
|
|
return pipeline( |
|
|
"text-classification", |
|
|
model="SamLowe/roberta-base-go_emotions", |
|
|
top_k=None, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
model_kwargs={ |
|
|
"cache_dir": CACHE_DIR, |
|
|
"torch_dtype": torch.float16, |
|
|
"max_memory": {0: "2GB"} if torch.cuda.is_available() else None |
|
|
}, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading emotion model: {e}") |
|
|
|
|
|
try: |
|
|
return pipeline( |
|
|
"text-classification", |
|
|
model="j-hartmann/emotion-english-distilroberta-base", |
|
|
return_all_scores=True, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
model_kwargs={ |
|
|
"cache_dir": CACHE_DIR, |
|
|
"torch_dtype": torch.float16, |
|
|
"max_memory": {0: "2GB"} if torch.cuda.is_available() else None |
|
|
}, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading fallback emotion model: {e}") |
|
|
|
|
|
return lambda text: [{"label": "neutral", "score": 1.0}] |
|
|
|
|
|
def _initialize_llm(self, model_name: str, use_4bit: bool): |
|
|
try: |
|
|
|
|
|
if use_4bit and torch.cuda.is_available(): |
|
|
quantization_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.float16, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_use_double_quant=True, |
|
|
) |
|
|
|
|
|
|
|
|
max_memory = {0: "14GB"} |
|
|
else: |
|
|
quantization_config = None |
|
|
max_memory = None |
|
|
logger.info("CUDA not available, running without quantization") |
|
|
|
|
|
|
|
|
logger.info(f"Loading base model: {model_name}") |
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
quantization_config=quantization_config, |
|
|
device_map="auto", |
|
|
max_memory=max_memory, |
|
|
trust_remote_code=True, |
|
|
cache_dir=CACHE_DIR, |
|
|
use_auth_token=os.environ.get('HF_TOKEN'), |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Loading tokenizer") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name, |
|
|
cache_dir=CACHE_DIR, |
|
|
use_auth_token=os.environ.get('HF_TOKEN') |
|
|
) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
logger.info(f"Loading PEFT model from {self.peft_model_path}") |
|
|
model = PeftModel.from_pretrained( |
|
|
base_model, |
|
|
self.peft_model_path, |
|
|
cache_dir=CACHE_DIR, |
|
|
use_auth_token=os.environ.get('HF_TOKEN') |
|
|
) |
|
|
logger.info("Successfully loaded PEFT model") |
|
|
|
|
|
|
|
|
text_generator = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
max_new_tokens=512, |
|
|
temperature=0.7, |
|
|
top_p=0.95, |
|
|
repetition_penalty=1.1, |
|
|
do_sample=True, |
|
|
device_map="auto" if torch.cuda.is_available() else None |
|
|
) |
|
|
|
|
|
|
|
|
llm = HuggingFacePipeline(pipeline=text_generator) |
|
|
|
|
|
return model, tokenizer, llm |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error initializing LLM: {str(e)}") |
|
|
raise |
|
|
|
|
|
def setup_vector_db(self, guidelines_path: str = None): |
|
|
|
|
|
logger.info("Setting up FAISS vector database") |
|
|
|
|
|
|
|
|
vector_db_exists = os.path.exists(os.path.join(VECTOR_DB_PATH, "index.faiss")) |
|
|
|
|
|
if not vector_db_exists: |
|
|
|
|
|
if guidelines_path and os.path.exists(guidelines_path): |
|
|
loader = TextLoader(guidelines_path) |
|
|
documents = loader.load() |
|
|
|
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=500, |
|
|
chunk_overlap=100, |
|
|
separators=["\n\n", "\n", " ", ""] |
|
|
) |
|
|
chunks = text_splitter.split_documents(documents) |
|
|
|
|
|
|
|
|
self.vector_db = FAISS.from_documents(chunks, self.embeddings) |
|
|
self.vector_db.save_local(VECTOR_DB_PATH) |
|
|
logger.info("Successfully loaded and indexed therapy guidelines") |
|
|
else: |
|
|
|
|
|
self.vector_db = FAISS.from_texts(["Initial empty vector store"], self.embeddings) |
|
|
self.vector_db.save_local(VECTOR_DB_PATH) |
|
|
logger.warning("No guidelines file provided, using empty vector store") |
|
|
else: |
|
|
|
|
|
self.vector_db = FAISS.load_local(VECTOR_DB_PATH, self.embeddings, allow_dangerous_deserialization=True) |
|
|
logger.info("Loaded existing vector database") |
|
|
|
|
|
def _load_existing_summaries(self): |
|
|
if not os.path.exists(SUMMARIES_DIR): |
|
|
return |
|
|
|
|
|
for filename in os.listdir(SUMMARIES_DIR): |
|
|
if filename.endswith('.json'): |
|
|
try: |
|
|
with open(os.path.join(SUMMARIES_DIR, filename), 'r') as f: |
|
|
summary_data = json.load(f) |
|
|
session_id = summary_data.get('session_id') |
|
|
if session_id: |
|
|
self.session_summaries[session_id] = summary_data |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load summary from {filename}: {e}") |
|
|
|
|
|
def detect_emotion(self, text: str) -> Dict[str, float]: |
|
|
try: |
|
|
results = self.emotion_classifier(text)[0] |
|
|
return {result['label']: result['score'] for result in results} |
|
|
except Exception as e: |
|
|
logger.error(f"Error detecting emotions: {e}") |
|
|
return {"neutral": 1.0} |
|
|
|
|
|
def _validate_and_limit_response(self, response: str, user_message: str) -> str: |
|
|
""" |
|
|
Validate and limit response length and content. |
|
|
Returns a properly limited response. |
|
|
""" |
|
|
if not response or not response.strip(): |
|
|
return "I understand. Could you tell me more about that?" |
|
|
|
|
|
|
|
|
response = response.strip() |
|
|
|
|
|
|
|
|
response = re.sub(r"(Your response|This response|Response:|Note:).*", "", response, flags=re.IGNORECASE).strip() |
|
|
response = re.sub(r"---.*", "", response).strip() |
|
|
|
|
|
|
|
|
response = re.sub(r'^(Hey|Hi|Hello|Hi there|Hey there),\s*', '', response) |
|
|
|
|
|
|
|
|
words = response.split() |
|
|
word_count = len(words) |
|
|
char_count = len(response) |
|
|
|
|
|
|
|
|
if word_count < self.min_response_words: |
|
|
logger.info(f"Response too short ({word_count} words), adding follow-up question") |
|
|
if not response.endswith('?'): |
|
|
response += " Could you tell me more about that?" |
|
|
|
|
|
|
|
|
if char_count > self.max_response_length or word_count > self.max_response_words: |
|
|
logger.info(f"Response too long ({char_count} chars, {word_count} words), truncating") |
|
|
|
|
|
|
|
|
if word_count > self.max_response_words: |
|
|
|
|
|
truncated_words = words[:self.max_response_words] |
|
|
response = ' '.join(truncated_words) |
|
|
|
|
|
|
|
|
last_period = response.rfind('.') |
|
|
last_question = response.rfind('?') |
|
|
last_exclamation = response.rfind('!') |
|
|
|
|
|
end_point = max(last_period, last_question, last_exclamation) |
|
|
if end_point > len(response) * 0.7: |
|
|
response = response[:end_point + 1] |
|
|
else: |
|
|
|
|
|
response = response.rstrip() + "..." |
|
|
|
|
|
elif char_count > self.max_response_length: |
|
|
|
|
|
response = response[:self.max_response_length] |
|
|
|
|
|
|
|
|
last_space = response.rfind(' ') |
|
|
if last_space > len(response) * 0.8: |
|
|
response = response[:last_space] |
|
|
else: |
|
|
|
|
|
response = response.rstrip() + "..." |
|
|
|
|
|
|
|
|
if self._is_repetitive(response, user_message): |
|
|
logger.info("Response detected as repetitive, generating alternative") |
|
|
return "I hear what you're saying. Could you help me understand this better?" |
|
|
|
|
|
|
|
|
if not response.endswith(('.', '!', '?')): |
|
|
response = response.rstrip() + '.' |
|
|
|
|
|
return response.strip() |
|
|
|
|
|
def _is_repetitive(self, response: str, user_message: str) -> bool: |
|
|
""" |
|
|
Check if response is repetitive or too similar to user message. |
|
|
""" |
|
|
|
|
|
response_lower = response.lower() |
|
|
user_lower = user_message.lower() |
|
|
|
|
|
|
|
|
user_words = set(user_lower.split()) |
|
|
response_words = set(response_lower.split()) |
|
|
|
|
|
if len(user_words) > 3: |
|
|
common_words = user_words.intersection(response_words) |
|
|
if len(common_words) / len(user_words) > 0.6: |
|
|
return True |
|
|
|
|
|
|
|
|
repetitive_phrases = [ |
|
|
"i understand", "i hear you", "that sounds", "i can see", |
|
|
"thank you for sharing", "i appreciate", "that must be" |
|
|
] |
|
|
|
|
|
phrase_count = sum(1 for phrase in repetitive_phrases if phrase in response_lower) |
|
|
if phrase_count > 2: |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def retrieve_relevant_context(self, query: str, k: int = 3) -> str: |
|
|
|
|
|
if not hasattr(self, 'vector_db'): |
|
|
return "" |
|
|
|
|
|
try: |
|
|
|
|
|
docs = self.vector_db.similarity_search(query, k=k) |
|
|
|
|
|
|
|
|
relevant_context = "\n".join([doc.page_content for doc in docs]) |
|
|
return relevant_context |
|
|
except Exception as e: |
|
|
logger.error(f"Error retrieving context: {e}") |
|
|
return "" |
|
|
|
|
|
def retrieve_relevant_guidelines(self, query: str, emotion_context: str) -> str: |
|
|
if not hasattr(self, 'vector_db'): |
|
|
return "" |
|
|
|
|
|
try: |
|
|
|
|
|
search_query = f"{query} {emotion_context}" |
|
|
|
|
|
|
|
|
docs = self.vector_db.similarity_search(search_query, k=2) |
|
|
|
|
|
|
|
|
relevant_guidelines = "\n".join([doc.page_content for doc in docs]) |
|
|
return relevant_guidelines |
|
|
except Exception as e: |
|
|
logger.error(f"Error retrieving guidelines: {e}") |
|
|
return "" |
|
|
|
|
|
def generate_response(self, prompt: str, emotion_data: Dict[str, float], conversation_history: List[Dict]) -> str: |
|
|
|
|
|
|
|
|
sorted_emotions = sorted(emotion_data.items(), key=lambda x: x[1], reverse=True) |
|
|
primary_emotion = sorted_emotions[0][0] if sorted_emotions else "neutral" |
|
|
|
|
|
|
|
|
secondary_emotions = [] |
|
|
for emotion, score in sorted_emotions[1:3]: |
|
|
if score > 0.2: |
|
|
secondary_emotions.append(emotion) |
|
|
|
|
|
|
|
|
emotion_context = f"User is primarily feeling {primary_emotion}" |
|
|
if secondary_emotions: |
|
|
emotion_context += f" with elements of {' and '.join(secondary_emotions)}" |
|
|
emotion_context += "." |
|
|
|
|
|
|
|
|
guidelines = self.retrieve_relevant_guidelines(prompt, emotion_context) |
|
|
|
|
|
|
|
|
past_context = self.retrieve_relevant_context(prompt) |
|
|
|
|
|
|
|
|
response = self.conversation.predict( |
|
|
input=prompt, |
|
|
past_context=past_context, |
|
|
emotion_context=emotion_context, |
|
|
guidelines=guidelines |
|
|
) |
|
|
|
|
|
|
|
|
response = self._validate_and_limit_response(response, prompt) |
|
|
|
|
|
|
|
|
if len(conversation_history) > 0: |
|
|
last_responses = [msg["text"] for msg in conversation_history[-4:] if msg["role"] == "assistant"] |
|
|
if response in last_responses: |
|
|
logger.info("Response detected as duplicate, generating alternative") |
|
|
|
|
|
alternative_response = self.conversation.predict( |
|
|
input=f"{prompt} (Please provide a different perspective)", |
|
|
past_context=past_context, |
|
|
emotion_context=emotion_context, |
|
|
guidelines=guidelines |
|
|
) |
|
|
alternative_response = self._validate_and_limit_response(alternative_response, prompt) |
|
|
response = alternative_response |
|
|
|
|
|
return response |
|
|
|
|
|
def generate_session_summary( |
|
|
self, |
|
|
flow_manager_session: Dict = None |
|
|
) -> Dict: |
|
|
|
|
|
if not flow_manager_session: |
|
|
return { |
|
|
"session_id": "", |
|
|
"user_id": "", |
|
|
"start_time": "", |
|
|
"end_time": datetime.now().isoformat(), |
|
|
"duration_minutes": 0, |
|
|
"current_phase": "unknown", |
|
|
"primary_emotions": [], |
|
|
"emotion_progression": [], |
|
|
"summary": "Error: No session data provided", |
|
|
"recommendations": ["Unable to generate recommendations"], |
|
|
"session_characteristics": {} |
|
|
} |
|
|
|
|
|
|
|
|
session_id = flow_manager_session.get('session_id', '') |
|
|
user_id = flow_manager_session.get('user_id', '') |
|
|
current_phase = flow_manager_session.get('current_phase') |
|
|
|
|
|
if current_phase: |
|
|
|
|
|
current_phase = { |
|
|
'name': current_phase.name, |
|
|
'description': current_phase.description, |
|
|
'goals': current_phase.goals, |
|
|
'started_at': current_phase.started_at, |
|
|
'ended_at': current_phase.ended_at, |
|
|
'completion_metrics': current_phase.completion_metrics |
|
|
} |
|
|
|
|
|
session_start = flow_manager_session.get('started_at') |
|
|
if isinstance(session_start, str): |
|
|
session_start = datetime.fromisoformat(session_start) |
|
|
session_duration = (datetime.now() - session_start).total_seconds() / 60 if session_start else 0 |
|
|
|
|
|
|
|
|
emotion_progression = flow_manager_session.get('emotion_progression', []) |
|
|
emotion_history = flow_manager_session.get('emotion_history', []) |
|
|
|
|
|
|
|
|
primary_emotions = [] |
|
|
if emotion_history: |
|
|
|
|
|
emotion_counts = {} |
|
|
for entry in emotion_history: |
|
|
emotions = entry.get('emotions', {}) |
|
|
if isinstance(emotions, dict): |
|
|
primary = max(emotions.items(), key=lambda x: x[1])[0] |
|
|
emotion_counts[primary] = emotion_counts.get(primary, 0) + 1 |
|
|
|
|
|
|
|
|
primary_emotions = sorted(emotion_counts.items(), key=lambda x: x[1], reverse=True)[:3] |
|
|
primary_emotions = [emotion for emotion, _ in primary_emotions] |
|
|
|
|
|
|
|
|
session_characteristics = flow_manager_session.get('llm_context', {}).get('session_characteristics', {}) |
|
|
|
|
|
|
|
|
summary_text = f""" |
|
|
Session Overview: |
|
|
- Session ID: {session_id} |
|
|
- User ID: {user_id} |
|
|
- Phase: {current_phase.get('name', 'unknown') if current_phase else 'unknown'} |
|
|
- Duration: {session_duration:.1f} minutes |
|
|
|
|
|
Emotional Analysis: |
|
|
- Primary Emotions: {', '.join(primary_emotions) if primary_emotions else 'No primary emotions detected'} |
|
|
- Emotion Progression: {', '.join(emotion_progression) if emotion_progression else 'No significant emotion changes noted'} |
|
|
|
|
|
Session Characteristics: |
|
|
- Therapeutic Alliance: {session_characteristics.get('alliance_strength', 'N/A')} |
|
|
- Engagement Level: {session_characteristics.get('engagement_level', 'N/A')} |
|
|
- Emotional Pattern: {session_characteristics.get('emotional_pattern', 'N/A')} |
|
|
- Cognitive Pattern: {session_characteristics.get('cognitive_pattern', 'N/A')} |
|
|
|
|
|
Key Observations: |
|
|
- The session focused on {current_phase.get('description', 'general discussion') if current_phase else 'general discussion'} |
|
|
- Main emotional themes: {', '.join(primary_emotions) if primary_emotions else 'not identified'} |
|
|
- Session progress: {session_characteristics.get('progress_quality', 'N/A')} |
|
|
""" |
|
|
|
|
|
|
|
|
summary = self.summary_model( |
|
|
summary_text, |
|
|
max_length=150, |
|
|
min_length=50, |
|
|
do_sample=False |
|
|
)[0]['summary_text'] |
|
|
|
|
|
|
|
|
recommendations_prompt = f""" |
|
|
Based on the following session summary, provide 2-3 specific recommendations for follow-up: |
|
|
|
|
|
{summary} |
|
|
|
|
|
Session Characteristics: |
|
|
- Therapeutic Alliance: {session_characteristics.get('alliance_strength', 'N/A')} |
|
|
- Engagement Level: {session_characteristics.get('engagement_level', 'N/A')} |
|
|
- Emotional Pattern: {session_characteristics.get('emotional_pattern', 'N/A')} |
|
|
- Cognitive Pattern: {session_characteristics.get('cognitive_pattern', 'N/A')} |
|
|
|
|
|
Recommendations should be: |
|
|
1. Actionable and specific |
|
|
2. Based on the session content |
|
|
3. Focused on next steps |
|
|
""" |
|
|
|
|
|
recommendations = self.llm.invoke(recommendations_prompt) |
|
|
|
|
|
|
|
|
recommendations = recommendations.split('\n') |
|
|
recommendations = [r.strip() for r in recommendations if r.strip()] |
|
|
recommendations = [r for r in recommendations if not r.startswith(('Based on', 'Session', 'Recommendations'))] |
|
|
|
|
|
|
|
|
return { |
|
|
"session_id": session_id, |
|
|
"user_id": user_id, |
|
|
"start_time": session_start.isoformat() if isinstance(session_start, datetime) else str(session_start), |
|
|
"end_time": datetime.now().isoformat(), |
|
|
"duration_minutes": session_duration, |
|
|
"current_phase": current_phase.get('name', 'unknown') if current_phase else 'unknown', |
|
|
"primary_emotions": primary_emotions, |
|
|
"emotion_progression": emotion_progression, |
|
|
"summary": summary, |
|
|
"recommendations": recommendations, |
|
|
"session_characteristics": session_characteristics |
|
|
} |
|
|
|
|
|
def start_session(self, user_id: str) -> tuple[str, str]: |
|
|
|
|
|
session_id = f"{user_id}_{datetime.now().strftime('%Y%m%d%H%M%S')}" |
|
|
|
|
|
|
|
|
self.flow_manager.initialize_session(user_id) |
|
|
|
|
|
|
|
|
self.conversations[user_id] = Conversation( |
|
|
user_id=user_id, |
|
|
session_id=session_id, |
|
|
start_time=datetime.now().isoformat(), |
|
|
is_active=True |
|
|
) |
|
|
|
|
|
|
|
|
self.memory.clear() |
|
|
|
|
|
|
|
|
initial_message = """Hello! I'm here to support you today. How have you been feeling lately?""" |
|
|
|
|
|
|
|
|
assistant_message = Message( |
|
|
text=initial_message, |
|
|
timestamp=datetime.now().isoformat(), |
|
|
role="assistant" |
|
|
) |
|
|
self.conversations[user_id].messages.append(assistant_message) |
|
|
|
|
|
logger.info(f"Session started for user {user_id}") |
|
|
return session_id, initial_message |
|
|
|
|
|
def end_session( |
|
|
self, |
|
|
user_id: str, |
|
|
flow_manager: Optional[Any] = None |
|
|
) -> Optional[Dict]: |
|
|
|
|
|
if user_id not in self.conversations or not self.conversations[user_id].is_active: |
|
|
return None |
|
|
|
|
|
conversation = self.conversations[user_id] |
|
|
conversation.is_active = False |
|
|
|
|
|
|
|
|
flow_manager_session = self.flow_manager.user_sessions.get(user_id) |
|
|
|
|
|
|
|
|
try: |
|
|
session_summary = self.generate_session_summary(flow_manager_session) |
|
|
|
|
|
|
|
|
summary_path = os.path.join(SUMMARIES_DIR, f"{session_summary['session_id']}.json") |
|
|
with open(summary_path, 'w') as f: |
|
|
json.dump(session_summary, f, indent=2) |
|
|
|
|
|
|
|
|
self.session_summaries[session_summary['session_id']] = session_summary |
|
|
|
|
|
|
|
|
self.memory.clear() |
|
|
|
|
|
return session_summary |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to generate session summary: {e}") |
|
|
return None |
|
|
|
|
|
def process_message(self, user_id: str, message: str) -> str: |
|
|
|
|
|
|
|
|
risk_keywords = ["suicide", "kill myself", "end my life", "self-harm", "hurt myself"] |
|
|
risk_detected = any(keyword in message.lower() for keyword in risk_keywords) |
|
|
|
|
|
|
|
|
if user_id not in self.conversations or not self.conversations[user_id].is_active: |
|
|
self.start_session(user_id) |
|
|
|
|
|
conversation = self.conversations[user_id] |
|
|
|
|
|
|
|
|
new_message = Message( |
|
|
text=message, |
|
|
timestamp=datetime.now().isoformat(), |
|
|
role="user" |
|
|
) |
|
|
conversation.messages.append(new_message) |
|
|
|
|
|
|
|
|
if risk_detected: |
|
|
logger.warning(f"Risk flag detected in session {user_id}") |
|
|
|
|
|
crisis_response = """ I'm really sorry you're feeling this way — it sounds incredibly heavy, and I want you to know that you're not alone. |
|
|
|
|
|
You don't have to face this by yourself. Our app has licensed mental health professionals who are ready to support you. I can connect you right now if you'd like. |
|
|
|
|
|
In the meantime, I'm here to listen and talk with you. You can also do grounding exercises or calming techniques with me if you prefer. Just say "help me calm down" or "I need a break." |
|
|
|
|
|
Would you like to connect with a professional now, or would you prefer to keep talking with me for a bit? Either way, I'm here for you.""" |
|
|
|
|
|
|
|
|
assistant_message = Message( |
|
|
text=crisis_response, |
|
|
timestamp=datetime.now().isoformat(), |
|
|
role="assistant" |
|
|
) |
|
|
conversation.messages.append(assistant_message) |
|
|
|
|
|
return crisis_response |
|
|
|
|
|
|
|
|
self.consecutive_response_count = 0 |
|
|
|
|
|
|
|
|
emotions = self.detect_emotion(message) |
|
|
conversation.emotion_history.append(emotions) |
|
|
|
|
|
|
|
|
flow_context = self.flow_manager.process_message(user_id, message, emotions) |
|
|
|
|
|
|
|
|
conversation_history = [] |
|
|
for msg in conversation.messages: |
|
|
conversation_history.append({ |
|
|
"text": msg.text, |
|
|
"timestamp": msg.timestamp, |
|
|
"role": msg.role |
|
|
}) |
|
|
|
|
|
|
|
|
if self.consecutive_response_count >= self.max_consecutive_responses: |
|
|
logger.warning(f"Rate limit reached for user {user_id}, sending brief response") |
|
|
response_text = "I'm here to listen. Take your time to share what's on your mind." |
|
|
self.consecutive_response_count = 0 |
|
|
else: |
|
|
|
|
|
response_text = self.generate_response(message, emotions, conversation_history) |
|
|
|
|
|
|
|
|
self.consecutive_response_count += 1 |
|
|
|
|
|
|
|
|
if (len(response_text.split()) < self.min_response_words and |
|
|
not response_text.endswith('?') and |
|
|
self.consecutive_response_count < self.max_consecutive_responses): |
|
|
|
|
|
follow_up_prompt = f""" |
|
|
Recent conversation: |
|
|
{chr(10).join([f"{msg['role']}: {msg['text']}" for msg in conversation_history[-3:]])} |
|
|
|
|
|
Now, write a single empathetic and open-ended question to encourage the user to share more. |
|
|
Respond with just the question, no explanation. |
|
|
""" |
|
|
follow_up = self.llm.invoke(follow_up_prompt).strip() |
|
|
|
|
|
matches = re.findall(r'([^\n.?!]*\?)', follow_up) |
|
|
if matches: |
|
|
question = matches[0].strip() |
|
|
else: |
|
|
question = follow_up.strip().split('\n')[0] |
|
|
|
|
|
|
|
|
question = self._validate_and_limit_response(question, message) |
|
|
|
|
|
|
|
|
if len(response_text.split()) < 5: |
|
|
response_text = question |
|
|
else: |
|
|
response_text = f"{response_text}\n\n{question}" |
|
|
|
|
|
|
|
|
assistant_message = Message( |
|
|
text=response_text, |
|
|
timestamp=datetime.now().isoformat(), |
|
|
role="assistant" |
|
|
) |
|
|
conversation.messages.append(assistant_message) |
|
|
|
|
|
|
|
|
conversation.context.update({ |
|
|
"last_emotion": emotions, |
|
|
"last_interaction": datetime.now().isoformat(), |
|
|
"flow_context": flow_context |
|
|
}) |
|
|
|
|
|
|
|
|
current_interaction = f"User: {message}\nChatbot: {response_text}" |
|
|
self.vector_db.add_texts([current_interaction]) |
|
|
self.vector_db.save_local(VECTOR_DB_PATH) |
|
|
|
|
|
return response_text |
|
|
|
|
|
def get_session_summary(self, session_id: str) -> Optional[Dict[str, Any]]: |
|
|
|
|
|
return self.session_summaries.get(session_id) |
|
|
|
|
|
def get_user_replies(self, user_id: str) -> List[Dict[str, Any]]: |
|
|
if user_id not in self.conversations: |
|
|
return [] |
|
|
|
|
|
conversation = self.conversations[user_id] |
|
|
user_replies = [] |
|
|
|
|
|
for message in conversation.messages: |
|
|
if message.role == "user": |
|
|
user_replies.append({ |
|
|
"text": message.text, |
|
|
"timestamp": message.timestamp, |
|
|
"session_id": conversation.session_id |
|
|
}) |
|
|
|
|
|
return user_replies |
|
|
|
|
|
if __name__ == "__main__": |
|
|
pass |