Multi_Modal_RAG / query_service.py
Sameer-Handsome173's picture
Update query_service.py
c56a43d verified
import os
import json
import requests
import base64
import re
from fastapi import FastAPI
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_core.documents import Document
# ───────────────────────────────────────────────
# Configuration
# ───────────────────────────────────────────────
VECTOR_PATH = "./vectorstore/faiss_index"
DOCSTORE_PATH = "./docstore"
FINAL_ANSWER_URL = "https://sameer-handsome173-multi-modal.hf.space/final_answer"
EXTENDED_TIMEOUT = int(os.getenv("FINAL_ANSWER_TIMEOUT", 150))
app = FastAPI(title="πŸ” Multimodal RAG Query Service")
# ───────────────────────────────────────────────
# JSONFileStore
# ───────────────────────────────────────────────
class JSONFileStore:
def __init__(self, store_path: str):
self.store_path = store_path
os.makedirs(self.store_path, exist_ok=True)
def mget(self, keys: list[str]) -> list[Document]:
"""Retrieve multiple documents by their keys."""
documents = []
for key in keys:
file_path = os.path.join(self.store_path, f"{key}.json")
if os.path.exists(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
doc_dict = json.load(f)
documents.append(
Document(page_content=doc_dict["page_content"], metadata=doc_dict["metadata"])
)
except Exception as e:
print(f"Error loading {key}: {e}")
documents.append(None)
else:
documents.append(None)
return documents
# ───────────────────────────────────────────────
# Initialize embeddings, vectorstore, docstore
# ───────────────────────────────────────────────
print("πŸ”„ Loading embedding model...")
try:
embedding_fn = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
print("βœ… Embedding model loaded")
except Exception as e:
print(f"❌ Error loading embeddings: {e}")
raise
try:
if os.path.exists(VECTOR_PATH):
vectorstore = FAISS.load_local(VECTOR_PATH, embedding_fn, allow_dangerous_deserialization=True)
print("βœ… Loaded FAISS vectorstore")
else:
raise FileNotFoundError("Vectorstore not found")
except Exception as e:
print(f"❌ Error loading vectorstore: {e}")
raise
try:
if not os.path.exists(DOCSTORE_PATH):
raise FileNotFoundError("Docstore not found")
store = JSONFileStore(DOCSTORE_PATH)
print("βœ… Loaded JSONFileStore")
except Exception as e:
print(f"❌ Error loading docstore: {e}")
raise
# ───────────────────────────────────────────────
# Response cleaning helper
# ───────────────────────────────────────────────
def clean_response_text(text: str) -> str:
"""Clean the model's response to remove hashtags, emojis, repetitions and weird tails."""
if not text:
return text
# Remove hashtags and URLs
text = re.sub(r"#\S+", "", text)
text = re.sub(r"http\S+", "", text)
# Remove non-ASCII characters (emojis, special symbols)
text = text.encode("ascii", "ignore").decode()
# Remove repeated words sequences (e.g. "word word word")
text = re.sub(r"\b(\w+)( \1\b)+", r"\1", text, flags=re.IGNORECASE)
# Collapse multiple newlines and spaces
text = re.sub(r"\n{2,}", "\n", text)
text = re.sub(r" {2,}", " ", text).strip()
# Remove trailing model apology lines or noisy tails
text = re.sub(r"I'm sorry.*", "", text, flags=re.IGNORECASE)
return text.strip()
# ───────────────────────────────────────────────
# Helpers for parsing, retrieval and final call
# ───────────────────────────────────────────────
def parse_docs(docs: list[Document]) -> dict:
"""
Split retrieved documents into images, texts, and tables.
Returns dict with lists: {"images": [...], "texts": [...], "tables": [...]}
"""
images, texts, tables = [], [], []
for doc in docs:
doc_type = doc.metadata.get("type", "text")
if doc_type == "image" and doc.metadata.get("is_base64", False):
# store base64 string
images.append(doc.page_content)
elif doc_type == "table":
tables.append(doc.page_content)
else:
texts.append(doc.page_content)
return {"images": images, "texts": texts, "tables": tables}
def retrieve_documents(query: str, k: int = 5) -> list[Document]:
"""
Retrieve documents:
1. Search vectorstore for similar summaries
2. Collect unique doc_ids from results (avoid duplicates)
3. Retrieve originals from docstore
"""
try:
similar_docs = vectorstore.similarity_search(query, k=k)
if not similar_docs:
print("⚠️ No similar documents found")
return []
doc_ids = []
for doc in similar_docs:
doc_id = doc.metadata.get("doc_id")
if doc_id and doc_id not in doc_ids:
doc_ids.append(doc_id)
if not doc_ids:
print("⚠️ No doc_ids found in metadata")
return []
print(f"πŸ”‘ Found {len(doc_ids)} unique doc_ids")
original_docs = store.mget(doc_ids)
original_docs = [d for d in original_docs if d is not None]
print(f"πŸ“„ Retrieved {len(original_docs)} unique documents")
return original_docs
except Exception as e:
print(f"❌ Error in retrieval: {e}")
return []
def build_context_and_images(docs_by_type: dict) -> tuple[str, list[str]]:
"""
Build context text from texts and tables, and collect image base64 strings.
Returns: (context_text, list_of_base64_images)
"""
context_parts = []
# Add text documents
for i, text_content in enumerate(docs_by_type.get("texts", []), 1):
context_parts.append(f"--- Text Document {i} ---\n{text_content}")
# Add table documents
for i, table_content in enumerate(docs_by_type.get("tables", []), 1):
context_parts.append(f"--- Table {i} ---\n{table_content}")
context_text = "\n\n".join(context_parts).strip()
images_b64 = docs_by_type.get("images", [])
return context_text, images_b64
def call_final_answer_endpoint(context: str, question: str, images_b64: list[str]) -> dict:
"""
Call the /final_answer endpoint with context, question, and images.
Uses extended timeout to allow for slow multimodal inference.
"""
try:
# Make prompt instruction clearer for concise output
data = {
"context": context,
"question": f"Answer concisely and without hashtags or emojis.\n\nQuestion: {question}"
}
files = []
if images_b64:
for i, img_b64 in enumerate(images_b64):
try:
img_bytes = base64.b64decode(img_b64)
files.append(("images", (f"image_{i}.jpg", img_bytes, "image/jpeg")))
except Exception as e:
print(f"⚠️ Error decoding image {i}: {e}")
if files:
response = requests.post(FINAL_ANSWER_URL, data=data, files=files, timeout=EXTENDED_TIMEOUT)
else:
response = requests.post(FINAL_ANSWER_URL, data=data, timeout=EXTENDED_TIMEOUT)
if response.status_code == 200:
return response.json()
else:
return {"error": f"API returned status {response.status_code}", "details": response.text}
except Exception as e:
return {"error": f"Error calling final_answer endpoint: {str(e)}"}
# ───────────────────────────────────────────────
# FastAPI endpoints
# ───────────────────────────────────────────────
@app.get("/")
def home():
return {
"message": "βœ… Multimodal RAG Query Service is running",
"timeout_seconds": EXTENDED_TIMEOUT,
"endpoints": {
"query": "/query?question=Your+Question",
"query_with_details": "/query_with_details?question=Your+Question",
"stats": "/stats",
},
}
@app.get("/stats")
def get_stats():
try:
vector_count = vectorstore.index.ntotal if hasattr(vectorstore, "index") else 0
docstore_files = len([f for f in os.listdir(DOCSTORE_PATH) if f.endswith(".json")]) if os.path.exists(DOCSTORE_PATH) else 0
return {"status": "ready", "vectorstore_count": vector_count, "docstore_count": docstore_files}
except Exception as e:
return {"status": "error", "error": str(e)}
@app.post("/query")
async def query_rag(question: str, k: int = 5):
"""
Query the Multimodal RAG system:
1. Search vectorstore for relevant summaries
2. Retrieve original documents (text + tables + images)
3. Parse into texts, tables, and images
4. Call final_answer endpoint with all content
5. Return cleaned answer
"""
try:
print(f"\nπŸ” Query: {question}")
docs = retrieve_documents(question, k=k)
if not docs:
return {"question": question, "answer": "No relevant documents found. Please ingest documents first.", "retrieved_docs": 0}
docs_by_type = parse_docs(docs)
print(f"πŸ“Š Parsed: {len(docs_by_type['texts'])} texts, {len(docs_by_type['tables'])} tables, {len(docs_by_type['images'])} images")
context_text, images_b64 = build_context_and_images(docs_by_type)
print("πŸš€ Calling final_answer endpoint...")
result = call_final_answer_endpoint(context_text, question, images_b64)
if "error" in result:
return {
"question": question,
"error": result["error"],
"details": result.get("details"),
"retrieved_docs": len(docs),
"context_preview": context_text[:300] if context_text else "No context"
}
cleaned_answer = clean_response_text(result.get("response", "No response generated"))
return {
"question": question,
"answer": cleaned_answer,
"retrieved_docs": len(docs),
"docs_info": {
"texts": len(docs_by_type["texts"]),
"tables": len(docs_by_type["tables"]),
"images": len(docs_by_type["images"]),
},
"context_preview": context_text[:300] if context_text else "No context",
}
except Exception as e:
import traceback
return {"question": question, "error": str(e), "traceback": traceback.format_exc()}
@app.post("/query_with_details")
async def query_with_details(question: str, k: int = 5):
"""Query with detailed document information"""
try:
print(f"\nπŸ” Detailed Query: {question}")
docs = retrieve_documents(question, k=k)
if not docs:
return {"question": question, "answer": "No relevant documents found.", "retrieved_docs": []}
docs_by_type = parse_docs(docs)
context_text, images_b64 = build_context_and_images(docs_by_type)
result = call_final_answer_endpoint(context_text, question, images_b64)
if "error" in result:
return {"question": question, "error": result["error"], "details": result.get("details")}
docs_info = []
for doc in docs:
doc_info = {
"doc_id": doc.metadata.get("doc_id"),
"type": doc.metadata.get("type"),
"source": doc.metadata.get("source"),
"summary": doc.metadata.get("summary", "")[:200],
}
doc_info["content"] = "[Base64 Image Data]" if doc.metadata.get("type") == "image" else doc.page_content[:300]
docs_info.append(doc_info)
cleaned_answer = clean_response_text(result.get("response", "No response generated"))
return {
"question": question,
"answer": cleaned_answer,
"retrieved_docs": docs_info,
"stats": {
"total_retrieved": len(docs),
"texts": len(docs_by_type["texts"]),
"tables": len(docs_by_type["tables"]),
"images": len(docs_by_type["images"]),
},
}
except Exception as e:
import traceback
return {"error": str(e), "traceback": traceback.format_exc()}