File size: 3,125 Bytes
030876e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
# 1. Environment & Configuration
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--lang", type=str, default="en", help="language code")
parser.add_argument("--shard_id", type=int, required=True, help="Shard ID for this run")
parser.add_argument("--num_shards", type=int, default=20, help="Total number of shards")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for embedding")
parser.add_argument("--cuda", type=str, default="none", help="CUDA device ID to use")
args = parser.parse_args()

if args.cuda == "none":
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
else:
    # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda


import gc
import torch
import faiss
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import pandas as pd


# --- SHARDING CONFIG ---
# SHARD_ID = 2  # Change this for each run (e.g., 0, 1, 2, 3...)
NUM_SHARDS = args.num_shards # Total number of parts to split Wikipedia into
SHARD_ID = args.shard_id
batch_size = args.batch_size
lang_code = args.lang
# -----------------------

model_id = "Qwen/Qwen3-Embedding-4B"
save_path = f"/home/mshahidul/readctrl/data/vector_db/qwen_em/shard_{SHARD_ID}_{lang_code}.faiss"
# batch_size = 64 #16  # Keep small for 4B model to avoid OOM

# 2. Load Model with Memory Optimizations
print("Loading model...")
model = SentenceTransformer(
    model_id, 
    trust_remote_code=True, 
    device="cuda",
    model_kwargs={"torch_dtype": torch.bfloat16} # Use half-precision
)
model.max_seq_length = 1024 # Truncate long paragraphs to save VRAM


load_path = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{lang_code}_shard_{SHARD_ID}.parquet"
df = pd.read_parquet(load_path)
wiki_chunks = df['text'].tolist()

# 5. Embedding Function
def build_faiss_index(chunks, model, batch_size):
    index = None
    total_chunks = len(chunks)
    
    print(f"Starting embedding process for {total_chunks} chunks...")
    import tqdm
    for i in tqdm.tqdm(range(0, total_chunks, batch_size)):
        batch = chunks[i : i + batch_size]
        
        # Generate Embeddings
        with torch.no_grad():
            embeddings = model.encode(
                batch, 
                show_progress_bar=False,
                convert_to_numpy=True
            ).astype('float32')
        
        # Initialize FAISS index on first batch
        if index is None:
            dimension = embeddings.shape[1]
            index = faiss.IndexFlatL2(dimension)
            # Optional: If you have a massive dataset, consider using faiss.IndexIVFFlat 
            # for faster search, though IndexFlatL2 is most accurate.
        
        index.add(embeddings)
        
        if i % 1000 == 0:
            print(f"Processed {i}/{total_chunks} chunks...")

    return index

# 6. Run and Save
vector_index = build_faiss_index(wiki_chunks, model, batch_size)

print(f"Saving index to {save_path}...")
faiss.write_index(vector_index, save_path)
print("Done!")