File size: 3,967 Bytes
adcfb91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from langchain_text_splitters import RecursiveCharacterTextSplitter
from retrieve_documents import retrieve_relevant_documents
from langchain_cohere import ChatCohere, CohereEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.chat import HumanMessagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
import json
import pathlib

__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
from langchain_chroma import Chroma

class RAG_with_memory:
    def __init__(self, retriever=None):
        self.file_loc = pathlib.Path(__file__).parent.resolve()
        with open(self.file_loc / "prompts/search.txt", "r") as f:
            self.raw_prompt_text = f.read()
        with open(self.file_loc / "prompts/rewrite_query.txt", "r") as f:
            self.raw_query_rewrite_text = f.read()
        self.retriever = retriever
        self.reset()
    
    def load_history(self, history):
        self.history = ""
        for item in history:
            if item["role"] == "user":
                self.history += "Human: " + item["content"] + "\n"
            else:
                self.history += "AI: " + item["content"] + "\n"
        self.history += "\n"

    def add_retriever(self, retriever):
        self.retriever = retriever
    
    def reset(self):
        with open(self.file_loc / "api_keys.json", "r") as f:
            api_keys = json.load(f)
            COHERE_API_KEY = api_keys["cohere"]
        self.llm = ChatCohere(model="command-r", cohere_api_key=COHERE_API_KEY)
        self.history = ""

    def rewrite_query(self, original_query):
        print("Rewriting query")
        prompt_text = self.raw_query_rewrite_text.format(history=self.history, query="{query}")
        prompt_template = PromptTemplate(
            input_variables=['query'],
            template=prompt_text
        )
        chain = (
            {"query" : RunnablePassthrough()}
            | prompt_template
            | self.llm
            | StrOutputParser()
        )
        return chain.invoke(original_query)

    def generate(self, query):
        # TODO: Generate should take a retriever, not constructor
        if self.retriever is None:
            raise Exception("Retriever must non-None")
        print("Prompting LLM")
        prompt_text = self.raw_prompt_text.format(history=self.history, context="{context}", question="{question}")
        prompt_template = PromptTemplate(
            input_variables=['context', 'question'],
            template=prompt_text
        )
        rag_chain = (
            {"context": self.retriever | format_docs, "question": RunnablePassthrough()}
            | prompt_template
            | self.llm
            | StrOutputParser()
        )

        response = rag_chain.invoke(query)
        self.history += "Human: " + query + "\n"
        self.history += "AI: " + response + "\n"
        print("Done")
        return response


def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


def create_vector_store(documents, embedding):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    splits = text_splitter.split_documents(documents)
    vectorstore = Chroma.from_documents(documents=splits, embedding=embedding)
    return vectorstore

def get_retriever_links(query, api_key_file):
    documents, links = retrieve_relevant_documents(query, api_key_file)

    with open(api_key_file, "r") as f: 
        api_keys = json.load(f)
        COHERE_API_KEY = api_keys["cohere"]
    embedding_model = CohereEmbeddings(cohere_api_key=COHERE_API_KEY, model='embed-english-v3.0')
    print("Splitting Documents + Loading into vectorstore")
    vectorstore = create_vector_store(documents, embedding_model)
    return vectorstore.as_retriever(), links