Spaces:
Running
Running
| import numpy as np | |
| from sklearn.cluster import KMeans | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.neural_network import MLPClassifier | |
| import lightgbm as lgb | |
| from typing import Tuple, Any | |
| import joblib | |
| import os | |
| class LearnedRouter: | |
| def __init__(self, model_type: str = "lightgbm", n_clusters: int = 32, mrl_dims: int = 64): | |
| self.model_type = model_type | |
| self.n_clusters = n_clusters | |
| self.mrl_dims = mrl_dims | |
| self.kmeans = None | |
| self.classifier = None | |
| def train(self, X_full: np.ndarray): | |
| """ | |
| Trains the router: | |
| 1. Cluster X_full using K-Means to generate ground-truth labels. | |
| 2. Slice X_full to MRL_DIMS. | |
| 3. Train the specified classifier on sliced vectors to predict cluster labels. | |
| """ | |
| print(f"Training Router ({self.model_type})...") | |
| # 1. Generate Ground Truth Labels with K-Means on FULL vectors | |
| # (We want the clusters to be based on the high-fidelity data) | |
| print(" - Running K-Means for ground truth labels...") | |
| self.kmeans = KMeans(n_clusters=self.n_clusters, random_state=42, n_init=10) | |
| y_labels = self.kmeans.fit_predict(X_full) | |
| # 2. Slice Input Data for the Router | |
| # The router only sees the low-dim MRL vector | |
| print(f" - Slicing vectors to {self.mrl_dims} dimensions...") | |
| # Note: We assume X_full is already normalized if needed, | |
| # but for MRL slicing we should re-normalize the slice. | |
| # We'll do a quick slice and normalize here locally or assume caller handles it. | |
| # Ideally, we use the mrl_slice function from data_pipeline, but to avoid circular imports | |
| # or dependency issues, let's implement the logic here or import it. | |
| # Let's do the math here to be self-contained in the class logic. | |
| X_sliced = X_full[:, :self.mrl_dims] | |
| norms = np.linalg.norm(X_sliced, axis=1, keepdims=True) | |
| norms[norms == 0] = 1e-10 | |
| X_train = X_sliced / norms | |
| # 3. Train Classifier | |
| print(f" - Training classifier: {self.model_type}...") | |
| if self.model_type == "lightgbm": | |
| # LightGBM | |
| train_data = lgb.Dataset(X_train, label=y_labels) | |
| params = { | |
| 'objective': 'multiclass', | |
| 'num_class': self.n_clusters, | |
| 'metric': 'multi_logloss', | |
| 'verbosity': -1, | |
| 'seed': 42 | |
| } | |
| self.classifier = lgb.train(params, train_data, num_boost_round=100) | |
| elif self.model_type == "logistic": | |
| # Logistic Regression | |
| self.classifier = LogisticRegression(max_iter=1000, multi_class='multinomial', random_state=42) | |
| self.classifier.fit(X_train, y_labels) | |
| elif self.model_type == "mlp": | |
| # MLP Classifier | |
| self.classifier = MLPClassifier(hidden_layer_sizes=(128, 64), max_iter=500, random_state=42) | |
| self.classifier.fit(X_train, y_labels) | |
| else: | |
| raise ValueError(f"Unknown router model type: {self.model_type}") | |
| print("Router training complete.") | |
| def predict(self, vector_full: np.ndarray) -> Tuple[int, float]: | |
| """ | |
| Predicts the target cluster for a query vector. | |
| 1. Slice input to MRL dims. | |
| 2. Predict probabilities. | |
| 3. Return (best_cluster, confidence_score). | |
| """ | |
| # Ensure input is 2D | |
| if vector_full.ndim == 1: | |
| vector_full = vector_full.reshape(1, -1) | |
| # 1. Slice and Normalize | |
| X_sliced = vector_full[:, :self.mrl_dims] | |
| norms = np.linalg.norm(X_sliced, axis=1, keepdims=True) | |
| norms[norms == 0] = 1e-10 | |
| X_input = X_sliced / norms | |
| # 2. Predict | |
| if self.model_type == "lightgbm": | |
| probs = self.classifier.predict(X_input) # Returns (n_samples, n_classes) | |
| elif self.model_type in ["logistic", "mlp"]: | |
| probs = self.classifier.predict_proba(X_input) | |
| else: | |
| raise ValueError("Model not trained or unknown type") | |
| # 3. Get best cluster and confidence | |
| best_cluster = np.argmax(probs, axis=1)[0] | |
| confidence = np.max(probs, axis=1)[0] | |
| return int(best_cluster), float(confidence) | |
| def save(self, path: str): | |
| """Saves the router (KMeans + Classifier) to disk.""" | |
| print(f"Saving router to {path}...") | |
| joblib.dump({ | |
| 'model_type': self.model_type, | |
| 'n_clusters': self.n_clusters, | |
| 'mrl_dims': self.mrl_dims, | |
| 'kmeans': self.kmeans, | |
| 'classifier': self.classifier | |
| }, path) | |
| print("Router saved.") | |
| def load(cls, path: str): | |
| """Loads the router from disk.""" | |
| print(f"Loading router from {path}...") | |
| data = joblib.load(path) | |
| router = cls( | |
| model_type=data['model_type'], | |
| n_clusters=data['n_clusters'], | |
| mrl_dims=data['mrl_dims'] | |
| ) | |
| router.kmeans = data['kmeans'] | |
| router.classifier = data['classifier'] | |
| print("Router loaded.") | |
| return router | |