File size: 5,297 Bytes
b92d96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import numpy as np
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
import lightgbm as lgb
from typing import Tuple, Any
import joblib
import os

class LearnedRouter:
    def __init__(self, model_type: str = "lightgbm", n_clusters: int = 32, mrl_dims: int = 64):
        self.model_type = model_type
        self.n_clusters = n_clusters
        self.mrl_dims = mrl_dims
        self.kmeans = None
        self.classifier = None
        
    def train(self, X_full: np.ndarray):
        """
        Trains the router:
        1. Cluster X_full using K-Means to generate ground-truth labels.
        2. Slice X_full to MRL_DIMS.
        3. Train the specified classifier on sliced vectors to predict cluster labels.
        """
        print(f"Training Router ({self.model_type})...")
        
        # 1. Generate Ground Truth Labels with K-Means on FULL vectors
        # (We want the clusters to be based on the high-fidelity data)
        print("  - Running K-Means for ground truth labels...")
        self.kmeans = KMeans(n_clusters=self.n_clusters, random_state=42, n_init=10)
        y_labels = self.kmeans.fit_predict(X_full)
        
        # 2. Slice Input Data for the Router
        # The router only sees the low-dim MRL vector
        print(f"  - Slicing vectors to {self.mrl_dims} dimensions...")
        # Note: We assume X_full is already normalized if needed, 
        # but for MRL slicing we should re-normalize the slice.
        # We'll do a quick slice and normalize here locally or assume caller handles it.
        # Ideally, we use the mrl_slice function from data_pipeline, but to avoid circular imports
        # or dependency issues, let's implement the logic here or import it.
        # Let's do the math here to be self-contained in the class logic.
        X_sliced = X_full[:, :self.mrl_dims]
        norms = np.linalg.norm(X_sliced, axis=1, keepdims=True)
        norms[norms == 0] = 1e-10
        X_train = X_sliced / norms
        
        # 3. Train Classifier
        print(f"  - Training classifier: {self.model_type}...")
        if self.model_type == "lightgbm":
            # LightGBM
            train_data = lgb.Dataset(X_train, label=y_labels)
            params = {
                'objective': 'multiclass',
                'num_class': self.n_clusters,
                'metric': 'multi_logloss',
                'verbosity': -1,
                'seed': 42
            }
            self.classifier = lgb.train(params, train_data, num_boost_round=100)
            
        elif self.model_type == "logistic":
            # Logistic Regression
            self.classifier = LogisticRegression(max_iter=1000, multi_class='multinomial', random_state=42)
            self.classifier.fit(X_train, y_labels)
            
        elif self.model_type == "mlp":
            # MLP Classifier
            self.classifier = MLPClassifier(hidden_layer_sizes=(128, 64), max_iter=500, random_state=42)
            self.classifier.fit(X_train, y_labels)
            
        else:
            raise ValueError(f"Unknown router model type: {self.model_type}")
            
        print("Router training complete.")

    def predict(self, vector_full: np.ndarray) -> Tuple[int, float]:
        """
        Predicts the target cluster for a query vector.
        1. Slice input to MRL dims.
        2. Predict probabilities.
        3. Return (best_cluster, confidence_score).
        """
        # Ensure input is 2D
        if vector_full.ndim == 1:
            vector_full = vector_full.reshape(1, -1)
            
        # 1. Slice and Normalize
        X_sliced = vector_full[:, :self.mrl_dims]
        norms = np.linalg.norm(X_sliced, axis=1, keepdims=True)
        norms[norms == 0] = 1e-10
        X_input = X_sliced / norms
        
        # 2. Predict
        if self.model_type == "lightgbm":
            probs = self.classifier.predict(X_input) # Returns (n_samples, n_classes)
        elif self.model_type in ["logistic", "mlp"]:
            probs = self.classifier.predict_proba(X_input)
        else:
            raise ValueError("Model not trained or unknown type")
            
        # 3. Get best cluster and confidence
        best_cluster = np.argmax(probs, axis=1)[0]
        confidence = np.max(probs, axis=1)[0]
        
        return int(best_cluster), float(confidence)

    def save(self, path: str):
        """Saves the router (KMeans + Classifier) to disk."""
        print(f"Saving router to {path}...")
        joblib.dump({
            'model_type': self.model_type,
            'n_clusters': self.n_clusters,
            'mrl_dims': self.mrl_dims,
            'kmeans': self.kmeans,
            'classifier': self.classifier
        }, path)
        print("Router saved.")

    @classmethod
    def load(cls, path: str):
        """Loads the router from disk."""
        print(f"Loading router from {path}...")
        data = joblib.load(path)
        router = cls(
            model_type=data['model_type'],
            n_clusters=data['n_clusters'],
            mrl_dims=data['mrl_dims']
        )
        router.kmeans = data['kmeans']
        router.classifier = data['classifier']
        print("Router loaded.")
        return router