File size: 3,687 Bytes
6f77181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e9f8d9
 
6f77181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e9f8d9
6f77181
 
 
 
 
 
 
6e9f8d9
6f77181
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from datasets import load_dataset
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.base import RunnableSequence


class RAGModel:
    def __init__(self, openai_api_key):
        #openai_api_key = os.getenv("OPENAI_API_KEY")
        # Load dataset
        dataset = load_dataset('csv', data_files='imdb.csv')
        dataset_dict = dataset
        imdb_csv = dataset_dict["train"].to_csv('imdb.csv')

        # Load documents
        loader = CSVLoader(file_path="imdb.csv")
        data = loader.load()

        # Split documents into chunks
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
        chunked_documents = text_splitter.split_documents(data)

        # Create embeddings
        self.embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_api_key)
        text_documents = [str(doc) for doc in chunked_documents]
        print(text_documents)
        
        # Create cache-backed embeddings
        self.store = LocalFileStore("./cache/")
        self.embedder = CacheBackedEmbeddings.from_bytes_store(
            self.embeddings, self.store, namespace=self.embeddings.model
        )

        # Load and split documents again for FAISS
        documents = loader.load()
        text_splitter = RecursiveCharacterTextSplitter()
        docs = text_splitter.split_documents(documents)

        # Create vector store using FAISS
        self.vector_store = FAISS.from_documents(docs, self.embedder)
        self.vector_store.save_local("faiss_index")

        # Create retriever
        self.retriever = self.vector_store.as_retriever()

        # Create chat model
        self.chat_model = ChatOpenAI(model="gpt-4", temperature=0, openai_api_key=openai_api_key)

        # Create parser
        self.parser = StrOutputParser()

        # Create prompt template
        messages = "Answer the {question} based on the following context: {context}"
        self.prompt_template = ChatPromptTemplate.from_template(messages)
        

    def query(self, question):
        # Retrieve similar documents
        embedding_query = self.embeddings.embed_query(question)
        similar_documents = self.vector_store.similarity_search_by_vector(embedding_query)

        # Create context from retrieved documents
        context = "\n".join([doc.page_content for doc in similar_documents])

        # Format prompt
        prompt = self.prompt_template.format(context=context, question=question)
        
        print(context)
        # Get response from chat model
#        response = self.chat_model(prompt)
        # Parse response
#        result = self.parser.parse(response)
        
    #    chain = prompt=prompt | self.chat_model | parser=self.parser
     #   result = chain.invoke()    
#        dict_context = {"question": question} 
        #chain = ({"context": context,"question":Runnab
        chain =({"context": lambda x: context,"question": RunnablePassthrough()}          
                  | self.prompt_template
                 | self.chat_model
                  | self.parser)
        #
        result = chain.invoke(question)
        return result