Spaces:
Running
Running
| """ | |
| rag_engine.py — Multimodal RAG Engine with Multi-File Support, Reranking & Memory | |
| Supports: PDF, TXT, DOCX, CSV, XLSX, Images (JPG/PNG/WEBP) | |
| Features: Up to 5 simultaneous files, per-file removal, additive indexing | |
| Memory: sliding window of last 6 exchanges | |
| KEY CHANGES (v5 — Cross-Encoder Reranking): | |
| 1. Cross-encoder reranker (ms-marco-MiniLM-L-6-v2) scores every retrieved | |
| chunk for true semantic relevance to the query — not just embedding distance. | |
| 2. Over-fetches 12+ candidates from the vectorstore, then reranks to pick | |
| the top-k most relevant chunks for the LLM context. | |
| 3. Graceful fallback — if the reranker fails to load, uses original order. | |
| Previous features preserved: | |
| - Additive indexing, per-file removal, MAX_FILES=5 | |
| - Multi-file aware generation, cross-doc coverage | |
| - OCR, color analysis, BLIP raw bytes, VLM descriptions for images | |
| - Conversation memory (6-exchange sliding window) | |
| """ | |
| import os | |
| import re | |
| import io | |
| import json | |
| import time | |
| import base64 | |
| import hashlib | |
| import tempfile | |
| import requests | |
| import logging | |
| from pathlib import Path | |
| from typing import Tuple, List, Optional, Dict | |
| from collections import Counter | |
| from chromadb.config import Settings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.schema import Document | |
| import monitor | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ── Constants ──────────────────────────────────────────────────────────────── | |
| EMBED_MODEL = "all-MiniLM-L6-v2" | |
| RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" # ~80MB, CPU-friendly | |
| CHUNK_SIZE = 600 | |
| CHUNK_OVERLAP = 100 | |
| TOP_K = 4 # final chunks sent to LLM after reranking | |
| RERANK_FETCH_K = 12 # over-fetch this many candidates for reranking | |
| COLLECTION_NAME = "docmind_multimodal" | |
| HF_API_URL = "https://router.huggingface.co/v1/chat/completions" | |
| MEMORY_WINDOW = 6 # number of past Q&A pairs to keep | |
| MAX_FILES = 5 # maximum simultaneous documents | |
| SUPPORTED_EXTENSIONS = { | |
| ".pdf", ".txt", | |
| ".docx", ".doc", | |
| ".csv", ".xlsx", ".xls", | |
| ".jpg", ".jpeg", ".png", ".webp", | |
| } | |
| CANDIDATE_MODELS = [ | |
| "meta-llama/Llama-3.1-8B-Instruct:cerebras", | |
| "meta-llama/Llama-3.3-70B-Instruct:cerebras", | |
| "mistralai/Mistral-7B-Instruct-v0.3:fireworks-ai", | |
| "HuggingFaceTB/SmolLM3-3B:hf-inference", | |
| ] | |
| # Vision-language models for detailed image descriptions (order matters) | |
| VLM_MODELS = [ | |
| "Qwen/Qwen2.5-VL-7B-Instruct", | |
| "meta-llama/Llama-3.2-11B-Vision-Instruct", | |
| ] | |
| def get_suffix(name: str) -> str: | |
| return Path(name).suffix.lower() or ".txt" | |
| def _classify_color(r: int, g: int, b: int) -> str: | |
| """Classify an RGB pixel into a human-readable color name.""" | |
| if r > 220 and g > 220 and b > 220: | |
| return "white" | |
| if r < 35 and g < 35 and b < 35: | |
| return "black" | |
| if max(r, g, b) - min(r, g, b) < 30: | |
| if r > 170: | |
| return "light gray" | |
| if r > 100: | |
| return "gray" | |
| return "dark gray" | |
| if r > 180 and g > 180 and b < 100: | |
| return "yellow" | |
| if r > 180 and g > 100 and g < 180 and b < 80: | |
| return "orange" | |
| if r > 150 and g < 80 and b < 80: | |
| return "red" | |
| if r > 150 and g < 100 and b > 100: | |
| return "pink" if r > 200 else "purple" | |
| if g > 150 and r < 100 and b < 100: | |
| return "green" | |
| if g > 120 and r < 80 and b < 80: | |
| return "dark green" | |
| if b > 150 and r < 100 and g < 100: | |
| return "blue" | |
| if b > 150 and g > 100 and r < 100: | |
| return "cyan" if g > 150 else "teal" | |
| if r > 100 and g > 100 and b < 80: | |
| return "olive" | |
| if r > 150 and g < 80 and b > 150: | |
| return "magenta" | |
| if g >= r and g >= b: | |
| return "green" | |
| if r >= g and r >= b: | |
| return "red" | |
| return "blue" | |
| class RAGEngine: | |
| def __init__(self): | |
| self._embeddings: Optional[HuggingFaceEmbeddings] = None | |
| self._reranker = None # lazy-loaded cross-encoder | |
| self._vectorstore: Optional[Chroma] = None | |
| self._splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=CHUNK_OVERLAP, | |
| separators=["\n\n", "\n", ". ", " ", ""], | |
| ) | |
| self._memory: List[dict] = [] | |
| self._documents: Dict[str, dict] = {} # {filename: {chunk_count, chunk_ids, type}} | |
| monitor.log_startup() | |
| def embeddings(self): | |
| if self._embeddings is None: | |
| logger.info("Loading embedding model...") | |
| self._embeddings = HuggingFaceEmbeddings( | |
| model_name=EMBED_MODEL, | |
| model_kwargs={"device": "cpu"}, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| return self._embeddings | |
| def reranker(self): | |
| """Lazy-load the cross-encoder reranker (~80MB, CPU-friendly).""" | |
| if self._reranker is None: | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| logger.info(f"Loading reranker model: {RERANK_MODEL}...") | |
| self._reranker = CrossEncoder(RERANK_MODEL, max_length=512) | |
| logger.info("Reranker loaded successfully.") | |
| except Exception as e: | |
| logger.warning(f"Failed to load reranker: {e}. Will skip reranking.") | |
| self._reranker = False # sentinel: don't retry | |
| return self._reranker if self._reranker is not False else None | |
| def _rerank_documents(self, question: str, docs: List[Document], top_k: int) -> List[Document]: | |
| """Score and reorder documents using the cross-encoder reranker.""" | |
| if not docs: | |
| return docs | |
| ranker = self.reranker | |
| if ranker is None: | |
| # Reranker unavailable — fall back to original order | |
| logger.info("Reranker not available, using original retrieval order.") | |
| return docs[:top_k] | |
| # Build query-document pairs for the cross-encoder | |
| pairs = [(question, doc.page_content) for doc in docs] | |
| try: | |
| scores = ranker.predict(pairs) | |
| # Pair each doc with its rerank score | |
| scored = list(zip(docs, scores)) | |
| scored.sort(key=lambda x: x[1], reverse=True) | |
| reranked = [doc for doc, score in scored[:top_k]] | |
| # Log the reranking effect | |
| original_sources = [d.metadata.get("source", "?")[:30] for d in docs[:top_k]] | |
| reranked_sources = [d.metadata.get("source", "?")[:30] for d in reranked] | |
| top_scores = [f"{s:.3f}" for _, s in scored[:top_k]] | |
| logger.info( | |
| f"Reranked {len(docs)} candidates → top {top_k}. " | |
| f"Scores: {top_scores}. " | |
| f"Before: {original_sources}, After: {reranked_sources}" | |
| ) | |
| return reranked | |
| except Exception as e: | |
| logger.warning(f"Reranking failed: {e}. Using original order.") | |
| return docs[:top_k] | |
| # ── Memory ─────────────────────────────────────────────────────────────── | |
| def clear_memory(self): | |
| self._memory = [] | |
| def add_to_memory(self, question: str, answer: str): | |
| self._memory.append({"role": "user", "content": question}) | |
| self._memory.append({"role": "assistant", "content": answer}) | |
| max_msgs = MEMORY_WINDOW * 2 | |
| if len(self._memory) > max_msgs: | |
| self._memory = self._memory[-max_msgs:] | |
| def get_memory_messages(self) -> List[dict]: | |
| return self._memory.copy() | |
| def get_memory_count(self) -> int: | |
| return len(self._memory) // 2 | |
| # ── Document Management ────────────────────────────────────────────────── | |
| def get_documents(self) -> List[dict]: | |
| """Return list of all loaded documents with their info.""" | |
| return [ | |
| { | |
| "name": name, | |
| "type": info["type"], | |
| "chunk_count": info["chunk_count"], | |
| } | |
| for name, info in self._documents.items() | |
| ] | |
| def get_total_chunks(self) -> int: | |
| """Total chunks across all loaded files.""" | |
| return sum(info["chunk_count"] for info in self._documents.values()) | |
| def get_file_count(self) -> int: | |
| return len(self._documents) | |
| def remove_file(self, filename: str) -> bool: | |
| """Remove a specific file's chunks from the vectorstore.""" | |
| if filename not in self._documents: | |
| logger.warning(f"Cannot remove '{filename}' — not found in loaded documents.") | |
| return False | |
| chunk_ids = self._documents[filename]["chunk_ids"] | |
| if self._vectorstore and chunk_ids: | |
| try: | |
| self._vectorstore._collection.delete(ids=chunk_ids) | |
| logger.info(f"Removed {len(chunk_ids)} chunks for '{filename}'") | |
| except Exception as e: | |
| logger.warning(f"Failed to delete chunks for '{filename}': {e}") | |
| del self._documents[filename] | |
| # If no documents left, clean up the vectorstore entirely | |
| if not self._documents: | |
| self._vectorstore = None | |
| logger.info("All documents removed — vectorstore cleared.") | |
| return True | |
| def reset(self): | |
| """Reset everything — all documents, vectorstore, and memory.""" | |
| self._documents = {} | |
| if self._vectorstore: | |
| try: | |
| self._vectorstore._client.reset() | |
| except Exception: | |
| pass | |
| self._vectorstore = None | |
| self._memory = [] | |
| logger.info("Full reset: all documents, vectorstore, and memory cleared.") | |
| # ── Ingestion ──────────────────────────────────────────────────────────── | |
| def ingest_file(self, uploaded_file) -> int: | |
| """Accept FastAPI UploadFile or Streamlit UploadedFile. Additive indexing.""" | |
| t0 = time.time() | |
| filename = getattr(uploaded_file, "name", None) or getattr(uploaded_file, "filename", "file") | |
| suffix = get_suffix(filename) | |
| error = "" | |
| chunks = 0 | |
| if suffix not in SUPPORTED_EXTENSIONS: | |
| raise ValueError( | |
| f"Unsupported: {suffix}. Supported: {', '.join(sorted(SUPPORTED_EXTENSIONS))}" | |
| ) | |
| # Enforce file limit (replacement of same name doesn't count as new) | |
| if filename not in self._documents and len(self._documents) >= MAX_FILES: | |
| raise ValueError( | |
| f"Maximum {MAX_FILES} files reached. Remove a file before uploading more." | |
| ) | |
| # If same filename exists, remove old version first (replacement) | |
| if filename in self._documents: | |
| logger.info(f"Replacing existing file: {filename}") | |
| self.remove_file(filename) | |
| try: | |
| if hasattr(uploaded_file, "read"): | |
| data = uploaded_file.read() | |
| if hasattr(uploaded_file, "seek"): | |
| uploaded_file.seek(0) | |
| else: | |
| data = uploaded_file.file.read() | |
| docs = self._route(data, filename, suffix) | |
| chunks = self._index(docs, filename) | |
| except Exception as e: | |
| error = str(e) | |
| logger.error(f"Ingestion error: {e}") | |
| raise | |
| finally: | |
| monitor.log_ingestion(filename, chunks, (time.time() - t0) * 1000, error) | |
| return chunks | |
| def ingest_path(self, path: str, name: str = "") -> int: | |
| """Ingest a file from a local path. Also additive.""" | |
| filename = name or Path(path).name | |
| suffix = get_suffix(filename) | |
| if filename not in self._documents and len(self._documents) >= MAX_FILES: | |
| raise ValueError( | |
| f"Maximum {MAX_FILES} files reached. Remove a file before uploading more." | |
| ) | |
| if filename in self._documents: | |
| self.remove_file(filename) | |
| with open(path, "rb") as f: | |
| data = f.read() | |
| docs = self._route(data, filename, suffix) | |
| chunks = self._index(docs, filename) | |
| return chunks | |
| def _route(self, data: bytes, filename: str, suffix: str) -> List[Document]: | |
| if suffix == ".pdf": | |
| return self._load_pdf(data, filename) | |
| elif suffix == ".txt": | |
| return self._load_text(data, filename) | |
| elif suffix in {".docx", ".doc"}: | |
| return self._load_docx(data, filename) | |
| elif suffix == ".csv": | |
| return self._load_csv(data, filename) | |
| elif suffix in {".xlsx", ".xls"}: | |
| return self._load_excel(data, filename) | |
| elif suffix in {".jpg", ".jpeg", ".png", ".webp"}: | |
| return self._load_image(data, filename) | |
| raise ValueError(f"No loader for {suffix}") | |
| # ── Loaders ────────────────────────────────────────────────────────────── | |
| def _load_pdf(self, data: bytes, filename: str) -> List[Document]: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: | |
| tmp.write(data) | |
| tmp_path = tmp.name | |
| try: | |
| docs = PyPDFLoader(tmp_path).load() | |
| for doc in docs: | |
| doc.metadata.update({"source": filename, "type": "pdf"}) | |
| return docs | |
| finally: | |
| os.unlink(tmp_path) | |
| def _load_text(self, data: bytes, filename: str) -> List[Document]: | |
| return [Document( | |
| page_content=data.decode("utf-8", errors="replace"), | |
| metadata={"source": filename, "type": "text"} | |
| )] | |
| def _load_docx(self, data: bytes, filename: str) -> List[Document]: | |
| text = "" | |
| try: | |
| import docx2txt | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as tmp: | |
| tmp.write(data) | |
| tmp_path = tmp.name | |
| try: | |
| text = docx2txt.process(tmp_path) | |
| finally: | |
| os.unlink(tmp_path) | |
| except ImportError: | |
| logger.warning("docx2txt not installed — falling back to raw text extraction") | |
| text = data.decode("utf-8", errors="replace") | |
| except Exception as e: | |
| logger.warning(f"docx2txt failed ({e}) — falling back to raw text extraction") | |
| text = data.decode("utf-8", errors="replace") | |
| if not text or not text.strip(): | |
| text = f"[Document: {filename}] — Could not extract text content." | |
| return [Document(page_content=text, metadata={"source": filename, "type": "docx"})] | |
| def _load_csv(self, data: bytes, filename: str) -> List[Document]: | |
| import pandas as pd | |
| df = pd.read_csv(io.BytesIO(data)) | |
| docs = [] | |
| summary = ( | |
| f"File: {filename}\n" | |
| f"Shape: {df.shape[0]} rows × {df.shape[1]} columns\n" | |
| f"Columns: {', '.join(df.columns.tolist())}\n\n" | |
| f"First 10 rows:\n{df.head(10).to_string(index=False)}" | |
| ) | |
| docs.append(Document(page_content=summary, metadata={"source": filename, "type": "csv_summary"})) | |
| try: | |
| stats = "Statistical summary:\n" + df.describe(include="all").to_string() | |
| docs.append(Document(page_content=stats, metadata={"source": filename, "type": "csv_stats"})) | |
| except Exception as e: | |
| logger.warning(f"CSV stats failed: {e}") | |
| try: | |
| for i in range(0, min(len(df), 500), 50): | |
| chunk = f"Rows {i}–{i+50}:\n{df.iloc[i:i+50].to_string(index=False)}" | |
| docs.append(Document(page_content=chunk, metadata={"source": filename, "type": "csv_rows"})) | |
| except Exception as e: | |
| logger.warning(f"CSV row chunking failed: {e}") | |
| return docs | |
| def _load_excel(self, data: bytes, filename: str) -> List[Document]: | |
| import pandas as pd | |
| xl = pd.ExcelFile(io.BytesIO(data)) | |
| docs = [] | |
| for sheet in xl.sheet_names: | |
| try: | |
| df = xl.parse(sheet) | |
| text = ( | |
| f"Sheet: {sheet} | {df.shape[0]} rows × {df.shape[1]} cols\n" | |
| f"Columns: {', '.join(str(c) for c in df.columns)}\n\n" | |
| f"{df.head(10).to_string(index=False)}" | |
| ) | |
| docs.append(Document(page_content=text, metadata={ | |
| "source": filename, "type": "excel", "sheet": sheet | |
| })) | |
| except Exception as e: | |
| logger.warning(f"Excel sheet '{sheet}' failed: {e}") | |
| return docs | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # IMAGE LOADING — v3: OCR + Color Analysis + BLIP + VLM | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| def _load_image(self, data: bytes, filename: str) -> List[Document]: | |
| logger.info(f"Processing image: {filename}") | |
| ocr_text = self._ocr_image(data, filename) | |
| color_info = self._analyze_colors(data, filename) | |
| blip_caption = self._caption_image_blip(data, filename) | |
| vlm_description = self._describe_image_with_vlm(data, filename, blip_caption, ocr_text) | |
| sections = [f"Image file: {filename}", ""] | |
| if ocr_text: | |
| sections.append("=== TEXT FOUND IN IMAGE (OCR) ===") | |
| sections.append(ocr_text) | |
| sections.append("") | |
| if color_info: | |
| sections.append("=== COLOR ANALYSIS ===") | |
| sections.append(color_info) | |
| sections.append("") | |
| sections.append("=== SHORT CAPTION ===") | |
| sections.append(blip_caption) | |
| sections.append("") | |
| sections.append("=== DETAILED VISUAL DESCRIPTION ===") | |
| sections.append(vlm_description) | |
| sections.append("") | |
| summary_parts = [f"This image ({filename})"] | |
| if ocr_text: | |
| summary_parts.append(f'contains the text: "{ocr_text}"') | |
| if color_info: | |
| summary_parts.append(f"has {color_info.lower()}") | |
| summary_parts.append(f"and shows: {blip_caption}") | |
| sections.append("=== SUMMARY ===") | |
| sections.append(". ".join(summary_parts) + ".") | |
| sections.append(f"Detailed: {vlm_description}") | |
| text = "\n".join(sections) | |
| logger.info(f"Image document length: {len(text)} chars " | |
| f"(OCR: {len(ocr_text)} chars, colors: {bool(color_info)}, " | |
| f"BLIP: {len(blip_caption)} chars, VLM: {len(vlm_description)} chars)") | |
| return [Document( | |
| page_content=text, | |
| metadata={ | |
| "source": filename, | |
| "type": "image", | |
| "ocr_text": ocr_text[:500] if ocr_text else "", | |
| "caption": blip_caption, | |
| "colors": color_info, | |
| } | |
| )] | |
| # ── OCR ────────────────────────────────────────────────────────────────── | |
| def _ocr_image(self, data: bytes, filename: str) -> str: | |
| try: | |
| import pytesseract | |
| from PIL import Image | |
| img = Image.open(io.BytesIO(data)) | |
| if img.mode not in ("RGB", "L"): | |
| img = img.convert("RGB") | |
| text = pytesseract.image_to_string(img).strip() | |
| if not text or len(text) < 2: | |
| gray = img.convert("L") | |
| w, h = gray.size | |
| if w < 1000 or h < 1000: | |
| scale = max(1000 / w, 1000 / h, 1) | |
| gray = gray.resize((int(w * scale), int(h * scale)), Image.LANCZOS) | |
| text = pytesseract.image_to_string(gray).strip() | |
| if text: | |
| text = re.sub(r'\n{3,}', '\n\n', text).strip() | |
| logger.info(f"OCR extracted text from {filename}: '{text[:100]}...'") | |
| return text | |
| else: | |
| logger.info(f"OCR found no text in {filename}") | |
| return "" | |
| except ImportError: | |
| logger.warning("pytesseract not installed — skipping OCR.") | |
| return self._ocr_image_api(data, filename) | |
| except Exception as e: | |
| logger.warning(f"OCR failed for {filename}: {e}") | |
| return self._ocr_image_api(data, filename) | |
| def _ocr_image_api(self, data: bytes, filename: str) -> str: | |
| hf_token = os.environ.get("HF_TOKEN", "") | |
| if not hf_token: | |
| return "" | |
| ocr_models = [ | |
| "microsoft/trocr-large-printed", | |
| "microsoft/trocr-base-printed", | |
| ] | |
| for model_id in ocr_models: | |
| try: | |
| resp = requests.post( | |
| f"https://api-inference.huggingface.co/models/{model_id}", | |
| headers={"Authorization": f"Bearer {hf_token}"}, | |
| data=data, | |
| timeout=30, | |
| ) | |
| if resp.status_code == 200: | |
| result = resp.json() | |
| if isinstance(result, list) and result: | |
| text = result[0].get("generated_text", "").strip() | |
| if text: | |
| return text | |
| except Exception as e: | |
| logger.warning(f"OCR API failed ({model_id}): {e}") | |
| return "" | |
| # ── Color Analysis ─────────────────────────────────────────────────────── | |
| def _analyze_colors(self, data: bytes, filename: str) -> str: | |
| try: | |
| from PIL import Image | |
| img = Image.open(io.BytesIO(data)).convert("RGB") | |
| img_small = img.resize((80, 80), Image.LANCZOS) | |
| pixels = list(img_small.getdata()) | |
| color_names = [_classify_color(r, g, b) for r, g, b in pixels] | |
| counter = Counter(color_names) | |
| total = len(pixels) | |
| dominant = [ | |
| (name, count / total * 100) | |
| for name, count in counter.most_common(5) | |
| if count / total * 100 >= 3 | |
| ] | |
| if not dominant: | |
| return "" | |
| bg_color = dominant[0][0] | |
| result = "dominant colors: " + ", ".join( | |
| f"{name} ({pct:.0f}%)" for name, pct in dominant | |
| ) | |
| result += f". The background appears to be {bg_color}." | |
| logger.info(f"Color analysis for {filename}: {result}") | |
| return result | |
| except Exception as e: | |
| logger.warning(f"Color analysis failed for {filename}: {e}") | |
| return "" | |
| # ── BLIP Caption ───────────────────────────────────────────────────────── | |
| def _caption_image_blip(self, data: bytes, filename: str) -> str: | |
| hf_token = os.environ.get("HF_TOKEN", "") | |
| if not hf_token: | |
| return f"[Image: {filename}] — Add HF_TOKEN to enable captioning." | |
| caption_models = [ | |
| "Salesforce/blip-image-captioning-large", | |
| "Salesforce/blip-image-captioning-base", | |
| "nlpconnect/vit-gpt2-image-captioning", | |
| ] | |
| for model_id in caption_models: | |
| try: | |
| resp = requests.post( | |
| f"https://api-inference.huggingface.co/models/{model_id}", | |
| headers={"Authorization": f"Bearer {hf_token}"}, | |
| data=data, # raw bytes, NOT json | |
| timeout=30, | |
| ) | |
| if resp.status_code == 200: | |
| result = resp.json() | |
| if isinstance(result, list) and result: | |
| caption = result[0].get("generated_text", "") | |
| if caption: | |
| logger.info(f"BLIP caption ({model_id}): {caption[:80]}") | |
| return caption | |
| elif resp.status_code == 503: | |
| logger.info(f"{model_id} is loading, waiting 10s...") | |
| time.sleep(10) | |
| resp2 = requests.post( | |
| f"https://api-inference.huggingface.co/models/{model_id}", | |
| headers={"Authorization": f"Bearer {hf_token}"}, | |
| data=data, | |
| timeout=45, | |
| ) | |
| if resp2.status_code == 200: | |
| result = resp2.json() | |
| if isinstance(result, list) and result: | |
| caption = result[0].get("generated_text", "") | |
| if caption: | |
| return caption | |
| else: | |
| logger.warning(f"BLIP {model_id}: {resp.status_code}: {resp.text[:100]}") | |
| except Exception as e: | |
| logger.warning(f"BLIP caption failed ({model_id}): {e}") | |
| continue | |
| return f"An image named {filename} was uploaded." | |
| # ── VLM Detailed Description ───────────────────────────────────────────── | |
| def _describe_image_with_vlm(self, data: bytes, filename: str, | |
| short_caption: str, ocr_text: str) -> str: | |
| hf_token = os.environ.get("HF_TOKEN", "") | |
| if not hf_token: | |
| return short_caption | |
| mime = "image/jpeg" | |
| if data[:8] == b'\x89PNG\r\n\x1a\n': | |
| mime = "image/png" | |
| elif data[:4] == b'RIFF' and data[8:12] == b'WEBP': | |
| mime = "image/webp" | |
| b64_image = base64.b64encode(data).decode("utf-8") | |
| image_url = f"data:{mime};base64,{b64_image}" | |
| headers = { | |
| "Authorization": f"Bearer {hf_token}", | |
| "Content-Type": "application/json", | |
| } | |
| ocr_hint = "" | |
| if ocr_text: | |
| ocr_hint = ( | |
| f"\n\nNote: An OCR system already detected this text in the image: " | |
| f'"{ocr_text}". Please confirm or correct this text reading.' | |
| ) | |
| prompt_text = ( | |
| "Analyze this image thoroughly and provide a detailed description. " | |
| "You MUST address ALL of the following:\n\n" | |
| "1. TEXT: Read and transcribe ALL text visible in the image, " | |
| "character by character, word by word. Include any titles, labels, " | |
| "captions, watermarks, or writing of any kind. If there is text, " | |
| "quote it exactly.\n\n" | |
| "2. COLORS: What are the exact colors visible? What is the " | |
| "background color? What color is the text (if any)? List all " | |
| "significant colors.\n\n" | |
| "3. OBJECTS & LAYOUT: What objects, shapes, people, or elements " | |
| "are in the image? Where are they positioned?\n\n" | |
| "4. CONTEXT: What type of image is this (photo, screenshot, " | |
| "diagram, logo, meme, sign, document, etc.)?\n\n" | |
| "Be specific and factual. Do not guess or make assumptions about " | |
| "things you cannot see." | |
| f"{ocr_hint}" | |
| ) | |
| for model_id in VLM_MODELS: | |
| try: | |
| logger.info(f"Trying VLM description with {model_id}...") | |
| payload = { | |
| "model": model_id, | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": image_url}, | |
| }, | |
| { | |
| "type": "text", | |
| "text": prompt_text, | |
| }, | |
| ], | |
| } | |
| ], | |
| "max_tokens": 700, | |
| "temperature": 0.1, | |
| "stream": False, | |
| } | |
| resp = requests.post( | |
| HF_API_URL, | |
| headers=headers, | |
| data=json.dumps(payload), | |
| timeout=60, | |
| ) | |
| if resp.status_code == 200: | |
| raw = resp.json()["choices"][0]["message"]["content"].strip() | |
| description = _strip_thinking(raw) | |
| if description and len(description) > 20: | |
| logger.info(f"VLM description ({model_id}): {description[:100]}...") | |
| return description | |
| else: | |
| logger.warning(f"VLM {model_id}: {resp.status_code}: {resp.text[:150]}") | |
| except Exception as e: | |
| logger.warning(f"VLM description failed ({model_id}): {e}") | |
| continue | |
| return self._expand_caption_with_llm(short_caption, ocr_text, filename) | |
| def _expand_caption_with_llm(self, caption: str, ocr_text: str, filename: str) -> str: | |
| hf_token = os.environ.get("HF_TOKEN", "") | |
| if not hf_token: | |
| parts = [caption] | |
| if ocr_text: | |
| parts.append(f'Text found in image: "{ocr_text}"') | |
| return " ".join(parts) | |
| headers = { | |
| "Authorization": f"Bearer {hf_token}", | |
| "Content-Type": "application/json", | |
| } | |
| ocr_section = "" | |
| if ocr_text: | |
| ocr_section = f'\nOCR text extracted from the image: "{ocr_text}"' | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are an image description assistant. You are given information " | |
| "extracted from an image (a short AI caption and OCR text). " | |
| "Combine this information into a clear, factual description. " | |
| "If OCR text was found, make sure to include it prominently. " | |
| "Do NOT invent details that aren't supported by the provided info." | |
| ), | |
| }, | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"Image file: '{filename}'\n" | |
| f"AI caption: \"{caption}\"\n" | |
| f"{ocr_section}\n\n" | |
| f"Please provide a consolidated description of this image." | |
| ), | |
| }, | |
| ] | |
| for model_id in CANDIDATE_MODELS: | |
| try: | |
| resp = requests.post( | |
| HF_API_URL, | |
| headers=headers, | |
| data=json.dumps({ | |
| "model": model_id, | |
| "messages": messages, | |
| "max_tokens": 400, | |
| "temperature": 0.2, | |
| "stream": False, | |
| }), | |
| timeout=45, | |
| ) | |
| if resp.status_code == 200: | |
| raw = resp.json()["choices"][0]["message"]["content"].strip() | |
| expanded = _strip_thinking(raw) | |
| if expanded and len(expanded) > 30: | |
| return expanded | |
| except Exception as e: | |
| logger.warning(f"Caption expansion failed ({model_id}): {e}") | |
| continue | |
| parts = [caption] | |
| if ocr_text: | |
| parts.append(f'Text visible in image: "{ocr_text}"') | |
| return " ".join(parts) | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # INDEXING — ADDITIVE (does NOT destroy existing data) | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| def _index(self, docs: List[Document], filename: str) -> int: | |
| chunks = self._splitter.split_documents(docs) | |
| if not chunks: | |
| logger.warning(f"No chunks produced from {filename}") | |
| return 0 | |
| # Generate unique, stable chunk IDs for this file | |
| safe_name = re.sub(r'[^a-zA-Z0-9_.-]', '_', filename) | |
| name_hash = hashlib.md5(filename.encode()).hexdigest()[:8] | |
| chunk_ids = [f"{safe_name}_{name_hash}__chunk__{i}" for i in range(len(chunks))] | |
| # Create vectorstore if this is the first file | |
| if self._vectorstore is None: | |
| self._vectorstore = Chroma( | |
| collection_name=COLLECTION_NAME, | |
| embedding_function=self.embeddings, | |
| client_settings=Settings(anonymized_telemetry=False), | |
| ) | |
| # Add chunks to the existing vectorstore (additive!) | |
| texts = [c.page_content for c in chunks] | |
| metadatas = [c.metadata for c in chunks] | |
| self._vectorstore.add_texts(texts=texts, metadatas=metadatas, ids=chunk_ids) | |
| # Track this file | |
| self._documents[filename] = { | |
| "chunk_count": len(chunks), | |
| "chunk_ids": chunk_ids, | |
| "type": get_suffix(filename), | |
| } | |
| logger.info( | |
| f"Indexed {len(chunks)} chunks from '{filename}' " | |
| f"(total files: {len(self._documents)}, total chunks: {self.get_total_chunks()})" | |
| ) | |
| return len(chunks) | |
| # ── Query ──────────────────────────────────────────────────────────────── | |
| def query(self, question: str) -> Tuple[str, List[str]]: | |
| if not self._documents or self._vectorstore is None: | |
| return "Please upload a document first.", [] | |
| t0 = time.time() | |
| error = answer = model_used = "" | |
| sources = [] | |
| try: | |
| # ── Step 1: Over-fetch candidates ──────────────────────────────── | |
| # Retrieve more candidates than needed so the reranker can pick | |
| # the truly relevant ones. Scale with number of loaded files. | |
| num_files = len(self._documents) | |
| fetch_k = max(RERANK_FETCH_K, RERANK_FETCH_K + (num_files - 1) * 2) | |
| initial_k = fetch_k # MMR will return this many diverse candidates | |
| retriever = self._vectorstore.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={"k": initial_k, "fetch_k": fetch_k * 2}, | |
| ) | |
| candidate_docs = retriever.invoke(question) | |
| # ── Step 2: Rerank with cross-encoder ──────────────────────────── | |
| # The cross-encoder scores each (query, chunk) pair for true | |
| # semantic relevance — much more accurate than embedding distance. | |
| final_k = min(TOP_K + num_files - 1, 6) | |
| docs = self._rerank_documents(question, candidate_docs, top_k=final_k) | |
| context = "\n\n---\n\n".join( | |
| f"[Chunk {i+1} | source: {d.metadata.get('source', '?')} | type: {d.metadata.get('type','text')}]\n{d.page_content}" | |
| for i, d in enumerate(docs) | |
| ) | |
| sources = list({d.metadata.get("source", "Document") for d in docs}) | |
| answer, model_used = self._generate(question, context) | |
| self.add_to_memory(question, answer) | |
| except Exception as e: | |
| error = str(e) | |
| answer = f"Error: {error}" | |
| logger.error(f"Query error: {e}") | |
| finally: | |
| monitor.log_query(question, answer, sources, (time.time() - t0) * 1000, model_used, TOP_K, error) | |
| return answer, sources | |
| # ── LLM ────────────────────────────────────────────────────────────────── | |
| def _generate(self, question: str, context: str) -> Tuple[str, str]: | |
| hf_token = os.environ.get("HF_TOKEN", "") | |
| if not hf_token: | |
| return ( | |
| "HF_TOKEN not set. Add it as a Secret in Space Settings.\n\n" | |
| "Best matching excerpt:\n\n" + _extract_best(question, context), | |
| "none" | |
| ) | |
| # ── Build doc-type hints from ALL loaded files ──────────────────────── | |
| loaded_types = set(info["type"] for info in self._documents.values()) | |
| all_names = list(self._documents.keys()) | |
| hints = [] | |
| image_types = {".jpg", ".jpeg", ".png", ".webp"} | |
| table_types = {".csv", ".xlsx", ".xls"} | |
| if loaded_types & image_types: | |
| hints.append( | |
| "Some documents are IMAGES. Their context contains:\n" | |
| " - OCR-extracted text (actual text visible in the image)\n" | |
| " - Color analysis (dominant colors detected)\n" | |
| " - AI-generated visual descriptions\n" | |
| "When asked about text in an image, refer to the OCR section. " | |
| "When asked about colors, refer to the color analysis. " | |
| "When asked what an image shows, use the descriptions. " | |
| "Be specific and quote the actual text/colors from the context." | |
| ) | |
| if loaded_types & table_types: | |
| hints.append( | |
| "Some documents are tabular data (spreadsheet/CSV). " | |
| "Refer to column names and values precisely." | |
| ) | |
| if loaded_types & {".docx", ".doc"}: | |
| hints.append("Some documents are Word documents.") | |
| doc_type_hint = "\n".join(hints) | |
| # ── File list for the system prompt ─────────────────────────────────── | |
| if len(all_names) == 1: | |
| files_str = f"You are analyzing: '{all_names[0]}'." | |
| else: | |
| files_list = ", ".join(f"'{n}'" for n in all_names) | |
| files_str = f"You have {len(all_names)} documents loaded: {files_list}." | |
| system_prompt = ( | |
| f"You are DocMind AI, an expert document analyst built by Ryan Farahani.\n" | |
| f"{files_str}\n" | |
| f"{doc_type_hint}\n" | |
| "Answer using ONLY the provided document context. " | |
| "When the context contains chunks from multiple files, indicate which file " | |
| "the information comes from if relevant.\n" | |
| "Be concise and precise. No preamble. No reasoning out loud. Just answer.\n" | |
| "If asked a follow-up question, use the conversation history for context." | |
| ) | |
| messages = [{"role": "system", "content": system_prompt}] | |
| memory = self.get_memory_messages() | |
| if memory: | |
| messages.append({ | |
| "role": "system", | |
| "content": f"Current document context:\n{context}" | |
| }) | |
| messages.extend(memory) | |
| messages.append({"role": "user", "content": question}) | |
| else: | |
| messages.append({ | |
| "role": "user", | |
| "content": f"Document context:\n{context}\n\n---\nQuestion: {question}" | |
| }) | |
| headers = {"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"} | |
| last_error = "" | |
| for model_id in CANDIDATE_MODELS: | |
| try: | |
| resp = requests.post( | |
| HF_API_URL, | |
| headers=headers, | |
| data=json.dumps({ | |
| "model": model_id, | |
| "messages": messages, | |
| "max_tokens": 500, | |
| "temperature": 0.1, | |
| "stream": False, | |
| }), | |
| timeout=60, | |
| ) | |
| if resp.status_code == 200: | |
| raw = resp.json()["choices"][0]["message"]["content"].strip() | |
| answer = _strip_thinking(raw) | |
| if answer: | |
| return answer, model_id | |
| else: | |
| last_error = f"{model_id} → {resp.status_code}: {resp.text[:150]}" | |
| logger.warning(last_error) | |
| except Exception as e: | |
| last_error = str(e) | |
| logger.warning(f"Exception on {model_id}: {e}") | |
| continue | |
| return ( | |
| "AI unavailable. Most relevant excerpt:\n\n" | |
| + _extract_best(question, context) | |
| + f"\n\n(Error: {last_error})", | |
| "fallback" | |
| ) | |
| # ── Helpers ────────────────────────────────────────────────────────────────── | |
| def _strip_thinking(text: str) -> str: | |
| text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip() | |
| starters = [ | |
| "okay", "ok,", "alright", "let me", "let's", "i need", "i will", | |
| "i'll", "first,", "so,", "the user", "looking at", "going through", | |
| "based on the chunk", "parsing", "to answer", "in order to", | |
| ] | |
| lines = text.split("\n") | |
| clean, found = [], False | |
| for line in lines: | |
| lower = line.strip().lower() | |
| if not found: | |
| if line.strip() and not any(lower.startswith(p) for p in starters): | |
| found = True | |
| clean.append(line) | |
| else: | |
| clean.append(line) | |
| return "\n".join(clean).strip() or text | |
| def _extract_best(question: str, context: str) -> str: | |
| keywords = set(re.findall(r'\b\w{4,}\b', question.lower())) | |
| best, score = "", 0 | |
| for chunk in context.split("---"): | |
| s = len(keywords & set(re.findall(r'\b\w{4,}\b', chunk.lower()))) | |
| if s > score: | |
| score, best = s, chunk.strip() | |
| return (best[:600] + "...") if len(best) > 600 else best or "No relevant content found." | |