| import os |
| import uuid |
| import time |
| import shutil |
| from base64 import b64decode |
| from langchain_community.vectorstores import Chroma |
| from langchain.storage import InMemoryStore |
| from langchain.schema.document import Document |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings |
| from langchain.retrievers.multi_vector import MultiVectorRetriever |
| import chromadb |
| from langchain_core.runnables import RunnablePassthrough, RunnableLambda |
| from langchain_core.messages import SystemMessage, HumanMessage |
| from langchain_groq import ChatGroq |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.prompts import ChatPromptTemplate |
|
|
|
|
| class RAGService: |
| def __init__(self): |
| self.gemini_key = os.getenv("GOOGLE_API_KEY") |
| self.groq_key = os.getenv("GROQ_API_KEY") |
|
|
| |
| self.embeddings = GoogleGenerativeAIEmbeddings( |
| model="models/text-embedding-004", |
| google_api_key=self.gemini_key |
| ) |
|
|
| |
| self.persist_directory = "/app/chromadb" |
| self.vectorstore = None |
| self.store = None |
| self.retriever = None |
| self.chain_with_sources = None |
|
|
| self._setup_chromadb() |
| self._setup_retriever() |
| self._setup_chain() |
|
|
| def _setup_chromadb(self): |
| """Initialize ChromaDB """ |
|
|
|
|
| self.vectorstore = Chroma( |
| collection_name="multi_modal_rag_new", |
| embedding_function=self.embeddings, |
| persist_directory=self.persist_directory |
| ) |
|
|
| self.store = InMemoryStore() |
|
|
| print(f"Number of documents in vectorstore: {self.vectorstore._collection.count()}") |
| print("ChromaDB loaded successfully!") |
|
|
| def _setup_retriever(self): |
| """Setup the MultiVectorRetriever""" |
| self.retriever = MultiVectorRetriever( |
| vectorstore=self.vectorstore, |
| docstore=self.store, |
| id_key="doc_id", |
| ) |
|
|
| |
| collection = self.vectorstore._collection |
| all_data = collection.get(include=['metadatas']) |
|
|
| doc_store_pairs = [] |
| for doc_id, metadata in zip(all_data['ids'], all_data['metadatas']): |
| if metadata and 'original_content' in metadata and 'doc_id' in metadata: |
| doc_store_pairs.append((metadata['doc_id'], metadata['original_content'])) |
|
|
| if doc_store_pairs: |
| self.store.mset(doc_store_pairs) |
| print(f"Populated docstore with {len(doc_store_pairs)} documents") |
|
|
| print(f"Vectorstore count: {self.vectorstore._collection.count()}") |
| print(f"Docstore count: {len(self.store.store)}") |
| print("ChromaDB loaded and ready for querying!") |
|
|
| def _setup_chain(self): |
| """Setup the RAG chain""" |
| self.chain_with_sources = { |
| "context": self.retriever | RunnableLambda(self.parse_docs), |
| "question": RunnablePassthrough(), |
| } | RunnablePassthrough().assign( |
| response=( |
| RunnableLambda(self.build_prompt) |
| | ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key) |
| | StrOutputParser() |
| ) |
| ) |
|
|
| def parse_docs(self, docs): |
| """Split base64-encoded images and texts""" |
| b64 = [] |
| text = [] |
| for doc in docs: |
| try: |
| b64decode(doc) |
| b64.append(doc) |
| except Exception as e: |
| text.append(doc) |
| return {"images": b64, "texts": text} |
|
|
|
|
| def build_prompt(self, kwargs): |
| """Build prompt with context and images""" |
| docs_by_type = kwargs["context"] |
| user_question = kwargs["question"] |
| |
| context_text = "" |
| prompt_content = [] |
| if len(docs_by_type["texts"]) > 0: |
| for text_element in docs_by_type["texts"]: |
| context_text += str(text_element) |
| |
| |
| if len(docs_by_type["images"]) > 0: |
| for image in docs_by_type["images"]: |
| prompt_content.append( |
| { |
| "type": "image_url", |
| "image_url": {"url": f"data:image/jpeg;base64,{image}"}, |
| } |
| ) |
| |
| |
| prompt_template = f""" |
| You are a helpful AI assistant. |
| |
| Context from documents (use if relevant): {context_text} |
| |
| Question: {user_question} |
| |
| Instructions: Answer the question. If the provided context is relevant to the question, use it. If not, answer using your general knowledge. |
| """ |
| |
| prompt_content = [{"type": "text", "text": prompt_template}] |
| |
| return ChatPromptTemplate.from_messages([HumanMessage(content=prompt_content)]) |
|
|
| else: |
| |
| prompt_template = f""" |
| You are a helpful AI assistant. Answer the following question using your general knowledge: |
| Question: {user_question} |
| """ |
|
|
| return ChatPromptTemplate.from_messages( |
| [HumanMessage(content=prompt_template.strip())] |
| ) |
|
|
| def ask_question(self, question: str): |
| """Process a question and return response""" |
| try: |
| |
| context_length = self._check_context_length(question) |
| |
| if context_length >= 0: |
| |
| retrieved_docs = self.retriever.invoke(question) |
| parsed_context = self.parse_docs(retrieved_docs) |
| |
| |
| context_text = "" |
| if len(parsed_context["texts"]) > 0: |
| for text_element in parsed_context["texts"]: |
| context_text += str(text_element) |
| |
| |
| try: |
| response = self.chain_with_sources.invoke(question) |
| result = response.get('response') if response else None |
| |
| |
| if self._is_response_invalid(result): |
| return self._generate_counter_questions(question, context_text) |
| |
| return result |
| |
| except Exception as e: |
| |
| return self._generate_counter_questions(question, context_text) |
| else: |
| |
| llm = ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key) |
| response = llm.invoke(question) |
| return response.content |
| |
| except Exception as e: |
| print(f"Error in ask_question: {e}") |
| return f"I encountered an error processing your question. Could you please rephrase it more clearly?" |
|
|
| def _is_response_invalid(self, response): |
| """Check if the response is None or invalid""" |
| |
| if response is None: |
| return True |
| if not response or len(response.strip()) < 5: |
| return True |
| |
| return False |
|
|
| def _generate_counter_questions(self, original_question, context_text): |
| """Generate counter questions based on retrieved context""" |
| try: |
| llm = ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key) |
| |
| counter_question_prompt = f""" |
| The user asked: "{original_question}" |
| |
| Based on the following context from documents: |
| {context_text} |
| |
| The question seems ambiguous. Generate 2-3 specific counter questions to help clarify what the user is asking about, using the context provided. |
| |
| Format your response exactly like this: |
| "Your question seems ambiguous. Are you asking about: |
| |
| 1. [Specific question based on context] |
| 2. [Another specific question based on context] |
| 3. [Third specific question based on context] |
| |
| Please choose one of these options or rephrase your question more specifically." |
| |
| Make sure the counter questions are directly related to the content in the context. |
| """ |
| |
| response = llm.invoke(counter_question_prompt) |
| return response.content |
| |
| except Exception as e: |
| return f"Your question seems unclear based on the available information. Could you please be more specific about what you're looking for?" |
|
|
| def _check_context_length(self, question: str): |
| """Check if RAG retrieval returns meaningful context""" |
| try: |
| |
| retrieved_docs = self.retriever.invoke(question) |
| |
| |
| parsed_context = self.parse_docs(retrieved_docs) |
| |
| |
| context_text = "" |
| if len(parsed_context["texts"]) > 0: |
| for text_element in parsed_context["texts"]: |
| context_text += str(text_element) |
| |
| return len(context_text.strip()) |
| except Exception as e: |
| print(f"Error checking context length: {e}") |
| return 0 |
|
|
| |
|
|
| |
| rag_service = RAGService() |