Nada commited on
Commit
5eea25b
·
1 Parent(s): 910249b
Files changed (1) hide show
  1. chatbot.py +167 -58
chatbot.py CHANGED
@@ -19,15 +19,14 @@ from peft import PeftModel, PeftConfig
19
  from sentence_transformers import SentenceTransformer
20
 
21
  # LangChain imports
22
- # Core LangChain components for building conversational AI
23
- from langchain.llms import HuggingFacePipeline # Wrapper for HuggingFace models
24
- from langchain.chains import LLMChain # Chain for LLM interactions
25
- from langchain.memory import ConversationBufferMemory # Memory for conversation history
26
- from langchain.prompts import PromptTemplate # Template for structured prompts
27
- from langchain.embeddings import HuggingFaceEmbeddings # Text embeddings for similarity search
28
- from langchain.text_splitter import RecursiveCharacterTextSplitter # Document chunking
29
- from langchain.document_loaders import TextLoader # Load text documents
30
- from langchain.vectorstores import FAISS # Vector database for similarity search
31
 
32
  # Import FlowManager
33
  from conversation_flow import FlowManager
@@ -214,7 +213,11 @@ class MentalHealthChatbot:
214
  peft_model_path: str = "nada013/mental-health-chatbot",
215
  therapy_guidelines_path: str = None,
216
  use_4bit: bool = True,
217
- device: str = None
 
 
 
 
218
  ):
219
  # Set device (cuda if available, otherwise cpu)
220
  if device is None:
@@ -234,6 +237,13 @@ class MentalHealthChatbot:
234
 
235
  logger.info(f"Using device: {self.device}")
236
 
 
 
 
 
 
 
 
237
  # Initialize models
238
  self.peft_model_path = peft_model_path
239
 
@@ -264,24 +274,12 @@ class MentalHealthChatbot:
264
  self.flow_manager = FlowManager(self.llm)
265
 
266
  # Setup conversation memory with LangChain
267
- # ConversationBufferMemory stores the conversation history in a buffer
268
- # This allows the chatbot to maintain context across multiple interactions
269
- # - return_messages=True: Returns messages as a list of message objects
270
- # - input_key="input": Specifies which key to use for the input in the memory
271
  self.memory = ConversationBufferMemory(
272
  return_messages=True,
273
  input_key="input"
274
  )
275
 
276
  # Create conversation prompt template
277
- # PromptTemplate defines the structure for generating responses
278
- # It includes placeholders for dynamic content that gets filled during generation
279
- # Input variables:
280
- # - history: Previous conversation context from memory
281
- # - input: Current user message
282
- # - past_context: Relevant past conversations from vector search
283
- # - emotion_context: Detected emotions and their context
284
- # - guidelines: Relevant therapeutic guidelines from vector search
285
  self.prompt_template = PromptTemplate(
286
  input_variables=["history", "input", "past_context", "emotion_context", "guidelines"],
287
  template="""You are a supportive and empathetic mental health conversational AI. Your role is to provide therapeutic support while maintaining professional boundaries.
@@ -323,7 +321,6 @@ Response:"""
323
  )
324
 
325
  # Setup vector database for retrieving relevant past conversations
326
-
327
  if therapy_guidelines_path and os.path.exists(therapy_guidelines_path):
328
  self.setup_vector_db(therapy_guidelines_path)
329
  else:
@@ -502,6 +499,109 @@ Response:"""
502
  logger.error(f"Error detecting emotions: {e}")
503
  return {"neutral": 1.0}
504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  def retrieve_relevant_context(self, query: str, k: int = 3) -> str:
506
  # Retrieve relevant past conversations using vector similarity
507
  if not hasattr(self, 'vector_db'):
@@ -568,31 +668,25 @@ Response:"""
568
  guidelines=guidelines
569
  )
570
 
571
- # Clean up the response to only include the actual message
572
- response = response.split("Response:")[-1].strip()
573
- response = response.split("---")[0].strip()
574
- response = response.split("Note:")[0].strip()
575
-
576
- # Remove any casual greetings like "Hey" or "Hi"
577
- response = re.sub(r'^(Hey|Hi|Hello|Hi there|Hey there),\s*', '', response)
578
 
579
  # Ensure the response is unique and not repeating previous messages
580
  if len(conversation_history) > 0:
581
  last_responses = [msg["text"] for msg in conversation_history[-4:] if msg["role"] == "assistant"]
582
  if response in last_responses:
 
583
  # Generate a new response with a different angle
584
- response = self.conversation.predict(
585
  input=f"{prompt} (Please provide a different perspective)",
586
  past_context=past_context,
587
  emotion_context=emotion_context,
588
  guidelines=guidelines
589
  )
590
- response = response.split("Response:")[-1].strip()
591
- response = re.sub(r'^(Hey|Hi|Hello|Hi there|Hey there),\s*', '', response)
592
 
593
-
594
-
595
- return response.strip()
596
 
597
  def generate_session_summary(
598
  self,
@@ -838,6 +932,9 @@ Would you like to connect with a professional now, or would you prefer to keep t
838
 
839
  return crisis_response
840
 
 
 
 
841
  # Detect emotions
842
  emotions = self.detect_emotion(message)
843
  conversation.emotion_history.append(emotions)
@@ -854,34 +951,46 @@ Would you like to connect with a professional now, or would you prefer to keep t
854
  "role": msg.role
855
  })
856
 
857
- # Generate response
858
- response_text = self.generate_response(message, emotions, conversation_history)
859
-
860
- # Generate a follow-up question if the response is too short
861
- if len(response_text.split()) < 20 and not response_text.endswith('?'):
862
- follow_up_prompt = f"""
 
 
 
 
 
 
 
 
 
 
 
 
863
  Recent conversation:
864
  {chr(10).join([f"{msg['role']}: {msg['text']}" for msg in conversation_history[-3:]])}
865
 
866
  Now, write a single empathetic and open-ended question to encourage the user to share more.
867
  Respond with just the question, no explanation.
868
  """
869
- follow_up = self.llm.invoke(follow_up_prompt).strip()
870
- # Clean and extract only the actual question (first sentence ending with '?')
871
- matches = re.findall(r'([^\n.?!]*\?)', follow_up)
872
- if matches:
873
- question = matches[0].strip()
874
- else:
875
- question = follow_up.strip().split('\n')[0]
876
- # If the main response is very short, return just the question
877
- if len(response_text.split()) < 5:
878
- response_text = question
879
- else:
880
- response_text = f"{response_text}\n\n{question}"
881
-
882
- # Final post-processing: remove any LLM commentary that may have leaked in
883
- response_text = response_text.strip()
884
- response_text = re.sub(r"(Your response|This response).*", "", response_text, flags=re.IGNORECASE).strip()
885
 
886
  # assistant response -> conversation history
887
  assistant_message = Message(
 
19
  from sentence_transformers import SentenceTransformer
20
 
21
  # LangChain imports
22
+ from langchain.llms import HuggingFacePipeline
23
+ from langchain.chains import LLMChain
24
+ from langchain.memory import ConversationBufferMemory
25
+ from langchain.prompts import PromptTemplate
26
+ from langchain.embeddings import HuggingFaceEmbeddings
27
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
28
+ from langchain.document_loaders import TextLoader
29
+ from langchain.vectorstores import FAISS
 
30
 
31
  # Import FlowManager
32
  from conversation_flow import FlowManager
 
213
  peft_model_path: str = "nada013/mental-health-chatbot",
214
  therapy_guidelines_path: str = None,
215
  use_4bit: bool = True,
216
+ device: str = None,
217
+ max_response_length: int = 500, # Maximum characters in response
218
+ max_response_words: int = 100, # Maximum words in response
219
+ min_response_words: int = 10, # Minimum words in response
220
+ max_consecutive_responses: int = 3 # Max consecutive responses without user input
221
  ):
222
  # Set device (cuda if available, otherwise cpu)
223
  if device is None:
 
237
 
238
  logger.info(f"Using device: {self.device}")
239
 
240
+ # Set response limits
241
+ self.max_response_length = max_response_length
242
+ self.max_response_words = max_response_words
243
+ self.min_response_words = min_response_words
244
+ self.max_consecutive_responses = max_consecutive_responses
245
+ self.consecutive_response_count = 0 # Track consecutive responses
246
+
247
  # Initialize models
248
  self.peft_model_path = peft_model_path
249
 
 
274
  self.flow_manager = FlowManager(self.llm)
275
 
276
  # Setup conversation memory with LangChain
 
 
 
 
277
  self.memory = ConversationBufferMemory(
278
  return_messages=True,
279
  input_key="input"
280
  )
281
 
282
  # Create conversation prompt template
 
 
 
 
 
 
 
 
283
  self.prompt_template = PromptTemplate(
284
  input_variables=["history", "input", "past_context", "emotion_context", "guidelines"],
285
  template="""You are a supportive and empathetic mental health conversational AI. Your role is to provide therapeutic support while maintaining professional boundaries.
 
321
  )
322
 
323
  # Setup vector database for retrieving relevant past conversations
 
324
  if therapy_guidelines_path and os.path.exists(therapy_guidelines_path):
325
  self.setup_vector_db(therapy_guidelines_path)
326
  else:
 
499
  logger.error(f"Error detecting emotions: {e}")
500
  return {"neutral": 1.0}
501
 
502
+ def _validate_and_limit_response(self, response: str, user_message: str) -> str:
503
+ """
504
+ Validate and limit response length and content.
505
+ Returns a properly limited response.
506
+ """
507
+ if not response or not response.strip():
508
+ return "I understand. Could you tell me more about that?"
509
+
510
+ # Clean the response
511
+ response = response.strip()
512
+
513
+ # Remove any LLM commentary or instructions
514
+ response = re.sub(r"(Your response|This response|Response:|Note:).*", "", response, flags=re.IGNORECASE).strip()
515
+ response = re.sub(r"---.*", "", response).strip()
516
+
517
+ # Remove casual greetings
518
+ response = re.sub(r'^(Hey|Hi|Hello|Hi there|Hey there),\s*', '', response)
519
+
520
+ # Count words and characters
521
+ words = response.split()
522
+ word_count = len(words)
523
+ char_count = len(response)
524
+
525
+ # Check if response is too short
526
+ if word_count < self.min_response_words:
527
+ logger.info(f"Response too short ({word_count} words), adding follow-up question")
528
+ if not response.endswith('?'):
529
+ response += " Could you tell me more about that?"
530
+
531
+ # Check if response is too long
532
+ if char_count > self.max_response_length or word_count > self.max_response_words:
533
+ logger.info(f"Response too long ({char_count} chars, {word_count} words), truncating")
534
+
535
+ # Try to find a good breaking point
536
+ if word_count > self.max_response_words:
537
+ # Truncate to max words
538
+ truncated_words = words[:self.max_response_words]
539
+ response = ' '.join(truncated_words)
540
+
541
+ # Try to end at a sentence
542
+ last_period = response.rfind('.')
543
+ last_question = response.rfind('?')
544
+ last_exclamation = response.rfind('!')
545
+
546
+ end_point = max(last_period, last_question, last_exclamation)
547
+ if end_point > len(response) * 0.7: # If we can end at a sentence within 70% of the limit
548
+ response = response[:end_point + 1]
549
+ else:
550
+ # Add ellipsis if we can't end naturally
551
+ response = response.rstrip() + "..."
552
+
553
+ elif char_count > self.max_response_length:
554
+ # Truncate to max characters
555
+ response = response[:self.max_response_length]
556
+
557
+ # Try to end at a word boundary
558
+ last_space = response.rfind(' ')
559
+ if last_space > len(response) * 0.8: # If we can end at a word within 80% of the limit
560
+ response = response[:last_space]
561
+ else:
562
+ # Add ellipsis
563
+ response = response.rstrip() + "..."
564
+
565
+ # Check for repetitive content
566
+ if self._is_repetitive(response, user_message):
567
+ logger.info("Response detected as repetitive, generating alternative")
568
+ return "I hear what you're saying. Could you help me understand this better?"
569
+
570
+ # Ensure response ends properly
571
+ if not response.endswith(('.', '!', '?')):
572
+ response = response.rstrip() + '.'
573
+
574
+ return response.strip()
575
+
576
+ def _is_repetitive(self, response: str, user_message: str) -> bool:
577
+ """
578
+ Check if response is repetitive or too similar to user message.
579
+ """
580
+ # Convert to lowercase for comparison
581
+ response_lower = response.lower()
582
+ user_lower = user_message.lower()
583
+
584
+ # Check if response contains too much of the user's message
585
+ user_words = set(user_lower.split())
586
+ response_words = set(response_lower.split())
587
+
588
+ if len(user_words) > 3: # Only check if user message has enough words
589
+ common_words = user_words.intersection(response_words)
590
+ if len(common_words) / len(user_words) > 0.6: # If more than 60% of user words are in response
591
+ return True
592
+
593
+ # Check for repetitive phrases
594
+ repetitive_phrases = [
595
+ "i understand", "i hear you", "that sounds", "i can see",
596
+ "thank you for sharing", "i appreciate", "that must be"
597
+ ]
598
+
599
+ phrase_count = sum(1 for phrase in repetitive_phrases if phrase in response_lower)
600
+ if phrase_count > 2: # If more than 2 repetitive phrases
601
+ return True
602
+
603
+ return False
604
+
605
  def retrieve_relevant_context(self, query: str, k: int = 3) -> str:
606
  # Retrieve relevant past conversations using vector similarity
607
  if not hasattr(self, 'vector_db'):
 
668
  guidelines=guidelines
669
  )
670
 
671
+ # Validate and limit the response
672
+ response = self._validate_and_limit_response(response, prompt)
 
 
 
 
 
673
 
674
  # Ensure the response is unique and not repeating previous messages
675
  if len(conversation_history) > 0:
676
  last_responses = [msg["text"] for msg in conversation_history[-4:] if msg["role"] == "assistant"]
677
  if response in last_responses:
678
+ logger.info("Response detected as duplicate, generating alternative")
679
  # Generate a new response with a different angle
680
+ alternative_response = self.conversation.predict(
681
  input=f"{prompt} (Please provide a different perspective)",
682
  past_context=past_context,
683
  emotion_context=emotion_context,
684
  guidelines=guidelines
685
  )
686
+ alternative_response = self._validate_and_limit_response(alternative_response, prompt)
687
+ response = alternative_response
688
 
689
+ return response
 
 
690
 
691
  def generate_session_summary(
692
  self,
 
932
 
933
  return crisis_response
934
 
935
+ # Reset consecutive response counter when user sends a message
936
+ self.consecutive_response_count = 0
937
+
938
  # Detect emotions
939
  emotions = self.detect_emotion(message)
940
  conversation.emotion_history.append(emotions)
 
951
  "role": msg.role
952
  })
953
 
954
+ # Check rate limiting for consecutive responses
955
+ if self.consecutive_response_count >= self.max_consecutive_responses:
956
+ logger.warning(f"Rate limit reached for user {user_id}, sending brief response")
957
+ response_text = "I'm here to listen. Take your time to share what's on your mind."
958
+ self.consecutive_response_count = 0 # Reset counter
959
+ else:
960
+ # Generate response
961
+ response_text = self.generate_response(message, emotions, conversation_history)
962
+
963
+ # Increment consecutive response counter
964
+ self.consecutive_response_count += 1
965
+
966
+ # Generate a follow-up question if the response is too short and we haven't hit limits
967
+ if (len(response_text.split()) < self.min_response_words and
968
+ not response_text.endswith('?') and
969
+ self.consecutive_response_count < self.max_consecutive_responses):
970
+
971
+ follow_up_prompt = f"""
972
  Recent conversation:
973
  {chr(10).join([f"{msg['role']}: {msg['text']}" for msg in conversation_history[-3:]])}
974
 
975
  Now, write a single empathetic and open-ended question to encourage the user to share more.
976
  Respond with just the question, no explanation.
977
  """
978
+ follow_up = self.llm.invoke(follow_up_prompt).strip()
979
+ # Clean and extract only the actual question (first sentence ending with '?')
980
+ matches = re.findall(r'([^\n.?!]*\?)', follow_up)
981
+ if matches:
982
+ question = matches[0].strip()
983
+ else:
984
+ question = follow_up.strip().split('\n')[0]
985
+
986
+ # Validate the follow-up question
987
+ question = self._validate_and_limit_response(question, message)
988
+
989
+ # If the main response is very short, return just the question
990
+ if len(response_text.split()) < 5:
991
+ response_text = question
992
+ else:
993
+ response_text = f"{response_text}\n\n{question}"
994
 
995
  # assistant response -> conversation history
996
  assistant_message = Message(