Nada
commited on
Commit
·
db641fe
1
Parent(s):
b67bb23
final
Browse files- chatbot.py +37 -25
chatbot.py
CHANGED
|
@@ -19,14 +19,15 @@ from peft import PeftModel, PeftConfig
|
|
| 19 |
from sentence_transformers import SentenceTransformer
|
| 20 |
|
| 21 |
# LangChain imports
|
| 22 |
-
|
| 23 |
-
from langchain.
|
| 24 |
-
from langchain.
|
| 25 |
-
from langchain.
|
| 26 |
-
from langchain.
|
| 27 |
-
from langchain.
|
| 28 |
-
from langchain.
|
| 29 |
-
from langchain.
|
|
|
|
| 30 |
|
| 31 |
# Import FlowManager
|
| 32 |
from conversation_flow import FlowManager
|
|
@@ -106,11 +107,8 @@ class SessionSummary(BaseModel):
|
|
| 106 |
user_id: str = Field(
|
| 107 |
...,
|
| 108 |
description="Identifier of the user",
|
| 109 |
-
examples=["user_123"]
|
| 110 |
-
|
| 111 |
-
start_time: str = Field(
|
| 112 |
-
...,
|
| 113 |
-
description="ISO format start time of the session"
|
| 114 |
)
|
| 115 |
end_time: str = Field(
|
| 116 |
...,
|
|
@@ -266,12 +264,24 @@ class MentalHealthChatbot:
|
|
| 266 |
self.flow_manager = FlowManager(self.llm)
|
| 267 |
|
| 268 |
# Setup conversation memory with LangChain
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
self.memory = ConversationBufferMemory(
|
| 270 |
return_messages=True,
|
| 271 |
input_key="input"
|
| 272 |
)
|
| 273 |
|
| 274 |
# Create conversation prompt template
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
self.prompt_template = PromptTemplate(
|
| 276 |
input_variables=["history", "input", "past_context", "emotion_context", "guidelines"],
|
| 277 |
template="""You are a supportive and empathetic mental health conversational AI. Your role is to provide therapeutic support while maintaining professional boundaries.
|
|
@@ -300,6 +310,12 @@ Response:"""
|
|
| 300 |
)
|
| 301 |
|
| 302 |
# Create the conversation chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
self.conversation = LLMChain(
|
| 304 |
llm=self.llm,
|
| 305 |
prompt=self.prompt_template,
|
|
@@ -308,11 +324,16 @@ Response:"""
|
|
| 308 |
)
|
| 309 |
|
| 310 |
# Setup embeddings for vector search
|
|
|
|
|
|
|
|
|
|
| 311 |
self.embeddings = HuggingFaceEmbeddings(
|
| 312 |
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
| 313 |
)
|
| 314 |
|
| 315 |
# Setup vector database for retrieving relevant past conversations
|
|
|
|
|
|
|
| 316 |
if therapy_guidelines_path and os.path.exists(therapy_guidelines_path):
|
| 317 |
self.setup_vector_db(therapy_guidelines_path)
|
| 318 |
else:
|
|
@@ -425,6 +446,9 @@ Response:"""
|
|
| 425 |
)
|
| 426 |
|
| 427 |
# Create LangChain wrapper
|
|
|
|
|
|
|
|
|
|
| 428 |
llm = HuggingFacePipeline(pipeline=text_generator)
|
| 429 |
|
| 430 |
return model, tokenizer, llm
|
|
@@ -579,19 +603,7 @@ Response:"""
|
|
| 579 |
response = response.split("Response:")[-1].strip()
|
| 580 |
response = re.sub(r'^(Hey|Hi|Hello|Hi there|Hey there),\s*', '', response)
|
| 581 |
|
| 582 |
-
# Limit response length
|
| 583 |
-
max_words = 60
|
| 584 |
-
max_sentences = 4
|
| 585 |
-
|
| 586 |
-
# Split into sentences
|
| 587 |
-
sentences = re.split(r'(?<=[.!?]) +', response)
|
| 588 |
-
if len(sentences) > max_sentences:
|
| 589 |
-
response = ' '.join(sentences[:max_sentences])
|
| 590 |
|
| 591 |
-
# Split into words
|
| 592 |
-
words = response.split()
|
| 593 |
-
if len(words) > max_words:
|
| 594 |
-
response = ' '.join(words[:max_words]) + '...'
|
| 595 |
|
| 596 |
return response.strip()
|
| 597 |
|
|
|
|
| 19 |
from sentence_transformers import SentenceTransformer
|
| 20 |
|
| 21 |
# LangChain imports
|
| 22 |
+
# Core LangChain components for building conversational AI
|
| 23 |
+
from langchain.llms import HuggingFacePipeline # Wrapper for HuggingFace models
|
| 24 |
+
from langchain.chains import LLMChain # Chain for LLM interactions
|
| 25 |
+
from langchain.memory import ConversationBufferMemory # Memory for conversation history
|
| 26 |
+
from langchain.prompts import PromptTemplate # Template for structured prompts
|
| 27 |
+
from langchain.embeddings import HuggingFaceEmbeddings # Text embeddings for similarity search
|
| 28 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter # Document chunking
|
| 29 |
+
from langchain.document_loaders import TextLoader # Load text documents
|
| 30 |
+
from langchain.vectorstores import FAISS # Vector database for similarity search
|
| 31 |
|
| 32 |
# Import FlowManager
|
| 33 |
from conversation_flow import FlowManager
|
|
|
|
| 107 |
user_id: str = Field(
|
| 108 |
...,
|
| 109 |
description="Identifier of the user",
|
| 110 |
+
examples=["user_123"])
|
| 111 |
+
start_time: str = Field(..., description="ISO format start time of the session"
|
|
|
|
|
|
|
|
|
|
| 112 |
)
|
| 113 |
end_time: str = Field(
|
| 114 |
...,
|
|
|
|
| 264 |
self.flow_manager = FlowManager(self.llm)
|
| 265 |
|
| 266 |
# Setup conversation memory with LangChain
|
| 267 |
+
# ConversationBufferMemory stores the conversation history in a buffer
|
| 268 |
+
# This allows the chatbot to maintain context across multiple interactions
|
| 269 |
+
# - return_messages=True: Returns messages as a list of message objects
|
| 270 |
+
# - input_key="input": Specifies which key to use for the input in the memory
|
| 271 |
self.memory = ConversationBufferMemory(
|
| 272 |
return_messages=True,
|
| 273 |
input_key="input"
|
| 274 |
)
|
| 275 |
|
| 276 |
# Create conversation prompt template
|
| 277 |
+
# PromptTemplate defines the structure for generating responses
|
| 278 |
+
# It includes placeholders for dynamic content that gets filled during generation
|
| 279 |
+
# Input variables:
|
| 280 |
+
# - history: Previous conversation context from memory
|
| 281 |
+
# - input: Current user message
|
| 282 |
+
# - past_context: Relevant past conversations from vector search
|
| 283 |
+
# - emotion_context: Detected emotions and their context
|
| 284 |
+
# - guidelines: Relevant therapeutic guidelines from vector search
|
| 285 |
self.prompt_template = PromptTemplate(
|
| 286 |
input_variables=["history", "input", "past_context", "emotion_context", "guidelines"],
|
| 287 |
template="""You are a supportive and empathetic mental health conversational AI. Your role is to provide therapeutic support while maintaining professional boundaries.
|
|
|
|
| 310 |
)
|
| 311 |
|
| 312 |
# Create the conversation chain
|
| 313 |
+
# LLMChain combines the language model, prompt template, and memory
|
| 314 |
+
# This creates a conversational agent that can:
|
| 315 |
+
# - Generate responses using the LLM
|
| 316 |
+
# - Use the prompt template for structured input
|
| 317 |
+
# - Maintain conversation history in memory
|
| 318 |
+
# - verbose=False: Disables detailed logging of chain operations
|
| 319 |
self.conversation = LLMChain(
|
| 320 |
llm=self.llm,
|
| 321 |
prompt=self.prompt_template,
|
|
|
|
| 324 |
)
|
| 325 |
|
| 326 |
# Setup embeddings for vector search
|
| 327 |
+
# HuggingFaceEmbeddings converts text to numerical vectors for similarity search
|
| 328 |
+
# all-MiniLM-L6-v2 is a lightweight but effective sentence embedding model
|
| 329 |
+
# These embeddings enable semantic search of past conversations and guidelines
|
| 330 |
self.embeddings = HuggingFaceEmbeddings(
|
| 331 |
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
| 332 |
)
|
| 333 |
|
| 334 |
# Setup vector database for retrieving relevant past conversations
|
| 335 |
+
# The vector database stores embeddings of therapy guidelines and past conversations
|
| 336 |
+
# This enables semantic search to find relevant context for each response
|
| 337 |
if therapy_guidelines_path and os.path.exists(therapy_guidelines_path):
|
| 338 |
self.setup_vector_db(therapy_guidelines_path)
|
| 339 |
else:
|
|
|
|
| 446 |
)
|
| 447 |
|
| 448 |
# Create LangChain wrapper
|
| 449 |
+
# HuggingFacePipeline wraps the HuggingFace pipeline for use with LangChain
|
| 450 |
+
# This enables the pipeline to work seamlessly with LangChain components
|
| 451 |
+
# like chains, memory, and prompts
|
| 452 |
llm = HuggingFacePipeline(pipeline=text_generator)
|
| 453 |
|
| 454 |
return model, tokenizer, llm
|
|
|
|
| 603 |
response = response.split("Response:")[-1].strip()
|
| 604 |
response = re.sub(r'^(Hey|Hi|Hello|Hi there|Hey there),\s*', '', response)
|
| 605 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
|
| 608 |
return response.strip()
|
| 609 |
|