Azizahalq commited on
Commit
b91214e
·
1 Parent(s): 177fad2

Update rag_mini.py

Browse files
Files changed (1) hide show
  1. rag_mini.py +56 -268
rag_mini.py CHANGED
@@ -1,292 +1,80 @@
1
  # rag_mini.py
2
- import os, re, uuid, textwrap, hashlib, json, shutil
3
  from pathlib import Path
4
- from typing import Iterable, List, Tuple, Dict, Any
5
-
6
- # ---------------- Paths ----------------
7
- ROOT_DIR = Path(__file__).parent.resolve()
8
- DATA_ROOT = ROOT_DIR / "MaterialMind"
9
- DATA_DIR = DATA_ROOT / "sources"
10
- INDEX_DIR = DATA_ROOT / "index" / "chroma_v3"
11
- MANIFEST = DATA_ROOT / "index" / "manifest.json"
12
 
 
 
 
13
  DEFAULT_TOPK = 5
14
- EMB_MODEL = "BAAI/bge-small-en-v1.5"
15
 
16
- def ensure_dirs():
17
- DATA_DIR.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  INDEX_DIR.mkdir(parents=True, exist_ok=True)
19
- MANIFEST.parent.mkdir(parents=True, exist_ok=True)
20
-
21
- # ---------------- Embeddings ----------------
22
- _EMBED_FAST = None
23
- _EMBED_ST = None
24
-
25
- def init_embedder():
26
- global _EMBED_FAST, _EMBED_ST
27
- if _EMBED_FAST or _EMBED_ST:
28
- return
29
- try:
30
- from fastembed import TextEmbedding
31
- _EMBED_FAST = TextEmbedding(model_name=EMB_MODEL)
32
- print(f"[EMB] FastEmbed ready: {EMB_MODEL}")
33
- except Exception as e:
34
- print(f"[EMB] FastEmbed not available ({e}). Falling back to SentenceTransformers.")
35
- from sentence_transformers import SentenceTransformer
36
- _EMBED_ST = SentenceTransformer(EMB_MODEL)
37
-
38
- def embed_texts(texts: List[str]) -> List[List[float]]:
39
- init_embedder()
40
- if _EMBED_FAST is not None:
41
- return [v for v in _EMBED_FAST.embed(texts)]
42
- return _EMBED_ST.encode(texts, normalize_embeddings=True).tolist()
43
-
44
- # ---------------- Loaders ----------------
45
- def normalize_spaces(text: str) -> str:
46
- text = text.replace("\r", "\n")
47
- text = re.sub(r"[ \t]+", " ", text)
48
- text = re.sub(r"\n{3,}", "\n\n", text)
49
- return text.strip()
50
-
51
- def load_text_from_pdf(path: Path):
52
- # try pymupdf
53
- try:
54
- import fitz
55
- doc = fitz.open(str(path))
56
- any_text = False
57
- for i, page in enumerate(doc):
58
- t = page.get_text("text").strip()
59
- if t:
60
- any_text = True
61
- yield normalize_spaces(t), i + 1
62
- doc.close()
63
- if not any_text:
64
- print(f"[HINT] scanned? {path.name}")
65
- return
66
- except Exception:
67
- pass
68
- # pypdf fallback
69
- try:
70
- from pypdf import PdfReader
71
- r = PdfReader(str(path))
72
- any_text = False
73
- for i, p in enumerate(r.pages):
74
- try:
75
- raw = p.extract_text() or ""
76
- except Exception:
77
- raw = ""
78
- t = normalize_spaces(raw)
79
- if t:
80
- any_text = True
81
- yield t, i + 1
82
- if not any_text:
83
- print(f"[HINT] no extractable text: {path.name}")
84
- except Exception as e:
85
- print(f"[WARN] PDF read fail {path.name}: {e}")
86
-
87
- def load_text_from_md_txt(path: Path) -> str:
88
- try:
89
- raw = path.read_text(errors="ignore")
90
- except Exception:
91
- raw = ""
92
- return normalize_spaces(raw)
93
-
94
- def chunk(text: str, max_chars=1200, overlap=150):
95
- n = len(text)
96
- if n <= max_chars:
97
- if n > 0:
98
- yield text
99
- return
100
- i = 0
101
- while i < n:
102
- j = min(i + max_chars, n)
103
- yield text[i:j]
104
- i = j - overlap if j < n else j
105
-
106
- def iter_documents():
107
- for f in DATA_DIR.rglob("*"):
108
- if not f.is_file():
109
- continue
110
- ext = f.suffix.lower()
111
- rel = f.relative_to(ROOT_DIR).as_posix()
112
- if ext == ".pdf":
113
- any_text = False
114
- for page_text, page in load_text_from_pdf(f):
115
- any_text = True
116
- for c in chunk(page_text):
117
- yield {"id": str(uuid.uuid4()), "text": c, "meta": {"source": rel, "page": page}}
118
- if not any_text:
119
- yield {"id": str(uuid.uuid4()), "text": f"[NO-TEXT] {f.name}", "meta": {"source": rel, "page": None}}
120
- elif ext in (".md", ".txt"):
121
- text = load_text_from_md_txt(f)
122
- for c in chunk(text):
123
- yield {"id": str(uuid.uuid4()), "text": c, "meta": {"source": rel, "page": None}}
124
-
125
- # ---------------- Chroma ----------------
126
- def _client():
127
- import chromadb
128
- return chromadb.PersistentClient(path=str(INDEX_DIR))
129
 
130
- def get_collection(reset: bool = False):
 
131
  import chromadb
132
- client = _client()
133
- if reset:
134
- try:
135
- client.delete_collection("materialmind")
136
- except Exception:
137
- pass
138
- # Important: name must match what you used when you built the index locally.
139
  return client.get_or_create_collection(name="materialmind")
140
 
141
- def add_batch(col, ids, docs, metas):
142
- embs = embed_texts(docs)
143
- col.add(ids=ids, documents=docs, metadatas=metas, embeddings=embs)
144
-
145
- def build_index(batch_size=256) -> int:
146
- ensure_dirs()
147
- col = get_collection(reset=True)
148
- ids, docs, metas, total = [], [], [], 0
149
- for doc in iter_documents():
150
- if doc["text"].startswith("[NO-TEXT]"):
151
- print(f"[INFO] skip unextractable: {doc['meta']['source']}")
152
- continue
153
- ids.append(doc["id"]); docs.append(doc["text"]); metas.append(doc["meta"])
154
- if len(ids) >= batch_size:
155
- add_batch(col, ids, docs, metas)
156
- total += len(ids); ids, docs, metas = [], [], []
157
- if ids:
158
- add_batch(col, ids, docs, metas); total += len(ids)
159
- print(f"[BUILD] complete: {total} chunks")
160
- return total
161
-
162
- # ---- Manifested incremental update (optional) ----
163
- def file_sig(path: Path):
164
- h = hashlib.sha1()
165
- try:
166
- with open(path, "rb") as f:
167
- for chunk in iter(lambda: f.read(1 << 20), b""):
168
- h.update(chunk)
169
- except Exception:
170
- return None
171
- stat = path.stat()
172
- return {"sha1": h.hexdigest(), "size": stat.st_size, "mtime": int(stat.st_mtime)}
173
-
174
- def load_manifest():
175
- if MANIFEST.exists():
176
- try:
177
- return json.loads(MANIFEST.read_text())
178
- except Exception:
179
- return {}
180
- return {}
181
-
182
- def save_manifest(m): MANIFEST.write_text(json.dumps(m, indent=2))
183
-
184
- def update_index():
185
- ensure_dirs()
186
- col = get_collection(reset=False)
187
- manifest = load_manifest()
188
- current = {f.relative_to(ROOT_DIR).as_posix(): f for f in DATA_DIR.rglob("*") if f.is_file()}
189
-
190
- # remove deleted
191
- for src in list(manifest.keys()):
192
- if src not in current:
193
- col.delete(where={"source": src})
194
- manifest.pop(src, None)
195
- print(f"[DEL] {src}")
196
-
197
- # add/refresh changed
198
- for src, path in current.items():
199
- sig = file_sig(path)
200
- if sig is None:
201
- continue
202
- if manifest.get(src) == sig:
203
- continue
204
- col.delete(where={"source": src})
205
- added = 0
206
- ext = path.suffix.lower()
207
- if ext == ".pdf":
208
- any_text = False
209
- for page_text, page in load_text_from_pdf(path):
210
- any_text = True
211
- for c in chunk(page_text):
212
- add_batch(col, [str(uuid.uuid4())], [c], [{"source": src, "page": page}])
213
- added += 1
214
- if not any_text:
215
- print(f"[INFO] skip unextractable: {src}")
216
- elif ext in (".md", ".txt"):
217
- text = load_text_from_md_txt(path)
218
- for c in chunk(text):
219
- add_batch(col, [str(uuid.uuid4())], [c], [{"source": src, "page": None}])
220
- added += 1
221
- manifest[src] = sig
222
- print(f"[UPD] {src} (+{added})")
223
- save_manifest(manifest)
224
- print("[UPDATE] done.")
225
-
226
- # ---------------- Search ----------------
227
  def search(query: str, k: int = DEFAULT_TOPK) -> List[Tuple[str, str]]:
 
 
 
228
  try:
229
- col = get_collection(reset=False)
230
  except Exception as e:
231
- print(f"[ERR] Opening collection failed: {e}")
232
  return []
 
233
  try:
234
- qvec = embed_texts([query])[0]
235
- res = col.query(query_embeddings=[qvec], n_results=k, include=["documents", "metadatas"])
236
  except Exception as e:
237
- print(f"[ERR] Query failed: {e}")
238
  return []
 
 
 
239
  hits = []
240
- for doc, meta in zip(res.get("documents", [[]])[0], res.get("metadatas", [[]])[0]):
241
- src = meta.get("source", "unknown")
242
- page = meta.get("page")
243
  cite = f"{src}" + (f":p.{page}" if page else "")
244
- hits.append((doc, cite))
 
245
  return hits
246
 
247
- # ---------------- Ready / Stats ----------------
248
- def index_stats() -> dict:
249
  try:
250
- col = get_collection(reset=False)
251
  return {"count": col.count()}
252
  except Exception as e:
253
- return {"count": 0, "error": str(e)}
254
-
255
- def ensure_ready():
256
- """
257
- Use the prebuilt Chroma index if it exists.
258
- If no index is present but 'sources/' exists, build from sources.
259
- If neither exists and CORPUS_DS is set, pull that dataset and build.
260
- """
261
- ensure_dirs()
262
-
263
- # 1) If index already has data, just use it
264
- has_any_file = any(INDEX_DIR.glob("**/*"))
265
- if has_any_file:
266
- st = index_stats()
267
- print(f"[READY] Using existing index at {INDEX_DIR} — count={st.get('count')}")
268
- return
269
-
270
- # 2) If you prefer to build from sources (optional)
271
- if any(DATA_DIR.glob("**/*")):
272
- print("[READY] No index detected. Building from local 'sources/'.")
273
- build_index()
274
- return
275
-
276
- # 3) Optional: pull a dataset then build
277
- repo_id = os.getenv("CORPUS_DS", "").strip()
278
- if repo_id:
279
- try:
280
- from huggingface_hub import snapshot_download
281
- print(f"[READY] Pulling dataset {repo_id} into {DATA_DIR} …")
282
- snapshot_download(
283
- repo_id=repo_id, repo_type="dataset",
284
- local_dir=DATA_DIR, local_dir_use_symlinks=False,
285
- ignore_patterns=["*.ipynb", ".*", "__pycache__/*"]
286
- )
287
- build_index()
288
- return
289
- except Exception as e:
290
- print(f"[WARN] dataset bootstrap failed: {e}")
291
-
292
- print("[READY] No index found; no sources; no dataset configured. Retrieval will be empty.")
 
1
  # rag_mini.py
2
+ import os, json, textwrap
3
  from pathlib import Path
4
+ from typing import List, Tuple
 
 
 
 
 
 
 
5
 
6
+ # ---------- Paths ----------
7
+ ROOT_DIR = Path(__file__).parent.resolve()
8
+ DATA_ROOT = ROOT_DIR / "MaterialMind" # repo root for app data
9
  DEFAULT_TOPK = 5
 
10
 
11
+ # Allow override from env, else use repo path
12
+ _DEFAULT_INDEX_DIR = DATA_ROOT / "index" / "chroma_v3"
13
+ INDEX_DIR = Path(os.getenv("INDEX_DIR", str(_DEFAULT_INDEX_DIR))).resolve()
14
+
15
+ def _has_catalog(path: Path) -> bool:
16
+ if not path.exists():
17
+ return False
18
+ # sqlite catalog (most common)
19
+ if (path / "chroma.sqlite3").exists():
20
+ return True
21
+ # parquet/duckdb variants (older/newer chroma)
22
+ for n in ["chroma-collections.parquet", "chroma-embeddings.parquet",
23
+ "chroma.sqlite", "duckdb", "collections.parquet"]:
24
+ if (path / n).exists():
25
+ return True
26
+ return False
27
+
28
+ def ensure_ready() -> None:
29
+ """Check the persistent index exists & print a small stat to logs."""
30
  INDEX_DIR.mkdir(parents=True, exist_ok=True)
31
+ if not _has_catalog(INDEX_DIR):
32
+ print(f"[RAG] WARNING: No Chroma catalog found in {INDEX_DIR}")
33
+ print(" Upload your prebuilt DB (e.g., chroma.sqlite3) into that folder.")
34
+ else:
35
+ try:
36
+ stats = index_stats()
37
+ print(f"[RAG] Index ready at {INDEX_DIR} — count={stats.get('count')}")
38
+ except Exception as e:
39
+ print(f"[RAG] Could not read index stats: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # ---------- Retrieval ----------
42
+ def get_collection():
43
  import chromadb
44
+ client = chromadb.PersistentClient(path=str(INDEX_DIR))
45
+ # NOTE: name must match the collection you used when building the index
 
 
 
 
 
46
  return client.get_or_create_collection(name="materialmind")
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def search(query: str, k: int = DEFAULT_TOPK) -> List[Tuple[str, str]]:
49
+ """
50
+ Returns [(snippet_text, 'source_path[:p.PAGE]'), ...]
51
+ """
52
  try:
53
+ col = get_collection()
54
  except Exception as e:
55
+ print(f"[RAG] get_collection() failed: {e}")
56
  return []
57
+
58
  try:
59
+ res = col.query(query_texts=[query], n_results=k, include=["documents", "metadatas"])
 
60
  except Exception as e:
61
+ print(f"[RAG] query failed: {e}")
62
  return []
63
+
64
+ docs = (res.get("documents") or [[]])[0]
65
+ metas = (res.get("metadatas") or [[]])[0]
66
  hits = []
67
+ for d, m in zip(docs, metas):
68
+ src = (m or {}).get("source") or (m or {}).get("path") or "unknown"
69
+ page = (m or {}).get("page")
70
  cite = f"{src}" + (f":p.{page}" if page else "")
71
+ if d:
72
+ hits.append((d, cite))
73
  return hits
74
 
75
+ def index_stats():
 
76
  try:
77
+ col = get_collection()
78
  return {"count": col.count()}
79
  except Exception as e:
80
+ return {"count": 0, "err": str(e)}