File size: 1,555 Bytes
9a9f1fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import sys
import os
import numpy as np
from tqdm import tqdm

# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from config import (
    NUM_CLUSTERS, MRL_DIMS, EMBEDDING_MODELS
)
from src.data_pipeline import get_embeddings, load_ms_marco
from src.router import LearnedRouter

def train_routers():
    print(">>> Starting Router Training Pipeline...")
    
    # 1. Load Data
    N_SAMPLES = 25000
    print(f"Loading {N_SAMPLES} samples from MS MARCO...")
    raw_texts = load_ms_marco(N_SAMPLES)
    
    # 2. Generate Embeddings (MiniLM-L6-v2)
    MODEL_NAME = EMBEDDING_MODELS["minilm"] 
    print(f"Generating embeddings using {MODEL_NAME}...")
    embeddings = get_embeddings(MODEL_NAME, raw_texts)
    
    # 3. Train & Save Routers
    models_to_train = [
        ("logistic", "router_logistic.pkl"),
        ("lightgbm", "router_lightgbm.pkl"),
        ("mlp", "router_mlp.pkl")
    ]
    
    for model_type, filename in models_to_train:
        print(f"\n--- Training {model_type.upper()} Router ---")
        router = LearnedRouter(model_type=model_type, n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS)
        router.train(embeddings)
        
        save_path = os.path.abspath(f"models/{filename}")
        print(f"Saving to {save_path}...")
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        router.save(save_path)
        print(f"{model_type.upper()} Router saved.")

    print("\n>>> All Routers Trained & Saved!")

if __name__ == "__main__":
    train_routers()