internationalscholarsprogram commited on
Commit
5d64d36
·
1 Parent(s): 26e05e4

Update Space: retriever-only RAG API + indexing

Browse files
Files changed (2) hide show
  1. Dockerfile +4 -19
  2. ingest.py +226 -661
Dockerfile CHANGED
@@ -1,9 +1,8 @@
1
  # ----------------------------------------
2
- # Career GPT RAG API - Hugging Face Space (Docker)
3
  # ----------------------------------------
4
  FROM python:3.11-slim-bookworm
5
 
6
- # --- Environment settings ---
7
  ENV PYTHONDONTWRITEBYTECODE=1 \
8
  PYTHONUNBUFFERED=1 \
9
  PIP_NO_CACHE_DIR=1 \
@@ -14,47 +13,33 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
14
  RAG_CORPUS_DIR=/data/corpus \
15
  RAG_DATASET_ID=internationalscholarsprogram/DOC \
16
  RAG_DATASET_REVISION=main \
17
- RAG_PORT=7860 \
18
  PORT=7860 \
19
  TOKENIZERS_PARALLELISM=false \
20
  HF_HUB_DISABLE_TELEMETRY=1 \
21
  CUDA_VISIBLE_DEVICES="" \
22
  OMP_NUM_THREADS=1 \
23
  ORT_LOG_SEVERITY_LEVEL=3 \
24
- ORT_FORCE_CPU=1 \
25
- TRANSFORMERS_VERBOSITY=info
26
 
27
- # --- System dependencies ---
28
  RUN apt-get update && apt-get install -y --no-install-recommends \
29
- tini wget curl ca-certificates tar git \
30
  && rm -rf /var/lib/apt/lists/*
31
 
32
- # --- App user (optional) ---
33
- RUN useradd -m -u 1000 appuser || true
34
-
35
  WORKDIR /app
36
 
37
- # --- Python dependencies ---
38
- COPY requirements.txt .
39
  RUN python -m pip install --upgrade pip setuptools wheel \
40
  && pip install --no-cache-dir -r requirements.txt
41
 
42
- # --- Project files ---
43
  COPY . .
44
 
45
- # --- Persistent / writable directories ---
46
  RUN mkdir -p /tmp/chroma_db /data/.huggingface /data/corpus \
47
  && chmod -R 777 /tmp /data /app
48
 
49
- # Keep root so /data and /tmp stay writable on Spaces
50
-
51
  EXPOSE 7860
52
 
53
- # --- Healthcheck (FastAPI has / or /health; prefer /health) ---
54
  HEALTHCHECK --interval=30s --timeout=5s --start-period=20s \
55
  CMD curl -fsS "http://127.0.0.1:${PORT}/health" || exit 1
56
 
57
  ENTRYPOINT ["/usr/bin/tini", "--"]
58
-
59
- # --- Start server: bind to 0.0.0.0 and $PORT ---
60
  CMD ["bash","-lc","python -m uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860}"]
 
1
  # ----------------------------------------
2
+ # ISP Retrieval (RAG) API - Hugging Face Space (Docker)
3
  # ----------------------------------------
4
  FROM python:3.11-slim-bookworm
5
 
 
6
  ENV PYTHONDONTWRITEBYTECODE=1 \
7
  PYTHONUNBUFFERED=1 \
8
  PIP_NO_CACHE_DIR=1 \
 
13
  RAG_CORPUS_DIR=/data/corpus \
14
  RAG_DATASET_ID=internationalscholarsprogram/DOC \
15
  RAG_DATASET_REVISION=main \
 
16
  PORT=7860 \
17
  TOKENIZERS_PARALLELISM=false \
18
  HF_HUB_DISABLE_TELEMETRY=1 \
19
  CUDA_VISIBLE_DEVICES="" \
20
  OMP_NUM_THREADS=1 \
21
  ORT_LOG_SEVERITY_LEVEL=3 \
22
+ ORT_FORCE_CPU=1
 
23
 
 
24
  RUN apt-get update && apt-get install -y --no-install-recommends \
25
+ tini curl ca-certificates git \
26
  && rm -rf /var/lib/apt/lists/*
27
 
 
 
 
28
  WORKDIR /app
29
 
30
+ COPY requirements.txt ./
 
31
  RUN python -m pip install --upgrade pip setuptools wheel \
32
  && pip install --no-cache-dir -r requirements.txt
33
 
 
34
  COPY . .
35
 
 
36
  RUN mkdir -p /tmp/chroma_db /data/.huggingface /data/corpus \
37
  && chmod -R 777 /tmp /data /app
38
 
 
 
39
  EXPOSE 7860
40
 
 
41
  HEALTHCHECK --interval=30s --timeout=5s --start-period=20s \
42
  CMD curl -fsS "http://127.0.0.1:${PORT}/health" || exit 1
43
 
44
  ENTRYPOINT ["/usr/bin/tini", "--"]
 
 
45
  CMD ["bash","-lc","python -m uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860}"]
ingest.py CHANGED
@@ -1,701 +1,266 @@
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
 
4
- """
5
- Robust RAG ingestion for Chroma + Embeddings (BGE / FastEmbed / HF Local / HF Inference / Ollama)
6
-
7
- Supported file types: PDF, HTML, DOCX, TXT/MD, CSV (+ watch mode)
8
-
9
- Features
10
- - Incremental ingest with stable chunk IDs (sha1(file|page|content)) -> no dup chunks
11
- - Rich metadata for citations: source, source_path, page, mtime
12
- - Clean/normalize text (page number + whitespace heuristics)
13
- - Config via CLI flags and env vars
14
- - Rebuild & dry-run modes, detailed logs
15
- - Optional watch mode (polling) for auto-reindex on file changes
16
- - Embeddings providers:
17
- * bge -> HuggingFaceBgeEmbeddings (no sklearn/scipy)
18
- * fastembed -> FastEmbedEmbeddings (tiny, fast)
19
- * hf_local -> sentence-transformers (may pull sklearn/scipy)
20
- * hf_inference -> Hugging Face Inference API (token required)
21
- * ollama -> OllamaEmbeddings
22
-
23
- BGE best practices:
24
- - L2 normalization (cosine space)
25
- - Prefix: "passage: " for docs, "query: " for queries (toggle with --no-bge-prefix)
26
- """
27
-
28
- from __future__ import annotations
29
-
30
- import argparse
31
- import hashlib
32
- import logging
33
- import os
34
- import re
35
- import signal
36
- import sys
37
- import time
38
- from pathlib import Path
39
- from typing import Any, Dict, Iterable, List, Optional
40
-
41
- import numpy as np
42
- from tqdm import tqdm
43
- from unidecode import unidecode
44
-
45
- # -------------------- Vector store --------------------
46
- try:
47
- from langchain_chroma import Chroma
48
- except ImportError:
49
- from langchain_community.vectorstores import Chroma # fallback
50
-
51
- try:
52
- # Newer splitters live here
53
- from langchain_text_splitters import RecursiveCharacterTextSplitter
54
- except ImportError:
55
- from langchain.text_splitter import RecursiveCharacterTextSplitter # fallback
56
-
57
- try:
58
- from langchain_core.documents import Document
59
- except ImportError:
60
- from langchain_community.docstore.document import Document # fallback
61
-
62
- try:
63
- from langchain_core.embeddings import Embeddings
64
- except ImportError:
65
- from langchain.embeddings.base import Embeddings # fallback
66
-
67
- from chromadb.config import Settings as ChromaSettings
68
-
69
- # Loaders
70
- from langchain_community.document_loaders import (
71
- PyMuPDFLoader, # PDF
72
- BSHTMLLoader, # HTML (BeautifulSoup)
73
- Docx2txtLoader, # DOCX
74
- TextLoader, # TXT/MD
75
- CSVLoader, # CSV
76
- )
77
 
78
- # Embedding impls (lazy-imported in builder too)
79
- from langchain_community.embeddings import (
80
- HuggingFaceBgeEmbeddings,
81
- FastEmbedEmbeddings,
82
- )
83
 
84
- # -------------------- Defaults (overridable via env) --------------------
85
- ENV = os.getenv
86
- DEFAULT_DOCS_DIR = ENV("RAG_DOCS_DIR", "docs")
87
- DEFAULT_DB_DIR = ENV("RAG_DB_DIR", "/data/chroma_db") # Spaces persistent storage by default
88
 
89
- # Provider: "bge" | "fastembed" | "hf_local" | "hf_inference" | "ollama"
90
- DEFAULT_EMBED_PROVIDER = ENV("RAG_EMBED_PROVIDER", "bge").lower()
91
- DEFAULT_EMBED_MODEL = ENV("RAG_EMBED_MODEL", "BAAI/bge-small-en-v1.5")
92
 
93
- # Device for local providers
94
- DEFAULT_DEVICE = ENV("RAG_DEVICE", "cuda" if os.getenv("CUDA_VISIBLE_DEVICES") else "cpu")
 
 
95
 
96
- # HF token for hf_inference (or for gated/private models if needed)
97
- DEFAULT_HF_TOKEN = ENV("HUGGINGFACEHUB_API_TOKEN", ENV("HF_TOKEN", ""))
98
 
99
- DEFAULT_USE_PREFIX = ENV("RAG_BGE_PREFIX", "1") not in ("0", "false", "False")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- DEFAULT_CHUNK_SIZE = int(ENV("RAG_CHUNK_SIZE", "900"))
102
- DEFAULT_CHUNK_OVERLAP = int(ENV("RAG_CHUNK_OVERLAP", "180"))
103
- DEFAULT_MIN_CHARS = int(ENV("RAG_MIN_CHARS", "200")) # drop tiny chunks
104
 
105
- DEFAULT_WATCH_INTERVAL = int(ENV("RAG_WATCH_INTERVAL", "5")) # seconds
106
- DEFAULT_BATCH_SIZE = int(ENV("RAG_EMBED_BATCH", "32")) # embedding batch size (hf_inference wrapper)
107
 
108
- # -------------------- Logging --------------------
109
- logging.basicConfig(
110
- level=logging.INFO,
111
- format="%(asctime)s | %(levelname)s | %(message)s",
112
- )
113
- log = logging.getLogger("rag_ingest")
114
- for _noisy in ["httpx", "chromadb", "langchain", "asyncio"]:
115
- logging.getLogger(_noisy).setLevel(logging.ERROR)
116
-
117
- # -------------------- Helpers --------------------
118
- def sha1(text: str) -> str:
119
- return hashlib.sha1(text.encode("utf-8")).hexdigest()
120
-
121
- def normalize_text(txt: str) -> str:
122
- """
123
- Normalize/clean text for better retrieval.
124
- - ascii transliteration to reduce unicode noise
125
- - strip trailing spaces
126
- - drop lines that are just numbers (page numbers)
127
- - collapse excessive blank lines/spaces
128
- """
129
- if not txt:
130
  return ""
131
- txt = unidecode(txt)
132
- lines = [re.sub(r"\s+$", "", line) for line in txt.splitlines()]
133
- lines = [l for l in lines if not re.fullmatch(r"\d{1,4}", l) and l.strip()]
134
- txt = "\n".join(lines)
135
- txt = re.sub(r"\n{3,}", "\n\n", txt)
136
- txt = re.sub(r"[ \t]{2,}", " ", txt)
137
- return txt.strip()
138
-
139
- def assign_common_metadata(doc: Document, path: Path, page: Optional[int] = None) -> None:
140
- doc.metadata = dict(doc.metadata or {})
141
- doc.metadata["source"] = path.name
142
- doc.metadata["source_path"] = str(path.resolve())
143
- if page is not None and doc.metadata.get("page") is None:
144
- doc.metadata["page"] = page
145
- try:
146
- doc.metadata["mtime"] = int(path.stat().st_mtime)
147
- except OSError:
148
- doc.metadata["mtime"] = 0
149
-
150
- def load_pdf(path: Path) -> List[Document]:
151
- loader = PyMuPDFLoader(str(path))
152
- docs = loader.load()
153
- out: List[Document] = []
154
- for d in docs:
155
- d.page_content = normalize_text(d.page_content)
156
- assign_common_metadata(d, path, d.metadata.get("page"))
157
- if d.page_content:
158
- out.append(d)
159
  return out
160
 
161
- def load_html(path: Path) -> List[Document]:
162
- loader = BSHTMLLoader(str(path))
163
- docs = loader.load()
164
- out: List[Document] = []
165
- for d in docs:
166
- d.page_content = normalize_text(d.page_content)
167
- assign_common_metadata(d, path, None)
168
- if d.page_content:
169
- out.append(d)
170
- return out
 
 
 
 
 
 
 
 
 
 
171
 
172
- def load_docx(path: Path) -> List[Document]:
173
- loader = Docx2txtLoader(str(path))
174
- docs = loader.load()
175
- out: List[Document] = []
176
- for d in docs:
177
- d.page_content = normalize_text(d.page_content)
178
- assign_common_metadata(d, path, None)
179
- if d.page_content:
180
- out.append(d)
181
- return out
182
 
183
- def load_text_like(path: Path) -> List[Document]:
184
- loader = TextLoader(str(path), autodetect_encoding=True)
185
- docs = loader.load()
186
- out: List[Document] = []
187
- for d in docs:
188
- d.page_content = normalize_text(d.page_content)
189
- assign_common_metadata(d, path, None)
190
- if d.page_content:
191
- out.append(d)
192
- return out
193
 
194
- def load_csv(path: Path) -> List[Document]:
195
- """Load CSV as one Document per row, including header mapping in content."""
196
- loader = CSVLoader(str(path))
197
- docs = loader.load()
198
- out: List[Document] = []
199
- for d in docs:
200
- d.page_content = normalize_text(d.page_content)
201
- assign_common_metadata(d, path, None)
202
- if d.page_content:
203
- out.append(d)
204
- return out
205
 
206
- SUPPORTED_SUFFIXES = {".pdf", ".html", ".htm", ".docx", ".txt", ".md", ".markdown", ".csv"}
207
-
208
- def discover_files(docs_dir: Path) -> List[Path]:
209
- files: List[Path] = []
210
- for p in docs_dir.rglob("*"):
211
- if p.is_file() and p.suffix.lower() in SUPPORTED_SUFFIXES:
212
- files.append(p)
213
- return files
214
-
215
- def chunk_documents(
216
- raw_docs: List[Document],
217
- chunk_size: int,
218
- chunk_overlap: int,
219
- min_chars: int,
220
- ) -> List[Document]:
221
- splitter = RecursiveCharacterTextSplitter(
222
- chunk_size=chunk_size,
223
- chunk_overlap=chunk_overlap,
224
- separators=["\n\n", "\n", " ", ""],
225
- )
226
- chunks = splitter.split_documents(raw_docs)
227
- return [c for c in chunks if len(c.page_content.strip()) >= min_chars]
228
-
229
- def make_chunk_id(doc: Document) -> str:
230
- src = doc.metadata.get("source_path", doc.metadata.get("source", "unknown"))
231
- page = str(doc.metadata.get("page"))
232
- basis = f"{src}|{page}|{doc.page_content}"
233
- return sha1(basis)
234
-
235
- def ensure_dirs(path: Path) -> None:
236
- path.mkdir(parents=True, exist_ok=True)
237
-
238
- def batched(iterable: Iterable[Any], n: int) -> Iterable[List[Any]]:
239
- batch: List[Any] = []
240
- for item in iterable:
241
- batch.append(item)
242
- if len(batch) >= n:
243
- yield batch
244
- batch = []
245
- if batch:
246
- yield batch
247
-
248
- # -------------------- Embedding Adapters --------------------
249
- class BGEAdapter(Embeddings):
250
- """Wraps any LangChain Embeddings and applies BGE prefixes."""
251
- def __init__(self, base: Embeddings, use_prefixes: bool = True):
252
- self.base = base
253
- self.use_prefixes = use_prefixes
254
-
255
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
256
- if self.use_prefixes:
257
- texts = [f"passage: {t}" for t in texts]
258
- return self.base.embed_documents(texts)
259
-
260
- def embed_query(self, text: str) -> List[float]:
261
- if self.use_prefixes:
262
- text = f"query: {text}"
263
- return self.base.embed_query(text)
264
-
265
- class HFInferenceEmbeddings(Embeddings):
266
- """
267
- Minimal embeddings wrapper using Hugging Face Inference API feature-extraction.
268
- - Mean-pools token embeddings
269
- - L2-normalizes vectors
270
- """
271
- def __init__(
272
- self,
273
- model: str,
274
- token: str,
275
- timeout: float = 60.0,
276
- max_retries: int = 5,
277
- batch_size: int = 32,
278
- ):
279
- from huggingface_hub import InferenceClient # lazy import
280
- if not token:
281
- raise ValueError("HF Inference API requires a token. Set HUGGINGFACEHUB_API_TOKEN or --hf-token.")
282
- self.client = InferenceClient(token=token, timeout=timeout)
283
- self.model = model
284
- self.max_retries = max_retries
285
- self.batch_size = max(1, batch_size)
286
-
287
- @staticmethod
288
- def _mean_pool(mat: List[List[float]]) -> List[float]:
289
- arr = np.asarray(mat, dtype=np.float32)
290
- v = arr.mean(axis=0)
291
- norm = np.linalg.norm(v) + 1e-12
292
- return (v / norm).tolist()
293
-
294
- def _fe(self, text: str) -> List[float]:
295
- for i in range(self.max_retries):
296
- try:
297
- mat = self.client.feature_extraction(model=self.model, inputs=text)
298
- return self._mean_pool(mat)
299
- except Exception as e:
300
- if i == self.max_retries - 1:
301
- raise
302
- sleep_s = max(0.5, 2 ** i * 0.5)
303
- log.warning(f"HF Inference backoff ({i+1}/{self.max_retries}): {e}. Sleeping {sleep_s:.1f}s")
304
- time.sleep(sleep_s)
305
- return []
306
-
307
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
308
- out: List[List[float]] = []
309
- for batch in batched(texts, self.batch_size):
310
- for t in batch:
311
- out.append(self._fe(t))
312
- return out
313
-
314
- def embed_query(self, text: str) -> List[float]:
315
- return self._fe(text)
316
-
317
- def build_embeddings(
318
- provider: str,
319
- model: str,
320
- device: str,
321
- use_prefixes: bool,
322
- hf_token: str,
323
- batch_size: int,
324
- ) -> Embeddings:
325
- provider = (provider or "").lower()
326
-
327
- if provider in ("bge", "hf_bge", "bge_small"):
328
- base = HuggingFaceBgeEmbeddings(
329
- model_name=model,
330
- model_kwargs={"device": device},
331
- encode_kwargs={"normalize_embeddings": True},
332
- )
333
- log.info(f"Embedding provider: BGE ({model}) on {device}")
334
- return BGEAdapter(base, use_prefixes=use_prefixes)
335
-
336
- if provider in ("fastembed", "fe"):
337
- log.info("Embedding provider: FastEmbed")
338
- return FastEmbedEmbeddings()
339
-
340
- if provider == "hf_inference":
341
- base = HFInferenceEmbeddings(model=model, token=hf_token, batch_size=batch_size)
342
- log.info("Embedding provider: HF Inference API")
343
- return BGEAdapter(base, use_prefixes=use_prefixes)
344
-
345
- if provider == "ollama":
346
- from langchain_ollama import OllamaEmbeddings # lazy import
347
- base = OllamaEmbeddings(model=model)
348
- log.info("Embedding provider: Ollama")
349
- return BGEAdapter(base, use_prefixes=use_prefixes)
350
-
351
- # hf_local (sentence-transformers)
352
- from langchain_community.embeddings import HuggingFaceEmbeddings # lazy import
353
- base = HuggingFaceEmbeddings(
354
- model_name=model,
355
- model_kwargs={"device": device},
356
- encode_kwargs={"normalize_embeddings": True},
357
- )
358
- log.info(f"Embedding provider: HF local (sentence-transformers) on {device}")
359
- # Use prefixes automatically if model name looks like BGE
360
- return BGEAdapter(base, use_prefixes=("bge" in model.lower() and use_prefixes))
361
-
362
- # -------------------- Ingest Core --------------------
363
- def _wipe_dir(path: Path) -> None:
364
- if not path.exists():
365
- return
366
- for p in sorted(path.glob("**/*"), reverse=True):
367
- try:
368
- if p.is_file():
369
- p.unlink()
370
- elif p.is_dir():
371
- p.rmdir()
372
- except Exception as e:
373
- log.debug(f"Skipping removal for {p}: {e}")
374
-
375
- def _build_vectordb(db_dir: Path, embeddings: Embeddings) -> Chroma:
376
- client_settings = ChromaSettings(
377
- is_persistent=True,
378
- persist_directory=str(db_dir),
379
- anonymized_telemetry=False,
380
- )
381
- return Chroma(
382
- persist_directory=str(db_dir),
383
- embedding_function=embeddings,
384
- collection_metadata={"hnsw:space": "cosine"},
385
- client_settings=client_settings,
386
- )
387
 
388
- def _load_docs_for_paths(files: List[Path]) -> List[Document]:
389
- loaders = {
390
- ".pdf": load_pdf,
391
- ".html": load_html,
392
- ".htm": load_html,
393
- ".docx": load_docx,
394
- ".txt": load_text_like,
395
- ".md": load_text_like,
396
- ".markdown": load_text_like,
397
- ".csv": load_csv,
398
- }
399
- raw_docs: List[Document] = []
400
- for path in tqdm(files, desc="Loading files", unit="file"):
401
  try:
402
- fn = loaders.get(path.suffix.lower())
403
- if fn:
404
- raw_docs.extend(fn(path))
405
- except KeyboardInterrupt:
406
- raise
407
- except Exception as e:
408
- log.error(f"Failed to load {path}: {e}")
409
- return raw_docs
410
-
411
- def ingest_once(
412
- docs_dir: Path,
413
- db_dir: Path,
414
- embed_provider: str,
415
- embed_model: str,
416
- device: str,
417
- use_prefixes: bool,
418
- hf_token: str,
419
- batch_size: int,
420
- rebuild: bool = False,
421
- dry_run: bool = False,
422
- chunk_size: int = DEFAULT_CHUNK_SIZE,
423
- chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
424
- min_chars: int = DEFAULT_MIN_CHARS,
425
- ) -> Dict[str, Any]:
426
-
427
- ensure_dirs(docs_dir)
428
- ensure_dirs(db_dir)
429
-
430
- if rebuild and not dry_run:
431
- _wipe_dir(db_dir)
432
- ensure_dirs(db_dir)
433
- log.warning("Rebuild mode: existing DB wiped.")
434
-
435
- log.info(f"Using embeddings: model={embed_model} provider={embed_provider}")
436
- embeddings = build_embeddings(
437
- provider=embed_provider,
438
- model=embed_model,
439
- device=device,
440
- use_prefixes=use_prefixes,
441
- hf_token=hf_token,
442
- batch_size=batch_size,
443
- )
444
 
445
- vectordb = _build_vectordb(db_dir, embeddings)
 
 
446
 
447
- files = discover_files(docs_dir)
448
- if not files:
449
- log.warning(f"No supported files found in {docs_dir.resolve()}")
450
- return {"added": 0, "skipped": 0, "total_chunks": 0, "files": 0}
451
 
452
- raw_docs = _load_docs_for_paths(files)
453
- if not raw_docs:
454
- log.warning("No documents loaded after parsing.")
455
- return {"added": 0, "skipped": 0, "total_chunks": 0, "files": len(files)}
456
 
457
- chunks = chunk_documents(raw_docs, chunk_size, chunk_overlap, min_chars)
458
- if not chunks:
459
- log.warning("No chunks produced (check chunking params / min_chars).")
460
- return {"added": 0, "skipped": 0, "total_chunks": 0, "files": len(files)}
461
 
462
- ids = [make_chunk_id(c) for c in chunks]
463
 
464
- # Find existing ids (batched), use underlying collection to be explicit/robust
465
- existing: set[str] = set()
466
- for batch in batched(ids, 500):
467
- try:
468
- # Prefer the underlying Chroma collection to avoid wrapper differences
469
- res = vectordb._collection.get(ids=batch) # type: ignore[attr-defined]
470
- if res and res.get("ids"):
471
- existing.update(res["ids"])
472
- except Exception:
473
- # If collection or ids don't exist, just continue
474
- pass
475
-
476
- to_add_docs: List[Document] = []
477
- to_add_ids: List[str] = []
478
- skipped = 0
479
-
480
- for doc, _id in zip(chunks, ids):
481
- if _id in existing:
482
- skipped += 1
483
- continue
484
- to_add_docs.append(doc)
485
- to_add_ids.append(_id)
486
-
487
- log.info(f"Total chunks: {len(chunks)} | To add: {len(to_add_docs)} | Skipped (dups): {skipped}")
488
-
489
- if dry_run:
490
- log.info("Dry-run mode: not writing to DB.")
491
- return {
492
- "added": len(to_add_docs),
493
- "skipped": skipped,
494
- "total_chunks": len(chunks),
495
- "files": len(files),
496
- "dry_run": True,
497
- }
498
 
499
- added = 0
500
- for batch_docs, batch_ids in zip(batched(to_add_docs, 256), batched(to_add_ids, 256)):
501
- try:
502
- # Correct instance call (not Chroma.add_documents(...))
503
- vectordb.add_documents(documents=batch_docs, ids=batch_ids)
504
- added += len(batch_docs)
505
- except Exception as e:
506
- log.error(f"Error adding batch ({len(batch_docs)} docs): {e}")
507
 
 
 
508
  try:
509
- vectordb.persist()
 
 
 
510
  except Exception as e:
511
- log.error(f"Persist error: {e}")
512
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  return {
514
- "added": added,
515
- "skipped": skipped,
516
- "total_chunks": len(chunks),
517
- "files": len(files),
518
- "db_dir": str(db_dir.resolve()),
519
- "embed_model": embed_model,
520
- "embed_provider": embed_provider,
 
 
 
521
  }
522
 
523
- # -------------------- Watch Mode (polling) --------------------
524
- def build_mtime_index(docs_dir: Path) -> Dict[str, float]:
525
- idx: Dict[str, float] = {}
526
- for f in discover_files(docs_dir):
527
- try:
528
- idx[str(f.resolve())] = f.stat().st_mtime
529
- except Exception:
530
- pass
531
- return idx
532
-
533
- def watch_and_ingest(
534
- docs_dir: Path,
535
- db_dir: Path,
536
- embed_provider: str,
537
- embed_model: str,
538
- device: str,
539
- use_prefixes: bool,
540
- hf_token: str,
541
- batch_size: int,
542
- interval: int,
543
- chunk_size: int,
544
- chunk_overlap: int,
545
- min_chars: int,
546
- ) -> None:
547
- log.info(f"Watching {docs_dir.resolve()} every {interval}s for changes...")
548
- baseline = build_mtime_index(docs_dir)
549
- while True:
550
- time.sleep(interval)
551
- curr = build_mtime_index(docs_dir)
552
-
553
- added_paths = [p for p in curr.keys() if p not in baseline]
554
- changed_paths = [p for p, mt in curr.items() if p in baseline and mt > baseline[p]]
555
- removed_paths = [p for p in baseline.keys() if p not in curr]
556
-
557
- if not (added_paths or changed_paths or removed_paths):
558
- continue
559
-
560
- if removed_paths:
561
- log.warning(f"{len(removed_paths)} files removed since last scan (not deleting existing vectors).")
562
-
563
- if added_paths or changed_paths:
564
- log.info(f"Detected {len(added_paths)} new and {len(changed_paths)} modified files. Re-ingesting incrementally...")
565
- summary = ingest_once(
566
- docs_dir=Path(docs_dir),
567
- db_dir=Path(db_dir),
568
- embed_provider=embed_provider,
569
- embed_model=embed_model,
570
- device=device,
571
- use_prefixes=use_prefixes,
572
- hf_token=hf_token,
573
- batch_size=batch_size,
574
- rebuild=False,
575
- dry_run=False,
576
- chunk_size=chunk_size,
577
- chunk_overlap=chunk_overlap,
578
- min_chars=min_chars,
579
- )
580
- log.info(f"Watch ingest summary: {summary}")
581
-
582
- baseline = curr
583
-
584
- # -------------------- CLI --------------------
585
- def parse_args() -> argparse.Namespace:
586
- p = argparse.ArgumentParser(description="Ingest documents (PDF/HTML/DOCX/TXT/MD/CSV) into Chroma for RAG.")
587
- # I/O
588
- p.add_argument("--docs", default=DEFAULT_DOCS_DIR, help=f"Docs directory (default: {DEFAULT_DOCS_DIR})")
589
- p.add_argument("--db", default=DEFAULT_DB_DIR, help=f"Chroma DB directory (default: {DEFAULT_DB_DIR})")
590
-
591
- # Embeddings
592
- p.add_argument("--embed-provider", default=DEFAULT_EMBED_PROVIDER,
593
- choices=["bge", "fastembed", "hf_local", "hf_inference", "ollama"],
594
- help=f"Embedding provider (default: {DEFAULT_EMBED_PROVIDER})")
595
- p.add_argument("--embed-model", default=DEFAULT_EMBED_MODEL,
596
- help=f"Embedding model name (default: {DEFAULT_EMBED_MODEL})")
597
- p.add_argument("--device", default=DEFAULT_DEVICE, help=f"'cpu' or 'cuda' (local providers, default: {DEFAULT_DEVICE})")
598
- p.add_argument("--hf-token", default=DEFAULT_HF_TOKEN, help="HF token (hf_inference or gated/private models)")
599
- p.add_argument("--no-bge-prefix", action="store_true", help="Disable 'passage:/query:' prefixes for embeddings")
600
- p.add_argument("--embed-batch", type=int, default=DEFAULT_BATCH_SIZE, help=f"Embedding batch size (hf_inference): default {DEFAULT_BATCH_SIZE}")
601
-
602
- # Ingest
603
- p.add_argument("--rebuild", action="store_true", help="Wipe and rebuild the DB")
604
- p.add_argument("--dry-run", action="store_true", help="Do everything except write to DB")
605
- p.add_argument("--chunk-size", type=int, default=DEFAULT_CHUNK_SIZE, help=f"Chunk size (default: {DEFAULT_CHUNK_SIZE})")
606
- p.add_argument("--chunk-overlap", type=int, default=DEFAULT_CHUNK_OVERLAP, help=f"Chunk overlap (default: {DEFAULT_CHUNK_OVERLAP})")
607
- p.add_argument("--min-chars", type=int, default=DEFAULT_MIN_CHARS, help=f"Drop chunks shorter than this (default: {DEFAULT_MIN_CHARS})")
608
-
609
- # Watch
610
- p.add_argument("--watch", action="store_true", help="Watch for file changes and ingest incrementally (polling)")
611
- p.add_argument("--interval", type=int, default=DEFAULT_WATCH_INTERVAL, help=f"Watch poll interval seconds (default: {DEFAULT_WATCH_INTERVAL})")
612
- return p.parse_args()
613
-
614
- def _install_signal_handlers() -> None:
615
- def _handler(signum, _frame):
616
- names = {signal.SIGINT: "SIGINT", signal.SIGTERM: "SIGTERM"}
617
- log.warning(f"Received {names.get(signum, signum)}. Exiting gracefully...")
618
- sys.exit(130 if signum == signal.SIGINT else 143)
619
-
620
  try:
621
- signal.signal(signal.SIGINT, _handler)
622
- signal.signal(signal.SIGTERM, _handler)
623
- except Exception:
624
- # Not all environments allow signal hooks (e.g., Windows threads)
625
- pass
626
-
627
- def main() -> None:
628
- _install_signal_handlers()
629
- args = parse_args()
630
- docs_dir = Path(args.docs)
631
- db_dir = Path(args.db)
632
-
633
- use_prefixes = not args.no_bge_prefix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
 
635
- log.info(f"Docs dir: {docs_dir.resolve()}")
636
- log.info(f"DB dir: {db_dir.resolve()}")
637
- log.info(f"Embed: provider={args.embed_provider} model={args.embed_model}")
 
 
 
 
 
 
 
 
 
638
 
639
  try:
640
- if args.watch:
641
- # Prime the DB once before watching
642
- summary = ingest_once(
643
- docs_dir=docs_dir,
644
- db_dir=db_dir,
645
- embed_provider=args.embed_provider,
646
- embed_model=args.embed_model,
647
- device=args.device,
648
- use_prefixes=use_prefixes,
649
- hf_token=args.hf_token,
650
- batch_size=args.embed_batch,
651
- rebuild=args.rebuild,
652
- dry_run=args.dry_run,
653
- chunk_size=args.chunk_size,
654
- chunk_overlap=args.chunk_overlap,
655
- min_chars=args.min_chars,
656
- )
657
- log.info(f"Initial ingest summary: {summary}")
658
- if args.dry_run:
659
- log.info("Dry-run set; skipping watch loop.")
660
- return
661
- watch_and_ingest(
662
- docs_dir=docs_dir,
663
- db_dir=db_dir,
664
- embed_provider=args.embed_provider,
665
- embed_model=args.embed_model,
666
- device=args.device,
667
- use_prefixes=use_prefixes,
668
- hf_token=args.hf_token,
669
- batch_size=args.embed_batch,
670
- interval=args.interval,
671
- chunk_size=args.chunk_size,
672
- chunk_overlap=args.chunk_overlap,
673
- min_chars=args.min_chars,
674
- )
675
- else:
676
- summary = ingest_once(
677
- docs_dir=docs_dir,
678
- db_dir=db_dir,
679
- embed_provider=args.embed_provider,
680
- embed_model=args.embed_model,
681
- device=args.device,
682
- use_prefixes=use_prefixes,
683
- hf_token=args.hf_token,
684
- batch_size=args.embed_batch,
685
- rebuild=args.rebuild,
686
- dry_run=args.dry_run,
687
- chunk_size=args.chunk_size,
688
- chunk_overlap=args.chunk_overlap,
689
- min_chars=args.min_chars,
690
- )
691
- log.info(f"Ingest summary: {summary}")
692
- except KeyboardInterrupt:
693
- log.warning("Interrupted.")
694
- sys.exit(130)
695
- except Exception:
696
- log.exception("Ingest failed")
697
- sys.exit(1)
698
-
699
 
 
700
  if __name__ == "__main__":
701
- main()
 
 
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
 
4
+ import os, threading, logging, warnings, json, re
5
+ from typing import Optional, List, Dict, Any
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ from fastapi import FastAPI, HTTPException, Header
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel
 
 
10
 
11
+ from huggingface_hub import snapshot_download, HfApi
12
+ from langchain_chroma import Chroma
13
+ from chromadb import PersistentClient
 
14
 
15
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
16
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ from langchain_community.document_loaders import PyPDFLoader
18
 
19
+ # --------------------- Setup & Logging ---------------------
20
+ warnings.filterwarnings("ignore")
21
+ os.environ["ORT_LOG_SEVERITY_LEVEL"] = "3"
22
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
23
 
24
+ log = logging.getLogger("rag_api")
25
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
26
 
27
+ # --------------------- Env ---------------------
28
+ ENV = os.getenv
29
+ DB_DIR = ENV("RAG_DB_DIR", "/tmp/chroma_db")
30
+ COLLECTION_NAME = ENV("RAG_COLLECTION", "isp_rag")
31
+ DATASET_ID = ENV("RAG_DATASET_ID", "internationalscholarsprogram/DOC")
32
+ DATA_REV = ENV("RAG_DATASET_REVISION", "main")
33
+ CORPUS_DIR = ENV("RAG_CORPUS_DIR", "/data/corpus")
34
+ STATE_FILE = "/data/.state.json"
35
+
36
+ PORT = int(ENV("PORT", "7860"))
37
+ HOST = "0.0.0.0"
38
+
39
+ # Optional: protect reindex endpoint (set in HF Space secrets)
40
+ ADMIN_REINDEX_TOKEN = ENV("ADMIN_REINDEX_TOKEN", "").strip()
41
+
42
+ # --------------------- Embeddings + Vector DB ---------------------
43
+ # BGE recommended settings
44
+ embeddings = HuggingFaceBgeEmbeddings(
45
+ model_name="BAAI/bge-small-en-v1.5",
46
+ encode_kwargs={"normalize_embeddings": True},
47
+ )
48
 
49
+ os.makedirs(DB_DIR, exist_ok=True)
50
+ client = PersistentClient(path=DB_DIR)
51
+ vectordb = Chroma(collection_name=COLLECTION_NAME, embedding_function=embeddings, client=client)
52
 
53
+ def build_retriever(k: int = 4):
54
+ return vectordb.as_retriever(search_type="mmr", search_kwargs={"k": k})
55
 
56
+ # --------------------- Text cleanup ---------------------
57
+ CONTROL_CHARS_RE = re.compile(r"[\x00-\x08\x0B\x0C\x0E-\x1F]")
58
+ def clean_text(s: str) -> str:
59
+ if not s:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  return ""
61
+ s = s.replace("\r", "")
62
+ s = CONTROL_CHARS_RE.sub(" ", s)
63
+ s = re.sub(r"[ \t]+$", "", s, flags=re.M)
64
+ s = re.sub(r"\n{3,}", "\n\n", s)
65
+ return s.strip()
66
+
67
+ # --------------------- Dataset sync + indexing ---------------------
68
+ def sync_pdfs() -> str:
69
+ os.makedirs(CORPUS_DIR, exist_ok=True)
70
+ snapshot_download(
71
+ repo_id=DATASET_ID,
72
+ repo_type="dataset",
73
+ revision=DATA_REV,
74
+ local_dir=CORPUS_DIR,
75
+ local_dir_use_symlinks=False
76
+ )
77
+ info = HfApi().repo_info(repo_id=DATASET_ID, repo_type="dataset", revision=DATA_REV)
78
+ return info.sha
79
+
80
+ def list_pdfs(root: str) -> List[str]:
81
+ out = []
82
+ for r, _, fs in os.walk(root):
83
+ for f in fs:
84
+ if f.lower().endswith(".pdf"):
85
+ out.append(os.path.join(r, f))
 
 
 
86
  return out
87
 
88
+ def load_docs(pdf_paths: List[str]):
89
+ docs = []
90
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=200)
91
+ for p in pdf_paths:
92
+ for pg in PyPDFLoader(p).load():
93
+ pg.page_content = clean_text(pg.page_content)
94
+ # ensure useful metadata for citations
95
+ pg.metadata = dict(pg.metadata or {})
96
+ pg.metadata["source_path"] = p
97
+ pg.metadata["source"] = os.path.basename(p)
98
+ # pg.metadata["page"] typically exists from loader
99
+ docs += splitter.split_documents([pg])
100
+ return docs
101
+
102
+ def rebuild_index(docs):
103
+ # delete existing collection
104
+ try:
105
+ client.delete_collection(COLLECTION_NAME)
106
+ except Exception:
107
+ pass
108
 
109
+ new_client = PersistentClient(path=DB_DIR)
110
+ new_db = Chroma(collection_name=COLLECTION_NAME, embedding_function=embeddings, client=new_client)
 
 
 
 
 
 
 
 
111
 
112
+ for i in range(0, len(docs), 32):
113
+ new_db.add_documents(docs[i:i+32])
 
 
 
 
 
 
 
 
114
 
115
+ return new_db
 
 
 
 
 
 
 
 
 
 
116
 
117
+ def reindex(force: bool = False) -> Dict[str, Any]:
118
+ os.makedirs(CORPUS_DIR, exist_ok=True)
119
+ new_sha = sync_pdfs()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ old_sha = None
122
+ if os.path.exists(STATE_FILE):
 
 
 
 
 
 
 
 
 
 
 
123
  try:
124
+ old_sha = json.load(open(STATE_FILE, "r"))["dataset_sha"]
125
+ except Exception:
126
+ old_sha = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ if force or new_sha != old_sha:
129
+ pdfs = list_pdfs(CORPUS_DIR)
130
+ docs = load_docs(pdfs)
131
 
132
+ global vectordb
133
+ vectordb = rebuild_index(docs)
 
 
134
 
135
+ os.makedirs(os.path.dirname(STATE_FILE), exist_ok=True)
136
+ json.dump({"dataset_sha": new_sha}, open(STATE_FILE, "w"))
 
 
137
 
138
+ return {"reindexed": True, "commit": new_sha, "chunks": len(docs), "pdfs": len(pdfs)}
 
 
 
139
 
140
+ return {"reindexed": False, "commit": new_sha}
141
 
142
+ # --------------------- FastAPI app ---------------------
143
+ app = FastAPI(title="ISP Retriever API (RAG only)", version="2.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ app.add_middleware(
146
+ CORSMiddleware,
147
+ allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
148
+ )
149
+
150
+ INDEX_STATUS = {"state": "idle", "detail": "", "last_commit": None}
 
 
151
 
152
+ def warmup():
153
+ global INDEX_STATUS
154
  try:
155
+ INDEX_STATUS.update({"state": "syncing", "detail": "starting"})
156
+ info = reindex(force=False)
157
+ INDEX_STATUS.update({"state": "ready", "detail": str(info), "last_commit": info.get("commit")})
158
+ log.info(f"Index ready {info}")
159
  except Exception as e:
160
+ INDEX_STATUS.update({"state": "error", "detail": str(e)})
161
+ log.exception("Index warmup failed")
162
+
163
+ @app.on_event("startup")
164
+ def _startup():
165
+ threading.Thread(target=warmup, daemon=True).start()
166
+
167
+ # --------------------- Schemas ---------------------
168
+ class AskIn(BaseModel):
169
+ question: Optional[str] = None
170
+ query: Optional[str] = None
171
+ k: Optional[int] = 4
172
+
173
+ @property
174
+ def text(self) -> str:
175
+ t = (self.question or self.query or "").strip()
176
+ t = clean_text(t)
177
+ if not t:
178
+ raise ValueError("Provide 'question' or 'query'.")
179
+ return t
180
+
181
+ class SourceOut(BaseModel):
182
+ source: str
183
+ source_path: str
184
+ page: Optional[int] = None
185
+ snippet: str
186
+
187
+ class AskOut(BaseModel):
188
+ question: str
189
+ context: str
190
+ sources: List[SourceOut]
191
+
192
+ # --------------------- Routes ---------------------
193
+ @app.get("/health")
194
+ def health():
195
  return {
196
+ "status": "ok",
197
+ "config": {
198
+ "dataset_id": DATASET_ID,
199
+ "dataset_rev": DATA_REV,
200
+ "collection": COLLECTION_NAME,
201
+ "db_dir": DB_DIR,
202
+ "corpus_dir": CORPUS_DIR,
203
+ "index_status": INDEX_STATUS,
204
+ "reindex_protected": bool(ADMIN_REINDEX_TOKEN),
205
+ }
206
  }
207
 
208
+ @app.post("/ask", response_model=AskOut)
209
+ def ask(payload: AskIn):
210
+ # Retrieval only (NO LLM generation here)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  try:
212
+ k = int(payload.k or 4)
213
+ k = max(1, min(k, 12))
214
+ docs = build_retriever(k).invoke(payload.text)
215
+
216
+ # Build context string
217
+ ctx_parts = []
218
+ sources: List[SourceOut] = []
219
+
220
+ for d in docs:
221
+ text = clean_text(d.page_content)
222
+ if not text:
223
+ continue
224
+ ctx_parts.append(text)
225
+
226
+ md = d.metadata or {}
227
+ sources.append(SourceOut(
228
+ source=str(md.get("source", "")),
229
+ source_path=str(md.get("source_path", "")),
230
+ page=md.get("page", None),
231
+ snippet=(text[:300] + "…") if len(text) > 300 else text
232
+ ))
233
+
234
+ context = "\n\n---\n\n".join(ctx_parts).strip()
235
+
236
+ return AskOut(
237
+ question=payload.text,
238
+ context=context,
239
+ sources=sources
240
+ )
241
 
242
+ except Exception as e:
243
+ raise HTTPException(status_code=500, detail=f"Retriever error: {str(e)[:500]}")
244
+
245
+ @app.post("/reindex")
246
+ def reindex_route(
247
+ x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token"),
248
+ force: Optional[bool] = True
249
+ ):
250
+ # Optional protection
251
+ if ADMIN_REINDEX_TOKEN:
252
+ if not x_admin_token or x_admin_token.strip() != ADMIN_REINDEX_TOKEN:
253
+ raise HTTPException(status_code=401, detail="Unauthorized")
254
 
255
  try:
256
+ info = reindex(force=bool(force))
257
+ INDEX_STATUS.update({"state": "ready", "detail": str(info), "last_commit": info.get("commit")})
258
+ return {"ok": True, "info": info}
259
+ except Exception as e:
260
+ INDEX_STATUS.update({"state": "error", "detail": str(e)})
261
+ raise HTTPException(status_code=500, detail=f"Reindex failed: {str(e)[:500]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
+ # --------------------- Entrypoint ---------------------
264
  if __name__ == "__main__":
265
+ import uvicorn
266
+ uvicorn.run(app, host=HOST, port=PORT)