File size: 2,119 Bytes
1db7196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
# Environment Setup
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

# 1. Configuration
model_id = "Qwen/Qwen3-Embedding-4B"
lang_code = "en"  # Change to your desired language
save_path = f"/home/mshahidul/readctrl/data/vector_db/qwen_em/{lang_code}_wikipedia_qwen3_index.faiss"
batch_size = 8  # Adjust based on your GPU VRAM (4B model is heavy)

# 2. Load Model
# Note: Qwen3 might require trust_remote_code=True depending on the implementation
model = SentenceTransformer(model_id, trust_remote_code=True, model_kwargs={"torch_dtype": torch.bfloat16}) # Use bfloat16 for Qwen3

# 3. Load Dataset (Streaming)
ds = load_dataset("wikimedia/wikipedia", f"20231101.{lang_code}", split='train', streaming=True)

def embed_wikipedia(dataset, model, batch_size):
    index = None
    metadata = [] # To store text or IDs
    
    batch_texts = []
    print("Starting embedding process...")
    
    for i, item in enumerate(dataset):
        batch_texts.append(item['text'])
        
        if len(batch_texts) == batch_size:
            # Generate Embeddings
            embeddings = model.encode(batch_texts, show_progress_bar=False)
            embeddings = np.array(embeddings).astype('float32')
            
            # Initialize FAISS index on first batch
            if index is None:
                dimension = embeddings.shape[1]
                index = faiss.IndexFlatL2(dimension)
            
            index.add(embeddings)
            
            # Optional: Store metadata (Warning: Wikipedia is huge, 
            # storing all text in RAM might crash your system)
            # metadata.extend(batch_texts) 
            
            batch_texts = []
            
            if i % 100 == 0:
                print(f"Processed {i} documents...")

    return index

# 4. Run and Save
vector_index = embed_wikipedia(ds, model, batch_size)
faiss.write_index(vector_index, save_path)
print(f"Index saved to {save_path}")