Spaces:
Sleeping
Sleeping
File size: 1,555 Bytes
9a9f1fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import sys
import os
import numpy as np
from tqdm import tqdm
# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import (
NUM_CLUSTERS, MRL_DIMS, EMBEDDING_MODELS
)
from src.data_pipeline import get_embeddings, load_ms_marco
from src.router import LearnedRouter
def train_routers():
print(">>> Starting Router Training Pipeline...")
# 1. Load Data
N_SAMPLES = 25000
print(f"Loading {N_SAMPLES} samples from MS MARCO...")
raw_texts = load_ms_marco(N_SAMPLES)
# 2. Generate Embeddings (MiniLM-L6-v2)
MODEL_NAME = EMBEDDING_MODELS["minilm"]
print(f"Generating embeddings using {MODEL_NAME}...")
embeddings = get_embeddings(MODEL_NAME, raw_texts)
# 3. Train & Save Routers
models_to_train = [
("logistic", "router_logistic.pkl"),
("lightgbm", "router_lightgbm.pkl"),
("mlp", "router_mlp.pkl")
]
for model_type, filename in models_to_train:
print(f"\n--- Training {model_type.upper()} Router ---")
router = LearnedRouter(model_type=model_type, n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS)
router.train(embeddings)
save_path = os.path.abspath(f"models/{filename}")
print(f"Saving to {save_path}...")
os.makedirs(os.path.dirname(save_path), exist_ok=True)
router.save(save_path)
print(f"{model_type.upper()} Router saved.")
print("\n>>> All Routers Trained & Saved!")
if __name__ == "__main__":
train_routers()
|