Rajan Sharma commited on
Commit
817d4c6
·
verified ·
1 Parent(s): 14fa872

Update session_rag.py

Browse files
Files changed (1) hide show
  1. session_rag.py +174 -29
session_rag.py CHANGED
@@ -1,36 +1,181 @@
1
- from typing import List, Tuple
2
- from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
- import faiss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  class SessionRAG:
7
  """
8
- In-memory, per-session store for uploaded content.
9
- No disk persistence to remain PHI-safe by default.
 
 
 
 
10
  """
11
- def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
 
12
  self.model = SentenceTransformer(model_name)
13
- self.docs: List[Tuple[str, str]] = []
14
- self.index = None
15
- self.vecs = None
16
-
17
- def add_docs(self, items: List[Tuple[str, str]]):
18
- self.docs.extend(items)
19
- texts = [t for _, t in self.docs]
20
- if not texts:
21
- self.index = None; self.vecs=None; return
22
- embs = self.model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
23
- self.vecs = embs
24
- self.index = faiss.IndexFlatIP(embs.shape[1])
25
- self.index.add(embs)
26
-
27
- def retrieve(self, query: str, k: int = 6) -> List[str]:
28
- if not self.index or self.vecs is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  return []
30
- q = self.model.encode([query], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
31
- D, I = self.index.search(q, k)
32
- out = []
33
- for idx in I[0]:
34
- if 0 <= idx < len(self.docs):
35
- out.append(self.docs[idx][1])
36
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Session-level RAG with graceful FAISS fallback.
3
+
4
+ - If FAISS is installed, uses a FAISS L2 index over normalized embeddings.
5
+ - If FAISS is missing, falls back to pure NumPy cosine similarity.
6
+ - Designed to work with extract_text_from_files(...) outputs:
7
+ * list[str]
8
+ * list[dict] with keys like "text" or "content"
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import logging
14
+ import hashlib
15
+ from typing import Iterable, List, Optional, Tuple
16
+
17
  import numpy as np
18
+ from sentence_transformers import SentenceTransformer
19
+
20
+ # ----- Optional FAISS -----
21
+ try:
22
+ import faiss # type: ignore
23
+ _HAS_FAISS = True
24
+ except Exception:
25
+ logging.warning(
26
+ "FAISS not installed — session RAG will use a NumPy cosine-similarity fallback. "
27
+ "Install faiss-cpu or faiss-gpu for faster retrieval."
28
+ )
29
+ faiss = None # type: ignore
30
+ _HAS_FAISS = False
31
+
32
+
33
+ def _normalize_rows(x: np.ndarray) -> np.ndarray:
34
+ """L2 normalize row vectors; avoids division by zero."""
35
+ norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-10
36
+ return x / norms
37
+
38
+
39
+ def _hash_text(s: str) -> str:
40
+ return hashlib.sha256(s.encode("utf-8")).hexdigest()
41
+
42
+
43
+ def _coerce_texts(items: Iterable) -> List[str]:
44
+ """Accept str or dict items, pull text safely, drop empties, dedupe by hash."""
45
+ out: List[str] = []
46
+ seen: set = set()
47
+ for it in items or []:
48
+ if isinstance(it, str):
49
+ txt = it.strip()
50
+ elif isinstance(it, dict):
51
+ txt = (it.get("text") or it.get("content") or "").strip()
52
+ else:
53
+ txt = ""
54
+ if not txt:
55
+ continue
56
+ h = _hash_text(txt)
57
+ if h in seen:
58
+ continue
59
+ seen.add(h)
60
+ out.append(txt)
61
+ return out
62
+
63
+
64
+ def _simple_chunk(text: str, max_chars: int = 1200, overlap: int = 150) -> List[str]:
65
+ """Lightweight char-based chunking to improve recall on long docs."""
66
+ if len(text) <= max_chars:
67
+ return [text]
68
+ chunks = []
69
+ i = 0
70
+ while i < len(text):
71
+ chunk = text[i : i + max_chars]
72
+ chunks.append(chunk)
73
+ i += max_chars - overlap
74
+ return chunks
75
+
76
 
77
  class SessionRAG:
78
  """
79
+ Ephemeral per-session retriever.
80
+
81
+ Methods:
82
+ - add_docs(items): add strings or dicts({"text"/"content": ...})
83
+ - retrieve(query, k=5): returns list[str] of top-k chunks
84
+ - clear(): drop index & memory
85
  """
86
+
87
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
88
  self.model = SentenceTransformer(model_name)
89
+ self.texts: List[str] = []
90
+ self.embeddings: Optional[np.ndarray] = None # shape: (N, D)
91
+ self.index = None # FAISS index if available
92
+ self.dim: Optional[int] = None
93
+
94
+ # ---------- Private helpers ----------
95
+ def _fit_faiss(self) -> None:
96
+ if not _HAS_FAISS or self.embeddings is None:
97
+ return
98
+ # Use inner product on normalized vectors (cosine similarity)
99
+ emb = _normalize_rows(self.embeddings.astype("float32"))
100
+ self.dim = emb.shape[1]
101
+ # Build IP index
102
+ self.index = faiss.IndexFlatIP(self.dim)
103
+ self.index.add(emb)
104
+
105
+ def _ensure_embeddings(self) -> None:
106
+ if not self.texts:
107
+ self.embeddings = None
108
+ self.index = None
109
+ return
110
+ # Compute embeddings
111
+ embs = self.model.encode(self.texts, batch_size=64, show_progress_bar=False)
112
+ self.embeddings = np.asarray(embs, dtype="float32")
113
+ # Build FAISS if available
114
+ if _HAS_FAISS:
115
+ self._fit_faiss()
116
+ else:
117
+ self.index = None
118
+
119
+ # ---------- Public API ----------
120
+ def add_docs(self, items: Iterable) -> int:
121
+ """
122
+ Add a batch of texts or dicts with 'text'/'content'.
123
+ Applies basic chunking and deduplication.
124
+ Returns the number of chunks added.
125
+ """
126
+ raw_texts = _coerce_texts(items)
127
+ if not raw_texts:
128
+ return 0
129
+
130
+ # Chunk each long text into manageable pieces
131
+ chunks: List[str] = []
132
+ for t in raw_texts:
133
+ chunks.extend(_simple_chunk(t))
134
+
135
+ # Deduplicate vs existing memory
136
+ existing_hashes = { _hash_text(t) for t in self.texts }
137
+ added = 0
138
+ for c in chunks:
139
+ h = _hash_text(c)
140
+ if h in existing_hashes:
141
+ continue
142
+ self.texts.append(c)
143
+ existing_hashes.add(h)
144
+ added += 1
145
+
146
+ # Recompute embeddings/index
147
+ if added > 0:
148
+ self._ensure_embeddings()
149
+
150
+ return added
151
+
152
+ def retrieve(self, query: str, k: int = 5) -> List[str]:
153
+ """Return up to k most similar chunks for the query."""
154
+ if not query or not self.texts:
155
+ return []
156
+
157
+ # Encode query, normalize
158
+ q_emb = self.model.encode([query], show_progress_bar=False)
159
+ q = _normalize_rows(np.asarray(q_emb, dtype="float32"))
160
+
161
+ if self.embeddings is None:
162
  return []
163
+
164
+ # FAISS path (inner product on normalized vectors)
165
+ if _HAS_FAISS and self.index is not None:
166
+ D, I = self.index.search(q, min(k, len(self.texts)))
167
+ idxs = [i for i in I[0] if 0 <= i < len(self.texts)]
168
+ return [self.texts[i] for i in idxs]
169
+
170
+ # NumPy fallback: cosine similarity via dot product on normalized vectors
171
+ docs = _normalize_rows(self.embeddings)
172
+ sims = (q @ docs.T)[0] # shape: (N,)
173
+ top_idx = np.argsort(-sims)[: min(k, len(self.texts))]
174
+ return [self.texts[i] for i in top_idx]
175
+
176
+ def clear(self) -> None:
177
+ """Drop all in-memory data for this session."""
178
+ self.texts = []
179
+ self.embeddings = None
180
+ self.index = None
181
+ self.dim = None