| import gradio as gr |
| from typing import Any, Dict, List, Tuple |
| from langchain_chroma import Chroma |
| from langchain_core.callbacks import BaseCallbackHandler |
| from langchain_core.messages import AIMessage, HumanMessage |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
| from langchain.chains.history_aware_retriever import create_history_aware_retriever |
| from langchain.chains.retrieval import create_retrieval_chain |
| from langchain.chains.combine_documents import create_stuff_documents_chain |
| from langchain_core.callbacks import CallbackManagerForRetrieverRun |
| from langchain_core.documents import Document |
| from langchain_core.retrievers import BaseRetriever |
| import pandas as pd |
|
|
| class CustomHandler(BaseCallbackHandler): |
| def __init__(self): |
| self.prompt = "" |
|
|
| def on_llm_start( |
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any |
| ) -> Any: |
| formatted_prompts = "\n".join(prompts) |
| self.prompt = formatted_prompts |
|
|
| class CustomRetriever(BaseRetriever): |
| vectorstore: Chroma |
| comments: pd.DataFrame |
|
|
| def _get_relevant_documents( |
| self, query: str, *, run_manager: CallbackManagerForRetrieverRun |
| ) -> List[Document]: |
| docs = self.vectorstore.similarity_search(query) |
| matching_documents = [] |
| for doc in docs: |
| post_id = int(doc.metadata['source']) |
| comment = self.comments.loc[self.comments['Post_ID'] == post_id, 'Comment_content'].values |
| query = doc.page_content.replace("Content: ", "User: ") |
| content = f"{query}\nAssistant: {comment[0]}" |
| matching_documents.append( |
| Document( |
| page_content=content, |
| metadata=doc.metadata |
| ) |
| ) |
|
|
| print(matching_documents) |
| return matching_documents |
|
|
| class ChatBot: |
| def __init__(self, is_debug=False): |
| self.is_debug = is_debug |
| self.model = ChatOpenAI() |
| self.handler = CustomHandler() |
| self.embedding_function = OpenAIEmbeddings() |
| self.vectorstore = Chroma( |
| embedding_function=self.embedding_function, |
| collection_name="documents", |
| persist_directory="chroma", |
| ) |
| self.comments = pd.read_csv("data/comments.csv") |
| self.retriever = CustomRetriever(vectorstore=self.vectorstore, comments=self.comments) |
| |
| def create_chain(self): |
| qa_system_prompt = """ |
| You are a helpful and joyous mental therapy assistant. Always answer as helpfully and cheerfully as possible, while being safe. |
| Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. |
| Please ensure that your responses are socially unbiased and positive in nature. |
| If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. |
| If you don't know the answer to a question, please don't share false information. |
| |
| Here are a few examples of answers: |
| {context} |
| |
| """ |
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", qa_system_prompt), |
| MessagesPlaceholder(variable_name="chat_history"), |
| ("human", "{input}") |
| ]) |
|
|
| chain = create_stuff_documents_chain( |
| llm=self.model, |
| prompt=prompt |
| ) |
|
|
| retriever_prompt = ChatPromptTemplate.from_messages([ |
| MessagesPlaceholder(variable_name="chat_history"), |
| ("human", "{input}"), |
| ("human", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation") |
| ]) |
| history_aware_retriever = create_history_aware_retriever( |
| llm=self.model, |
| retriever=self.retriever, |
| prompt=retriever_prompt |
| ) |
|
|
| retrieval_chain = create_retrieval_chain( |
| |
| history_aware_retriever, |
| chain |
| ) |
|
|
| return retrieval_chain |
|
|
| def process_chat_history(self, chat_history): |
| history = [] |
| for (query, response) in chat_history: |
| history.append(HumanMessage(content=query)) |
| history.append(AIMessage(content=response)) |
| return history |
|
|
| def generate_response(self, query, chat_history): |
| if not input: |
| raise gr.Error("Please enter a question.") |
|
|
| history = self.process_chat_history(chat_history) |
| conversational_chain = self.create_chain() |
| response = conversational_chain.invoke( |
| { |
| "input": query, |
| "chat_history": history, |
| }, |
| config={"callbacks": [self.handler]} |
| )["answer"] |
|
|
| references = self.handler.prompt if self.is_debug else "This is for debugging purposes only." |
| return response, references |