Spaces:
Sleeping
Sleeping
| """ | |
| Session-level RAG with graceful FAISS fallback. | |
| - If FAISS is installed, uses a FAISS L2 index over normalized embeddings. | |
| - If FAISS is missing, falls back to pure NumPy cosine similarity. | |
| - Designed to work with extract_text_from_files(...) outputs: | |
| * list[str] | |
| * list[dict] with keys like "text" or "content" | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import hashlib | |
| from typing import Iterable, List, Optional, Tuple | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| # ----- Optional FAISS ----- | |
| try: | |
| import faiss # type: ignore | |
| _HAS_FAISS = True | |
| except Exception: | |
| logging.warning( | |
| "FAISS not installed — session RAG will use a NumPy cosine-similarity fallback. " | |
| "Install faiss-cpu or faiss-gpu for faster retrieval." | |
| ) | |
| faiss = None # type: ignore | |
| _HAS_FAISS = False | |
| def _normalize_rows(x: np.ndarray) -> np.ndarray: | |
| """L2 normalize row vectors; avoids division by zero.""" | |
| norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-10 | |
| return x / norms | |
| def _hash_text(s: str) -> str: | |
| return hashlib.sha256(s.encode("utf-8")).hexdigest() | |
| def _coerce_texts(items: Iterable) -> List[str]: | |
| """Accept str or dict items, pull text safely, drop empties, dedupe by hash.""" | |
| out: List[str] = [] | |
| seen: set = set() | |
| for it in items or []: | |
| if isinstance(it, str): | |
| txt = it.strip() | |
| elif isinstance(it, dict): | |
| txt = (it.get("text") or it.get("content") or "").strip() | |
| else: | |
| txt = "" | |
| if not txt: | |
| continue | |
| h = _hash_text(txt) | |
| if h in seen: | |
| continue | |
| seen.add(h) | |
| out.append(txt) | |
| return out | |
| def _simple_chunk(text: str, max_chars: int = 1200, overlap: int = 150) -> List[str]: | |
| """Lightweight char-based chunking to improve recall on long docs.""" | |
| if len(text) <= max_chars: | |
| return [text] | |
| chunks = [] | |
| i = 0 | |
| while i < len(text): | |
| chunk = text[i : i + max_chars] | |
| chunks.append(chunk) | |
| i += max_chars - overlap | |
| return chunks | |
| class SessionRAG: | |
| """ | |
| Ephemeral per-session retriever. | |
| Methods: | |
| - add_docs(items): add strings or dicts({"text"/"content": ...}) | |
| - retrieve(query, k=5): returns list[str] of top-k chunks | |
| - clear(): drop index & memory | |
| """ | |
| def __init__(self, model_name: str = "all-MiniLM-L6-v2"): | |
| self.model = SentenceTransformer(model_name) | |
| self.texts: List[str] = [] | |
| self.embeddings: Optional[np.ndarray] = None # shape: (N, D) | |
| self.index = None # FAISS index if available | |
| self.dim: Optional[int] = None | |
| # ---------- Private helpers ---------- | |
| def _fit_faiss(self) -> None: | |
| if not _HAS_FAISS or self.embeddings is None: | |
| return | |
| # Use inner product on normalized vectors (cosine similarity) | |
| emb = _normalize_rows(self.embeddings.astype("float32")) | |
| self.dim = emb.shape[1] | |
| # Build IP index | |
| self.index = faiss.IndexFlatIP(self.dim) | |
| self.index.add(emb) | |
| def _ensure_embeddings(self) -> None: | |
| if not self.texts: | |
| self.embeddings = None | |
| self.index = None | |
| return | |
| # Compute embeddings | |
| embs = self.model.encode(self.texts, batch_size=64, show_progress_bar=False) | |
| self.embeddings = np.asarray(embs, dtype="float32") | |
| # Build FAISS if available | |
| if _HAS_FAISS: | |
| self._fit_faiss() | |
| else: | |
| self.index = None | |
| # ---------- Public API ---------- | |
| def add_docs(self, items: Iterable) -> int: | |
| """ | |
| Add a batch of texts or dicts with 'text'/'content'. | |
| Applies basic chunking and deduplication. | |
| Returns the number of chunks added. | |
| """ | |
| raw_texts = _coerce_texts(items) | |
| if not raw_texts: | |
| return 0 | |
| # Chunk each long text into manageable pieces | |
| chunks: List[str] = [] | |
| for t in raw_texts: | |
| chunks.extend(_simple_chunk(t)) | |
| # Deduplicate vs existing memory | |
| existing_hashes = { _hash_text(t) for t in self.texts } | |
| added = 0 | |
| for c in chunks: | |
| h = _hash_text(c) | |
| if h in existing_hashes: | |
| continue | |
| self.texts.append(c) | |
| existing_hashes.add(h) | |
| added += 1 | |
| # Recompute embeddings/index | |
| if added > 0: | |
| self._ensure_embeddings() | |
| return added | |
| def retrieve(self, query: str, k: int = 5) -> List[str]: | |
| """Return up to k most similar chunks for the query.""" | |
| if not query or not self.texts: | |
| return [] | |
| # Encode query, normalize | |
| q_emb = self.model.encode([query], show_progress_bar=False) | |
| q = _normalize_rows(np.asarray(q_emb, dtype="float32")) | |
| if self.embeddings is None: | |
| return [] | |
| # FAISS path (inner product on normalized vectors) | |
| if _HAS_FAISS and self.index is not None: | |
| D, I = self.index.search(q, min(k, len(self.texts))) | |
| idxs = [i for i in I[0] if 0 <= i < len(self.texts)] | |
| return [self.texts[i] for i in idxs] | |
| # NumPy fallback: cosine similarity via dot product on normalized vectors | |
| docs = _normalize_rows(self.embeddings) | |
| sims = (q @ docs.T)[0] # shape: (N,) | |
| top_idx = np.argsort(-sims)[: min(k, len(self.texts))] | |
| return [self.texts[i] for i in top_idx] | |
| def clear(self) -> None: | |
| """Drop all in-memory data for this session.""" | |
| self.texts = [] | |
| self.embeddings = None | |
| self.index = None | |
| self.dim = None | |