Spaces:
Runtime error
Runtime error
| from datasets import load_dataset | |
| from langchain_community.document_loaders.csv_loader import CSVLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.embeddings import CacheBackedEmbeddings | |
| from langchain.storage import LocalFileStore | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_openai import ChatOpenAI | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_core.runnables.passthrough import RunnablePassthrough | |
| from langchain_core.runnables.base import RunnableSequence | |
| class RAGModel: | |
| def __init__(self, openai_api_key): | |
| #openai_api_key = os.getenv("OPENAI_API_KEY") | |
| # Load dataset | |
| dataset = load_dataset('csv', data_files='imdb.csv') | |
| dataset_dict = dataset | |
| imdb_csv = dataset_dict["train"].to_csv('imdb.csv') | |
| # Load documents | |
| loader = CSVLoader(file_path="imdb.csv") | |
| data = loader.load() | |
| # Split documents into chunks | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| chunked_documents = text_splitter.split_documents(data) | |
| # Create embeddings | |
| self.embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_api_key) | |
| text_documents = [str(doc) for doc in chunked_documents] | |
| print(text_documents) | |
| # Create cache-backed embeddings | |
| self.store = LocalFileStore("./cache/") | |
| self.embedder = CacheBackedEmbeddings.from_bytes_store( | |
| self.embeddings, self.store, namespace=self.embeddings.model | |
| ) | |
| # Load and split documents again for FAISS | |
| documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter() | |
| docs = text_splitter.split_documents(documents) | |
| # Create vector store using FAISS | |
| self.vector_store = FAISS.from_documents(docs, self.embedder) | |
| self.vector_store.save_local("faiss_index") | |
| # Create retriever | |
| self.retriever = self.vector_store.as_retriever() | |
| # Create chat model | |
| self.chat_model = ChatOpenAI(model="gpt-4", temperature=0, openai_api_key=openai_api_key) | |
| # Create parser | |
| self.parser = StrOutputParser() | |
| # Create prompt template | |
| messages = "Answer the {question} based on the following context: {context}" | |
| self.prompt_template = ChatPromptTemplate.from_template(messages) | |
| def query(self, question): | |
| # Retrieve similar documents | |
| embedding_query = self.embeddings.embed_query(question) | |
| similar_documents = self.vector_store.similarity_search_by_vector(embedding_query) | |
| # Create context from retrieved documents | |
| context = "\n".join([doc.page_content for doc in similar_documents]) | |
| # Format prompt | |
| prompt = self.prompt_template.format(context=context, question=question) | |
| print(context) | |
| # Get response from chat model | |
| # response = self.chat_model(prompt) | |
| # Parse response | |
| # result = self.parser.parse(response) | |
| # chain = prompt=prompt | self.chat_model | parser=self.parser | |
| # result = chain.invoke() | |
| # dict_context = {"question": question} | |
| #chain = ({"context": context,"question":Runnab | |
| chain =({"context": lambda x: context,"question": RunnablePassthrough()} | |
| | self.prompt_template | |
| | self.chat_model | |
| | self.parser) | |
| # | |
| result = chain.invoke(question) | |
| return result |