Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import requests | |
| import base64 | |
| import re | |
| from fastapi import FastAPI | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import SentenceTransformerEmbeddings | |
| from langchain_core.documents import Document | |
| # βββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Configuration | |
| # βββββββββββββββββββββββββββββββββββββββββββββββ | |
| VECTOR_PATH = "./vectorstore/faiss_index" | |
| DOCSTORE_PATH = "./docstore" | |
| FINAL_ANSWER_URL = "https://sameer-handsome173-multi-modal.hf.space/final_answer" | |
| EXTENDED_TIMEOUT = int(os.getenv("FINAL_ANSWER_TIMEOUT", 150)) | |
| app = FastAPI(title="π Multimodal RAG Query Service") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββ | |
| # JSONFileStore | |
| # βββββββββββββββββββββββββββββββββββββββββββββββ | |
| class JSONFileStore: | |
| def __init__(self, store_path: str): | |
| self.store_path = store_path | |
| os.makedirs(self.store_path, exist_ok=True) | |
| def mget(self, keys: list[str]) -> list[Document]: | |
| """Retrieve multiple documents by their keys.""" | |
| documents = [] | |
| for key in keys: | |
| file_path = os.path.join(self.store_path, f"{key}.json") | |
| if os.path.exists(file_path): | |
| try: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| doc_dict = json.load(f) | |
| documents.append( | |
| Document(page_content=doc_dict["page_content"], metadata=doc_dict["metadata"]) | |
| ) | |
| except Exception as e: | |
| print(f"Error loading {key}: {e}") | |
| documents.append(None) | |
| else: | |
| documents.append(None) | |
| return documents | |
| # βββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Initialize embeddings, vectorstore, docstore | |
| # βββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("π Loading embedding model...") | |
| try: | |
| embedding_fn = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
| print("β Embedding model loaded") | |
| except Exception as e: | |
| print(f"β Error loading embeddings: {e}") | |
| raise | |
| try: | |
| if os.path.exists(VECTOR_PATH): | |
| vectorstore = FAISS.load_local(VECTOR_PATH, embedding_fn, allow_dangerous_deserialization=True) | |
| print("β Loaded FAISS vectorstore") | |
| else: | |
| raise FileNotFoundError("Vectorstore not found") | |
| except Exception as e: | |
| print(f"β Error loading vectorstore: {e}") | |
| raise | |
| try: | |
| if not os.path.exists(DOCSTORE_PATH): | |
| raise FileNotFoundError("Docstore not found") | |
| store = JSONFileStore(DOCSTORE_PATH) | |
| print("β Loaded JSONFileStore") | |
| except Exception as e: | |
| print(f"β Error loading docstore: {e}") | |
| raise | |
| # βββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Response cleaning helper | |
| # βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def clean_response_text(text: str) -> str: | |
| """Clean the model's response to remove hashtags, emojis, repetitions and weird tails.""" | |
| if not text: | |
| return text | |
| # Remove hashtags and URLs | |
| text = re.sub(r"#\S+", "", text) | |
| text = re.sub(r"http\S+", "", text) | |
| # Remove non-ASCII characters (emojis, special symbols) | |
| text = text.encode("ascii", "ignore").decode() | |
| # Remove repeated words sequences (e.g. "word word word") | |
| text = re.sub(r"\b(\w+)( \1\b)+", r"\1", text, flags=re.IGNORECASE) | |
| # Collapse multiple newlines and spaces | |
| text = re.sub(r"\n{2,}", "\n", text) | |
| text = re.sub(r" {2,}", " ", text).strip() | |
| # Remove trailing model apology lines or noisy tails | |
| text = re.sub(r"I'm sorry.*", "", text, flags=re.IGNORECASE) | |
| return text.strip() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Helpers for parsing, retrieval and final call | |
| # βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parse_docs(docs: list[Document]) -> dict: | |
| """ | |
| Split retrieved documents into images, texts, and tables. | |
| Returns dict with lists: {"images": [...], "texts": [...], "tables": [...]} | |
| """ | |
| images, texts, tables = [], [], [] | |
| for doc in docs: | |
| doc_type = doc.metadata.get("type", "text") | |
| if doc_type == "image" and doc.metadata.get("is_base64", False): | |
| # store base64 string | |
| images.append(doc.page_content) | |
| elif doc_type == "table": | |
| tables.append(doc.page_content) | |
| else: | |
| texts.append(doc.page_content) | |
| return {"images": images, "texts": texts, "tables": tables} | |
| def retrieve_documents(query: str, k: int = 5) -> list[Document]: | |
| """ | |
| Retrieve documents: | |
| 1. Search vectorstore for similar summaries | |
| 2. Collect unique doc_ids from results (avoid duplicates) | |
| 3. Retrieve originals from docstore | |
| """ | |
| try: | |
| similar_docs = vectorstore.similarity_search(query, k=k) | |
| if not similar_docs: | |
| print("β οΈ No similar documents found") | |
| return [] | |
| doc_ids = [] | |
| for doc in similar_docs: | |
| doc_id = doc.metadata.get("doc_id") | |
| if doc_id and doc_id not in doc_ids: | |
| doc_ids.append(doc_id) | |
| if not doc_ids: | |
| print("β οΈ No doc_ids found in metadata") | |
| return [] | |
| print(f"π Found {len(doc_ids)} unique doc_ids") | |
| original_docs = store.mget(doc_ids) | |
| original_docs = [d for d in original_docs if d is not None] | |
| print(f"π Retrieved {len(original_docs)} unique documents") | |
| return original_docs | |
| except Exception as e: | |
| print(f"β Error in retrieval: {e}") | |
| return [] | |
| def build_context_and_images(docs_by_type: dict) -> tuple[str, list[str]]: | |
| """ | |
| Build context text from texts and tables, and collect image base64 strings. | |
| Returns: (context_text, list_of_base64_images) | |
| """ | |
| context_parts = [] | |
| # Add text documents | |
| for i, text_content in enumerate(docs_by_type.get("texts", []), 1): | |
| context_parts.append(f"--- Text Document {i} ---\n{text_content}") | |
| # Add table documents | |
| for i, table_content in enumerate(docs_by_type.get("tables", []), 1): | |
| context_parts.append(f"--- Table {i} ---\n{table_content}") | |
| context_text = "\n\n".join(context_parts).strip() | |
| images_b64 = docs_by_type.get("images", []) | |
| return context_text, images_b64 | |
| def call_final_answer_endpoint(context: str, question: str, images_b64: list[str]) -> dict: | |
| """ | |
| Call the /final_answer endpoint with context, question, and images. | |
| Uses extended timeout to allow for slow multimodal inference. | |
| """ | |
| try: | |
| # Make prompt instruction clearer for concise output | |
| data = { | |
| "context": context, | |
| "question": f"Answer concisely and without hashtags or emojis.\n\nQuestion: {question}" | |
| } | |
| files = [] | |
| if images_b64: | |
| for i, img_b64 in enumerate(images_b64): | |
| try: | |
| img_bytes = base64.b64decode(img_b64) | |
| files.append(("images", (f"image_{i}.jpg", img_bytes, "image/jpeg"))) | |
| except Exception as e: | |
| print(f"β οΈ Error decoding image {i}: {e}") | |
| if files: | |
| response = requests.post(FINAL_ANSWER_URL, data=data, files=files, timeout=EXTENDED_TIMEOUT) | |
| else: | |
| response = requests.post(FINAL_ANSWER_URL, data=data, timeout=EXTENDED_TIMEOUT) | |
| if response.status_code == 200: | |
| return response.json() | |
| else: | |
| return {"error": f"API returned status {response.status_code}", "details": response.text} | |
| except Exception as e: | |
| return {"error": f"Error calling final_answer endpoint: {str(e)}"} | |
| # βββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FastAPI endpoints | |
| # βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def home(): | |
| return { | |
| "message": "β Multimodal RAG Query Service is running", | |
| "timeout_seconds": EXTENDED_TIMEOUT, | |
| "endpoints": { | |
| "query": "/query?question=Your+Question", | |
| "query_with_details": "/query_with_details?question=Your+Question", | |
| "stats": "/stats", | |
| }, | |
| } | |
| def get_stats(): | |
| try: | |
| vector_count = vectorstore.index.ntotal if hasattr(vectorstore, "index") else 0 | |
| docstore_files = len([f for f in os.listdir(DOCSTORE_PATH) if f.endswith(".json")]) if os.path.exists(DOCSTORE_PATH) else 0 | |
| return {"status": "ready", "vectorstore_count": vector_count, "docstore_count": docstore_files} | |
| except Exception as e: | |
| return {"status": "error", "error": str(e)} | |
| async def query_rag(question: str, k: int = 5): | |
| """ | |
| Query the Multimodal RAG system: | |
| 1. Search vectorstore for relevant summaries | |
| 2. Retrieve original documents (text + tables + images) | |
| 3. Parse into texts, tables, and images | |
| 4. Call final_answer endpoint with all content | |
| 5. Return cleaned answer | |
| """ | |
| try: | |
| print(f"\nπ Query: {question}") | |
| docs = retrieve_documents(question, k=k) | |
| if not docs: | |
| return {"question": question, "answer": "No relevant documents found. Please ingest documents first.", "retrieved_docs": 0} | |
| docs_by_type = parse_docs(docs) | |
| print(f"π Parsed: {len(docs_by_type['texts'])} texts, {len(docs_by_type['tables'])} tables, {len(docs_by_type['images'])} images") | |
| context_text, images_b64 = build_context_and_images(docs_by_type) | |
| print("π Calling final_answer endpoint...") | |
| result = call_final_answer_endpoint(context_text, question, images_b64) | |
| if "error" in result: | |
| return { | |
| "question": question, | |
| "error": result["error"], | |
| "details": result.get("details"), | |
| "retrieved_docs": len(docs), | |
| "context_preview": context_text[:300] if context_text else "No context" | |
| } | |
| cleaned_answer = clean_response_text(result.get("response", "No response generated")) | |
| return { | |
| "question": question, | |
| "answer": cleaned_answer, | |
| "retrieved_docs": len(docs), | |
| "docs_info": { | |
| "texts": len(docs_by_type["texts"]), | |
| "tables": len(docs_by_type["tables"]), | |
| "images": len(docs_by_type["images"]), | |
| }, | |
| "context_preview": context_text[:300] if context_text else "No context", | |
| } | |
| except Exception as e: | |
| import traceback | |
| return {"question": question, "error": str(e), "traceback": traceback.format_exc()} | |
| async def query_with_details(question: str, k: int = 5): | |
| """Query with detailed document information""" | |
| try: | |
| print(f"\nπ Detailed Query: {question}") | |
| docs = retrieve_documents(question, k=k) | |
| if not docs: | |
| return {"question": question, "answer": "No relevant documents found.", "retrieved_docs": []} | |
| docs_by_type = parse_docs(docs) | |
| context_text, images_b64 = build_context_and_images(docs_by_type) | |
| result = call_final_answer_endpoint(context_text, question, images_b64) | |
| if "error" in result: | |
| return {"question": question, "error": result["error"], "details": result.get("details")} | |
| docs_info = [] | |
| for doc in docs: | |
| doc_info = { | |
| "doc_id": doc.metadata.get("doc_id"), | |
| "type": doc.metadata.get("type"), | |
| "source": doc.metadata.get("source"), | |
| "summary": doc.metadata.get("summary", "")[:200], | |
| } | |
| doc_info["content"] = "[Base64 Image Data]" if doc.metadata.get("type") == "image" else doc.page_content[:300] | |
| docs_info.append(doc_info) | |
| cleaned_answer = clean_response_text(result.get("response", "No response generated")) | |
| return { | |
| "question": question, | |
| "answer": cleaned_answer, | |
| "retrieved_docs": docs_info, | |
| "stats": { | |
| "total_retrieved": len(docs), | |
| "texts": len(docs_by_type["texts"]), | |
| "tables": len(docs_by_type["tables"]), | |
| "images": len(docs_by_type["images"]), | |
| }, | |
| } | |
| except Exception as e: | |
| import traceback | |
| return {"error": str(e), "traceback": traceback.format_exc()} | |