| import json |
| import warnings |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader, TensorDataset |
| from tqdm import tqdm |
|
|
| from huggingface_hub import PyTorchModelHubMixin |
| from sklearn.base import BaseEstimator, ClassifierMixin |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.metrics import roc_auc_score |
| from sklearn.preprocessing import StandardScaler |
|
|
| warnings.filterwarnings("ignore") |
|
|
|
|
| def set_seed(seed: int = 42): |
| """Set random seeds for reproducibility.""" |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| class MLPProbeTorch(BaseEstimator, ClassifierMixin, PyTorchModelHubMixin): |
| def __init__( |
| self, |
| l2_lambda=1.0, |
| l1_lambda=0.0, |
| epochs=10, |
| batch_size=256, |
| normalize=True, |
| device="cuda", |
| random_state=42, |
| verbose=True, |
| early_stopping_patience=None, |
| hidden_dim=1024, |
| ): |
| self.l2_lambda = l2_lambda |
| self.l1_lambda = l1_lambda |
| self.epochs = epochs |
| self.batch_size = batch_size |
| self.normalize = normalize |
| self.device = device if torch.cuda.is_available() else "cpu" |
| self.random_state = random_state |
| self.verbose = verbose |
| self.early_stopping_patience = early_stopping_patience |
| self.hidden_dim = hidden_dim |
| self.model = None |
| self.scaler = None |
| self.classes_ = None |
| self.history = { |
| "train_loss": [], |
| "train_bce_loss": [], |
| "train_l2_reg": [], |
| "train_l1_reg": [], |
| "val_loss": [], |
| "val_bce_loss": [], |
| "val_l2_reg": [], |
| "val_l1_reg": [], |
| "val_auc": [], |
| } |
| self.best_val_auc_ = -np.inf |
| self.best_epoch_ = 0 |
|
|
| def _create_model(self, input_dim, positive_class_proportion=0.5): |
| model = nn.Sequential( |
| nn.Dropout(0.1), |
| nn.Linear(input_dim, self.hidden_dim), |
| nn.SiLU(), |
| nn.BatchNorm1d(self.hidden_dim), |
| nn.Dropout(0.1), |
| nn.Linear(self.hidden_dim, 1), |
| ) |
| nn.init.xavier_uniform_(model[1].weight, 0.8) |
| nn.init.zeros_(model[1].bias) |
| nn.init.zeros_(model[5].weight) |
| p = np.clip(positive_class_proportion, 0.01, 0.99) |
| nn.init.constant_(model[5].bias, np.log(p / (1 - p))) |
| return model |
|
|
| def _create_dataloader(self, X, y): |
| dataset = TensorDataset(torch.FloatTensor(X), torch.FloatTensor(y)) |
| return DataLoader( |
| dataset, batch_size=self.batch_size, shuffle=True, drop_last=False |
| ) |
|
|
| def fit(self, X, y, X_val=None, y_val=None): |
| set_seed(self.random_state) |
| if isinstance(X, torch.Tensor): |
| X = X.cpu().numpy() |
| if isinstance(y, torch.Tensor): |
| y = y.cpu().numpy() |
| X, y = np.asarray(X), np.asarray(y) |
| has_validation = X_val is not None and y_val is not None |
| if has_validation: |
| X_val = np.asarray( |
| X_val.cpu().numpy() if isinstance(X_val, torch.Tensor) else X_val |
| ) |
| y_val = np.asarray( |
| y_val.cpu().numpy() if isinstance(y_val, torch.Tensor) else y_val |
| ) |
| self.classes_ = np.unique(y) |
| if self.normalize: |
| self.scaler = StandardScaler() |
| X = self.scaler.fit_transform(X) |
| if has_validation: |
| X_val = self.scaler.transform(X_val) |
| pos = np.mean(y_val if has_validation else y) |
| input_dim = X.shape[1] |
| self.model = self._create_model(input_dim, pos).to(self.device) |
| optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001) |
| train_loader = self._create_dataloader(X, y) |
| loss_fn = nn.BCEWithLogitsLoss() |
| best_val_auc, best_model_state, patience_counter = -np.inf, None, 0 |
| self.model.train() |
| for epoch in range(self.epochs): |
| epoch_loss = epoch_bce = epoch_l2 = epoch_l1 = 0.0 |
| it = ( |
| tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.epochs}") |
| if self.verbose |
| else train_loader |
| ) |
| for X_batch, y_batch in it: |
| X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device) |
| batch_losses = {} |
|
|
| def closure(): |
| optimizer.zero_grad() |
| out = self.model(X_batch).squeeze(-1) |
| bce = loss_fn(out, y_batch) |
| loss = bce |
| batch_losses["bce"] = bce.item() |
| batch_losses["l2"] = batch_losses["l1"] = 0.0 |
| if self.l2_lambda > 0: |
| l2_reg = sum( |
| p.pow(2).sum() |
| for n, p in self.model.named_parameters() |
| if "bias" not in n |
| ) |
| n_params = sum( |
| p.numel() |
| for n, p in self.model.named_parameters() |
| if "bias" not in n |
| ) |
| loss += 0.5 * self.l2_lambda * l2_reg / n_params |
| batch_losses["l2"] = ( |
| 0.5 * self.l2_lambda * l2_reg / n_params |
| ).item() |
| if self.l1_lambda > 0: |
| l1_reg = sum( |
| p.abs().sum() |
| for n, p in self.model.named_parameters() |
| if "bias" not in n |
| ) |
| n_params = sum( |
| p.numel() |
| for n, p in self.model.named_parameters() |
| if "bias" not in n |
| ) |
| loss += self.l1_lambda * l1_reg / n_params |
| batch_losses["l1"] = (self.l1_lambda * l1_reg / n_params).item() |
| loss.backward() |
| return loss |
|
|
| loss = closure() |
| optimizer.step() |
| epoch_loss += loss.item() |
| epoch_bce += batch_losses.get("bce", 0) |
| epoch_l2 += batch_losses.get("l2", 0) |
| epoch_l1 += batch_losses.get("l1", 0) |
| n_batches = len(train_loader) |
| self.history["train_loss"].append(epoch_loss / n_batches) |
| self.history["train_bce_loss"].append(epoch_bce / n_batches) |
| self.history["train_l2_reg"].append(epoch_l2 / n_batches) |
| self.history["train_l1_reg"].append(epoch_l1 / n_batches) |
| if has_validation: |
| self.model.eval() |
| with torch.no_grad(): |
| X_v = torch.FloatTensor(X_val).to(self.device) |
| y_v = torch.FloatTensor(y_val).to(self.device) |
| logits = self.model(X_v).squeeze(-1) |
| val_bce = loss_fn(logits, y_v).item() |
| val_proba = torch.sigmoid(logits).cpu().numpy() |
| val_auc = roc_auc_score(y_val, val_proba) |
| self.history["val_auc"].append(val_auc) |
| self.model.train() |
| if val_auc > best_val_auc: |
| best_val_auc, best_model_state = val_auc, { |
| k: v.cpu().clone() for k, v in self.model.state_dict().items() |
| } |
| patience_counter = 0 |
| self.best_val_auc_, self.best_epoch_ = best_val_auc, epoch + 1 |
| else: |
| patience_counter += 1 |
| if self.verbose: |
| print( |
| f"Epoch {epoch+1}/{self.epochs} - Val AUC: {val_auc:.4f} {'*' if val_auc == best_val_auc else ''}" |
| ) |
| if ( |
| self.early_stopping_patience is not None |
| and patience_counter >= self.early_stopping_patience |
| ): |
| break |
| if has_validation and best_model_state is not None: |
| self.model.load_state_dict(best_model_state) |
| return self |
|
|
| def predict_proba(self, X): |
| if self.model is None: |
| raise ValueError("Model not fitted yet. Call fit() first.") |
| X = np.asarray(X.cpu().numpy() if isinstance(X, torch.Tensor) else X) |
| if self.normalize and self.scaler is not None: |
| X = self.scaler.transform(X) |
| self.model.eval() |
| with torch.no_grad(): |
|
|
| logits = self.model(torch.FloatTensor(X).to(self.device)).squeeze(-1) |
| proba_pos = torch.sigmoid(logits).cpu().numpy() |
| return np.column_stack([1 - proba_pos, proba_pos]) |
|
|
| def predict(self, X): |
|
|
| return (self.predict_proba(X)[:, 1] >= 0.5).astype(int) |
|
|
| def _save_pretrained(self, save_directory): |
| os.makedirs(save_directory, exist_ok=True) |
|
|
| torch.save( |
| self.model.state_dict(), os.path.join(save_directory, "pytorch_model.bin") |
| ) |
|
|
| config = { |
| "hidden_dim": self.hidden_dim, |
| "l1_lambda": self.l1_lambda, |
| "l2_lambda": self.l2_lambda, |
| "normalize": self.normalize, |
| } |
|
|
| with open(os.path.join(save_directory, "config.json"), "w") as f: |
| json.dump(config, f) |
|
|
| if self.scaler is not None: |
| with open(os.path.join(save_directory, "scaler.pkl"), "wb") as f: |
| pickle.dump(self.scaler, f) |
|
|
| if self.classes_ is not None: |
| np.save(os.path.join(save_directory, "classes.npy"), self.classes_) |
|
|
| @classmethod |
| def _from_pretrained( |
| cls, |
| model_id, |
| *args, |
| config=None, |
| cache_dir=None, |
| force_download=False, |
| **kwargs, |
| ): |
|
|
| from huggingface_hub import hf_hub_download |
| import pickle |
|
|
| weights_path = hf_hub_download(model_id, "pytorch_model.bin") |
| config_path = hf_hub_download(model_id, "config.json") |
|
|
| with open(config_path) as f: |
| cfg = json.load(f) |
|
|
| model = cls(**cfg) |
|
|
| state_dict = torch.load(weights_path, map_location="cpu") |
|
|
| input_dim = state_dict["1.weight"].shape[1] |
| model.model = model._create_model(input_dim) |
|
|
| model.model.load_state_dict(state_dict) |
|
|
| try: |
| scaler_path = hf_hub_download(model_id, "scaler.pkl") |
| with open(scaler_path, "rb") as f: |
| model.scaler = pickle.load(f) |
| except: |
| pass |
|
|
| try: |
| classes_path = hf_hub_download(model_id, "classes.npy") |
| model.classes_ = np.load(classes_path) |
| except: |
| pass |
|
|
| model.model.eval() |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model.device = device |
| model.model = model.model.to(device) |
|
|
| return model |
|
|