zym / chat.py
zym0216's picture
Update chat.py
4e088e2 verified
import os
import openai
import langchain
from langchain.document_loaders import PyMuPDFLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.globals import set_verbose
from langchain.llms import VLLM
class RetrievalChatbot():
def __init__(self, api_key, api_base, model_name):
openai.api_key = api_key
#openai.api_base = api_base
self.model_name = model_name
set_verbose(True)
embeddings = OpenAIEmbeddings(max_retries=100)
if os.path.exists("persist"):
vectordb = Chroma(persist_directory="persist", embedding_function=embeddings)
print("loaded existing database")
else:
os.mkdir("persist")
retriever = vectordb.as_retriever(
search_type="mmr",
search_kwargs={"k": 10, "fetch_k": 50}
)
llm = ChatOpenAI(model_name="gpt-4-1106-preview")
self.qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
self.prompt = (
"Please answer the following question using information with the assistance of the given context.\n",
"Question: {message}\n",
"Directly output your answer without additional explanations.\n",
)
self.rephrase_prompt = (
"You will be given a question. Break down the question into subquestions, and output the subquestions in the format of a python list: [\"<subquestion1>\", \"<subquestion2>\", ...].\n",
"If the question cannot be broken down into subquestions, rephrase it 3 times and output the rephrases in the format of a python list: [\"<rephrase1>\", \"<rephrase>\", ...].\n",
"Do not include line breaks in any of your answers.\n",
"Question: {question}"
)
self.summarize_prompt = (
"You will be given a question and several answers to it. Please organize and summarize the answers to form one coherent answer to the question.\n",
"Question: {question}\n",
"Answers: {answers}\n",
"Directly output your answer without additional explanations.\n",
)
def get_openai_response(self, query, model_name):
response = openai.ChatCompletion.create(
model="gpt-4-1106-preview",
messages=[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": query}],
)
return response["choices"][0]["message"]["content"]
def get_response(self, chat_history, message):
rephrase_prompt = "".join(self.rephrase_prompt).format(question=message)
raw_subquestions = self.get_openai_response(rephrase_prompt, self.model_name)
raw_subquestions = raw_subquestions.replace("\n", "").replace("', '", '","').replace('", "', '","').replace("','", '","')
subquestion_list = raw_subquestions.lstrip('[').rstrip(']').split('","')
subquestion_list = [subquestion.strip('"').strip("'") for subquestion in subquestion_list]
responses_list = []
for rephrase in subquestion_list:
query = "".join(self.prompt).format(message=rephrase)
retrieval_response = self.qa(query)["result"]
responses_list.append(retrieval_response)
summarize_prompt = "".join(self.summarize_prompt).format(question=message, answers=responses_list)
summarized_answer = self.get_openai_response(summarize_prompt, self.model_name)
return summarized_answer