File size: 4,248 Bytes
030876e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
# 1. Environment & Configuration
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--lang", type=str, default="en", help="language code")
parser.add_argument("--shard_id", type=int, required=True, help="Shard ID for this run")
parser.add_argument("--num_shards", type=int, default=20, help="Total number of shards")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for embedding")
parser.add_argument("--cuda", type=str, default="none", help="CUDA device ID to use")
args = parser.parse_args()

if args.cuda == "none":
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
else:
    # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda


import gc
import torch
import faiss
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer



# --- SHARDING CONFIG ---
# SHARD_ID = 2  # Change this for each run (e.g., 0, 1, 2, 3...)
NUM_SHARDS = args.num_shards # Total number of parts to split Wikipedia into
SHARD_ID = args.shard_id
batch_size = args.batch_size
lang_code = args.lang
# -----------------------

model_id = "Qwen/Qwen3-Embedding-4B"
save_path = f"/home/mshahidul/readctrl/data/vector_db/qwen_em/shard_{SHARD_ID}_{lang_code}.faiss"
# batch_size = 64 #16  # Keep small for 4B model to avoid OOM

# 2. Load Model with Memory Optimizations
print("Loading model...")
model = SentenceTransformer(
    model_id, 
    trust_remote_code=True, 
    device="cuda",
    model_kwargs={"torch_dtype": torch.bfloat16} # Use half-precision
)
model.max_seq_length = 1024 # Truncate long paragraphs to save VRAM

# 3. Load Full Dataset (Non-Streaming)
print(f"Loading {lang_code} Wikipedia dataset into RAM...")
ds = load_dataset("wikimedia/wikipedia", f"20231101.{lang_code}", split='train', streaming=False)
ds_shard = ds.shard(num_shards=NUM_SHARDS, index=SHARD_ID)
# 4. Chunking Logic
print("Chunking articles into paragraphs...")
STOP_HEADERS = [
    "\nReferences", "\nSee also", "\nExternal links", 
    "\nNotes", "\nFurther reading", "\nBibliography"
]

MAX_CHUNKS_PER_ARTICLE = 5  # Adjust this to cap the size
wiki_chunks = []
import tqdm
import tqdm
for text in tqdm.tqdm(ds_shard['text']):
    # A. Clean the text: Remove everything after the first "STOP_HEADER"
    clean_text = text
    for header in STOP_HEADERS:
        if header in clean_text:
            clean_text = clean_text.split(header)[0]
    
    # B. Paragraph Split
    paragraphs = [p.strip() for p in clean_text.split('\n\n') if len(p.split()) > 20]
    
    # C. Cap the chunks per article
    # This prevents very long articles from dominating your index
    if len(paragraphs) > MAX_CHUNKS_PER_ARTICLE:
        paragraphs = paragraphs[:MAX_CHUNKS_PER_ARTICLE]
        
    wiki_chunks.extend(paragraphs)

print(f"Total chunks created: {len(wiki_chunks)}")

# Clear original dataset from RAM to free up space for embeddings
del ds
gc.collect()

# 5. Embedding Function
def build_faiss_index(chunks, model, batch_size):
    index = None
    total_chunks = len(chunks)
    
    print(f"Starting embedding process for {total_chunks} chunks...")
    import tqdm
    for i in tqdm.tqdm(range(0, total_chunks, batch_size)):
        batch = chunks[i : i + batch_size]
        
        # Generate Embeddings
        with torch.no_grad():
            embeddings = model.encode(
                batch, 
                show_progress_bar=False,
                convert_to_numpy=True
            ).astype('float32')
        
        # Initialize FAISS index on first batch
        if index is None:
            dimension = embeddings.shape[1]
            index = faiss.IndexFlatL2(dimension)
            # Optional: If you have a massive dataset, consider using faiss.IndexIVFFlat 
            # for faster search, though IndexFlatL2 is most accurate.
        
        index.add(embeddings)
        
        if i % 1000 == 0:
            print(f"Processed {i}/{total_chunks} chunks...")

    return index

# 6. Run and Save
vector_index = build_faiss_index(wiki_chunks, model, batch_size)

print(f"Saving index to {save_path}...")
faiss.write_index(vector_index, save_path)
print("Done!")