import json import warnings import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset from tqdm import tqdm from huggingface_hub import PyTorchModelHubMixin from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.linear_model import LogisticRegression from sklearn.metrics import roc_auc_score from sklearn.preprocessing import StandardScaler warnings.filterwarnings("ignore") def set_seed(seed: int = 42): """Set random seeds for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False class MLPProbeTorch(BaseEstimator, ClassifierMixin, PyTorchModelHubMixin): def __init__( self, l2_lambda=1.0, l1_lambda=0.0, epochs=10, batch_size=256, normalize=True, device="cuda", random_state=42, verbose=True, early_stopping_patience=None, hidden_dim=1024, ): self.l2_lambda = l2_lambda self.l1_lambda = l1_lambda self.epochs = epochs self.batch_size = batch_size self.normalize = normalize self.device = device if torch.cuda.is_available() else "cpu" self.random_state = random_state self.verbose = verbose self.early_stopping_patience = early_stopping_patience self.hidden_dim = hidden_dim self.model = None self.scaler = None self.classes_ = None self.history = { "train_loss": [], "train_bce_loss": [], "train_l2_reg": [], "train_l1_reg": [], "val_loss": [], "val_bce_loss": [], "val_l2_reg": [], "val_l1_reg": [], "val_auc": [], } self.best_val_auc_ = -np.inf self.best_epoch_ = 0 def _create_model(self, input_dim, positive_class_proportion=0.5): model = nn.Sequential( nn.Dropout(0.1), nn.Linear(input_dim, self.hidden_dim), nn.SiLU(), nn.BatchNorm1d(self.hidden_dim), nn.Dropout(0.1), nn.Linear(self.hidden_dim, 1), ) nn.init.xavier_uniform_(model[1].weight, 0.8) nn.init.zeros_(model[1].bias) nn.init.zeros_(model[5].weight) p = np.clip(positive_class_proportion, 0.01, 0.99) nn.init.constant_(model[5].bias, np.log(p / (1 - p))) return model def _create_dataloader(self, X, y): dataset = TensorDataset(torch.FloatTensor(X), torch.FloatTensor(y)) return DataLoader( dataset, batch_size=self.batch_size, shuffle=True, drop_last=False ) def fit(self, X, y, X_val=None, y_val=None): set_seed(self.random_state) if isinstance(X, torch.Tensor): X = X.cpu().numpy() if isinstance(y, torch.Tensor): y = y.cpu().numpy() X, y = np.asarray(X), np.asarray(y) has_validation = X_val is not None and y_val is not None if has_validation: X_val = np.asarray( X_val.cpu().numpy() if isinstance(X_val, torch.Tensor) else X_val ) y_val = np.asarray( y_val.cpu().numpy() if isinstance(y_val, torch.Tensor) else y_val ) self.classes_ = np.unique(y) if self.normalize: self.scaler = StandardScaler() X = self.scaler.fit_transform(X) if has_validation: X_val = self.scaler.transform(X_val) pos = np.mean(y_val if has_validation else y) input_dim = X.shape[1] self.model = self._create_model(input_dim, pos).to(self.device) optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001) train_loader = self._create_dataloader(X, y) loss_fn = nn.BCEWithLogitsLoss() best_val_auc, best_model_state, patience_counter = -np.inf, None, 0 self.model.train() for epoch in range(self.epochs): epoch_loss = epoch_bce = epoch_l2 = epoch_l1 = 0.0 it = ( tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.epochs}") if self.verbose else train_loader ) for X_batch, y_batch in it: X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device) batch_losses = {} def closure(): optimizer.zero_grad() out = self.model(X_batch).squeeze(-1) bce = loss_fn(out, y_batch) loss = bce batch_losses["bce"] = bce.item() batch_losses["l2"] = batch_losses["l1"] = 0.0 if self.l2_lambda > 0: l2_reg = sum( p.pow(2).sum() for n, p in self.model.named_parameters() if "bias" not in n ) n_params = sum( p.numel() for n, p in self.model.named_parameters() if "bias" not in n ) loss += 0.5 * self.l2_lambda * l2_reg / n_params batch_losses["l2"] = ( 0.5 * self.l2_lambda * l2_reg / n_params ).item() if self.l1_lambda > 0: l1_reg = sum( p.abs().sum() for n, p in self.model.named_parameters() if "bias" not in n ) n_params = sum( p.numel() for n, p in self.model.named_parameters() if "bias" not in n ) loss += self.l1_lambda * l1_reg / n_params batch_losses["l1"] = (self.l1_lambda * l1_reg / n_params).item() loss.backward() return loss loss = closure() optimizer.step() epoch_loss += loss.item() epoch_bce += batch_losses.get("bce", 0) epoch_l2 += batch_losses.get("l2", 0) epoch_l1 += batch_losses.get("l1", 0) n_batches = len(train_loader) self.history["train_loss"].append(epoch_loss / n_batches) self.history["train_bce_loss"].append(epoch_bce / n_batches) self.history["train_l2_reg"].append(epoch_l2 / n_batches) self.history["train_l1_reg"].append(epoch_l1 / n_batches) if has_validation: self.model.eval() with torch.no_grad(): X_v = torch.FloatTensor(X_val).to(self.device) y_v = torch.FloatTensor(y_val).to(self.device) logits = self.model(X_v).squeeze(-1) val_bce = loss_fn(logits, y_v).item() val_proba = torch.sigmoid(logits).cpu().numpy() val_auc = roc_auc_score(y_val, val_proba) self.history["val_auc"].append(val_auc) self.model.train() if val_auc > best_val_auc: best_val_auc, best_model_state = val_auc, { k: v.cpu().clone() for k, v in self.model.state_dict().items() } patience_counter = 0 self.best_val_auc_, self.best_epoch_ = best_val_auc, epoch + 1 else: patience_counter += 1 if self.verbose: print( f"Epoch {epoch+1}/{self.epochs} - Val AUC: {val_auc:.4f} {'*' if val_auc == best_val_auc else ''}" ) if ( self.early_stopping_patience is not None and patience_counter >= self.early_stopping_patience ): break if has_validation and best_model_state is not None: self.model.load_state_dict(best_model_state) return self def predict_proba(self, X): if self.model is None: raise ValueError("Model not fitted yet. Call fit() first.") X = np.asarray(X.cpu().numpy() if isinstance(X, torch.Tensor) else X) if self.normalize and self.scaler is not None: X = self.scaler.transform(X) self.model.eval() with torch.no_grad(): logits = self.model(torch.FloatTensor(X).to(self.device)).squeeze(-1) proba_pos = torch.sigmoid(logits).cpu().numpy() return np.column_stack([1 - proba_pos, proba_pos]) def predict(self, X): return (self.predict_proba(X)[:, 1] >= 0.5).astype(int) def _save_pretrained(self, save_directory): os.makedirs(save_directory, exist_ok=True) torch.save( self.model.state_dict(), os.path.join(save_directory, "pytorch_model.bin") ) config = { "hidden_dim": self.hidden_dim, "l1_lambda": self.l1_lambda, "l2_lambda": self.l2_lambda, "normalize": self.normalize, } with open(os.path.join(save_directory, "config.json"), "w") as f: json.dump(config, f) if self.scaler is not None: with open(os.path.join(save_directory, "scaler.pkl"), "wb") as f: pickle.dump(self.scaler, f) if self.classes_ is not None: np.save(os.path.join(save_directory, "classes.npy"), self.classes_) @classmethod def _from_pretrained( cls, model_id, *args, config=None, cache_dir=None, force_download=False, **kwargs, ): from huggingface_hub import hf_hub_download import pickle weights_path = hf_hub_download(model_id, "pytorch_model.bin") config_path = hf_hub_download(model_id, "config.json") with open(config_path) as f: cfg = json.load(f) model = cls(**cfg) state_dict = torch.load(weights_path, map_location="cpu") input_dim = state_dict["1.weight"].shape[1] model.model = model._create_model(input_dim) model.model.load_state_dict(state_dict) try: scaler_path = hf_hub_download(model_id, "scaler.pkl") with open(scaler_path, "rb") as f: model.scaler = pickle.load(f) except: pass try: classes_path = hf_hub_download(model_id, "classes.npy") model.classes_ = np.load(classes_path) except: pass model.model.eval() device = "cuda" if torch.cuda.is_available() else "cpu" model.device = device model.model = model.model.to(device) return model