dashVectorSpace / scripts /train_routers_only.py
justmotes's picture
Deploy 9-Row Benchmark (via API)
9a9f1fb verified
raw
history blame
1.56 kB
import sys
import os
import numpy as np
from tqdm import tqdm
# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import (
NUM_CLUSTERS, MRL_DIMS, EMBEDDING_MODELS
)
from src.data_pipeline import get_embeddings, load_ms_marco
from src.router import LearnedRouter
def train_routers():
print(">>> Starting Router Training Pipeline...")
# 1. Load Data
N_SAMPLES = 25000
print(f"Loading {N_SAMPLES} samples from MS MARCO...")
raw_texts = load_ms_marco(N_SAMPLES)
# 2. Generate Embeddings (MiniLM-L6-v2)
MODEL_NAME = EMBEDDING_MODELS["minilm"]
print(f"Generating embeddings using {MODEL_NAME}...")
embeddings = get_embeddings(MODEL_NAME, raw_texts)
# 3. Train & Save Routers
models_to_train = [
("logistic", "router_logistic.pkl"),
("lightgbm", "router_lightgbm.pkl"),
("mlp", "router_mlp.pkl")
]
for model_type, filename in models_to_train:
print(f"\n--- Training {model_type.upper()} Router ---")
router = LearnedRouter(model_type=model_type, n_clusters=NUM_CLUSTERS, mrl_dims=MRL_DIMS)
router.train(embeddings)
save_path = os.path.abspath(f"models/{filename}")
print(f"Saving to {save_path}...")
os.makedirs(os.path.dirname(save_path), exist_ok=True)
router.save(save_path)
print(f"{model_type.upper()} Router saved.")
print("\n>>> All Routers Trained & Saved!")
if __name__ == "__main__":
train_routers()