File size: 10,056 Bytes
ebe8786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2b8912
ebe8786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99330c6
ebe8786
9134ea2
 
 
99330c6
9134ea2
d72aee6
9134ea2
 
 
99330c6
e59fcc3
 
 
 
 
 
 
 
 
 
f391abf
 
 
 
 
 
 
 
 
 
9134ea2
f391abf
e59fcc3
 
f391abf
 
e59fcc3
f391abf
 
 
 
 
e59fcc3
 
 
ebe8786
 
6e23909
c07fdcc
5df0348
 
c07fdcc
d5b4a47
5df0348
 
 
 
 
 
 
 
 
 
 
 
 
6a6bb58
5df0348
6a6bb58
 
5df0348
 
 
 
 
 
 
 
 
 
 
 
 
c07fdcc
 
5df0348
d72aee6
6a6bb58
 
 
 
 
eab31fc
5df0348
 
6a6bb58
5df0348
 
 
d72aee6
 
 
5df0348
 
 
 
 
d72aee6
5df0348
d72aee6
5df0348
 
d72aee6
5df0348
 
 
d72aee6
5df0348
d72aee6
5df0348
d72aee6
 
5df0348
d72aee6
 
 
5df0348
ebe8786
5df0348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebe8786
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import os
import uuid
import time
import shutil
from base64 import b64decode
from langchain_community.vectorstores import Chroma
from langchain.storage import InMemoryStore
from langchain.schema.document import Document
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
import chromadb
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_groq import ChatGroq
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate


class RAGService:
    def __init__(self):
        self.gemini_key = os.getenv("GOOGLE_API_KEY")
        self.groq_key = os.getenv("GROQ_API_KEY")

        # Initialize embeddings
        self.embeddings = GoogleGenerativeAIEmbeddings(
            model="models/text-embedding-004",
            google_api_key=self.gemini_key
        )

        # Setup ChromaDB
        self.persist_directory = "/app/chromadb"
        self.vectorstore = None
        self.store = None
        self.retriever = None
        self.chain_with_sources = None

        self._setup_chromadb()
        self._setup_retriever()
        self._setup_chain()

    def _setup_chromadb(self):
        """Initialize ChromaDB """


        self.vectorstore = Chroma(
            collection_name="multi_modal_rag_new",
            embedding_function=self.embeddings,
            persist_directory=self.persist_directory
        )

        self.store = InMemoryStore()

        print(f"Number of documents in vectorstore: {self.vectorstore._collection.count()}")
        print("ChromaDB loaded successfully!")

    def _setup_retriever(self):
        """Setup the MultiVectorRetriever"""
        self.retriever = MultiVectorRetriever(
            vectorstore=self.vectorstore,
            docstore=self.store,
            id_key="doc_id",
        )

        # Load data into docstore
        collection = self.vectorstore._collection
        all_data = collection.get(include=['metadatas'])

        doc_store_pairs = []
        for doc_id, metadata in zip(all_data['ids'], all_data['metadatas']):
            if metadata and 'original_content' in metadata and 'doc_id' in metadata:
                doc_store_pairs.append((metadata['doc_id'], metadata['original_content']))

        if doc_store_pairs:
            self.store.mset(doc_store_pairs)
            print(f"Populated docstore with {len(doc_store_pairs)} documents")

        print(f"Vectorstore count: {self.vectorstore._collection.count()}")
        print(f"Docstore count: {len(self.store.store)}")
        print("ChromaDB loaded and ready for querying!")

    def _setup_chain(self):
        """Setup the RAG chain"""
        self.chain_with_sources = {
                                      "context": self.retriever | RunnableLambda(self.parse_docs),
                                      "question": RunnablePassthrough(),
                                  } | RunnablePassthrough().assign(
            response=(
                    RunnableLambda(self.build_prompt)
                    | ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key)
                    | StrOutputParser()
            )
        )

    def parse_docs(self, docs):
        """Split base64-encoded images and texts"""
        b64 = []
        text = []
        for doc in docs:
            try:
                b64decode(doc)
                b64.append(doc)
            except Exception as e:
                text.append(doc)
        return {"images": b64, "texts": text}


    def build_prompt(self, kwargs):
        """Build prompt with context and images"""
        docs_by_type = kwargs["context"]
        user_question = kwargs["question"]
    
        context_text = ""
        prompt_content = []
        if len(docs_by_type["texts"]) > 0:
            for text_element in docs_by_type["texts"]:
                context_text += str(text_element)
    
            # Add images only if context exists
            if len(docs_by_type["images"]) > 0:
                for image in docs_by_type["images"]:
                    prompt_content.append(
                        {
                            "type": "image_url",
                            "image_url": {"url": f"data:image/jpeg;base64,{image}"},
                        }
                    )
            
            # Always use this flexible prompt
            prompt_template = f"""
            You are a helpful AI assistant. 
            
            Context from documents (use if relevant): {context_text}
            
            Question: {user_question}
            
            Instructions: Answer the question. If the provided context is relevant to the question, use it. If not, answer using your general knowledge.
            """
        
            prompt_content = [{"type": "text", "text": prompt_template}]
            
            return ChatPromptTemplate.from_messages([HumanMessage(content=prompt_content)])

        else:
            # Generic question with no context or images
            prompt_template = f"""
            You are a helpful AI assistant. Answer the following question using your general knowledge:
            Question: {user_question}
            """

            return ChatPromptTemplate.from_messages(
                [HumanMessage(content=prompt_template.strip())]  # plain string
            )

    def ask_question(self, question: str):
        """Process a question and return response"""
        try:
            # Check if RAG retrieval finds relevant context
            context_length = self._check_context_length(question)
            
            if context_length >= 0:
                # Get the retrieved context for potential clarification
                retrieved_docs = self.retriever.invoke(question)
                parsed_context = self.parse_docs(retrieved_docs)
                
                # Build context text
                context_text = ""
                if len(parsed_context["texts"]) > 0:
                    for text_element in parsed_context["texts"]:
                        context_text += str(text_element)
                
                # First, try to get a normal response
                try:
                    response = self.chain_with_sources.invoke(question)
                    result = response.get('response') if response else None
                    
                    # Check if response is None or invalid
                    if self._is_response_invalid(result):
                        return self._generate_counter_questions(question, context_text)
                    
                    return result
                    
                except Exception as e:
                    # If RAG fails, try to generate counter questions from context
                    return self._generate_counter_questions(question, context_text)
            else:
                # Direct LLM call for questions without relevant context
                llm = ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key)
                response = llm.invoke(question)
                return response.content
                
        except Exception as e:
            print(f"Error in ask_question: {e}")
            return f"I encountered an error processing your question. Could you please rephrase it more clearly?"

    def _is_response_invalid(self, response):
        """Check if the response is None or invalid"""
        # Check if response is None, empty, or too short
        if response is None:
            return True
        if not response or len(response.strip()) < 5:
            return True
        
        return False

    def _generate_counter_questions(self, original_question, context_text):
        """Generate counter questions based on retrieved context"""
        try:
            llm = ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key)
            
            counter_question_prompt = f"""
            The user asked: "{original_question}"
            
            Based on the following context from documents:
            {context_text}
            
            The question seems ambiguous. Generate 2-3 specific counter questions to help clarify what the user is asking about, using the context provided.
            
            Format your response exactly like this:
            "Your question seems ambiguous. Are you asking about:
            
            1. [Specific question based on context]
            2. [Another specific question based on context]  
            3. [Third specific question based on context]
            
            Please choose one of these options or rephrase your question more specifically."
            
            Make sure the counter questions are directly related to the content in the context.
            """
            
            response = llm.invoke(counter_question_prompt)
            return response.content
            
        except Exception as e:
            return f"Your question seems unclear based on the available information. Could you please be more specific about what you're looking for?"

    def _check_context_length(self, question: str):
        """Check if RAG retrieval returns meaningful context"""
        try:
            # Get retrieved documents
            retrieved_docs = self.retriever.invoke(question)
            
            # Parse the documents
            parsed_context = self.parse_docs(retrieved_docs)
            
            # Check context length
            context_text = ""
            if len(parsed_context["texts"]) > 0:
                for text_element in parsed_context["texts"]:
                    context_text += str(text_element)
            
            return len(context_text.strip())
        except Exception as e:
            print(f"Error checking context length: {e}")
            return 0

    

# Create a global instance
rag_service = RAGService()