Spaces:
Sleeping
Sleeping
| # session_rag.py | |
| from __future__ import annotations | |
| import logging, hashlib | |
| from typing import Iterable, List, Optional, Dict, Any | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| try: | |
| import faiss # type: ignore | |
| _HAS_FAISS = True | |
| except Exception: | |
| logging.warning("FAISS not installed — using NumPy cosine fallback.") | |
| faiss = None # type: ignore | |
| _HAS_FAISS = False | |
| def _normalize_rows(x: np.ndarray) -> np.ndarray: | |
| norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-10 | |
| return x / norms | |
| def _hash_text(s: str) -> str: | |
| return hashlib.sha256(s.encode("utf-8")).hexdigest() | |
| def _coerce_texts(items: Iterable) -> List[str]: | |
| out: List[str] = [] | |
| seen = set() | |
| for it in items or []: | |
| if isinstance(it, str): | |
| txt = it.strip() | |
| elif isinstance(it, dict): | |
| txt = (it.get("text") or it.get("content") or "").strip() | |
| else: | |
| txt = "" | |
| if not txt: | |
| continue | |
| h = _hash_text(txt) | |
| if h in seen: | |
| continue | |
| seen.add(h) | |
| out.append(txt) | |
| return out | |
| def _simple_chunk(text: str, max_chars: int = 1200, overlap: int = 150) -> List[str]: | |
| if len(text) <= max_chars: | |
| return [text] | |
| chunks = [] | |
| i = 0 | |
| while i < len(text): | |
| chunks.append(text[i : i + max_chars]) | |
| i += max_chars - overlap | |
| return chunks | |
| class SessionRAG: | |
| """ | |
| Ephemeral per-session retriever with artifact registry. | |
| Public: | |
| - add_docs(items) | |
| - register_artifacts(arts) | |
| - retrieve(query, k=5) | |
| - get_latest_csv_columns() | |
| - clear() | |
| """ | |
| def __init__(self, model_name: str = "all-MiniLM-L6-v2"): | |
| self.model = SentenceTransformer(model_name) | |
| self.texts: List[str] = [] | |
| self.embeddings: Optional[np.ndarray] = None | |
| self.index = None | |
| self.dim: Optional[int] = None | |
| self.artifacts: List[Dict[str, Any]] = [] # keeps structured info per upload | |
| def _fit_faiss(self) -> None: | |
| if not _HAS_FAISS or self.embeddings is None: | |
| return | |
| emb = _normalize_rows(self.embeddings.astype("float32")) | |
| self.dim = emb.shape[1] | |
| self.index = faiss.IndexFlatIP(self.dim) | |
| self.index.add(emb) | |
| def _ensure_embeddings(self) -> None: | |
| if not self.texts: | |
| self.embeddings = None | |
| self.index = None | |
| return | |
| embs = self.model.encode(self.texts, batch_size=64, show_progress_bar=False) | |
| self.embeddings = np.asarray(embs, dtype="float32") | |
| if _HAS_FAISS: | |
| self._fit_faiss() | |
| else: | |
| self.index = None | |
| def add_docs(self, items: Iterable) -> int: | |
| raw_texts = _coerce_texts(items) | |
| if not raw_texts: | |
| return 0 | |
| chunks: List[str] = [] | |
| for t in raw_texts: | |
| chunks.extend(_simple_chunk(t)) | |
| existing_hashes = {_hash_text(t) for t in self.texts} | |
| added = 0 | |
| for c in chunks: | |
| h = _hash_text(c) | |
| if h in existing_hashes: | |
| continue | |
| self.texts.append(c) | |
| existing_hashes.add(h) | |
| added += 1 | |
| if added > 0: | |
| self._ensure_embeddings() | |
| return added | |
| def register_artifacts(self, arts: Iterable[Dict[str, Any]]) -> int: | |
| count = 0 | |
| for a in (arts or []): | |
| if isinstance(a, dict): | |
| self.artifacts.append(a) | |
| count += 1 | |
| return count | |
| def retrieve(self, query: str, k: int = 5) -> List[str]: | |
| if not query or not self.texts: | |
| return [] | |
| q_emb = self.model.encode([query], show_progress_bar=False) | |
| q = _normalize_rows(np.asarray(q_emb, dtype="float32")) | |
| if self.embeddings is None: | |
| return [] | |
| if _HAS_FAISS and self.index is not None: | |
| D, I = self.index.search(q, min(k, len(self.texts))) | |
| idxs = [i for i in I[0] if 0 <= i < len(self.texts)] | |
| return [self.texts[i] for i in idxs] | |
| docs = _normalize_rows(self.embeddings) | |
| sims = (q @ docs.T)[0] | |
| top_idx = np.argsort(-sims)[: min(k, len(self.texts))] | |
| return [self.texts[i] for i in top_idx] | |
| # ---------- helpers for structured Qs ---------- | |
| def get_latest_csv_columns(self) -> List[str]: | |
| # scan artifacts in reverse insertion order | |
| for a in reversed(self.artifacts): | |
| if a.get("kind") == "csv" and a.get("columns"): | |
| return list(map(str, a["columns"])) | |
| return [] | |
| def clear(self) -> None: | |
| self.texts = [] | |
| self.embeddings = None | |
| self.index = None | |
| self.dim = None | |
| self.artifacts = [] | |