Spaces:
Sleeping
Sleeping
File size: 5,719 Bytes
817d4c6 f051f2e 817d4c6 f051f2e 1e3e62c 817d4c6 1e3e62c 817d4c6 f051f2e 817d4c6 f051f2e 817d4c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
"""
Session-level RAG with graceful FAISS fallback.
- If FAISS is installed, uses a FAISS L2 index over normalized embeddings.
- If FAISS is missing, falls back to pure NumPy cosine similarity.
- Designed to work with extract_text_from_files(...) outputs:
* list[str]
* list[dict] with keys like "text" or "content"
"""
from __future__ import annotations
import logging
import hashlib
from typing import Iterable, List, Optional, Tuple
import numpy as np
from sentence_transformers import SentenceTransformer
# ----- Optional FAISS -----
try:
import faiss # type: ignore
_HAS_FAISS = True
except Exception:
logging.warning(
"FAISS not installed — session RAG will use a NumPy cosine-similarity fallback. "
"Install faiss-cpu or faiss-gpu for faster retrieval."
)
faiss = None # type: ignore
_HAS_FAISS = False
def _normalize_rows(x: np.ndarray) -> np.ndarray:
"""L2 normalize row vectors; avoids division by zero."""
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-10
return x / norms
def _hash_text(s: str) -> str:
return hashlib.sha256(s.encode("utf-8")).hexdigest()
def _coerce_texts(items: Iterable) -> List[str]:
"""Accept str or dict items, pull text safely, drop empties, dedupe by hash."""
out: List[str] = []
seen: set = set()
for it in items or []:
if isinstance(it, str):
txt = it.strip()
elif isinstance(it, dict):
txt = (it.get("text") or it.get("content") or "").strip()
else:
txt = ""
if not txt:
continue
h = _hash_text(txt)
if h in seen:
continue
seen.add(h)
out.append(txt)
return out
def _simple_chunk(text: str, max_chars: int = 1200, overlap: int = 150) -> List[str]:
"""Lightweight char-based chunking to improve recall on long docs."""
if len(text) <= max_chars:
return [text]
chunks = []
i = 0
while i < len(text):
chunk = text[i : i + max_chars]
chunks.append(chunk)
i += max_chars - overlap
return chunks
class SessionRAG:
"""
Ephemeral per-session retriever.
Methods:
- add_docs(items): add strings or dicts({"text"/"content": ...})
- retrieve(query, k=5): returns list[str] of top-k chunks
- clear(): drop index & memory
"""
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
self.model = SentenceTransformer(model_name)
self.texts: List[str] = []
self.embeddings: Optional[np.ndarray] = None # shape: (N, D)
self.index = None # FAISS index if available
self.dim: Optional[int] = None
# ---------- Private helpers ----------
def _fit_faiss(self) -> None:
if not _HAS_FAISS or self.embeddings is None:
return
# Use inner product on normalized vectors (cosine similarity)
emb = _normalize_rows(self.embeddings.astype("float32"))
self.dim = emb.shape[1]
# Build IP index
self.index = faiss.IndexFlatIP(self.dim)
self.index.add(emb)
def _ensure_embeddings(self) -> None:
if not self.texts:
self.embeddings = None
self.index = None
return
# Compute embeddings
embs = self.model.encode(self.texts, batch_size=64, show_progress_bar=False)
self.embeddings = np.asarray(embs, dtype="float32")
# Build FAISS if available
if _HAS_FAISS:
self._fit_faiss()
else:
self.index = None
# ---------- Public API ----------
def add_docs(self, items: Iterable) -> int:
"""
Add a batch of texts or dicts with 'text'/'content'.
Applies basic chunking and deduplication.
Returns the number of chunks added.
"""
raw_texts = _coerce_texts(items)
if not raw_texts:
return 0
# Chunk each long text into manageable pieces
chunks: List[str] = []
for t in raw_texts:
chunks.extend(_simple_chunk(t))
# Deduplicate vs existing memory
existing_hashes = { _hash_text(t) for t in self.texts }
added = 0
for c in chunks:
h = _hash_text(c)
if h in existing_hashes:
continue
self.texts.append(c)
existing_hashes.add(h)
added += 1
# Recompute embeddings/index
if added > 0:
self._ensure_embeddings()
return added
def retrieve(self, query: str, k: int = 5) -> List[str]:
"""Return up to k most similar chunks for the query."""
if not query or not self.texts:
return []
# Encode query, normalize
q_emb = self.model.encode([query], show_progress_bar=False)
q = _normalize_rows(np.asarray(q_emb, dtype="float32"))
if self.embeddings is None:
return []
# FAISS path (inner product on normalized vectors)
if _HAS_FAISS and self.index is not None:
D, I = self.index.search(q, min(k, len(self.texts)))
idxs = [i for i in I[0] if 0 <= i < len(self.texts)]
return [self.texts[i] for i in idxs]
# NumPy fallback: cosine similarity via dot product on normalized vectors
docs = _normalize_rows(self.embeddings)
sims = (q @ docs.T)[0] # shape: (N,)
top_idx = np.argsort(-sims)[: min(k, len(self.texts))]
return [self.texts[i] for i in top_idx]
def clear(self) -> None:
"""Drop all in-memory data for this session."""
self.texts = []
self.embeddings = None
self.index = None
self.dim = None
|