Spaces:
Sleeping
Sleeping
File size: 4,776 Bytes
817d4c6 5543737 817d4c6 5543737 f051f2e 817d4c6 5543737 817d4c6 5543737 817d4c6 5543737 817d4c6 f051f2e 1e3e62c 5543737 1e3e62c 817d4c6 f051f2e 817d4c6 5543737 817d4c6 5543737 817d4c6 5543737 817d4c6 5543737 817d4c6 f051f2e 817d4c6 5543737 817d4c6 5543737 817d4c6 5543737 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# session_rag.py
from __future__ import annotations
import logging, hashlib
from typing import Iterable, List, Optional, Dict, Any
import numpy as np
from sentence_transformers import SentenceTransformer
try:
import faiss # type: ignore
_HAS_FAISS = True
except Exception:
logging.warning("FAISS not installed — using NumPy cosine fallback.")
faiss = None # type: ignore
_HAS_FAISS = False
def _normalize_rows(x: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-10
return x / norms
def _hash_text(s: str) -> str:
return hashlib.sha256(s.encode("utf-8")).hexdigest()
def _coerce_texts(items: Iterable) -> List[str]:
out: List[str] = []
seen = set()
for it in items or []:
if isinstance(it, str):
txt = it.strip()
elif isinstance(it, dict):
txt = (it.get("text") or it.get("content") or "").strip()
else:
txt = ""
if not txt:
continue
h = _hash_text(txt)
if h in seen:
continue
seen.add(h)
out.append(txt)
return out
def _simple_chunk(text: str, max_chars: int = 1200, overlap: int = 150) -> List[str]:
if len(text) <= max_chars:
return [text]
chunks = []
i = 0
while i < len(text):
chunks.append(text[i : i + max_chars])
i += max_chars - overlap
return chunks
class SessionRAG:
"""
Ephemeral per-session retriever with artifact registry.
Public:
- add_docs(items)
- register_artifacts(arts)
- retrieve(query, k=5)
- get_latest_csv_columns()
- clear()
"""
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
self.model = SentenceTransformer(model_name)
self.texts: List[str] = []
self.embeddings: Optional[np.ndarray] = None
self.index = None
self.dim: Optional[int] = None
self.artifacts: List[Dict[str, Any]] = [] # keeps structured info per upload
def _fit_faiss(self) -> None:
if not _HAS_FAISS or self.embeddings is None:
return
emb = _normalize_rows(self.embeddings.astype("float32"))
self.dim = emb.shape[1]
self.index = faiss.IndexFlatIP(self.dim)
self.index.add(emb)
def _ensure_embeddings(self) -> None:
if not self.texts:
self.embeddings = None
self.index = None
return
embs = self.model.encode(self.texts, batch_size=64, show_progress_bar=False)
self.embeddings = np.asarray(embs, dtype="float32")
if _HAS_FAISS:
self._fit_faiss()
else:
self.index = None
def add_docs(self, items: Iterable) -> int:
raw_texts = _coerce_texts(items)
if not raw_texts:
return 0
chunks: List[str] = []
for t in raw_texts:
chunks.extend(_simple_chunk(t))
existing_hashes = {_hash_text(t) for t in self.texts}
added = 0
for c in chunks:
h = _hash_text(c)
if h in existing_hashes:
continue
self.texts.append(c)
existing_hashes.add(h)
added += 1
if added > 0:
self._ensure_embeddings()
return added
def register_artifacts(self, arts: Iterable[Dict[str, Any]]) -> int:
count = 0
for a in (arts or []):
if isinstance(a, dict):
self.artifacts.append(a)
count += 1
return count
def retrieve(self, query: str, k: int = 5) -> List[str]:
if not query or not self.texts:
return []
q_emb = self.model.encode([query], show_progress_bar=False)
q = _normalize_rows(np.asarray(q_emb, dtype="float32"))
if self.embeddings is None:
return []
if _HAS_FAISS and self.index is not None:
D, I = self.index.search(q, min(k, len(self.texts)))
idxs = [i for i in I[0] if 0 <= i < len(self.texts)]
return [self.texts[i] for i in idxs]
docs = _normalize_rows(self.embeddings)
sims = (q @ docs.T)[0]
top_idx = np.argsort(-sims)[: min(k, len(self.texts))]
return [self.texts[i] for i in top_idx]
# ---------- helpers for structured Qs ----------
def get_latest_csv_columns(self) -> List[str]:
# scan artifacts in reverse insertion order
for a in reversed(self.artifacts):
if a.get("kind") == "csv" and a.get("columns"):
return list(map(str, a["columns"]))
return []
def clear(self) -> None:
self.texts = []
self.embeddings = None
self.index = None
self.dim = None
self.artifacts = []
|