Spaces:
Sleeping
Sleeping
| import pickle | |
| from Constants import * | |
| from langchain.chains import (ConversationalRetrievalChain, RetrievalQA, | |
| RetrievalQAWithSourcesChain) | |
| from langchain.chains.query_constructor.base import AttributeInfo | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.prompts.chat import (ChatPromptTemplate, | |
| HumanMessagePromptTemplate, | |
| SystemMessagePromptTemplate) | |
| from langchain.prompts.prompt import PromptTemplate | |
| #from langchain.retrievers.self_query import BaseTranslator | |
| from langchain.retrievers.self_query.base import SelfQueryRetriever | |
| from langchain.chains.query_constructor.ir import Visitor | |
| from langchain.vectorstores import Chroma | |
| from langchain.vectorstores.base import VectorStoreRetriever | |
| from metadatainfo import metadata_field_info | |
| from notionMetadataInfo import notion_metadata_field_info | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from typing import Any, List, Optional, Sequence, Union | |
| import chromadb | |
| from db_types import * | |
| _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. | |
| You can assume the question about persons. | |
| Chat History: | |
| {chat_history} | |
| Follow Up Input: {question} | |
| Standalone question:""" | |
| CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) | |
| template = """You are an AI assistant for answering questions about persons. | |
| You are given the following extracted parts of a long document and a question. Provide a conversational answer. | |
| If you don't know the answer, do not try to makeup the answer from other sources. If the answer is found quote the source of the answer as SOURCE: | |
| Also include Topics in the answers as "TOPICS": Also include tags in the answers as "TAGS": | |
| Question: {question} | |
| ========= | |
| {context} | |
| ========= | |
| Answer in Markdown:""" | |
| QA_PROMPT = PromptTemplate(template=template, input_variables=[ | |
| "question", "context"]) | |
| class MyVisitor(Visitor) : | |
| def visit_operation(self, op) -> Any: | |
| print ("in operation") | |
| return op | |
| def visit_comparison(self, comparison) -> Any: | |
| print("in comparison") | |
| return comparison | |
| def visit_structured_query(self, arg2) -> Any: | |
| print("in structured query "+ arg2.query) | |
| return self, arg2 | |
| def load_retriever(): | |
| retriever = VectorStoreRetriever(vectorstore=get_vectorstore(),dict=metadata_field_info) | |
| return retriever | |
| def get_vectorstore(): | |
| print("Reading from vectorstore " + DB_TYPE) | |
| custom_meta_data_info = metadata_field_info | |
| if (DB_TYPE==DBTypes['FAISS'].value) : | |
| print("reading faiss vectorstore") | |
| vectorstore = PERSIST_DIRECTORY + "myvectorstore.pkl" | |
| with open(vectorstore, "rb") as f: | |
| vectorstore = pickle.load(f) | |
| elif (DB_TYPE == DBTypes['NOTION'].value) : | |
| print("reading from Notion...") | |
| custom_meta_data_info = notion_metadata_field_info | |
| vectorstore = Chroma(persist_directory=NOTION_PERSIST_DIRECTORY,embedding_function=OpenAIEmbeddings(),collection_name=NOTION_COLLECTION_NAME) | |
| print("Notion collection count : " + str(vectorstore._collection.count())) | |
| else : | |
| vectorstore = Chroma(persist_directory=CHROMA_PERSIST_DIRECTORY,collection_name=CHROMA_COLLECTION_NAME,embedding_function=OpenAIEmbeddings()) | |
| print("Chroma collection count : " + str(vectorstore._collection.count())) | |
| #vectorstore = Chroma(persist_directory=PERSIST_DIRECTORY,embedding_function=OpenAIEmbeddings(),collection_name="chatdata") | |
| return vectorstore | |
| def get_basic_qa_chain(): | |
| llm = ChatOpenAI(model_name="gpt-4", temperature=0) | |
| retriever = load_retriever() | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", return_messages=True) | |
| model = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retriever, | |
| memory=memory, | |
| verbose=True) | |
| return model | |
| def get_custom_prompt_qa_chain(): | |
| llm = ChatOpenAI(model_name="gpt-4", temperature=0) | |
| retriever = load_retriever() | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", return_messages=True) | |
| # see: https://github.com/langchain-ai/langchain/issues/6635 | |
| # see: https://github.com/langchain-ai/langchain/issues/1497 | |
| model = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retriever, | |
| memory=memory, | |
| combine_docs_chain_kwargs={"prompt": QA_PROMPT}) | |
| return model | |
| def get_condense_prompt_qa_chain(): | |
| llm = ChatOpenAI(model_name="gpt-4", temperature=0) | |
| retriever = load_retriever() | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", return_messages=True) | |
| # see: https://github.com/langchain-ai/langchain/issues/5890 | |
| model = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retriever, | |
| memory=memory, | |
| condense_question_prompt=CONDENSE_QUESTION_PROMPT, | |
| combine_docs_chain_kwargs={"prompt": QA_PROMPT}) | |
| return model | |
| def get_retrievalqa_with_sources_chain(): | |
| system_template="""Use the following pieces of context to answer the users question. | |
| Take note of the sources and include them in the answer in the format: "SOURCES: source1 source2", use "SOURCES" in capital letters regardless of the number of sources. | |
| Also include Topics in the answers as "TOPICS". Also include tags in the answer as "TAGS". Include creationYear in the answers as "YEAR". If you don't know the answer, just say that "I donot know", don't try to make up an answer. | |
| ---------------- | |
| {summaries}""" | |
| messages = [ | |
| SystemMessagePromptTemplate.from_template(system_template), | |
| HumanMessagePromptTemplate.from_template("{question}") | |
| ] | |
| prompt = ChatPromptTemplate.from_messages(messages) | |
| chain_type_kwargs = {"prompt": prompt} | |
| document_content_description = "Personal files" | |
| llm = ChatOpenAI(model_name="gpt-4", temperature=0) | |
| vectorstore = get_vectorstore() | |
| history=[] | |
| retriever = SelfQueryRetriever.from_llm( | |
| llm, | |
| vectorstore, | |
| document_content_description, | |
| metadata_field_info, | |
| #structured_query_translator=myVisitor, | |
| verbose=True, | |
| enable_limit=True, | |
| ) | |
| #myVisitor = MyVisitor() | |
| def model_func(question) : | |
| #retriever.get_relevant_documents(query) | |
| ''' | |
| def model_func(question): | |
| # bug: this doesn't work with the built-in memory | |
| # hacking around it for the tutorial | |
| # see: https://github.com/langchain-ai/langchain/issues/5630 | |
| result = retriever.get_relevant_documents(question) | |
| history.append((question, result['answer'])) | |
| return result | |
| return model_func | |
| ''' | |
| #print("metadata : " + retriever.metadata) | |
| chain = RetrievalQAWithSourcesChain.from_chain_type(llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| chain_type_kwargs=chain_type_kwargs | |
| ) | |
| results = chain({"question": question}) | |
| return results | |
| return model_func | |
| def get_qa_with_sources_chain(): | |
| llm = ChatOpenAI(model_name="gpt-4", temperature=0) | |
| retriever = load_retriever() | |
| history = [] | |
| model = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retriever, | |
| return_source_documents=True, | |
| verbose=True) | |
| def model_func(question): | |
| # bug: this doesn't work with the built-in memory | |
| # hacking around it for the tutorial | |
| # see: https://github.com/langchain-ai/langchain/issues/5630 | |
| new_input = {"question": question['question'], "chat_history": history} | |
| for i in new_input: | |
| print("new_input"+ i) | |
| result = model(new_input) | |
| history.append((question['question'], result['answer'])) | |
| return result | |
| return model_func | |
| chain_options = { | |
| "basic": get_basic_qa_chain, | |
| "with_sources": get_qa_with_sources_chain, | |
| "custom_prompt": get_custom_prompt_qa_chain, | |
| "condense_prompt": get_condense_prompt_qa_chain, | |
| "retrieval_sources_chain" : get_retrievalqa_with_sources_chain, | |
| } | |