Spaces:
Paused
Paused
File size: 3,967 Bytes
adcfb91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
from retrieve_documents import retrieve_relevant_documents
from langchain_cohere import ChatCohere, CohereEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.chat import HumanMessagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
import json
import pathlib
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
from langchain_chroma import Chroma
class RAG_with_memory:
def __init__(self, retriever=None):
self.file_loc = pathlib.Path(__file__).parent.resolve()
with open(self.file_loc / "prompts/search.txt", "r") as f:
self.raw_prompt_text = f.read()
with open(self.file_loc / "prompts/rewrite_query.txt", "r") as f:
self.raw_query_rewrite_text = f.read()
self.retriever = retriever
self.reset()
def load_history(self, history):
self.history = ""
for item in history:
if item["role"] == "user":
self.history += "Human: " + item["content"] + "\n"
else:
self.history += "AI: " + item["content"] + "\n"
self.history += "\n"
def add_retriever(self, retriever):
self.retriever = retriever
def reset(self):
with open(self.file_loc / "api_keys.json", "r") as f:
api_keys = json.load(f)
COHERE_API_KEY = api_keys["cohere"]
self.llm = ChatCohere(model="command-r", cohere_api_key=COHERE_API_KEY)
self.history = ""
def rewrite_query(self, original_query):
print("Rewriting query")
prompt_text = self.raw_query_rewrite_text.format(history=self.history, query="{query}")
prompt_template = PromptTemplate(
input_variables=['query'],
template=prompt_text
)
chain = (
{"query" : RunnablePassthrough()}
| prompt_template
| self.llm
| StrOutputParser()
)
return chain.invoke(original_query)
def generate(self, query):
# TODO: Generate should take a retriever, not constructor
if self.retriever is None:
raise Exception("Retriever must non-None")
print("Prompting LLM")
prompt_text = self.raw_prompt_text.format(history=self.history, context="{context}", question="{question}")
prompt_template = PromptTemplate(
input_variables=['context', 'question'],
template=prompt_text
)
rag_chain = (
{"context": self.retriever | format_docs, "question": RunnablePassthrough()}
| prompt_template
| self.llm
| StrOutputParser()
)
response = rag_chain.invoke(query)
self.history += "Human: " + query + "\n"
self.history += "AI: " + response + "\n"
print("Done")
return response
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
def create_vector_store(documents, embedding):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(documents)
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding)
return vectorstore
def get_retriever_links(query, api_key_file):
documents, links = retrieve_relevant_documents(query, api_key_file)
with open(api_key_file, "r") as f:
api_keys = json.load(f)
COHERE_API_KEY = api_keys["cohere"]
embedding_model = CohereEmbeddings(cohere_api_key=COHERE_API_KEY, model='embed-english-v3.0')
print("Splitting Documents + Loading into vectorstore")
vectorstore = create_vector_store(documents, embedding_model)
return vectorstore.as_retriever(), links
|