import sys import os import numpy as np from tqdm import tqdm # Add project root to path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config import ( NUM_CLUSTERS, MRL_DIMS, EMBEDDING_MODELS ) from src.data_pipeline import get_embeddings, load_ms_marco from src.router import LearnedRouter def train_routers(): print(">>> Starting Router Training Pipeline...") # 1. Load Data N_SAMPLES = 25000 print(f"Loading {N_SAMPLES} samples from MS MARCO...") raw_texts = load_ms_marco(N_SAMPLES) # 2. Generate Embeddings (MiniLM-L6-v2) MODEL_NAME = EMBEDDING_MODELS["minilm"] print(f"Generating embeddings using {MODEL_NAME}...") embeddings = get_embeddings(MODEL_NAME, raw_texts) # 3. Train & Save Routers models_to_train = [ ("logistic", "router_logistic.pkl"), ("lightgbm", "router_lightgbm.pkl"), ("mlp", "router_mlp.pkl") ] for model_type, filename in models_to_train: print(f"\n--- Training {model_type.upper()} Router ---") router = LearnedRouter(model_type=model_type, n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS) router.train(embeddings) save_path = os.path.abspath(f"models/{filename}") print(f"Saving to {save_path}...") os.makedirs(os.path.dirname(save_path), exist_ok=True) router.save(save_path) print(f"{model_type.upper()} Router saved.") print("\n>>> All Routers Trained & Saved!") if __name__ == "__main__": train_routers()