Nada commited on
Commit ·
5eea25b
1
Parent(s): 910249b
yy
Browse files- chatbot.py +167 -58
chatbot.py
CHANGED
|
@@ -19,15 +19,14 @@ from peft import PeftModel, PeftConfig
|
|
| 19 |
from sentence_transformers import SentenceTransformer
|
| 20 |
|
| 21 |
# LangChain imports
|
| 22 |
-
|
| 23 |
-
from langchain.
|
| 24 |
-
from langchain.
|
| 25 |
-
from langchain.
|
| 26 |
-
from langchain.
|
| 27 |
-
from langchain.
|
| 28 |
-
from langchain.
|
| 29 |
-
from langchain.
|
| 30 |
-
from langchain.vectorstores import FAISS # Vector database for similarity search
|
| 31 |
|
| 32 |
# Import FlowManager
|
| 33 |
from conversation_flow import FlowManager
|
|
@@ -214,7 +213,11 @@ class MentalHealthChatbot:
|
|
| 214 |
peft_model_path: str = "nada013/mental-health-chatbot",
|
| 215 |
therapy_guidelines_path: str = None,
|
| 216 |
use_4bit: bool = True,
|
| 217 |
-
device: str = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
):
|
| 219 |
# Set device (cuda if available, otherwise cpu)
|
| 220 |
if device is None:
|
|
@@ -234,6 +237,13 @@ class MentalHealthChatbot:
|
|
| 234 |
|
| 235 |
logger.info(f"Using device: {self.device}")
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
# Initialize models
|
| 238 |
self.peft_model_path = peft_model_path
|
| 239 |
|
|
@@ -264,24 +274,12 @@ class MentalHealthChatbot:
|
|
| 264 |
self.flow_manager = FlowManager(self.llm)
|
| 265 |
|
| 266 |
# Setup conversation memory with LangChain
|
| 267 |
-
# ConversationBufferMemory stores the conversation history in a buffer
|
| 268 |
-
# This allows the chatbot to maintain context across multiple interactions
|
| 269 |
-
# - return_messages=True: Returns messages as a list of message objects
|
| 270 |
-
# - input_key="input": Specifies which key to use for the input in the memory
|
| 271 |
self.memory = ConversationBufferMemory(
|
| 272 |
return_messages=True,
|
| 273 |
input_key="input"
|
| 274 |
)
|
| 275 |
|
| 276 |
# Create conversation prompt template
|
| 277 |
-
# PromptTemplate defines the structure for generating responses
|
| 278 |
-
# It includes placeholders for dynamic content that gets filled during generation
|
| 279 |
-
# Input variables:
|
| 280 |
-
# - history: Previous conversation context from memory
|
| 281 |
-
# - input: Current user message
|
| 282 |
-
# - past_context: Relevant past conversations from vector search
|
| 283 |
-
# - emotion_context: Detected emotions and their context
|
| 284 |
-
# - guidelines: Relevant therapeutic guidelines from vector search
|
| 285 |
self.prompt_template = PromptTemplate(
|
| 286 |
input_variables=["history", "input", "past_context", "emotion_context", "guidelines"],
|
| 287 |
template="""You are a supportive and empathetic mental health conversational AI. Your role is to provide therapeutic support while maintaining professional boundaries.
|
|
@@ -323,7 +321,6 @@ Response:"""
|
|
| 323 |
)
|
| 324 |
|
| 325 |
# Setup vector database for retrieving relevant past conversations
|
| 326 |
-
|
| 327 |
if therapy_guidelines_path and os.path.exists(therapy_guidelines_path):
|
| 328 |
self.setup_vector_db(therapy_guidelines_path)
|
| 329 |
else:
|
|
@@ -502,6 +499,109 @@ Response:"""
|
|
| 502 |
logger.error(f"Error detecting emotions: {e}")
|
| 503 |
return {"neutral": 1.0}
|
| 504 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
def retrieve_relevant_context(self, query: str, k: int = 3) -> str:
|
| 506 |
# Retrieve relevant past conversations using vector similarity
|
| 507 |
if not hasattr(self, 'vector_db'):
|
|
@@ -568,31 +668,25 @@ Response:"""
|
|
| 568 |
guidelines=guidelines
|
| 569 |
)
|
| 570 |
|
| 571 |
-
#
|
| 572 |
-
response =
|
| 573 |
-
response = response.split("---")[0].strip()
|
| 574 |
-
response = response.split("Note:")[0].strip()
|
| 575 |
-
|
| 576 |
-
# Remove any casual greetings like "Hey" or "Hi"
|
| 577 |
-
response = re.sub(r'^(Hey|Hi|Hello|Hi there|Hey there),\s*', '', response)
|
| 578 |
|
| 579 |
# Ensure the response is unique and not repeating previous messages
|
| 580 |
if len(conversation_history) > 0:
|
| 581 |
last_responses = [msg["text"] for msg in conversation_history[-4:] if msg["role"] == "assistant"]
|
| 582 |
if response in last_responses:
|
|
|
|
| 583 |
# Generate a new response with a different angle
|
| 584 |
-
|
| 585 |
input=f"{prompt} (Please provide a different perspective)",
|
| 586 |
past_context=past_context,
|
| 587 |
emotion_context=emotion_context,
|
| 588 |
guidelines=guidelines
|
| 589 |
)
|
| 590 |
-
|
| 591 |
-
response =
|
| 592 |
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
return response.strip()
|
| 596 |
|
| 597 |
def generate_session_summary(
|
| 598 |
self,
|
|
@@ -838,6 +932,9 @@ Would you like to connect with a professional now, or would you prefer to keep t
|
|
| 838 |
|
| 839 |
return crisis_response
|
| 840 |
|
|
|
|
|
|
|
|
|
|
| 841 |
# Detect emotions
|
| 842 |
emotions = self.detect_emotion(message)
|
| 843 |
conversation.emotion_history.append(emotions)
|
|
@@ -854,34 +951,46 @@ Would you like to connect with a professional now, or would you prefer to keep t
|
|
| 854 |
"role": msg.role
|
| 855 |
})
|
| 856 |
|
| 857 |
-
#
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 863 |
Recent conversation:
|
| 864 |
{chr(10).join([f"{msg['role']}: {msg['text']}" for msg in conversation_history[-3:]])}
|
| 865 |
|
| 866 |
Now, write a single empathetic and open-ended question to encourage the user to share more.
|
| 867 |
Respond with just the question, no explanation.
|
| 868 |
"""
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
| 885 |
|
| 886 |
# assistant response -> conversation history
|
| 887 |
assistant_message = Message(
|
|
|
|
| 19 |
from sentence_transformers import SentenceTransformer
|
| 20 |
|
| 21 |
# LangChain imports
|
| 22 |
+
from langchain.llms import HuggingFacePipeline
|
| 23 |
+
from langchain.chains import LLMChain
|
| 24 |
+
from langchain.memory import ConversationBufferMemory
|
| 25 |
+
from langchain.prompts import PromptTemplate
|
| 26 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 27 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 28 |
+
from langchain.document_loaders import TextLoader
|
| 29 |
+
from langchain.vectorstores import FAISS
|
|
|
|
| 30 |
|
| 31 |
# Import FlowManager
|
| 32 |
from conversation_flow import FlowManager
|
|
|
|
| 213 |
peft_model_path: str = "nada013/mental-health-chatbot",
|
| 214 |
therapy_guidelines_path: str = None,
|
| 215 |
use_4bit: bool = True,
|
| 216 |
+
device: str = None,
|
| 217 |
+
max_response_length: int = 500, # Maximum characters in response
|
| 218 |
+
max_response_words: int = 100, # Maximum words in response
|
| 219 |
+
min_response_words: int = 10, # Minimum words in response
|
| 220 |
+
max_consecutive_responses: int = 3 # Max consecutive responses without user input
|
| 221 |
):
|
| 222 |
# Set device (cuda if available, otherwise cpu)
|
| 223 |
if device is None:
|
|
|
|
| 237 |
|
| 238 |
logger.info(f"Using device: {self.device}")
|
| 239 |
|
| 240 |
+
# Set response limits
|
| 241 |
+
self.max_response_length = max_response_length
|
| 242 |
+
self.max_response_words = max_response_words
|
| 243 |
+
self.min_response_words = min_response_words
|
| 244 |
+
self.max_consecutive_responses = max_consecutive_responses
|
| 245 |
+
self.consecutive_response_count = 0 # Track consecutive responses
|
| 246 |
+
|
| 247 |
# Initialize models
|
| 248 |
self.peft_model_path = peft_model_path
|
| 249 |
|
|
|
|
| 274 |
self.flow_manager = FlowManager(self.llm)
|
| 275 |
|
| 276 |
# Setup conversation memory with LangChain
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
self.memory = ConversationBufferMemory(
|
| 278 |
return_messages=True,
|
| 279 |
input_key="input"
|
| 280 |
)
|
| 281 |
|
| 282 |
# Create conversation prompt template
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
self.prompt_template = PromptTemplate(
|
| 284 |
input_variables=["history", "input", "past_context", "emotion_context", "guidelines"],
|
| 285 |
template="""You are a supportive and empathetic mental health conversational AI. Your role is to provide therapeutic support while maintaining professional boundaries.
|
|
|
|
| 321 |
)
|
| 322 |
|
| 323 |
# Setup vector database for retrieving relevant past conversations
|
|
|
|
| 324 |
if therapy_guidelines_path and os.path.exists(therapy_guidelines_path):
|
| 325 |
self.setup_vector_db(therapy_guidelines_path)
|
| 326 |
else:
|
|
|
|
| 499 |
logger.error(f"Error detecting emotions: {e}")
|
| 500 |
return {"neutral": 1.0}
|
| 501 |
|
| 502 |
+
def _validate_and_limit_response(self, response: str, user_message: str) -> str:
|
| 503 |
+
"""
|
| 504 |
+
Validate and limit response length and content.
|
| 505 |
+
Returns a properly limited response.
|
| 506 |
+
"""
|
| 507 |
+
if not response or not response.strip():
|
| 508 |
+
return "I understand. Could you tell me more about that?"
|
| 509 |
+
|
| 510 |
+
# Clean the response
|
| 511 |
+
response = response.strip()
|
| 512 |
+
|
| 513 |
+
# Remove any LLM commentary or instructions
|
| 514 |
+
response = re.sub(r"(Your response|This response|Response:|Note:).*", "", response, flags=re.IGNORECASE).strip()
|
| 515 |
+
response = re.sub(r"---.*", "", response).strip()
|
| 516 |
+
|
| 517 |
+
# Remove casual greetings
|
| 518 |
+
response = re.sub(r'^(Hey|Hi|Hello|Hi there|Hey there),\s*', '', response)
|
| 519 |
+
|
| 520 |
+
# Count words and characters
|
| 521 |
+
words = response.split()
|
| 522 |
+
word_count = len(words)
|
| 523 |
+
char_count = len(response)
|
| 524 |
+
|
| 525 |
+
# Check if response is too short
|
| 526 |
+
if word_count < self.min_response_words:
|
| 527 |
+
logger.info(f"Response too short ({word_count} words), adding follow-up question")
|
| 528 |
+
if not response.endswith('?'):
|
| 529 |
+
response += " Could you tell me more about that?"
|
| 530 |
+
|
| 531 |
+
# Check if response is too long
|
| 532 |
+
if char_count > self.max_response_length or word_count > self.max_response_words:
|
| 533 |
+
logger.info(f"Response too long ({char_count} chars, {word_count} words), truncating")
|
| 534 |
+
|
| 535 |
+
# Try to find a good breaking point
|
| 536 |
+
if word_count > self.max_response_words:
|
| 537 |
+
# Truncate to max words
|
| 538 |
+
truncated_words = words[:self.max_response_words]
|
| 539 |
+
response = ' '.join(truncated_words)
|
| 540 |
+
|
| 541 |
+
# Try to end at a sentence
|
| 542 |
+
last_period = response.rfind('.')
|
| 543 |
+
last_question = response.rfind('?')
|
| 544 |
+
last_exclamation = response.rfind('!')
|
| 545 |
+
|
| 546 |
+
end_point = max(last_period, last_question, last_exclamation)
|
| 547 |
+
if end_point > len(response) * 0.7: # If we can end at a sentence within 70% of the limit
|
| 548 |
+
response = response[:end_point + 1]
|
| 549 |
+
else:
|
| 550 |
+
# Add ellipsis if we can't end naturally
|
| 551 |
+
response = response.rstrip() + "..."
|
| 552 |
+
|
| 553 |
+
elif char_count > self.max_response_length:
|
| 554 |
+
# Truncate to max characters
|
| 555 |
+
response = response[:self.max_response_length]
|
| 556 |
+
|
| 557 |
+
# Try to end at a word boundary
|
| 558 |
+
last_space = response.rfind(' ')
|
| 559 |
+
if last_space > len(response) * 0.8: # If we can end at a word within 80% of the limit
|
| 560 |
+
response = response[:last_space]
|
| 561 |
+
else:
|
| 562 |
+
# Add ellipsis
|
| 563 |
+
response = response.rstrip() + "..."
|
| 564 |
+
|
| 565 |
+
# Check for repetitive content
|
| 566 |
+
if self._is_repetitive(response, user_message):
|
| 567 |
+
logger.info("Response detected as repetitive, generating alternative")
|
| 568 |
+
return "I hear what you're saying. Could you help me understand this better?"
|
| 569 |
+
|
| 570 |
+
# Ensure response ends properly
|
| 571 |
+
if not response.endswith(('.', '!', '?')):
|
| 572 |
+
response = response.rstrip() + '.'
|
| 573 |
+
|
| 574 |
+
return response.strip()
|
| 575 |
+
|
| 576 |
+
def _is_repetitive(self, response: str, user_message: str) -> bool:
|
| 577 |
+
"""
|
| 578 |
+
Check if response is repetitive or too similar to user message.
|
| 579 |
+
"""
|
| 580 |
+
# Convert to lowercase for comparison
|
| 581 |
+
response_lower = response.lower()
|
| 582 |
+
user_lower = user_message.lower()
|
| 583 |
+
|
| 584 |
+
# Check if response contains too much of the user's message
|
| 585 |
+
user_words = set(user_lower.split())
|
| 586 |
+
response_words = set(response_lower.split())
|
| 587 |
+
|
| 588 |
+
if len(user_words) > 3: # Only check if user message has enough words
|
| 589 |
+
common_words = user_words.intersection(response_words)
|
| 590 |
+
if len(common_words) / len(user_words) > 0.6: # If more than 60% of user words are in response
|
| 591 |
+
return True
|
| 592 |
+
|
| 593 |
+
# Check for repetitive phrases
|
| 594 |
+
repetitive_phrases = [
|
| 595 |
+
"i understand", "i hear you", "that sounds", "i can see",
|
| 596 |
+
"thank you for sharing", "i appreciate", "that must be"
|
| 597 |
+
]
|
| 598 |
+
|
| 599 |
+
phrase_count = sum(1 for phrase in repetitive_phrases if phrase in response_lower)
|
| 600 |
+
if phrase_count > 2: # If more than 2 repetitive phrases
|
| 601 |
+
return True
|
| 602 |
+
|
| 603 |
+
return False
|
| 604 |
+
|
| 605 |
def retrieve_relevant_context(self, query: str, k: int = 3) -> str:
|
| 606 |
# Retrieve relevant past conversations using vector similarity
|
| 607 |
if not hasattr(self, 'vector_db'):
|
|
|
|
| 668 |
guidelines=guidelines
|
| 669 |
)
|
| 670 |
|
| 671 |
+
# Validate and limit the response
|
| 672 |
+
response = self._validate_and_limit_response(response, prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 673 |
|
| 674 |
# Ensure the response is unique and not repeating previous messages
|
| 675 |
if len(conversation_history) > 0:
|
| 676 |
last_responses = [msg["text"] for msg in conversation_history[-4:] if msg["role"] == "assistant"]
|
| 677 |
if response in last_responses:
|
| 678 |
+
logger.info("Response detected as duplicate, generating alternative")
|
| 679 |
# Generate a new response with a different angle
|
| 680 |
+
alternative_response = self.conversation.predict(
|
| 681 |
input=f"{prompt} (Please provide a different perspective)",
|
| 682 |
past_context=past_context,
|
| 683 |
emotion_context=emotion_context,
|
| 684 |
guidelines=guidelines
|
| 685 |
)
|
| 686 |
+
alternative_response = self._validate_and_limit_response(alternative_response, prompt)
|
| 687 |
+
response = alternative_response
|
| 688 |
|
| 689 |
+
return response
|
|
|
|
|
|
|
| 690 |
|
| 691 |
def generate_session_summary(
|
| 692 |
self,
|
|
|
|
| 932 |
|
| 933 |
return crisis_response
|
| 934 |
|
| 935 |
+
# Reset consecutive response counter when user sends a message
|
| 936 |
+
self.consecutive_response_count = 0
|
| 937 |
+
|
| 938 |
# Detect emotions
|
| 939 |
emotions = self.detect_emotion(message)
|
| 940 |
conversation.emotion_history.append(emotions)
|
|
|
|
| 951 |
"role": msg.role
|
| 952 |
})
|
| 953 |
|
| 954 |
+
# Check rate limiting for consecutive responses
|
| 955 |
+
if self.consecutive_response_count >= self.max_consecutive_responses:
|
| 956 |
+
logger.warning(f"Rate limit reached for user {user_id}, sending brief response")
|
| 957 |
+
response_text = "I'm here to listen. Take your time to share what's on your mind."
|
| 958 |
+
self.consecutive_response_count = 0 # Reset counter
|
| 959 |
+
else:
|
| 960 |
+
# Generate response
|
| 961 |
+
response_text = self.generate_response(message, emotions, conversation_history)
|
| 962 |
+
|
| 963 |
+
# Increment consecutive response counter
|
| 964 |
+
self.consecutive_response_count += 1
|
| 965 |
+
|
| 966 |
+
# Generate a follow-up question if the response is too short and we haven't hit limits
|
| 967 |
+
if (len(response_text.split()) < self.min_response_words and
|
| 968 |
+
not response_text.endswith('?') and
|
| 969 |
+
self.consecutive_response_count < self.max_consecutive_responses):
|
| 970 |
+
|
| 971 |
+
follow_up_prompt = f"""
|
| 972 |
Recent conversation:
|
| 973 |
{chr(10).join([f"{msg['role']}: {msg['text']}" for msg in conversation_history[-3:]])}
|
| 974 |
|
| 975 |
Now, write a single empathetic and open-ended question to encourage the user to share more.
|
| 976 |
Respond with just the question, no explanation.
|
| 977 |
"""
|
| 978 |
+
follow_up = self.llm.invoke(follow_up_prompt).strip()
|
| 979 |
+
# Clean and extract only the actual question (first sentence ending with '?')
|
| 980 |
+
matches = re.findall(r'([^\n.?!]*\?)', follow_up)
|
| 981 |
+
if matches:
|
| 982 |
+
question = matches[0].strip()
|
| 983 |
+
else:
|
| 984 |
+
question = follow_up.strip().split('\n')[0]
|
| 985 |
+
|
| 986 |
+
# Validate the follow-up question
|
| 987 |
+
question = self._validate_and_limit_response(question, message)
|
| 988 |
+
|
| 989 |
+
# If the main response is very short, return just the question
|
| 990 |
+
if len(response_text.split()) < 5:
|
| 991 |
+
response_text = question
|
| 992 |
+
else:
|
| 993 |
+
response_text = f"{response_text}\n\n{question}"
|
| 994 |
|
| 995 |
# assistant response -> conversation history
|
| 996 |
assistant_message = Message(
|