import os import uuid import time import shutil from base64 import b64decode from langchain_community.vectorstores import Chroma from langchain.storage import InMemoryStore from langchain.schema.document import Document from langchain_google_genai import GoogleGenerativeAIEmbeddings from langchain.retrievers.multi_vector import MultiVectorRetriever import chromadb from langchain_core.runnables import RunnablePassthrough, RunnableLambda from langchain_core.messages import SystemMessage, HumanMessage from langchain_groq import ChatGroq from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate class RAGService: def __init__(self): self.gemini_key = os.getenv("GOOGLE_API_KEY") self.groq_key = os.getenv("GROQ_API_KEY") # Initialize embeddings self.embeddings = GoogleGenerativeAIEmbeddings( model="models/text-embedding-004", google_api_key=self.gemini_key ) # Setup ChromaDB self.persist_directory = "/app/chromadb" self.vectorstore = None self.store = None self.retriever = None self.chain_with_sources = None self._setup_chromadb() self._setup_retriever() self._setup_chain() def _setup_chromadb(self): """Initialize ChromaDB """ self.vectorstore = Chroma( collection_name="multi_modal_rag_new", embedding_function=self.embeddings, persist_directory=self.persist_directory ) self.store = InMemoryStore() print(f"Number of documents in vectorstore: {self.vectorstore._collection.count()}") print("ChromaDB loaded successfully!") def _setup_retriever(self): """Setup the MultiVectorRetriever""" self.retriever = MultiVectorRetriever( vectorstore=self.vectorstore, docstore=self.store, id_key="doc_id", ) # Load data into docstore collection = self.vectorstore._collection all_data = collection.get(include=['metadatas']) doc_store_pairs = [] for doc_id, metadata in zip(all_data['ids'], all_data['metadatas']): if metadata and 'original_content' in metadata and 'doc_id' in metadata: doc_store_pairs.append((metadata['doc_id'], metadata['original_content'])) if doc_store_pairs: self.store.mset(doc_store_pairs) print(f"Populated docstore with {len(doc_store_pairs)} documents") print(f"Vectorstore count: {self.vectorstore._collection.count()}") print(f"Docstore count: {len(self.store.store)}") print("ChromaDB loaded and ready for querying!") def _setup_chain(self): """Setup the RAG chain""" self.chain_with_sources = { "context": self.retriever | RunnableLambda(self.parse_docs), "question": RunnablePassthrough(), } | RunnablePassthrough().assign( response=( RunnableLambda(self.build_prompt) | ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key) | StrOutputParser() ) ) def parse_docs(self, docs): """Split base64-encoded images and texts""" b64 = [] text = [] for doc in docs: try: b64decode(doc) b64.append(doc) except Exception as e: text.append(doc) return {"images": b64, "texts": text} def build_prompt(self, kwargs): """Build prompt with context and images""" docs_by_type = kwargs["context"] user_question = kwargs["question"] context_text = "" prompt_content = [] if len(docs_by_type["texts"]) > 0: for text_element in docs_by_type["texts"]: context_text += str(text_element) # Add images only if context exists if len(docs_by_type["images"]) > 0: for image in docs_by_type["images"]: prompt_content.append( { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image}"}, } ) # Always use this flexible prompt prompt_template = f""" You are a helpful AI assistant. Context from documents (use if relevant): {context_text} Question: {user_question} Instructions: Answer the question. If the provided context is relevant to the question, use it. If not, answer using your general knowledge. """ prompt_content = [{"type": "text", "text": prompt_template}] return ChatPromptTemplate.from_messages([HumanMessage(content=prompt_content)]) else: # Generic question with no context or images prompt_template = f""" You are a helpful AI assistant. Answer the following question using your general knowledge: Question: {user_question} """ return ChatPromptTemplate.from_messages( [HumanMessage(content=prompt_template.strip())] # plain string ) def ask_question(self, question: str): """Process a question and return response""" try: # Check if RAG retrieval finds relevant context context_length = self._check_context_length(question) if context_length >= 0: # Get the retrieved context for potential clarification retrieved_docs = self.retriever.invoke(question) parsed_context = self.parse_docs(retrieved_docs) # Build context text context_text = "" if len(parsed_context["texts"]) > 0: for text_element in parsed_context["texts"]: context_text += str(text_element) # First, try to get a normal response try: response = self.chain_with_sources.invoke(question) result = response.get('response') if response else None # Check if response is None or invalid if self._is_response_invalid(result): return self._generate_counter_questions(question, context_text) return result except Exception as e: # If RAG fails, try to generate counter questions from context return self._generate_counter_questions(question, context_text) else: # Direct LLM call for questions without relevant context llm = ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key) response = llm.invoke(question) return response.content except Exception as e: print(f"Error in ask_question: {e}") return f"I encountered an error processing your question. Could you please rephrase it more clearly?" def _is_response_invalid(self, response): """Check if the response is None or invalid""" # Check if response is None, empty, or too short if response is None: return True if not response or len(response.strip()) < 5: return True return False def _generate_counter_questions(self, original_question, context_text): """Generate counter questions based on retrieved context""" try: llm = ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key) counter_question_prompt = f""" The user asked: "{original_question}" Based on the following context from documents: {context_text} The question seems ambiguous. Generate 2-3 specific counter questions to help clarify what the user is asking about, using the context provided. Format your response exactly like this: "Your question seems ambiguous. Are you asking about: 1. [Specific question based on context] 2. [Another specific question based on context] 3. [Third specific question based on context] Please choose one of these options or rephrase your question more specifically." Make sure the counter questions are directly related to the content in the context. """ response = llm.invoke(counter_question_prompt) return response.content except Exception as e: return f"Your question seems unclear based on the available information. Could you please be more specific about what you're looking for?" def _check_context_length(self, question: str): """Check if RAG retrieval returns meaningful context""" try: # Get retrieved documents retrieved_docs = self.retriever.invoke(question) # Parse the documents parsed_context = self.parse_docs(retrieved_docs) # Check context length context_text = "" if len(parsed_context["texts"]) > 0: for text_element in parsed_context["texts"]: context_text += str(text_element) return len(context_text.strip()) except Exception as e: print(f"Error checking context length: {e}") return 0 # Create a global instance rag_service = RAGService()