File size: 8,571 Bytes
b0d4092
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import pickle

from Constants import *
from langchain.chains import (ConversationalRetrievalChain, RetrievalQA,
                              RetrievalQAWithSourcesChain)
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts.chat import (ChatPromptTemplate,
                                    HumanMessagePromptTemplate,
                                    SystemMessagePromptTemplate)
from langchain.prompts.prompt import PromptTemplate
#from langchain.retrievers.self_query import BaseTranslator
from langchain.retrievers.self_query.base import SelfQueryRetriever  
from langchain.chains.query_constructor.ir  import  Visitor
from langchain.vectorstores import Chroma
from langchain.vectorstores.base import VectorStoreRetriever
from metadatainfo import metadata_field_info
from notionMetadataInfo import notion_metadata_field_info
from langchain.embeddings import OpenAIEmbeddings
from typing import Any, List, Optional, Sequence, Union
import chromadb
from db_types import *

_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
You can assume the question about persons.

Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

template = """You are an AI assistant for answering questions about persons.
You are given the following extracted parts of a long document and a question. Provide a conversational answer.
If you don't know the answer, do not try to makeup the answer from other sources. If the answer is found quote the source of the answer as SOURCE:
Also include Topics in the answers as "TOPICS": Also include tags in the answers as "TAGS":
Question: {question}
=========
{context}
=========
Answer in Markdown:"""
QA_PROMPT = PromptTemplate(template=template, input_variables=[
                           "question", "context"])

class MyVisitor(Visitor) : 
    
    def visit_operation(self, op) -> Any:
        print ("in operation")
        return op
    def visit_comparison(self, comparison) -> Any:
        print("in comparison")
        return comparison
    def visit_structured_query(self, arg2) -> Any:
        print("in structured query "+ arg2.query)
        return self, arg2
    

def load_retriever():
    
    retriever = VectorStoreRetriever(vectorstore=get_vectorstore(),dict=metadata_field_info)
    return retriever

def get_vectorstore():
    print("Reading from vectorstore " + DB_TYPE)
    custom_meta_data_info = metadata_field_info
    if (DB_TYPE==DBTypes['FAISS'].value) : 
        print("reading faiss vectorstore")
        vectorstore = PERSIST_DIRECTORY + "myvectorstore.pkl"
        with open(vectorstore, "rb") as f:
            vectorstore = pickle.load(f)
    elif (DB_TYPE == DBTypes['NOTION'].value) :
        print("reading from Notion...")
        custom_meta_data_info = notion_metadata_field_info
        vectorstore = Chroma(persist_directory=NOTION_PERSIST_DIRECTORY,embedding_function=OpenAIEmbeddings(),collection_name=NOTION_COLLECTION_NAME)
        print("Notion collection count : " + str(vectorstore._collection.count()))
    else :
        vectorstore = Chroma(persist_directory=CHROMA_PERSIST_DIRECTORY,collection_name=CHROMA_COLLECTION_NAME,embedding_function=OpenAIEmbeddings())
        print("Chroma collection count : " + str(vectorstore._collection.count()))
        #vectorstore = Chroma(persist_directory=PERSIST_DIRECTORY,embedding_function=OpenAIEmbeddings(),collection_name="chatdata")
    return vectorstore

def get_basic_qa_chain():
    llm = ChatOpenAI(model_name="gpt-4", temperature=0)
    retriever = load_retriever()
    memory = ConversationBufferMemory(
        memory_key="chat_history", return_messages=True)
    model = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=retriever,
        memory=memory,
        verbose=True)
    return model


def get_custom_prompt_qa_chain():
    llm = ChatOpenAI(model_name="gpt-4", temperature=0)
    retriever = load_retriever()
    memory = ConversationBufferMemory(
        memory_key="chat_history", return_messages=True)
    # see: https://github.com/langchain-ai/langchain/issues/6635
    # see: https://github.com/langchain-ai/langchain/issues/1497
    model = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=retriever,
        memory=memory,
        combine_docs_chain_kwargs={"prompt": QA_PROMPT})
    return model


def get_condense_prompt_qa_chain():
    llm = ChatOpenAI(model_name="gpt-4", temperature=0)
    retriever = load_retriever()
    memory = ConversationBufferMemory(
        memory_key="chat_history", return_messages=True)
    # see: https://github.com/langchain-ai/langchain/issues/5890
    model = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=retriever,
        memory=memory,
        condense_question_prompt=CONDENSE_QUESTION_PROMPT,
        combine_docs_chain_kwargs={"prompt": QA_PROMPT})
    return model

def get_retrievalqa_with_sources_chain():
    system_template="""Use the following pieces of context to answer the users question.
    Take note of the sources and include them in the answer in the format: "SOURCES: source1 source2", use "SOURCES" in capital letters regardless of the number of sources.
    Also include Topics in the answers as "TOPICS". Also include tags in the answer as "TAGS". Include creationYear in the answers as "YEAR". If you don't know the answer, just say that "I donot know", don't try to make up an answer.
    ----------------
    {summaries}"""
    messages = [
        SystemMessagePromptTemplate.from_template(system_template),
        HumanMessagePromptTemplate.from_template("{question}")
    ]

    prompt = ChatPromptTemplate.from_messages(messages)

    chain_type_kwargs = {"prompt": prompt}
  
    document_content_description = "Personal files"
    llm = ChatOpenAI(model_name="gpt-4", temperature=0)
    vectorstore = get_vectorstore()
    history=[]
    retriever = SelfQueryRetriever.from_llm(
        llm,
        vectorstore,
        document_content_description,
        metadata_field_info,
        #structured_query_translator=myVisitor,
        verbose=True,
        enable_limit=True,
        )
        
    #myVisitor = MyVisitor()
    def model_func(question) : 
        
        #retriever.get_relevant_documents(query)
        
        '''
        def model_func(question):
            # bug: this doesn't work with the built-in memory
            # hacking around it for the tutorial
            # see: https://github.com/langchain-ai/langchain/issues/5630
            result = retriever.get_relevant_documents(question)
        
            history.append((question, result['answer']))
            return result

        return model_func
    '''
        #print("metadata : " + retriever.metadata)
        chain = RetrievalQAWithSourcesChain.from_chain_type(llm, 
                                                    chain_type="stuff", 
                                                    retriever=retriever,
                                                    chain_type_kwargs=chain_type_kwargs
                                                    )
                                
        results = chain({"question": question})                  
        return results
    return model_func

def get_qa_with_sources_chain():
    llm = ChatOpenAI(model_name="gpt-4", temperature=0)
    retriever = load_retriever()
    history = []
    model = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=retriever,
        return_source_documents=True,
        verbose=True)

    def model_func(question):
        # bug: this doesn't work with the built-in memory
        # hacking around it for the tutorial
        # see: https://github.com/langchain-ai/langchain/issues/5630
        new_input = {"question": question['question'], "chat_history": history}
        for i in new_input:
            print("new_input"+ i)
        result = model(new_input)
        history.append((question['question'], result['answer']))
        return result

    return model_func


chain_options = {
    "basic": get_basic_qa_chain,
    "with_sources": get_qa_with_sources_chain,
    "custom_prompt": get_custom_prompt_qa_chain,
    "condense_prompt": get_condense_prompt_qa_chain,
    "retrieval_sources_chain" : get_retrievalqa_with_sources_chain,
}