Spaces:
Paused
Paused
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from retrieve_documents import retrieve_relevant_documents | |
| from langchain_cohere import ChatCohere, CohereEmbeddings | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.prompts.chat import HumanMessagePromptTemplate | |
| from langchain_core.prompts.prompt import PromptTemplate | |
| import json | |
| import pathlib | |
| __import__('pysqlite3') | |
| import sys | |
| sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') | |
| from langchain_chroma import Chroma | |
| class RAG_with_memory: | |
| def __init__(self, retriever=None): | |
| self.file_loc = pathlib.Path(__file__).parent.resolve() | |
| with open(self.file_loc / "prompts/search.txt", "r") as f: | |
| self.raw_prompt_text = f.read() | |
| with open(self.file_loc / "prompts/rewrite_query.txt", "r") as f: | |
| self.raw_query_rewrite_text = f.read() | |
| self.retriever = retriever | |
| self.reset() | |
| def load_history(self, history): | |
| self.history = "" | |
| for item in history: | |
| if item["role"] == "user": | |
| self.history += "Human: " + item["content"] + "\n" | |
| else: | |
| self.history += "AI: " + item["content"] + "\n" | |
| self.history += "\n" | |
| def add_retriever(self, retriever): | |
| self.retriever = retriever | |
| def reset(self): | |
| with open(self.file_loc / "api_keys.json", "r") as f: | |
| api_keys = json.load(f) | |
| COHERE_API_KEY = api_keys["cohere"] | |
| self.llm = ChatCohere(model="command-r", cohere_api_key=COHERE_API_KEY) | |
| self.history = "" | |
| def rewrite_query(self, original_query): | |
| print("Rewriting query") | |
| prompt_text = self.raw_query_rewrite_text.format(history=self.history, query="{query}") | |
| prompt_template = PromptTemplate( | |
| input_variables=['query'], | |
| template=prompt_text | |
| ) | |
| chain = ( | |
| {"query" : RunnablePassthrough()} | |
| | prompt_template | |
| | self.llm | |
| | StrOutputParser() | |
| ) | |
| return chain.invoke(original_query) | |
| def generate(self, query): | |
| # TODO: Generate should take a retriever, not constructor | |
| if self.retriever is None: | |
| raise Exception("Retriever must non-None") | |
| print("Prompting LLM") | |
| prompt_text = self.raw_prompt_text.format(history=self.history, context="{context}", question="{question}") | |
| prompt_template = PromptTemplate( | |
| input_variables=['context', 'question'], | |
| template=prompt_text | |
| ) | |
| rag_chain = ( | |
| {"context": self.retriever | format_docs, "question": RunnablePassthrough()} | |
| | prompt_template | |
| | self.llm | |
| | StrOutputParser() | |
| ) | |
| response = rag_chain.invoke(query) | |
| self.history += "Human: " + query + "\n" | |
| self.history += "AI: " + response + "\n" | |
| print("Done") | |
| return response | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| def create_vector_store(documents, embedding): | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| splits = text_splitter.split_documents(documents) | |
| vectorstore = Chroma.from_documents(documents=splits, embedding=embedding) | |
| return vectorstore | |
| def get_retriever_links(query, api_key_file): | |
| documents, links = retrieve_relevant_documents(query, api_key_file) | |
| with open(api_key_file, "r") as f: | |
| api_keys = json.load(f) | |
| COHERE_API_KEY = api_keys["cohere"] | |
| embedding_model = CohereEmbeddings(cohere_api_key=COHERE_API_KEY, model='embed-english-v3.0') | |
| print("Splitting Documents + Loading into vectorstore") | |
| vectorstore = create_vector_store(documents, embedding_model) | |
| return vectorstore.as_retriever(), links | |