Amna2024 commited on
Commit
ebe8786
·
verified ·
1 Parent(s): 7bd0f22

Create rag_service.py

Browse files
Files changed (1) hide show
  1. rag_service.py +146 -0
rag_service.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import time
4
+ import shutil
5
+ from base64 import b64decode
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain.storage import InMemoryStore
8
+ from langchain.schema.document import Document
9
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
10
+ from langchain.retrievers.multi_vector import MultiVectorRetriever
11
+ import chromadb
12
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda
13
+ from langchain_core.messages import SystemMessage, HumanMessage
14
+ from langchain_groq import ChatGroq
15
+ from langchain_core.output_parsers import StrOutputParser
16
+ from langchain_core.prompts import ChatPromptTemplate
17
+
18
+
19
+ class RAGService:
20
+ def __init__(self):
21
+ self.gemini_key = os.getenv("GEMINI_API_KEY")
22
+ self.groq_key = os.getenv("GROQ_API_KEY")
23
+
24
+ # Initialize embeddings
25
+ self.embeddings = GoogleGenerativeAIEmbeddings(
26
+ model="models/text-embedding-004",
27
+ google_api_key=self.gemini_key
28
+ )
29
+
30
+ # Setup ChromaDB
31
+ self.persist_directory = "/app/chromadb"
32
+ self.vectorstore = None
33
+ self.store = None
34
+ self.retriever = None
35
+ self.chain_with_sources = None
36
+
37
+ self._setup_chromadb()
38
+ self._setup_retriever()
39
+ self._setup_chain()
40
+
41
+ def _setup_chromadb(self):
42
+ """Initialize ChromaDB """
43
+
44
+
45
+ self.vectorstore = Chroma(
46
+ collection_name="multi_modal_rag_new",
47
+ embedding_function=self.embeddings,
48
+ persist_directory=self.persist_directory
49
+ )
50
+
51
+ self.store = InMemoryStore()
52
+
53
+ print(f"Number of documents in vectorstore: {self.vectorstore._collection.count()}")
54
+ print("ChromaDB loaded successfully!")
55
+
56
+ def _setup_retriever(self):
57
+ """Setup the MultiVectorRetriever"""
58
+ self.retriever = MultiVectorRetriever(
59
+ vectorstore=self.vectorstore,
60
+ docstore=self.store,
61
+ id_key="doc_id",
62
+ )
63
+
64
+ # Load data into docstore
65
+ collection = self.vectorstore._collection
66
+ all_data = collection.get(include=['metadatas'])
67
+
68
+ doc_store_pairs = []
69
+ for doc_id, metadata in zip(all_data['ids'], all_data['metadatas']):
70
+ if metadata and 'original_content' in metadata and 'doc_id' in metadata:
71
+ doc_store_pairs.append((metadata['doc_id'], metadata['original_content']))
72
+
73
+ if doc_store_pairs:
74
+ self.store.mset(doc_store_pairs)
75
+ print(f"Populated docstore with {len(doc_store_pairs)} documents")
76
+
77
+ print(f"Vectorstore count: {self.vectorstore._collection.count()}")
78
+ print(f"Docstore count: {len(self.store.store)}")
79
+ print("ChromaDB loaded and ready for querying!")
80
+
81
+ def _setup_chain(self):
82
+ """Setup the RAG chain"""
83
+ self.chain_with_sources = {
84
+ "context": self.retriever | RunnableLambda(self.parse_docs),
85
+ "question": RunnablePassthrough(),
86
+ } | RunnablePassthrough().assign(
87
+ response=(
88
+ RunnableLambda(self.build_prompt)
89
+ | ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key)
90
+ | StrOutputParser()
91
+ )
92
+ )
93
+
94
+ def parse_docs(self, docs):
95
+ """Split base64-encoded images and texts"""
96
+ b64 = []
97
+ text = []
98
+ for doc in docs:
99
+ try:
100
+ b64decode(doc)
101
+ b64.append(doc)
102
+ except Exception as e:
103
+ text.append(doc)
104
+ return {"images": b64, "texts": text}
105
+
106
+ def build_prompt(self, kwargs):
107
+ """Build prompt with context and images"""
108
+ docs_by_type = kwargs["context"]
109
+ user_question = kwargs["question"]
110
+
111
+ context_text = ""
112
+ if len(docs_by_type["texts"]) > 0:
113
+ for text_element in docs_by_type["texts"]:
114
+ context_text += str(text_element)
115
+
116
+ prompt_template = f"""
117
+ Answer the question based only on the following context, which can include text, tables, and the below image.
118
+ Context: {context_text}
119
+ Question: {user_question}
120
+ """
121
+
122
+ prompt_content = [{"type": "text", "text": prompt_template}]
123
+
124
+ if len(docs_by_type["images"]) > 0:
125
+ for image in docs_by_type["images"]:
126
+ prompt_content.append(
127
+ {
128
+ "type": "image_url",
129
+ "image_url": {"url": f"data:image/jpeg;base64,{image}"},
130
+ }
131
+ )
132
+
133
+ return ChatPromptTemplate.from_messages(
134
+ [
135
+ HumanMessage(content=prompt_content),
136
+ ]
137
+ )
138
+
139
+ def ask_question(self, question: str):
140
+ """Process a question and return response"""
141
+ response = self.chain_with_sources.invoke(question)
142
+ return response['response']
143
+
144
+
145
+ # Create a global instance
146
+ rag_service = RAGService()