RAG-API / rag_service.py
Amna2024's picture
Update rag_service.py
eab31fc verified
import os
import uuid
import time
import shutil
from base64 import b64decode
from langchain_community.vectorstores import Chroma
from langchain.storage import InMemoryStore
from langchain.schema.document import Document
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
import chromadb
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_groq import ChatGroq
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
class RAGService:
def __init__(self):
self.gemini_key = os.getenv("GOOGLE_API_KEY")
self.groq_key = os.getenv("GROQ_API_KEY")
# Initialize embeddings
self.embeddings = GoogleGenerativeAIEmbeddings(
model="models/text-embedding-004",
google_api_key=self.gemini_key
)
# Setup ChromaDB
self.persist_directory = "/app/chromadb"
self.vectorstore = None
self.store = None
self.retriever = None
self.chain_with_sources = None
self._setup_chromadb()
self._setup_retriever()
self._setup_chain()
def _setup_chromadb(self):
"""Initialize ChromaDB """
self.vectorstore = Chroma(
collection_name="multi_modal_rag_new",
embedding_function=self.embeddings,
persist_directory=self.persist_directory
)
self.store = InMemoryStore()
print(f"Number of documents in vectorstore: {self.vectorstore._collection.count()}")
print("ChromaDB loaded successfully!")
def _setup_retriever(self):
"""Setup the MultiVectorRetriever"""
self.retriever = MultiVectorRetriever(
vectorstore=self.vectorstore,
docstore=self.store,
id_key="doc_id",
)
# Load data into docstore
collection = self.vectorstore._collection
all_data = collection.get(include=['metadatas'])
doc_store_pairs = []
for doc_id, metadata in zip(all_data['ids'], all_data['metadatas']):
if metadata and 'original_content' in metadata and 'doc_id' in metadata:
doc_store_pairs.append((metadata['doc_id'], metadata['original_content']))
if doc_store_pairs:
self.store.mset(doc_store_pairs)
print(f"Populated docstore with {len(doc_store_pairs)} documents")
print(f"Vectorstore count: {self.vectorstore._collection.count()}")
print(f"Docstore count: {len(self.store.store)}")
print("ChromaDB loaded and ready for querying!")
def _setup_chain(self):
"""Setup the RAG chain"""
self.chain_with_sources = {
"context": self.retriever | RunnableLambda(self.parse_docs),
"question": RunnablePassthrough(),
} | RunnablePassthrough().assign(
response=(
RunnableLambda(self.build_prompt)
| ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key)
| StrOutputParser()
)
)
def parse_docs(self, docs):
"""Split base64-encoded images and texts"""
b64 = []
text = []
for doc in docs:
try:
b64decode(doc)
b64.append(doc)
except Exception as e:
text.append(doc)
return {"images": b64, "texts": text}
def build_prompt(self, kwargs):
"""Build prompt with context and images"""
docs_by_type = kwargs["context"]
user_question = kwargs["question"]
context_text = ""
prompt_content = []
if len(docs_by_type["texts"]) > 0:
for text_element in docs_by_type["texts"]:
context_text += str(text_element)
# Add images only if context exists
if len(docs_by_type["images"]) > 0:
for image in docs_by_type["images"]:
prompt_content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
}
)
# Always use this flexible prompt
prompt_template = f"""
You are a helpful AI assistant.
Context from documents (use if relevant): {context_text}
Question: {user_question}
Instructions: Answer the question. If the provided context is relevant to the question, use it. If not, answer using your general knowledge.
"""
prompt_content = [{"type": "text", "text": prompt_template}]
return ChatPromptTemplate.from_messages([HumanMessage(content=prompt_content)])
else:
# Generic question with no context or images
prompt_template = f"""
You are a helpful AI assistant. Answer the following question using your general knowledge:
Question: {user_question}
"""
return ChatPromptTemplate.from_messages(
[HumanMessage(content=prompt_template.strip())] # plain string
)
def ask_question(self, question: str):
"""Process a question and return response"""
try:
# Check if RAG retrieval finds relevant context
context_length = self._check_context_length(question)
if context_length >= 0:
# Get the retrieved context for potential clarification
retrieved_docs = self.retriever.invoke(question)
parsed_context = self.parse_docs(retrieved_docs)
# Build context text
context_text = ""
if len(parsed_context["texts"]) > 0:
for text_element in parsed_context["texts"]:
context_text += str(text_element)
# First, try to get a normal response
try:
response = self.chain_with_sources.invoke(question)
result = response.get('response') if response else None
# Check if response is None or invalid
if self._is_response_invalid(result):
return self._generate_counter_questions(question, context_text)
return result
except Exception as e:
# If RAG fails, try to generate counter questions from context
return self._generate_counter_questions(question, context_text)
else:
# Direct LLM call for questions without relevant context
llm = ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key)
response = llm.invoke(question)
return response.content
except Exception as e:
print(f"Error in ask_question: {e}")
return f"I encountered an error processing your question. Could you please rephrase it more clearly?"
def _is_response_invalid(self, response):
"""Check if the response is None or invalid"""
# Check if response is None, empty, or too short
if response is None:
return True
if not response or len(response.strip()) < 5:
return True
return False
def _generate_counter_questions(self, original_question, context_text):
"""Generate counter questions based on retrieved context"""
try:
llm = ChatGroq(model="llama-3.1-8b-instant", groq_api_key=self.groq_key)
counter_question_prompt = f"""
The user asked: "{original_question}"
Based on the following context from documents:
{context_text}
The question seems ambiguous. Generate 2-3 specific counter questions to help clarify what the user is asking about, using the context provided.
Format your response exactly like this:
"Your question seems ambiguous. Are you asking about:
1. [Specific question based on context]
2. [Another specific question based on context]
3. [Third specific question based on context]
Please choose one of these options or rephrase your question more specifically."
Make sure the counter questions are directly related to the content in the context.
"""
response = llm.invoke(counter_question_prompt)
return response.content
except Exception as e:
return f"Your question seems unclear based on the available information. Could you please be more specific about what you're looking for?"
def _check_context_length(self, question: str):
"""Check if RAG retrieval returns meaningful context"""
try:
# Get retrieved documents
retrieved_docs = self.retriever.invoke(question)
# Parse the documents
parsed_context = self.parse_docs(retrieved_docs)
# Check context length
context_text = ""
if len(parsed_context["texts"]) > 0:
for text_element in parsed_context["texts"]:
context_text += str(text_element)
return len(context_text.strip())
except Exception as e:
print(f"Error checking context length: {e}")
return 0
# Create a global instance
rag_service = RAGService()