import os import json import requests import base64 import re from fastapi import FastAPI from langchain_community.vectorstores import FAISS from langchain_community.embeddings import SentenceTransformerEmbeddings from langchain_core.documents import Document # ─────────────────────────────────────────────── # Configuration # ─────────────────────────────────────────────── VECTOR_PATH = "./vectorstore/faiss_index" DOCSTORE_PATH = "./docstore" FINAL_ANSWER_URL = "https://sameer-handsome173-multi-modal.hf.space/final_answer" EXTENDED_TIMEOUT = int(os.getenv("FINAL_ANSWER_TIMEOUT", 150)) app = FastAPI(title="🔍 Multimodal RAG Query Service") # ─────────────────────────────────────────────── # JSONFileStore # ─────────────────────────────────────────────── class JSONFileStore: def __init__(self, store_path: str): self.store_path = store_path os.makedirs(self.store_path, exist_ok=True) def mget(self, keys: list[str]) -> list[Document]: """Retrieve multiple documents by their keys.""" documents = [] for key in keys: file_path = os.path.join(self.store_path, f"{key}.json") if os.path.exists(file_path): try: with open(file_path, "r", encoding="utf-8") as f: doc_dict = json.load(f) documents.append( Document(page_content=doc_dict["page_content"], metadata=doc_dict["metadata"]) ) except Exception as e: print(f"Error loading {key}: {e}") documents.append(None) else: documents.append(None) return documents # ─────────────────────────────────────────────── # Initialize embeddings, vectorstore, docstore # ─────────────────────────────────────────────── print("🔄 Loading embedding model...") try: embedding_fn = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") print("✅ Embedding model loaded") except Exception as e: print(f"❌ Error loading embeddings: {e}") raise try: if os.path.exists(VECTOR_PATH): vectorstore = FAISS.load_local(VECTOR_PATH, embedding_fn, allow_dangerous_deserialization=True) print("✅ Loaded FAISS vectorstore") else: raise FileNotFoundError("Vectorstore not found") except Exception as e: print(f"❌ Error loading vectorstore: {e}") raise try: if not os.path.exists(DOCSTORE_PATH): raise FileNotFoundError("Docstore not found") store = JSONFileStore(DOCSTORE_PATH) print("✅ Loaded JSONFileStore") except Exception as e: print(f"❌ Error loading docstore: {e}") raise # ─────────────────────────────────────────────── # Response cleaning helper # ─────────────────────────────────────────────── def clean_response_text(text: str) -> str: """Clean the model's response to remove hashtags, emojis, repetitions and weird tails.""" if not text: return text # Remove hashtags and URLs text = re.sub(r"#\S+", "", text) text = re.sub(r"http\S+", "", text) # Remove non-ASCII characters (emojis, special symbols) text = text.encode("ascii", "ignore").decode() # Remove repeated words sequences (e.g. "word word word") text = re.sub(r"\b(\w+)( \1\b)+", r"\1", text, flags=re.IGNORECASE) # Collapse multiple newlines and spaces text = re.sub(r"\n{2,}", "\n", text) text = re.sub(r" {2,}", " ", text).strip() # Remove trailing model apology lines or noisy tails text = re.sub(r"I'm sorry.*", "", text, flags=re.IGNORECASE) return text.strip() # ─────────────────────────────────────────────── # Helpers for parsing, retrieval and final call # ─────────────────────────────────────────────── def parse_docs(docs: list[Document]) -> dict: """ Split retrieved documents into images, texts, and tables. Returns dict with lists: {"images": [...], "texts": [...], "tables": [...]} """ images, texts, tables = [], [], [] for doc in docs: doc_type = doc.metadata.get("type", "text") if doc_type == "image" and doc.metadata.get("is_base64", False): # store base64 string images.append(doc.page_content) elif doc_type == "table": tables.append(doc.page_content) else: texts.append(doc.page_content) return {"images": images, "texts": texts, "tables": tables} def retrieve_documents(query: str, k: int = 5) -> list[Document]: """ Retrieve documents: 1. Search vectorstore for similar summaries 2. Collect unique doc_ids from results (avoid duplicates) 3. Retrieve originals from docstore """ try: similar_docs = vectorstore.similarity_search(query, k=k) if not similar_docs: print("⚠️ No similar documents found") return [] doc_ids = [] for doc in similar_docs: doc_id = doc.metadata.get("doc_id") if doc_id and doc_id not in doc_ids: doc_ids.append(doc_id) if not doc_ids: print("⚠️ No doc_ids found in metadata") return [] print(f"🔑 Found {len(doc_ids)} unique doc_ids") original_docs = store.mget(doc_ids) original_docs = [d for d in original_docs if d is not None] print(f"📄 Retrieved {len(original_docs)} unique documents") return original_docs except Exception as e: print(f"❌ Error in retrieval: {e}") return [] def build_context_and_images(docs_by_type: dict) -> tuple[str, list[str]]: """ Build context text from texts and tables, and collect image base64 strings. Returns: (context_text, list_of_base64_images) """ context_parts = [] # Add text documents for i, text_content in enumerate(docs_by_type.get("texts", []), 1): context_parts.append(f"--- Text Document {i} ---\n{text_content}") # Add table documents for i, table_content in enumerate(docs_by_type.get("tables", []), 1): context_parts.append(f"--- Table {i} ---\n{table_content}") context_text = "\n\n".join(context_parts).strip() images_b64 = docs_by_type.get("images", []) return context_text, images_b64 def call_final_answer_endpoint(context: str, question: str, images_b64: list[str]) -> dict: """ Call the /final_answer endpoint with context, question, and images. Uses extended timeout to allow for slow multimodal inference. """ try: # Make prompt instruction clearer for concise output data = { "context": context, "question": f"Answer concisely and without hashtags or emojis.\n\nQuestion: {question}" } files = [] if images_b64: for i, img_b64 in enumerate(images_b64): try: img_bytes = base64.b64decode(img_b64) files.append(("images", (f"image_{i}.jpg", img_bytes, "image/jpeg"))) except Exception as e: print(f"⚠️ Error decoding image {i}: {e}") if files: response = requests.post(FINAL_ANSWER_URL, data=data, files=files, timeout=EXTENDED_TIMEOUT) else: response = requests.post(FINAL_ANSWER_URL, data=data, timeout=EXTENDED_TIMEOUT) if response.status_code == 200: return response.json() else: return {"error": f"API returned status {response.status_code}", "details": response.text} except Exception as e: return {"error": f"Error calling final_answer endpoint: {str(e)}"} # ─────────────────────────────────────────────── # FastAPI endpoints # ─────────────────────────────────────────────── @app.get("/") def home(): return { "message": "✅ Multimodal RAG Query Service is running", "timeout_seconds": EXTENDED_TIMEOUT, "endpoints": { "query": "/query?question=Your+Question", "query_with_details": "/query_with_details?question=Your+Question", "stats": "/stats", }, } @app.get("/stats") def get_stats(): try: vector_count = vectorstore.index.ntotal if hasattr(vectorstore, "index") else 0 docstore_files = len([f for f in os.listdir(DOCSTORE_PATH) if f.endswith(".json")]) if os.path.exists(DOCSTORE_PATH) else 0 return {"status": "ready", "vectorstore_count": vector_count, "docstore_count": docstore_files} except Exception as e: return {"status": "error", "error": str(e)} @app.post("/query") async def query_rag(question: str, k: int = 5): """ Query the Multimodal RAG system: 1. Search vectorstore for relevant summaries 2. Retrieve original documents (text + tables + images) 3. Parse into texts, tables, and images 4. Call final_answer endpoint with all content 5. Return cleaned answer """ try: print(f"\n🔍 Query: {question}") docs = retrieve_documents(question, k=k) if not docs: return {"question": question, "answer": "No relevant documents found. Please ingest documents first.", "retrieved_docs": 0} docs_by_type = parse_docs(docs) print(f"📊 Parsed: {len(docs_by_type['texts'])} texts, {len(docs_by_type['tables'])} tables, {len(docs_by_type['images'])} images") context_text, images_b64 = build_context_and_images(docs_by_type) print("🚀 Calling final_answer endpoint...") result = call_final_answer_endpoint(context_text, question, images_b64) if "error" in result: return { "question": question, "error": result["error"], "details": result.get("details"), "retrieved_docs": len(docs), "context_preview": context_text[:300] if context_text else "No context" } cleaned_answer = clean_response_text(result.get("response", "No response generated")) return { "question": question, "answer": cleaned_answer, "retrieved_docs": len(docs), "docs_info": { "texts": len(docs_by_type["texts"]), "tables": len(docs_by_type["tables"]), "images": len(docs_by_type["images"]), }, "context_preview": context_text[:300] if context_text else "No context", } except Exception as e: import traceback return {"question": question, "error": str(e), "traceback": traceback.format_exc()} @app.post("/query_with_details") async def query_with_details(question: str, k: int = 5): """Query with detailed document information""" try: print(f"\n🔍 Detailed Query: {question}") docs = retrieve_documents(question, k=k) if not docs: return {"question": question, "answer": "No relevant documents found.", "retrieved_docs": []} docs_by_type = parse_docs(docs) context_text, images_b64 = build_context_and_images(docs_by_type) result = call_final_answer_endpoint(context_text, question, images_b64) if "error" in result: return {"question": question, "error": result["error"], "details": result.get("details")} docs_info = [] for doc in docs: doc_info = { "doc_id": doc.metadata.get("doc_id"), "type": doc.metadata.get("type"), "source": doc.metadata.get("source"), "summary": doc.metadata.get("summary", "")[:200], } doc_info["content"] = "[Base64 Image Data]" if doc.metadata.get("type") == "image" else doc.page_content[:300] docs_info.append(doc_info) cleaned_answer = clean_response_text(result.get("response", "No response generated")) return { "question": question, "answer": cleaned_answer, "retrieved_docs": docs_info, "stats": { "total_retrieved": len(docs), "texts": len(docs_by_type["texts"]), "tables": len(docs_by_type["tables"]), "images": len(docs_by_type["images"]), }, } except Exception as e: import traceback return {"error": str(e), "traceback": traceback.format_exc()}