from datasets import load_dataset from langchain_community.document_loaders.csv_loader import CSVLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings import CacheBackedEmbeddings from langchain.storage import LocalFileStore from langchain_community.vectorstores import FAISS from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langchain_openai import OpenAIEmbeddings from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.base import RunnableSequence class RAGModel: def __init__(self, openai_api_key): #openai_api_key = os.getenv("OPENAI_API_KEY") # Load dataset dataset = load_dataset('csv', data_files='imdb.csv') dataset_dict = dataset imdb_csv = dataset_dict["train"].to_csv('imdb.csv') # Load documents loader = CSVLoader(file_path="imdb.csv") data = loader.load() # Split documents into chunks text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) chunked_documents = text_splitter.split_documents(data) # Create embeddings self.embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_api_key) text_documents = [str(doc) for doc in chunked_documents] print(text_documents) # Create cache-backed embeddings self.store = LocalFileStore("./cache/") self.embedder = CacheBackedEmbeddings.from_bytes_store( self.embeddings, self.store, namespace=self.embeddings.model ) # Load and split documents again for FAISS documents = loader.load() text_splitter = RecursiveCharacterTextSplitter() docs = text_splitter.split_documents(documents) # Create vector store using FAISS self.vector_store = FAISS.from_documents(docs, self.embedder) self.vector_store.save_local("faiss_index") # Create retriever self.retriever = self.vector_store.as_retriever() # Create chat model self.chat_model = ChatOpenAI(model="gpt-4", temperature=0, openai_api_key=openai_api_key) # Create parser self.parser = StrOutputParser() # Create prompt template messages = "Answer the {question} based on the following context: {context}" self.prompt_template = ChatPromptTemplate.from_template(messages) def query(self, question): # Retrieve similar documents embedding_query = self.embeddings.embed_query(question) similar_documents = self.vector_store.similarity_search_by_vector(embedding_query) # Create context from retrieved documents context = "\n".join([doc.page_content for doc in similar_documents]) # Format prompt prompt = self.prompt_template.format(context=context, question=question) print(context) # Get response from chat model # response = self.chat_model(prompt) # Parse response # result = self.parser.parse(response) # chain = prompt=prompt | self.chat_model | parser=self.parser # result = chain.invoke() # dict_context = {"question": question} #chain = ({"context": context,"question":Runnab chain =({"context": lambda x: context,"question": RunnablePassthrough()} | self.prompt_template | self.chat_model | self.parser) # result = chain.invoke(question) return result