import os # 1. Environment & Configuration import argparse parser = argparse.ArgumentParser() parser.add_argument("--lang", type=str, default="en", help="language code") parser.add_argument("--shard_id", type=int, required=True, help="Shard ID for this run") parser.add_argument("--num_shards", type=int, default=20, help="Total number of shards") parser.add_argument("--batch_size", type=int, default=16, help="Batch size for embedding") parser.add_argument("--cuda", type=str, default="none", help="CUDA device ID to use") args = parser.parse_args() if args.cuda == "none": os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "2" else: # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda import gc import torch import faiss import numpy as np from datasets import load_dataset from sentence_transformers import SentenceTransformer # --- SHARDING CONFIG --- # SHARD_ID = 2 # Change this for each run (e.g., 0, 1, 2, 3...) NUM_SHARDS = args.num_shards # Total number of parts to split Wikipedia into SHARD_ID = args.shard_id batch_size = args.batch_size lang_code = args.lang # ----------------------- model_id = "Qwen/Qwen3-Embedding-4B" save_path = f"/home/mshahidul/readctrl/data/vector_db/qwen_em/shard_{SHARD_ID}_{lang_code}.faiss" # batch_size = 64 #16 # Keep small for 4B model to avoid OOM # 2. Load Model with Memory Optimizations print("Loading model...") model = SentenceTransformer( model_id, trust_remote_code=True, device="cuda", model_kwargs={"torch_dtype": torch.bfloat16} # Use half-precision ) model.max_seq_length = 1024 # Truncate long paragraphs to save VRAM # 3. Load Full Dataset (Non-Streaming) print(f"Loading {lang_code} Wikipedia dataset into RAM...") ds = load_dataset("wikimedia/wikipedia", f"20231101.{lang_code}", split='train', streaming=False) ds_shard = ds.shard(num_shards=NUM_SHARDS, index=SHARD_ID) # 4. Chunking Logic print("Chunking articles into paragraphs...") STOP_HEADERS = [ "\nReferences", "\nSee also", "\nExternal links", "\nNotes", "\nFurther reading", "\nBibliography" ] MAX_CHUNKS_PER_ARTICLE = 5 # Adjust this to cap the size wiki_chunks = [] import tqdm import tqdm for text in tqdm.tqdm(ds_shard['text']): # A. Clean the text: Remove everything after the first "STOP_HEADER" clean_text = text for header in STOP_HEADERS: if header in clean_text: clean_text = clean_text.split(header)[0] # B. Paragraph Split paragraphs = [p.strip() for p in clean_text.split('\n\n') if len(p.split()) > 20] # C. Cap the chunks per article # This prevents very long articles from dominating your index if len(paragraphs) > MAX_CHUNKS_PER_ARTICLE: paragraphs = paragraphs[:MAX_CHUNKS_PER_ARTICLE] wiki_chunks.extend(paragraphs) print(f"Total chunks created: {len(wiki_chunks)}") # Clear original dataset from RAM to free up space for embeddings del ds gc.collect() # 5. Embedding Function def build_faiss_index(chunks, model, batch_size): index = None total_chunks = len(chunks) print(f"Starting embedding process for {total_chunks} chunks...") import tqdm for i in tqdm.tqdm(range(0, total_chunks, batch_size)): batch = chunks[i : i + batch_size] # Generate Embeddings with torch.no_grad(): embeddings = model.encode( batch, show_progress_bar=False, convert_to_numpy=True ).astype('float32') # Initialize FAISS index on first batch if index is None: dimension = embeddings.shape[1] index = faiss.IndexFlatL2(dimension) # Optional: If you have a massive dataset, consider using faiss.IndexIVFFlat # for faster search, though IndexFlatL2 is most accurate. index.add(embeddings) if i % 1000 == 0: print(f"Processed {i}/{total_chunks} chunks...") return index # 6. Run and Save vector_index = build_faiss_index(wiki_chunks, model, batch_size) print(f"Saving index to {save_path}...") faiss.write_index(vector_index, save_path) print("Done!")