Term-4-Project / RAG.py
Dhenenjay's picture
Upload folder using huggingface_hub
adcfb91 verified
from langchain_text_splitters import RecursiveCharacterTextSplitter
from retrieve_documents import retrieve_relevant_documents
from langchain_cohere import ChatCohere, CohereEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.chat import HumanMessagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
import json
import pathlib
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
from langchain_chroma import Chroma
class RAG_with_memory:
def __init__(self, retriever=None):
self.file_loc = pathlib.Path(__file__).parent.resolve()
with open(self.file_loc / "prompts/search.txt", "r") as f:
self.raw_prompt_text = f.read()
with open(self.file_loc / "prompts/rewrite_query.txt", "r") as f:
self.raw_query_rewrite_text = f.read()
self.retriever = retriever
self.reset()
def load_history(self, history):
self.history = ""
for item in history:
if item["role"] == "user":
self.history += "Human: " + item["content"] + "\n"
else:
self.history += "AI: " + item["content"] + "\n"
self.history += "\n"
def add_retriever(self, retriever):
self.retriever = retriever
def reset(self):
with open(self.file_loc / "api_keys.json", "r") as f:
api_keys = json.load(f)
COHERE_API_KEY = api_keys["cohere"]
self.llm = ChatCohere(model="command-r", cohere_api_key=COHERE_API_KEY)
self.history = ""
def rewrite_query(self, original_query):
print("Rewriting query")
prompt_text = self.raw_query_rewrite_text.format(history=self.history, query="{query}")
prompt_template = PromptTemplate(
input_variables=['query'],
template=prompt_text
)
chain = (
{"query" : RunnablePassthrough()}
| prompt_template
| self.llm
| StrOutputParser()
)
return chain.invoke(original_query)
def generate(self, query):
# TODO: Generate should take a retriever, not constructor
if self.retriever is None:
raise Exception("Retriever must non-None")
print("Prompting LLM")
prompt_text = self.raw_prompt_text.format(history=self.history, context="{context}", question="{question}")
prompt_template = PromptTemplate(
input_variables=['context', 'question'],
template=prompt_text
)
rag_chain = (
{"context": self.retriever | format_docs, "question": RunnablePassthrough()}
| prompt_template
| self.llm
| StrOutputParser()
)
response = rag_chain.invoke(query)
self.history += "Human: " + query + "\n"
self.history += "AI: " + response + "\n"
print("Done")
return response
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
def create_vector_store(documents, embedding):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(documents)
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding)
return vectorstore
def get_retriever_links(query, api_key_file):
documents, links = retrieve_relevant_documents(query, api_key_file)
with open(api_key_file, "r") as f:
api_keys = json.load(f)
COHERE_API_KEY = api_keys["cohere"]
embedding_model = CohereEmbeddings(cohere_api_key=COHERE_API_KEY, model='embed-english-v3.0')
print("Splitting Documents + Loading into vectorstore")
vectorstore = create_vector_store(documents, embedding_model)
return vectorstore.as_retriever(), links