| import os |
| |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--lang", type=str, default="en", help="language code") |
| parser.add_argument("--shard_id", type=int, required=True, help="Shard ID for this run") |
| parser.add_argument("--num_shards", type=int, default=20, help="Total number of shards") |
| parser.add_argument("--batch_size", type=int, default=16, help="Batch size for embedding") |
| parser.add_argument("--cuda", type=str, default="none", help="CUDA device ID to use") |
| args = parser.parse_args() |
|
|
| if args.cuda == "none": |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
| else: |
| |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda |
|
|
|
|
| import gc |
| import torch |
| import faiss |
| import numpy as np |
| from datasets import load_dataset |
| from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
| |
| |
| NUM_SHARDS = args.num_shards |
| SHARD_ID = args.shard_id |
| batch_size = args.batch_size |
| lang_code = args.lang |
| |
|
|
| model_id = "Qwen/Qwen3-Embedding-4B" |
| save_path = f"/home/mshahidul/readctrl/data/vector_db/qwen_em/shard_{SHARD_ID}_{lang_code}.faiss" |
| |
|
|
| |
| print("Loading model...") |
| model = SentenceTransformer( |
| model_id, |
| trust_remote_code=True, |
| device="cuda", |
| model_kwargs={"torch_dtype": torch.bfloat16} |
| ) |
| model.max_seq_length = 1024 |
|
|
| |
| print(f"Loading {lang_code} Wikipedia dataset into RAM...") |
| ds = load_dataset("wikimedia/wikipedia", f"20231101.{lang_code}", split='train', streaming=False) |
| ds_shard = ds.shard(num_shards=NUM_SHARDS, index=SHARD_ID) |
| |
| print("Chunking articles into paragraphs...") |
| STOP_HEADERS = [ |
| "\nReferences", "\nSee also", "\nExternal links", |
| "\nNotes", "\nFurther reading", "\nBibliography" |
| ] |
|
|
| MAX_CHUNKS_PER_ARTICLE = 5 |
| wiki_chunks = [] |
| import tqdm |
| import tqdm |
| for text in tqdm.tqdm(ds_shard['text']): |
| |
| clean_text = text |
| for header in STOP_HEADERS: |
| if header in clean_text: |
| clean_text = clean_text.split(header)[0] |
| |
| |
| paragraphs = [p.strip() for p in clean_text.split('\n\n') if len(p.split()) > 20] |
| |
| |
| |
| if len(paragraphs) > MAX_CHUNKS_PER_ARTICLE: |
| paragraphs = paragraphs[:MAX_CHUNKS_PER_ARTICLE] |
| |
| wiki_chunks.extend(paragraphs) |
|
|
| print(f"Total chunks created: {len(wiki_chunks)}") |
|
|
| |
| del ds |
| gc.collect() |
|
|
| |
| def build_faiss_index(chunks, model, batch_size): |
| index = None |
| total_chunks = len(chunks) |
| |
| print(f"Starting embedding process for {total_chunks} chunks...") |
| import tqdm |
| for i in tqdm.tqdm(range(0, total_chunks, batch_size)): |
| batch = chunks[i : i + batch_size] |
| |
| |
| with torch.no_grad(): |
| embeddings = model.encode( |
| batch, |
| show_progress_bar=False, |
| convert_to_numpy=True |
| ).astype('float32') |
| |
| |
| if index is None: |
| dimension = embeddings.shape[1] |
| index = faiss.IndexFlatL2(dimension) |
| |
| |
| |
| index.add(embeddings) |
| |
| if i % 1000 == 0: |
| print(f"Processed {i}/{total_chunks} chunks...") |
|
|
| return index |
|
|
| |
| vector_index = build_faiss_index(wiki_chunks, model, batch_size) |
|
|
| print(f"Saving index to {save_path}...") |
| faiss.write_index(vector_index, save_path) |
| print("Done!") |