File size: 3,694 Bytes
ba86059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import sys
import torch
import numpy as np
import time
from tqdm import tqdm
from datasets import load_dataset
from huggingface_hub import HfApi

# Thêm đường dẫn để load các service
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, backend_dir)

from services.tq_service import tq_service
from services.metadata_service import metadata_service
from services.ingestion_service import ingestion_service

def run_worker(limit=100000, batch_size=128, push_every=10000):
    print(f"🚀 HF Worker Started: Target {limit} chunks")
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"💻 Device: {device}")
    
    # 1. Load Dataset
    print("🌐 Loading Dataset from HF...")
    ds = load_dataset("facebook/wiki_dpr", "psgs_w100.multiset", split="train", streaming=True)
    
    count = 0
    batch_texts = []
    batch_metadata = []
    all_embeddings = []
    
    repo_id = os.getenv("HF_DATASET_REPO")
    token = os.getenv("HF_TOKEN")
    
    if not repo_id or not token:
        print("⚠️ Warning: HF_DATASET_REPO or HF_TOKEN not set. Results will only be saved locally.")

    pbar = tqdm(total=limit)
    
    for i, row in enumerate(ds):
        if i >= limit: break
        
        batch_texts.append(row['text'])
        batch_metadata.append({
            "text": row['text'],
            "source": row.get('title', 'Wikipedia'),
            "user_id": -1,
            "session_id": "system"
        })
        
        if len(batch_texts) >= batch_size:
            # Embedding
            emb = ingestion_service.get_embeddings(batch_texts, is_query=False)
            all_embeddings.append(emb)
            
            # Save Metadata
            metadata_service.add_chunks(count, batch_metadata, user_id=-1)
            
            count += len(batch_texts)
            pbar.update(len(batch_texts))
            
            # Checkpoint & Push
            if count % push_every == 0:
                print(f"\n💾 Saving checkpoint at {count}...")
                save_and_push(all_embeddings, repo_id, token)
            
            batch_texts = []
            batch_metadata = []

    # Final Save
    if batch_texts:
        emb = ingestion_service.get_embeddings(batch_texts, is_query=False)
        all_embeddings.append(emb)
        metadata_service.add_chunks(count, batch_metadata, user_id=-1)
    
    save_and_push(all_embeddings, repo_id, token)
    print("✅ WORKER COMPLETED!")

def save_and_push(all_embeddings, repo_id, token):
    if not all_embeddings: return
    
    vectors = np.vstack(all_embeddings)
    tq_service.system_engine.ivf_nlist = 4096
    tq_service.system_engine.index(vectors)
    
    save_path = os.path.join(tq_service.data_dir, "tq_index_4bit_np4096_system")
    tq_service.system_engine.save_index(save_path)
    
    if repo_id and token:
        api = HfApi(token=token)
        # Push Index
        for file in os.listdir(save_path):
            api.upload_file(
                path_or_fileobj=os.path.join(save_path, file),
                path_in_repo=f"tq_index_system_e5/{file}",
                repo_id=repo_id,
                repo_type="dataset"
            )
        # Push DB
        api.upload_file(
            path_or_fileobj=metadata_service.db_path,
            path_in_repo="metadata.db",
            repo_id=repo_id,
            repo_type="dataset"
        )
        print(f"📤 Pushed to {repo_id}")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--limit", type=int, default=100000)
    args = parser.parse_args()
    run_worker(limit=args.limit)