Medica_DecisionSupportAI / session_rag.py
Rajan Sharma
Update session_rag.py
5daa3d4 verified
raw
history blame
5.31 kB
# session_rag.py
from __future__ import annotations
import logging, hashlib
from typing import Iterable, List, Optional, Dict, Any
import numpy as np
from sentence_transformers import SentenceTransformer
try:
import faiss # type: ignore
_HAS_FAISS = True
except Exception:
logging.warning("FAISS not installed — using NumPy cosine fallback.")
faiss = None # type: ignore
_HAS_FAISS = False
def _normalize_rows(x: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-10
return x / norms
def _hash_text(s: str) -> str:
return hashlib.sha256(s.encode("utf-8")).hexdigest()
def _coerce_texts(items: Iterable) -> List[str]:
out: List[str] = []
seen = set()
for it in items or []:
if isinstance(it, str):
txt = it.strip()
elif isinstance(it, dict):
txt = (it.get("text") or it.get("content") or "").strip()
else:
txt = ""
if not txt:
continue
h = _hash_text(txt)
if h in seen:
continue
seen.add(h)
out.append(txt)
return out
def _simple_chunk(text: str, max_chars: int = 1200, overlap: int = 150) -> List[str]:
if len(text) <= max_chars:
return [text]
chunks = []
i = 0
while i < len(text):
chunks.append(text[i : i + max_chars])
i += max_chars - overlap
return chunks
class SessionRAG:
"""
Ephemeral per-session retriever with artifact registry.
Public:
- add_docs(items)
- register_artifacts(arts)
- retrieve(query, k=5)
- get_latest_csv_columns()
- get_csv_summaries()
- clear()
"""
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
self.model = SentenceTransformer(model_name)
self.texts: List[str] = []
self.embeddings: Optional[np.ndarray] = None
self.index = None
self.dim: Optional[int] = None
self.artifacts: List[Dict[str, Any]] = [] # keeps structured info per upload
def _fit_faiss(self) -> None:
if not _HAS_FAISS or self.embeddings is None:
return
emb = _normalize_rows(self.embeddings.astype("float32"))
self.dim = emb.shape[1]
self.index = faiss.IndexFlatIP(self.dim)
self.index.add(emb)
def _ensure_embeddings(self) -> None:
if not self.texts:
self.embeddings = None
self.index = None
return
embs = self.model.encode(self.texts, batch_size=64, show_progress_bar=False)
self.embeddings = np.asarray(embs, dtype="float32")
if _HAS_FAISS:
self._fit_faiss()
else:
self.index = None
def add_docs(self, items: Iterable) -> int:
raw_texts = _coerce_texts(items)
if not raw_texts:
return 0
chunks: List[str] = []
for t in raw_texts:
chunks.extend(_simple_chunk(t))
existing_hashes = {_hash_text(t) for t in self.texts}
added = 0
for c in chunks:
h = _hash_text(c)
if h in existing_hashes:
continue
self.texts.append(c)
existing_hashes.add(h)
added += 1
if added > 0:
self._ensure_embeddings()
return added
def register_artifacts(self, arts: Iterable[Dict[str, Any]]) -> int:
count = 0
for a in (arts or []):
if isinstance(a, dict):
self.artifacts.append(a)
count += 1
return count
def retrieve(self, query: str, k: int = 5) -> List[str]:
if not query or not self.texts:
return []
q_emb = self.model.encode([query], show_progress_bar=False)
q = _normalize_rows(np.asarray(q_emb, dtype="float32"))
if self.embeddings is None:
return []
if _HAS_FAISS and self.index is not None:
D, I = self.index.search(q, min(k, len(self.texts)))
idxs = [i for i in I[0] if 0 <= i < len(self.texts)]
return [self.texts[i] for i in idxs]
docs = _normalize_rows(self.embeddings)
sims = (q @ docs.T)[0]
top_idx = np.argsort(-sims)[: min(k, len(self.texts))]
return [self.texts[i] for i in top_idx]
# ---------- helpers for structured data ----------
def get_latest_csv_columns(self) -> List[str]:
for a in reversed(self.artifacts):
if a.get("kind") == "csv" and a.get("columns"):
return list(map(str, a["columns"]))
return []
def get_csv_summaries(self) -> List[Dict[str, Any]]:
"""
Return a list of dicts with keys:
- file (str)
- digest (str)
- summary (dict)
newest-first
"""
out: List[Dict[str, Any]] = []
for a in reversed(self.artifacts):
if a.get("kind") == "csv_summary":
out.append({
"file": a.get("name"),
"digest": a.get("digest"),
"summary": a.get("summary"),
})
return out
def clear(self) -> None:
self.texts = []
self.embeddings = None
self.index = None
self.dim = None
self.artifacts = []