dashVectorSpace / src /router.py
justmotes's picture
Deploy dashVectorspace v1 (Full)
b92d96d
raw
history blame
5.3 kB
import numpy as np
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
import lightgbm as lgb
from typing import Tuple, Any
import joblib
import os
class LearnedRouter:
def __init__(self, model_type: str = "lightgbm", n_clusters: int = 32, mrl_dims: int = 64):
self.model_type = model_type
self.n_clusters = n_clusters
self.mrl_dims = mrl_dims
self.kmeans = None
self.classifier = None
def train(self, X_full: np.ndarray):
"""
Trains the router:
1. Cluster X_full using K-Means to generate ground-truth labels.
2. Slice X_full to MRL_DIMS.
3. Train the specified classifier on sliced vectors to predict cluster labels.
"""
print(f"Training Router ({self.model_type})...")
# 1. Generate Ground Truth Labels with K-Means on FULL vectors
# (We want the clusters to be based on the high-fidelity data)
print(" - Running K-Means for ground truth labels...")
self.kmeans = KMeans(n_clusters=self.n_clusters, random_state=42, n_init=10)
y_labels = self.kmeans.fit_predict(X_full)
# 2. Slice Input Data for the Router
# The router only sees the low-dim MRL vector
print(f" - Slicing vectors to {self.mrl_dims} dimensions...")
# Note: We assume X_full is already normalized if needed,
# but for MRL slicing we should re-normalize the slice.
# We'll do a quick slice and normalize here locally or assume caller handles it.
# Ideally, we use the mrl_slice function from data_pipeline, but to avoid circular imports
# or dependency issues, let's implement the logic here or import it.
# Let's do the math here to be self-contained in the class logic.
X_sliced = X_full[:, :self.mrl_dims]
norms = np.linalg.norm(X_sliced, axis=1, keepdims=True)
norms[norms == 0] = 1e-10
X_train = X_sliced / norms
# 3. Train Classifier
print(f" - Training classifier: {self.model_type}...")
if self.model_type == "lightgbm":
# LightGBM
train_data = lgb.Dataset(X_train, label=y_labels)
params = {
'objective': 'multiclass',
'num_class': self.n_clusters,
'metric': 'multi_logloss',
'verbosity': -1,
'seed': 42
}
self.classifier = lgb.train(params, train_data, num_boost_round=100)
elif self.model_type == "logistic":
# Logistic Regression
self.classifier = LogisticRegression(max_iter=1000, multi_class='multinomial', random_state=42)
self.classifier.fit(X_train, y_labels)
elif self.model_type == "mlp":
# MLP Classifier
self.classifier = MLPClassifier(hidden_layer_sizes=(128, 64), max_iter=500, random_state=42)
self.classifier.fit(X_train, y_labels)
else:
raise ValueError(f"Unknown router model type: {self.model_type}")
print("Router training complete.")
def predict(self, vector_full: np.ndarray) -> Tuple[int, float]:
"""
Predicts the target cluster for a query vector.
1. Slice input to MRL dims.
2. Predict probabilities.
3. Return (best_cluster, confidence_score).
"""
# Ensure input is 2D
if vector_full.ndim == 1:
vector_full = vector_full.reshape(1, -1)
# 1. Slice and Normalize
X_sliced = vector_full[:, :self.mrl_dims]
norms = np.linalg.norm(X_sliced, axis=1, keepdims=True)
norms[norms == 0] = 1e-10
X_input = X_sliced / norms
# 2. Predict
if self.model_type == "lightgbm":
probs = self.classifier.predict(X_input) # Returns (n_samples, n_classes)
elif self.model_type in ["logistic", "mlp"]:
probs = self.classifier.predict_proba(X_input)
else:
raise ValueError("Model not trained or unknown type")
# 3. Get best cluster and confidence
best_cluster = np.argmax(probs, axis=1)[0]
confidence = np.max(probs, axis=1)[0]
return int(best_cluster), float(confidence)
def save(self, path: str):
"""Saves the router (KMeans + Classifier) to disk."""
print(f"Saving router to {path}...")
joblib.dump({
'model_type': self.model_type,
'n_clusters': self.n_clusters,
'mrl_dims': self.mrl_dims,
'kmeans': self.kmeans,
'classifier': self.classifier
}, path)
print("Router saved.")
@classmethod
def load(cls, path: str):
"""Loads the router from disk."""
print(f"Loading router from {path}...")
data = joblib.load(path)
router = cls(
model_type=data['model_type'],
n_clusters=data['n_clusters'],
mrl_dims=data['mrl_dims']
)
router.kmeans = data['kmeans']
router.classifier = data['classifier']
print("Router loaded.")
return router