readctrl / code /vectordb_build /qwen_embed_v3.py
shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import os
# 1. Environment & Configuration
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--lang", type=str, default="en", help="language code")
parser.add_argument("--shard_id", type=int, required=True, help="Shard ID for this run")
parser.add_argument("--num_shards", type=int, default=20, help="Total number of shards")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for embedding")
parser.add_argument("--cuda", type=str, default="none", help="CUDA device ID to use")
args = parser.parse_args()
if args.cuda == "none":
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
else:
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
import gc
import torch
import faiss
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import pandas as pd
# --- SHARDING CONFIG ---
# SHARD_ID = 2 # Change this for each run (e.g., 0, 1, 2, 3...)
NUM_SHARDS = args.num_shards # Total number of parts to split Wikipedia into
SHARD_ID = args.shard_id
batch_size = args.batch_size
lang_code = args.lang
# -----------------------
model_id = "Qwen/Qwen3-Embedding-4B"
save_path = f"/home/mshahidul/readctrl/data/vector_db/qwen_em/shard_{SHARD_ID}_{lang_code}.faiss"
# batch_size = 64 #16 # Keep small for 4B model to avoid OOM
# 2. Load Model with Memory Optimizations
print("Loading model...")
model = SentenceTransformer(
model_id,
trust_remote_code=True,
device="cuda",
model_kwargs={"torch_dtype": torch.bfloat16} # Use half-precision
)
model.max_seq_length = 1024 # Truncate long paragraphs to save VRAM
load_path = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{lang_code}_shard_{SHARD_ID}.parquet"
df = pd.read_parquet(load_path)
wiki_chunks = df['text'].tolist()
# 5. Embedding Function
def build_faiss_index(chunks, model, batch_size):
index = None
total_chunks = len(chunks)
print(f"Starting embedding process for {total_chunks} chunks...")
import tqdm
for i in tqdm.tqdm(range(0, total_chunks, batch_size)):
batch = chunks[i : i + batch_size]
# Generate Embeddings
with torch.no_grad():
embeddings = model.encode(
batch,
show_progress_bar=False,
convert_to_numpy=True
).astype('float32')
# Initialize FAISS index on first batch
if index is None:
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
# Optional: If you have a massive dataset, consider using faiss.IndexIVFFlat
# for faster search, though IndexFlatL2 is most accurate.
index.add(embeddings)
if i % 1000 == 0:
print(f"Processed {i}/{total_chunks} chunks...")
return index
# 6. Run and Save
vector_index = build_faiss_index(wiki_chunks, model, batch_size)
print(f"Saving index to {save_path}...")
faiss.write_index(vector_index, save_path)
print("Done!")