CodePathAI / rag.py
coolgandhi's picture
updating rag
6e9f8d9
from datasets import load_dataset
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.base import RunnableSequence
class RAGModel:
def __init__(self, openai_api_key):
#openai_api_key = os.getenv("OPENAI_API_KEY")
# Load dataset
dataset = load_dataset('csv', data_files='imdb.csv')
dataset_dict = dataset
imdb_csv = dataset_dict["train"].to_csv('imdb.csv')
# Load documents
loader = CSVLoader(file_path="imdb.csv")
data = loader.load()
# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
chunked_documents = text_splitter.split_documents(data)
# Create embeddings
self.embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_api_key)
text_documents = [str(doc) for doc in chunked_documents]
print(text_documents)
# Create cache-backed embeddings
self.store = LocalFileStore("./cache/")
self.embedder = CacheBackedEmbeddings.from_bytes_store(
self.embeddings, self.store, namespace=self.embeddings.model
)
# Load and split documents again for FAISS
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter()
docs = text_splitter.split_documents(documents)
# Create vector store using FAISS
self.vector_store = FAISS.from_documents(docs, self.embedder)
self.vector_store.save_local("faiss_index")
# Create retriever
self.retriever = self.vector_store.as_retriever()
# Create chat model
self.chat_model = ChatOpenAI(model="gpt-4", temperature=0, openai_api_key=openai_api_key)
# Create parser
self.parser = StrOutputParser()
# Create prompt template
messages = "Answer the {question} based on the following context: {context}"
self.prompt_template = ChatPromptTemplate.from_template(messages)
def query(self, question):
# Retrieve similar documents
embedding_query = self.embeddings.embed_query(question)
similar_documents = self.vector_store.similarity_search_by_vector(embedding_query)
# Create context from retrieved documents
context = "\n".join([doc.page_content for doc in similar_documents])
# Format prompt
prompt = self.prompt_template.format(context=context, question=question)
print(context)
# Get response from chat model
# response = self.chat_model(prompt)
# Parse response
# result = self.parser.parse(response)
# chain = prompt=prompt | self.chat_model | parser=self.parser
# result = chain.invoke()
# dict_context = {"question": question}
#chain = ({"context": context,"question":Runnab
chain =({"context": lambda x: context,"question": RunnablePassthrough()}
| self.prompt_template
| self.chat_model
| self.parser)
#
result = chain.invoke(question)
return result