| import os |
| |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--lang", type=str, default="en", help="language code") |
| parser.add_argument("--shard_id", type=int, required=True, help="Shard ID for this run") |
| parser.add_argument("--num_shards", type=int, default=20, help="Total number of shards") |
| parser.add_argument("--batch_size", type=int, default=16, help="Batch size for embedding") |
| parser.add_argument("--cuda", type=str, default="none", help="CUDA device ID to use") |
| args = parser.parse_args() |
|
|
| if args.cuda == "none": |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
| else: |
| |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda |
|
|
|
|
| import gc |
| import torch |
| import faiss |
| import numpy as np |
| from datasets import load_dataset |
| from sentence_transformers import SentenceTransformer |
| import pandas as pd |
|
|
|
|
| |
| |
| NUM_SHARDS = args.num_shards |
| SHARD_ID = args.shard_id |
| batch_size = args.batch_size |
| lang_code = args.lang |
| |
|
|
| model_id = "Qwen/Qwen3-Embedding-4B" |
| save_path = f"/home/mshahidul/readctrl/data/vector_db/qwen_em/shard_{SHARD_ID}_{lang_code}.faiss" |
| |
|
|
| |
| print("Loading model...") |
| model = SentenceTransformer( |
| model_id, |
| trust_remote_code=True, |
| device="cuda", |
| model_kwargs={"torch_dtype": torch.bfloat16} |
| ) |
| model.max_seq_length = 1024 |
|
|
|
|
| load_path = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{lang_code}_shard_{SHARD_ID}.parquet" |
| df = pd.read_parquet(load_path) |
| wiki_chunks = df['text'].tolist() |
|
|
| |
| def build_faiss_index(chunks, model, batch_size): |
| index = None |
| total_chunks = len(chunks) |
| |
| print(f"Starting embedding process for {total_chunks} chunks...") |
| import tqdm |
| for i in tqdm.tqdm(range(0, total_chunks, batch_size)): |
| batch = chunks[i : i + batch_size] |
| |
| |
| with torch.no_grad(): |
| embeddings = model.encode( |
| batch, |
| show_progress_bar=False, |
| convert_to_numpy=True |
| ).astype('float32') |
| |
| |
| if index is None: |
| dimension = embeddings.shape[1] |
| index = faiss.IndexFlatL2(dimension) |
| |
| |
| |
| index.add(embeddings) |
| |
| if i % 1000 == 0: |
| print(f"Processed {i}/{total_chunks} chunks...") |
|
|
| return index |
|
|
| |
| vector_index = build_faiss_index(wiki_chunks, model, batch_size) |
|
|
| print(f"Saving index to {save_path}...") |
| faiss.write_index(vector_index, save_path) |
| print("Done!") |