readCtrl_lambda / code /vectordb_build /qwen_embed_v2.py
mshahidul
Initial commit of readCtrl code without large models
030876e
import os
# 1. Environment & Configuration
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--lang", type=str, default="en", help="language code")
parser.add_argument("--shard_id", type=int, required=True, help="Shard ID for this run")
parser.add_argument("--num_shards", type=int, default=20, help="Total number of shards")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for embedding")
parser.add_argument("--cuda", type=str, default="none", help="CUDA device ID to use")
args = parser.parse_args()
if args.cuda == "none":
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
else:
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
import gc
import torch
import faiss
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
# --- SHARDING CONFIG ---
# SHARD_ID = 2 # Change this for each run (e.g., 0, 1, 2, 3...)
NUM_SHARDS = args.num_shards # Total number of parts to split Wikipedia into
SHARD_ID = args.shard_id
batch_size = args.batch_size
lang_code = args.lang
# -----------------------
model_id = "Qwen/Qwen3-Embedding-4B"
save_path = f"/home/mshahidul/readctrl/data/vector_db/qwen_em/shard_{SHARD_ID}_{lang_code}.faiss"
# batch_size = 64 #16 # Keep small for 4B model to avoid OOM
# 2. Load Model with Memory Optimizations
print("Loading model...")
model = SentenceTransformer(
model_id,
trust_remote_code=True,
device="cuda",
model_kwargs={"torch_dtype": torch.bfloat16} # Use half-precision
)
model.max_seq_length = 1024 # Truncate long paragraphs to save VRAM
# 3. Load Full Dataset (Non-Streaming)
print(f"Loading {lang_code} Wikipedia dataset into RAM...")
ds = load_dataset("wikimedia/wikipedia", f"20231101.{lang_code}", split='train', streaming=False)
ds_shard = ds.shard(num_shards=NUM_SHARDS, index=SHARD_ID)
# 4. Chunking Logic
print("Chunking articles into paragraphs...")
STOP_HEADERS = [
"\nReferences", "\nSee also", "\nExternal links",
"\nNotes", "\nFurther reading", "\nBibliography"
]
MAX_CHUNKS_PER_ARTICLE = 5 # Adjust this to cap the size
wiki_chunks = []
import tqdm
import tqdm
for text in tqdm.tqdm(ds_shard['text']):
# A. Clean the text: Remove everything after the first "STOP_HEADER"
clean_text = text
for header in STOP_HEADERS:
if header in clean_text:
clean_text = clean_text.split(header)[0]
# B. Paragraph Split
paragraphs = [p.strip() for p in clean_text.split('\n\n') if len(p.split()) > 20]
# C. Cap the chunks per article
# This prevents very long articles from dominating your index
if len(paragraphs) > MAX_CHUNKS_PER_ARTICLE:
paragraphs = paragraphs[:MAX_CHUNKS_PER_ARTICLE]
wiki_chunks.extend(paragraphs)
print(f"Total chunks created: {len(wiki_chunks)}")
# Clear original dataset from RAM to free up space for embeddings
del ds
gc.collect()
# 5. Embedding Function
def build_faiss_index(chunks, model, batch_size):
index = None
total_chunks = len(chunks)
print(f"Starting embedding process for {total_chunks} chunks...")
import tqdm
for i in tqdm.tqdm(range(0, total_chunks, batch_size)):
batch = chunks[i : i + batch_size]
# Generate Embeddings
with torch.no_grad():
embeddings = model.encode(
batch,
show_progress_bar=False,
convert_to_numpy=True
).astype('float32')
# Initialize FAISS index on first batch
if index is None:
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
# Optional: If you have a massive dataset, consider using faiss.IndexIVFFlat
# for faster search, though IndexFlatL2 is most accurate.
index.add(embeddings)
if i % 1000 == 0:
print(f"Processed {i}/{total_chunks} chunks...")
return index
# 6. Run and Save
vector_index = build_faiss_index(wiki_chunks, model, batch_size)
print(f"Saving index to {save_path}...")
faiss.write_index(vector_index, save_path)
print("Done!")