Medica_DecisionSupportAI / session_rag.py
Rajan Sharma
Update session_rag.py
817d4c6 verified
raw
history blame
5.72 kB
"""
Session-level RAG with graceful FAISS fallback.
- If FAISS is installed, uses a FAISS L2 index over normalized embeddings.
- If FAISS is missing, falls back to pure NumPy cosine similarity.
- Designed to work with extract_text_from_files(...) outputs:
* list[str]
* list[dict] with keys like "text" or "content"
"""
from __future__ import annotations
import logging
import hashlib
from typing import Iterable, List, Optional, Tuple
import numpy as np
from sentence_transformers import SentenceTransformer
# ----- Optional FAISS -----
try:
import faiss # type: ignore
_HAS_FAISS = True
except Exception:
logging.warning(
"FAISS not installed — session RAG will use a NumPy cosine-similarity fallback. "
"Install faiss-cpu or faiss-gpu for faster retrieval."
)
faiss = None # type: ignore
_HAS_FAISS = False
def _normalize_rows(x: np.ndarray) -> np.ndarray:
"""L2 normalize row vectors; avoids division by zero."""
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-10
return x / norms
def _hash_text(s: str) -> str:
return hashlib.sha256(s.encode("utf-8")).hexdigest()
def _coerce_texts(items: Iterable) -> List[str]:
"""Accept str or dict items, pull text safely, drop empties, dedupe by hash."""
out: List[str] = []
seen: set = set()
for it in items or []:
if isinstance(it, str):
txt = it.strip()
elif isinstance(it, dict):
txt = (it.get("text") or it.get("content") or "").strip()
else:
txt = ""
if not txt:
continue
h = _hash_text(txt)
if h in seen:
continue
seen.add(h)
out.append(txt)
return out
def _simple_chunk(text: str, max_chars: int = 1200, overlap: int = 150) -> List[str]:
"""Lightweight char-based chunking to improve recall on long docs."""
if len(text) <= max_chars:
return [text]
chunks = []
i = 0
while i < len(text):
chunk = text[i : i + max_chars]
chunks.append(chunk)
i += max_chars - overlap
return chunks
class SessionRAG:
"""
Ephemeral per-session retriever.
Methods:
- add_docs(items): add strings or dicts({"text"/"content": ...})
- retrieve(query, k=5): returns list[str] of top-k chunks
- clear(): drop index & memory
"""
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
self.model = SentenceTransformer(model_name)
self.texts: List[str] = []
self.embeddings: Optional[np.ndarray] = None # shape: (N, D)
self.index = None # FAISS index if available
self.dim: Optional[int] = None
# ---------- Private helpers ----------
def _fit_faiss(self) -> None:
if not _HAS_FAISS or self.embeddings is None:
return
# Use inner product on normalized vectors (cosine similarity)
emb = _normalize_rows(self.embeddings.astype("float32"))
self.dim = emb.shape[1]
# Build IP index
self.index = faiss.IndexFlatIP(self.dim)
self.index.add(emb)
def _ensure_embeddings(self) -> None:
if not self.texts:
self.embeddings = None
self.index = None
return
# Compute embeddings
embs = self.model.encode(self.texts, batch_size=64, show_progress_bar=False)
self.embeddings = np.asarray(embs, dtype="float32")
# Build FAISS if available
if _HAS_FAISS:
self._fit_faiss()
else:
self.index = None
# ---------- Public API ----------
def add_docs(self, items: Iterable) -> int:
"""
Add a batch of texts or dicts with 'text'/'content'.
Applies basic chunking and deduplication.
Returns the number of chunks added.
"""
raw_texts = _coerce_texts(items)
if not raw_texts:
return 0
# Chunk each long text into manageable pieces
chunks: List[str] = []
for t in raw_texts:
chunks.extend(_simple_chunk(t))
# Deduplicate vs existing memory
existing_hashes = { _hash_text(t) for t in self.texts }
added = 0
for c in chunks:
h = _hash_text(c)
if h in existing_hashes:
continue
self.texts.append(c)
existing_hashes.add(h)
added += 1
# Recompute embeddings/index
if added > 0:
self._ensure_embeddings()
return added
def retrieve(self, query: str, k: int = 5) -> List[str]:
"""Return up to k most similar chunks for the query."""
if not query or not self.texts:
return []
# Encode query, normalize
q_emb = self.model.encode([query], show_progress_bar=False)
q = _normalize_rows(np.asarray(q_emb, dtype="float32"))
if self.embeddings is None:
return []
# FAISS path (inner product on normalized vectors)
if _HAS_FAISS and self.index is not None:
D, I = self.index.search(q, min(k, len(self.texts)))
idxs = [i for i in I[0] if 0 <= i < len(self.texts)]
return [self.texts[i] for i in idxs]
# NumPy fallback: cosine similarity via dot product on normalized vectors
docs = _normalize_rows(self.embeddings)
sims = (q @ docs.T)[0] # shape: (N,)
top_idx = np.argsort(-sims)[: min(k, len(self.texts))]
return [self.texts[i] for i in top_idx]
def clear(self) -> None:
"""Drop all in-memory data for this session."""
self.texts = []
self.embeddings = None
self.index = None
self.dim = None