import json import os import pickle import numpy as np import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin from sklearn.preprocessing import StandardScaler from safetensors.torch import save_file, load_file class LinearProbeTorch(PyTorchModelHubMixin): def __init__( self, normalize=True, device="cpu", random_state=42, hidden_dim=1024, ): self.normalize = normalize self.device = device if torch.cuda.is_available() else "cpu" self.random_state = random_state self.hidden_dim = hidden_dim self.model = None self.scaler = None def _create_model(self, input_dim, positive_class_proportion=0.5): model = nn.Linear(input_dim, 1) nn.init.zeros_(model.weight) p = np.clip(positive_class_proportion, 0.01, 0.99) nn.init.constant_(model.bias, np.log(p / (1 - p))) return model def predict_proba(self, X): if self.model is None: raise ValueError("Model not fitted yet. Call fit() first.") X = np.asarray(X.cpu().numpy() if isinstance(X, torch.Tensor) else X) if self.normalize and self.scaler is not None: X = self.scaler.transform(X) self.model.eval() with torch.no_grad(): logits = self.model(torch.FloatTensor(X).to(self.device)).squeeze(-1) proba_pos = torch.sigmoid(logits).cpu().numpy() return np.column_stack([1 - proba_pos, proba_pos]) def predict(self, X): return (self.predict_proba(X)[:, 1] >= 0.5).astype(int) def _save_pretrained(self, save_directory): os.makedirs(save_directory, exist_ok=True) save_file( self.model.state_dict(), os.path.join(save_directory, "model.safetensors") ) config = { "hidden_dim": self.hidden_dim, "normalize": self.normalize, } with open(os.path.join(save_directory, "config.json"), "w") as f: json.dump(config, f) if self.scaler is not None: with open(os.path.join(save_directory, "scaler.pkl"), "wb") as f: pickle.dump(self.scaler, f) @classmethod def _from_pretrained( cls, model_id, *args, config=None, cache_dir=None, force_download=False, revision=None, **kwargs, ): from huggingface_hub import hf_hub_download import pickle weights_path = hf_hub_download(model_id, "model.safetensors",revision=revision) config_path = hf_hub_download(model_id, "config.json", revision=revision) with open(config_path) as f: cfg = json.load(f) model = cls(**cfg) state_dict = load_file(weights_path) input_dim = state_dict["weight"].shape[1] model.model = model._create_model(input_dim) model.model.load_state_dict(state_dict) try: scaler_path = hf_hub_download(model_id, "scaler.pkl", revision=revision) with open(scaler_path, "rb") as f: model.scaler = pickle.load(f) except: pass model.model.eval() device = "cuda" if torch.cuda.is_available() else "cpu" model.device = device model.model = model.model.to(device) return model