Azizahalq commited on
Commit
76eee36
·
verified ·
1 Parent(s): 9221408

Create build_index_from_hf.py

Browse files
Files changed (1) hide show
  1. build_index_from_hf.py +142 -0
build_index_from_hf.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Rebuild Chroma index from a Hugging Face dataset using BGE-small (384-d) embeddings.
4
+ - Dataset: Azizahalq/materialmind-corpus (override with --repo)
5
+ - Output: MaterialMind/index/chroma_v3/<uuid> (override with --out_dir / --uuid)
6
+ - Collection: materialmind (override with --collection)
7
+ """
8
+ import os, argparse, uuid, math
9
+ from pathlib import Path
10
+ from typing import Dict, List, Any, Iterable
11
+
12
+ from datasets import load_dataset, concatenate_datasets
13
+ from tqdm import tqdm
14
+
15
+ EMB_MODEL = "BAAI/bge-small-en-v1.5"
16
+
17
+ def pick_text(row: Dict[str, Any]) -> str:
18
+ candidates = ["text","content","chunk","page_text","passage","body","abstract"]
19
+ for k in candidates:
20
+ if k in row and isinstance(row[k], str) and row[k].strip():
21
+ return row[k]
22
+ return " ".join([str(v) for v in row.values() if isinstance(v, str)])
23
+
24
+ def chunk_text(text: str, max_chars: int = 900, overlap: int = 120) -> List[str]:
25
+ text = " ".join(text.split())
26
+ if len(text) <= max_chars:
27
+ return [text] if text else []
28
+ chunks, i = [], 0
29
+ while i < len(text):
30
+ j = min(len(text), i + max_chars)
31
+ cut = text.rfind(". ", i, j)
32
+ if cut == -1 or cut <= i + 200:
33
+ cut = j
34
+ chunk = text[i:cut].strip()
35
+ if chunk:
36
+ chunks.append(chunk)
37
+ i = max(cut - overlap, i + 1)
38
+ return chunks
39
+
40
+ def l2norm(vec: List[float]) -> List[float]:
41
+ s = math.sqrt(sum(x*x for x in vec)) or 1.0
42
+ return [x/s for x in vec]
43
+
44
+ def embed_bge_small(texts: List[str]) -> List[List[float]]:
45
+ try:
46
+ from fastembed import TextEmbedding
47
+ emb = TextEmbedding(model_name=EMB_MODEL)
48
+ return [l2norm(v) for v in emb.embed(texts)]
49
+ except Exception:
50
+ from sentence_transformers import SentenceTransformer
51
+ model = SentenceTransformer(EMB_MODEL)
52
+ arr = model.encode(texts, normalize_embeddings=True)
53
+ return [l2norm(v.tolist()) for v in arr]
54
+
55
+ def batched(iterable, batch_size: int):
56
+ buf = []
57
+ for x in iterable:
58
+ buf.append(x)
59
+ if len(buf) >= batch_size:
60
+ yield buf
61
+ buf = []
62
+ if buf:
63
+ yield buf
64
+
65
+ def main():
66
+ ap = argparse.ArgumentParser()
67
+ ap.add_argument("--repo", default="Azizahalq/materialmind-corpus")
68
+ ap.add_argument("--split", default="train", help="train/test/all")
69
+ ap.add_argument("--out_dir", default="MaterialMind/index/chroma_v3")
70
+ ap.add_argument("--uuid", default=str(uuid.uuid4())[:8])
71
+ ap.add_argument("--collection", default="materialmind")
72
+ ap.add_argument("--batch", type=int, default=64)
73
+ args = ap.parse_args()
74
+
75
+ out_root = Path(args.out_dir).resolve()
76
+ index_dir = (out_root / args.uuid).resolve()
77
+ index_dir.mkdir(parents=True, exist_ok=True)
78
+ print(f"[BUILD] Index path: {index_dir}")
79
+
80
+ # Load dataset
81
+ try:
82
+ if args.split == "all":
83
+ ds_map = load_dataset(args.repo)
84
+ data = concatenate_datasets(list(ds_map.values()))
85
+ else:
86
+ data = load_dataset(args.repo, split=args.split)
87
+ except Exception as e:
88
+ raise SystemExit(f"[BUILD] Failed to load dataset {args.repo}: {e}")
89
+
90
+ # Chroma
91
+ import chromadb
92
+ client = chromadb.PersistentClient(path=str(index_dir))
93
+ col = client.get_or_create_collection(
94
+ name=args.collection,
95
+ metadata={"hnsw:space": "cosine"},
96
+ )
97
+
98
+ docs, metas, ids = [], [], []
99
+ total_rows = len(data)
100
+ print(f"[BUILD] Rows in split '{args.split}': {total_rows}")
101
+
102
+ for ridx in tqdm(range(total_rows), desc="Chunking"):
103
+ row = data[ridx]
104
+ text = pick_text(row)
105
+ if not text:
106
+ continue
107
+ meta = {}
108
+ for key in ("source","path","file","url","title","page"):
109
+ if key in row:
110
+ meta[key] = row[key]
111
+ parts = chunk_text(text, max_chars=900, overlap=120)
112
+ for pidx, chunk in enumerate(parts):
113
+ docs.append(chunk)
114
+ metas.append(meta.copy())
115
+ ids.append(f"r{ridx}-p{pidx}")
116
+
117
+ if not docs:
118
+ raise SystemExit("[BUILD] No text to index. Check your dataset fields.")
119
+
120
+ added = 0
121
+ for bi in tqdm(list(batched(list(zip(ids, docs, metas)), args.batch)), desc="Embedding+Add"):
122
+ b_ids = [b[0] for b in bi]
123
+ b_docs = [b[1] for b in bi]
124
+ b_meta = [b[2] for b in bi]
125
+ vecs = embed_bge_small(b_docs)
126
+ col.add(ids=b_ids, documents=b_docs, metadatas=b_meta, embeddings=vecs)
127
+ added += len(b_ids)
128
+
129
+ try:
130
+ count = col.count()
131
+ except Exception:
132
+ count = added
133
+
134
+ print(f"[BUILD] Done. Added {added} chunks. Collection count = {count}")
135
+ print(f"[BUILD] Set env vars for the app:")
136
+ print(f" EMB_PROVIDER=hf")
137
+ print(f" EMB_MODEL={EMB_MODEL}")
138
+ print(f" INDEX_DIR=MaterialMind/index/chroma_v3/{args.uuid}")
139
+ print(f" INDEX_COLLECTION={args.collection}")
140
+
141
+ if __name__ == "__main__":
142
+ main()