File size: 6,595 Bytes
b92d96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51fc709
b92d96d
 
51fc709
b92d96d
 
 
 
 
 
 
51fc709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b92d96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9df6ef
 
 
 
b92d96d
b9df6ef
 
 
 
 
 
 
 
 
 
b92d96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import numpy as np
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
import lightgbm as lgb
from typing import Tuple, Any
import joblib
import os

class LearnedRouter:
    def __init__(self, model_type: str = "lightgbm", n_clusters: int = 32, mrl_dims: int = 64):
        self.model_type = model_type
        self.n_clusters = n_clusters
        self.mrl_dims = mrl_dims
        self.kmeans = None
        self.classifier = None
        
    def train(self, X_full: np.ndarray, labels: np.ndarray = None):
        """
        Trains the router:
        1. Cluster X_full using K-Means to generate ground-truth labels (if not provided).
        2. Slice X_full to MRL_DIMS.
        3. Train the specified classifier on sliced vectors to predict cluster labels.
        """
        print(f"Training Router ({self.model_type})...")
        
        # 1. Generate Ground Truth Labels with K-Means on FULL vectors
        # (We want the clusters to be based on the high-fidelity data)
        if labels is not None:
            print("  - Using provided ground-truth labels (Shared KMeans).")
            y_labels = labels
            # We still need a kmeans object for the save/load to work, 
            # but if we are just using the classifier, maybe not?
            # The predict method DOES NOT use kmeans. It uses the classifier.
            # However, for consistency, we should probably have the kmeans object if possible,
            # but if we passed labels, we might not have the object.
            # Let's assume the caller handles the 'kmeans' attribute if they want to save it,
            # or we just don't save it if it's None.
            # Actually, 'save' method dumps self.kmeans. 
            # If it's None, it might break if we try to use it later? 
            # Predict doesn't use it. So it's fine.
        else:
            print("  - Running K-Means for ground truth labels...")
            self.kmeans = KMeans(n_clusters=self.n_clusters, random_state=42, n_init=10)
            y_labels = self.kmeans.fit_predict(X_full)
        
        # 2. Slice Input Data for the Router
        # The router only sees the low-dim MRL vector
        print(f"  - Slicing vectors to {self.mrl_dims} dimensions...")
        # Note: We assume X_full is already normalized if needed, 
        # but for MRL slicing we should re-normalize the slice.
        # We'll do a quick slice and normalize here locally or assume caller handles it.
        # Ideally, we use the mrl_slice function from data_pipeline, but to avoid circular imports
        # or dependency issues, let's implement the logic here or import it.
        # Let's do the math here to be self-contained in the class logic.
        X_sliced = X_full[:, :self.mrl_dims]
        norms = np.linalg.norm(X_sliced, axis=1, keepdims=True)
        norms[norms == 0] = 1e-10
        X_train = X_sliced / norms
        
        # 3. Train Classifier
        print(f"  - Training classifier: {self.model_type}...")
        if self.model_type == "lightgbm":
            # LightGBM
            train_data = lgb.Dataset(X_train, label=y_labels)
            params = {
                'objective': 'multiclass',
                'num_class': self.n_clusters,
                'metric': 'multi_logloss',
                'verbosity': -1,
                'seed': 42
            }
            self.classifier = lgb.train(params, train_data, num_boost_round=100)
            
        elif self.model_type == "logistic":
            # Logistic Regression
            self.classifier = LogisticRegression(max_iter=1000, multi_class='multinomial', random_state=42)
            self.classifier.fit(X_train, y_labels)
            
        elif self.model_type == "mlp":
            # MLP Classifier
            self.classifier = MLPClassifier(hidden_layer_sizes=(128, 64), max_iter=500, random_state=42)
            self.classifier.fit(X_train, y_labels)
            
        else:
            raise ValueError(f"Unknown router model type: {self.model_type}")
            
        print("Router training complete.")

    def predict(self, vector_full: np.ndarray) -> Tuple[int, float]:
        """
        Predicts the target cluster for a query vector.
        1. Slice input to MRL dims.
        2. Predict probabilities.
        3. Return (best_cluster, confidence_score).
        """
        # Ensure input is 2D
        if vector_full.ndim == 1:
            vector_full = vector_full.reshape(1, -1)
            
        # 1. Slice and Normalize
        X_sliced = vector_full[:, :self.mrl_dims]
        norms = np.linalg.norm(X_sliced, axis=1, keepdims=True)
        norms[norms == 0] = 1e-10
        X_input = X_sliced / norms
        
        # 2. Predict
        if self.model_type == "lightgbm":
            probs = self.classifier.predict(X_input) # Returns (n_samples, n_classes)
        elif self.model_type in ["logistic", "mlp"]:
            probs = self.classifier.predict_proba(X_input)
        else:
            raise ValueError("Model not trained or unknown type")
            
        # 3. Get clusters based on cumulative confidence > 0.9
        probs = probs[0] # Flatten to 1D array
        sorted_indices = np.argsort(probs)[::-1]
        sorted_probs = probs[sorted_indices]
        
        selected_clusters = []
        cumulative_confidence = 0.0
        
        for idx, prob in zip(sorted_indices, sorted_probs):
            selected_clusters.append(int(idx))
            cumulative_confidence += prob
            if cumulative_confidence > 0.9:
                break
                
        return selected_clusters, float(cumulative_confidence)

    def save(self, path: str):
        """Saves the router (KMeans + Classifier) to disk."""
        print(f"Saving router to {path}...")
        joblib.dump({
            'model_type': self.model_type,
            'n_clusters': self.n_clusters,
            'mrl_dims': self.mrl_dims,
            'kmeans': self.kmeans,
            'classifier': self.classifier
        }, path)
        print("Router saved.")

    @classmethod
    def load(cls, path: str):
        """Loads the router from disk."""
        print(f"Loading router from {path}...")
        data = joblib.load(path)
        router = cls(
            model_type=data['model_type'],
            n_clusters=data['n_clusters'],
            mrl_dims=data['mrl_dims']
        )
        router.kmeans = data['kmeans']
        router.classifier = data['classifier']
        print("Router loaded.")
        return router