Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import time | |
| import shutil | |
| from base64 import b64decode | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.storage import InMemoryStore | |
| from langchain.schema.document import Document | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
| from langchain.retrievers.multi_vector import MultiVectorRetriever | |
| import chromadb | |
| from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_groq import ChatGroq | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| class RAGService: | |
| def __init__(self): | |
| self.gemini_key = os.getenv("GOOGLE_API_KEY") | |
| self.groq_key = os.getenv("GROQ_API_KEY") | |
| # Initialize embeddings | |
| self.embeddings = GoogleGenerativeAIEmbeddings( | |
| model="models/text-embedding-004", | |
| google_api_key=self.gemini_key | |
| ) | |
| # Setup ChromaDB | |
| self.persist_directory = "/app/chromadb" | |
| self.vectorstore = None | |
| self.store = None | |
| self.retriever = None | |
| self.chain_with_sources = None | |
| self._setup_chromadb() | |
| self._setup_retriever() | |
| self._setup_chain() | |
| def _setup_chromadb(self): | |
| """Initialize ChromaDB """ | |
| self.vectorstore = Chroma( | |
| collection_name="multi_modal_rag_new", | |
| embedding_function=self.embeddings, | |
| persist_directory=self.persist_directory | |
| ) | |
| self.store = InMemoryStore() | |
| print(f"Number of documents in vectorstore: {self.vectorstore._collection.count()}") | |
| print("ChromaDB loaded successfully!") | |
| def _setup_retriever(self): | |
| """Setup the MultiVectorRetriever""" | |
| self.retriever = MultiVectorRetriever( | |
| vectorstore=self.vectorstore, | |
| docstore=self.store, | |
| id_key="doc_id", | |
| ) | |
| # Load data into docstore | |
| collection = self.vectorstore._collection | |
| all_data = collection.get(include=['metadatas']) | |
| doc_store_pairs = [] | |
| for doc_id, metadata in zip(all_data['ids'], all_data['metadatas']): | |
| if metadata and 'original_content' in metadata and 'doc_id' in metadata: | |
| doc_store_pairs.append((metadata['doc_id'], metadata['original_content'])) | |
| if doc_store_pairs: | |
| self.store.mset(doc_store_pairs) | |
| print(f"Populated docstore with {len(doc_store_pairs)} documents") | |
| print(f"Vectorstore count: {self.vectorstore._collection.count()}") | |
| print(f"Docstore count: {len(self.store.store)}") | |
| print("ChromaDB loaded and ready for querying!") | |
| def _setup_chain(self): | |
| """Setup the RAG chain""" | |
| self.chain_with_sources = { | |
| "context": self.retriever | RunnableLambda(self.parse_docs), | |
| "question": RunnablePassthrough(), | |
| } | RunnablePassthrough().assign( | |
| response=( | |
| RunnableLambda(self.build_prompt) | |
| | ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key) | |
| | StrOutputParser() | |
| ) | |
| ) | |
| def parse_docs(self, docs): | |
| """Split base64-encoded images and texts""" | |
| b64 = [] | |
| text = [] | |
| for doc in docs: | |
| try: | |
| b64decode(doc) | |
| b64.append(doc) | |
| except Exception as e: | |
| text.append(doc) | |
| return {"images": b64, "texts": text} | |
| def build_prompt(self, kwargs): | |
| """Build prompt with context and images""" | |
| docs_by_type = kwargs["context"] | |
| user_question = kwargs["question"] | |
| context_text = "" | |
| prompt_content = [] | |
| if len(docs_by_type["texts"]) > 0: | |
| for text_element in docs_by_type["texts"]: | |
| context_text += str(text_element) | |
| # Add images only if context exists | |
| if len(docs_by_type["images"]) > 0: | |
| for image in docs_by_type["images"]: | |
| prompt_content.append( | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{image}"}, | |
| } | |
| ) | |
| # Always use this flexible prompt | |
| prompt_template = f""" | |
| You are a helpful AI assistant. | |
| Context from documents (use if relevant): {context_text} | |
| Question: {user_question} | |
| Instructions: Answer the question. If the provided context is relevant to the question, use it. If not, answer using your general knowledge. | |
| """ | |
| prompt_content = [{"type": "text", "text": prompt_template}] | |
| return ChatPromptTemplate.from_messages([HumanMessage(content=prompt_content)]) | |
| else: | |
| # Generic question with no context or images | |
| prompt_template = f""" | |
| You are a helpful AI assistant. Answer the following question using your general knowledge: | |
| Question: {user_question} | |
| """ | |
| return ChatPromptTemplate.from_messages( | |
| [HumanMessage(content=prompt_template.strip())] # plain string | |
| ) | |
| def ask_question(self, question: str): | |
| """Process a question and return response""" | |
| try: | |
| # Check if RAG retrieval finds relevant context | |
| context_length = self._check_context_length(question) | |
| if context_length >= 0: | |
| # Get the retrieved context for potential clarification | |
| retrieved_docs = self.retriever.invoke(question) | |
| parsed_context = self.parse_docs(retrieved_docs) | |
| # Build context text | |
| context_text = "" | |
| if len(parsed_context["texts"]) > 0: | |
| for text_element in parsed_context["texts"]: | |
| context_text += str(text_element) | |
| # First, try to get a normal response | |
| try: | |
| response = self.chain_with_sources.invoke(question) | |
| result = response.get('response') if response else None | |
| # Check if response is None or invalid | |
| if self._is_response_invalid(result): | |
| return self._generate_counter_questions(question, context_text) | |
| return result | |
| except Exception as e: | |
| # If RAG fails, try to generate counter questions from context | |
| return self._generate_counter_questions(question, context_text) | |
| else: | |
| # Direct LLM call for questions without relevant context | |
| llm = ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key) | |
| response = llm.invoke(question) | |
| return response.content | |
| except Exception as e: | |
| print(f"Error in ask_question: {e}") | |
| return f"I encountered an error processing your question. Could you please rephrase it more clearly?" | |
| def _is_response_invalid(self, response): | |
| """Check if the response is None or invalid""" | |
| # Check if response is None, empty, or too short | |
| if response is None: | |
| return True | |
| if not response or len(response.strip()) < 5: | |
| return True | |
| return False | |
| def _generate_counter_questions(self, original_question, context_text): | |
| """Generate counter questions based on retrieved context""" | |
| try: | |
| llm = ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key) | |
| counter_question_prompt = f""" | |
| The user asked: "{original_question}" | |
| Based on the following context from documents: | |
| {context_text} | |
| The question seems ambiguous. Generate 2-3 specific counter questions to help clarify what the user is asking about, using the context provided. | |
| Format your response exactly like this: | |
| "Your question seems ambiguous. Are you asking about: | |
| 1. [Specific question based on context] | |
| 2. [Another specific question based on context] | |
| 3. [Third specific question based on context] | |
| Please choose one of these options or rephrase your question more specifically." | |
| Make sure the counter questions are directly related to the content in the context. | |
| """ | |
| response = llm.invoke(counter_question_prompt) | |
| return response.content | |
| except Exception as e: | |
| return f"Your question seems unclear based on the available information. Could you please be more specific about what you're looking for?" | |
| def _check_context_length(self, question: str): | |
| """Check if RAG retrieval returns meaningful context""" | |
| try: | |
| # Get retrieved documents | |
| retrieved_docs = self.retriever.invoke(question) | |
| # Parse the documents | |
| parsed_context = self.parse_docs(retrieved_docs) | |
| # Check context length | |
| context_text = "" | |
| if len(parsed_context["texts"]) > 0: | |
| for text_element in parsed_context["texts"]: | |
| context_text += str(text_element) | |
| return len(context_text.strip()) | |
| except Exception as e: | |
| print(f"Error checking context length: {e}") | |
| return 0 | |
| # Create a global instance | |
| rag_service = RAGService() |