# session_rag.py from __future__ import annotations import logging, hashlib from typing import Iterable, List, Optional, Dict, Any import numpy as np from sentence_transformers import SentenceTransformer try: import faiss # type: ignore _HAS_FAISS = True except Exception: logging.warning("FAISS not installed — using NumPy cosine fallback.") faiss = None # type: ignore _HAS_FAISS = False def _normalize_rows(x: np.ndarray) -> np.ndarray: norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-10 return x / norms def _hash_text(s: str) -> str: return hashlib.sha256(s.encode("utf-8")).hexdigest() def _coerce_texts(items: Iterable) -> List[str]: out: List[str] = [] seen = set() for it in items or []: if isinstance(it, str): txt = it.strip() elif isinstance(it, dict): txt = (it.get("text") or it.get("content") or "").strip() else: txt = "" if not txt: continue h = _hash_text(txt) if h in seen: continue seen.add(h) out.append(txt) return out def _simple_chunk(text: str, max_chars: int = 1200, overlap: int = 150) -> List[str]: if len(text) <= max_chars: return [text] chunks = [] i = 0 while i < len(text): chunks.append(text[i : i + max_chars]) i += max_chars - overlap return chunks class SessionRAG: """ Ephemeral per-session retriever with artifact registry. Public: - add_docs(items) - register_artifacts(arts) - retrieve(query, k=5) - get_latest_csv_columns() - get_csv_summaries() - clear() """ def __init__(self, model_name: str = "all-MiniLM-L6-v2"): self.model = SentenceTransformer(model_name) self.texts: List[str] = [] self.embeddings: Optional[np.ndarray] = None self.index = None self.dim: Optional[int] = None self.artifacts: List[Dict[str, Any]] = [] # keeps structured info per upload def _fit_faiss(self) -> None: if not _HAS_FAISS or self.embeddings is None: return emb = _normalize_rows(self.embeddings.astype("float32")) self.dim = emb.shape[1] self.index = faiss.IndexFlatIP(self.dim) self.index.add(emb) def _ensure_embeddings(self) -> None: if not self.texts: self.embeddings = None self.index = None return embs = self.model.encode(self.texts, batch_size=64, show_progress_bar=False) self.embeddings = np.asarray(embs, dtype="float32") if _HAS_FAISS: self._fit_faiss() else: self.index = None def add_docs(self, items: Iterable) -> int: raw_texts = _coerce_texts(items) if not raw_texts: return 0 chunks: List[str] = [] for t in raw_texts: chunks.extend(_simple_chunk(t)) existing_hashes = {_hash_text(t) for t in self.texts} added = 0 for c in chunks: h = _hash_text(c) if h in existing_hashes: continue self.texts.append(c) existing_hashes.add(h) added += 1 if added > 0: self._ensure_embeddings() return added def register_artifacts(self, arts: Iterable[Dict[str, Any]]) -> int: count = 0 for a in (arts or []): if isinstance(a, dict): self.artifacts.append(a) count += 1 return count def retrieve(self, query: str, k: int = 5) -> List[str]: if not query or not self.texts: return [] q_emb = self.model.encode([query], show_progress_bar=False) q = _normalize_rows(np.asarray(q_emb, dtype="float32")) if self.embeddings is None: return [] if _HAS_FAISS and self.index is not None: D, I = self.index.search(q, min(k, len(self.texts))) idxs = [i for i in I[0] if 0 <= i < len(self.texts)] return [self.texts[i] for i in idxs] docs = _normalize_rows(self.embeddings) sims = (q @ docs.T)[0] top_idx = np.argsort(-sims)[: min(k, len(self.texts))] return [self.texts[i] for i in top_idx] # ---------- helpers for structured data ---------- def get_latest_csv_columns(self) -> List[str]: for a in reversed(self.artifacts): if a.get("kind") == "csv" and a.get("columns"): return list(map(str, a["columns"])) return [] def get_csv_summaries(self) -> List[Dict[str, Any]]: """ Return a list of dicts with keys: - file (str) - digest (str) - summary (dict) newest-first """ out: List[Dict[str, Any]] = [] for a in reversed(self.artifacts): if a.get("kind") == "csv_summary": out.append({ "file": a.get("name"), "digest": a.get("digest"), "summary": a.get("summary"), }) return out def clear(self) -> None: self.texts = [] self.embeddings = None self.index = None self.dim = None self.artifacts = []