Spaces:
Running
Running
| import sys | |
| import os | |
| import numpy as np | |
| from tqdm import tqdm | |
| # Add project root to path | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from config import ( | |
| NUM_CLUSTERS, MRL_DIMS, EMBEDDING_MODELS | |
| ) | |
| from src.data_pipeline import get_embeddings, load_ms_marco | |
| from src.router import LearnedRouter | |
| def train_routers(): | |
| print(">>> Starting Router Training Pipeline...") | |
| # 1. Load Data | |
| N_SAMPLES = 25000 | |
| print(f"Loading {N_SAMPLES} samples from MS MARCO...") | |
| raw_texts = load_ms_marco(N_SAMPLES) | |
| # 2. Generate Embeddings (MiniLM-L6-v2) | |
| MODEL_NAME = EMBEDDING_MODELS["minilm"] | |
| print(f"Generating embeddings using {MODEL_NAME}...") | |
| embeddings = get_embeddings(MODEL_NAME, raw_texts) | |
| # 3. Train & Save Routers | |
| models_to_train = [ | |
| ("logistic", "router_logistic.pkl"), | |
| ("lightgbm", "router_lightgbm.pkl"), | |
| ("mlp", "router_mlp.pkl") | |
| ] | |
| for model_type, filename in models_to_train: | |
| print(f"\n--- Training {model_type.upper()} Router ---") | |
| router = LearnedRouter(model_type=model_type, n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS) | |
| router.train(embeddings) | |
| save_path = os.path.abspath(f"models/{filename}") | |
| print(f"Saving to {save_path}...") | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| router.save(save_path) | |
| print(f"{model_type.upper()} Router saved.") | |
| print("\n>>> All Routers Trained & Saved!") | |
| if __name__ == "__main__": | |
| train_routers() | |