overflow_probe_xrag_full / mlp_probe.py
wexumin's picture
Create mlp_probe.py
e884210 verified
import json
import warnings
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from huggingface_hub import PyTorchModelHubMixin
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
warnings.filterwarnings("ignore")
def set_seed(seed: int = 42):
"""Set random seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class MLPProbeTorch(BaseEstimator, ClassifierMixin, PyTorchModelHubMixin):
def __init__(
self,
l2_lambda=1.0,
l1_lambda=0.0,
epochs=10,
batch_size=256,
normalize=True,
device="cuda",
random_state=42,
verbose=True,
early_stopping_patience=None,
hidden_dim=1024,
):
self.l2_lambda = l2_lambda
self.l1_lambda = l1_lambda
self.epochs = epochs
self.batch_size = batch_size
self.normalize = normalize
self.device = device if torch.cuda.is_available() else "cpu"
self.random_state = random_state
self.verbose = verbose
self.early_stopping_patience = early_stopping_patience
self.hidden_dim = hidden_dim
self.model = None
self.scaler = None
self.classes_ = None
self.history = {
"train_loss": [],
"train_bce_loss": [],
"train_l2_reg": [],
"train_l1_reg": [],
"val_loss": [],
"val_bce_loss": [],
"val_l2_reg": [],
"val_l1_reg": [],
"val_auc": [],
}
self.best_val_auc_ = -np.inf
self.best_epoch_ = 0
def _create_model(self, input_dim, positive_class_proportion=0.5):
model = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(input_dim, self.hidden_dim),
nn.SiLU(),
nn.BatchNorm1d(self.hidden_dim),
nn.Dropout(0.1),
nn.Linear(self.hidden_dim, 1),
)
nn.init.xavier_uniform_(model[1].weight, 0.8)
nn.init.zeros_(model[1].bias)
nn.init.zeros_(model[5].weight)
p = np.clip(positive_class_proportion, 0.01, 0.99)
nn.init.constant_(model[5].bias, np.log(p / (1 - p)))
return model
def _create_dataloader(self, X, y):
dataset = TensorDataset(torch.FloatTensor(X), torch.FloatTensor(y))
return DataLoader(
dataset, batch_size=self.batch_size, shuffle=True, drop_last=False
)
def fit(self, X, y, X_val=None, y_val=None):
set_seed(self.random_state)
if isinstance(X, torch.Tensor):
X = X.cpu().numpy()
if isinstance(y, torch.Tensor):
y = y.cpu().numpy()
X, y = np.asarray(X), np.asarray(y)
has_validation = X_val is not None and y_val is not None
if has_validation:
X_val = np.asarray(
X_val.cpu().numpy() if isinstance(X_val, torch.Tensor) else X_val
)
y_val = np.asarray(
y_val.cpu().numpy() if isinstance(y_val, torch.Tensor) else y_val
)
self.classes_ = np.unique(y)
if self.normalize:
self.scaler = StandardScaler()
X = self.scaler.fit_transform(X)
if has_validation:
X_val = self.scaler.transform(X_val)
pos = np.mean(y_val if has_validation else y)
input_dim = X.shape[1]
self.model = self._create_model(input_dim, pos).to(self.device)
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001)
train_loader = self._create_dataloader(X, y)
loss_fn = nn.BCEWithLogitsLoss()
best_val_auc, best_model_state, patience_counter = -np.inf, None, 0
self.model.train()
for epoch in range(self.epochs):
epoch_loss = epoch_bce = epoch_l2 = epoch_l1 = 0.0
it = (
tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.epochs}")
if self.verbose
else train_loader
)
for X_batch, y_batch in it:
X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
batch_losses = {}
def closure():
optimizer.zero_grad()
out = self.model(X_batch).squeeze(-1)
bce = loss_fn(out, y_batch)
loss = bce
batch_losses["bce"] = bce.item()
batch_losses["l2"] = batch_losses["l1"] = 0.0
if self.l2_lambda > 0:
l2_reg = sum(
p.pow(2).sum()
for n, p in self.model.named_parameters()
if "bias" not in n
)
n_params = sum(
p.numel()
for n, p in self.model.named_parameters()
if "bias" not in n
)
loss += 0.5 * self.l2_lambda * l2_reg / n_params
batch_losses["l2"] = (
0.5 * self.l2_lambda * l2_reg / n_params
).item()
if self.l1_lambda > 0:
l1_reg = sum(
p.abs().sum()
for n, p in self.model.named_parameters()
if "bias" not in n
)
n_params = sum(
p.numel()
for n, p in self.model.named_parameters()
if "bias" not in n
)
loss += self.l1_lambda * l1_reg / n_params
batch_losses["l1"] = (self.l1_lambda * l1_reg / n_params).item()
loss.backward()
return loss
loss = closure()
optimizer.step()
epoch_loss += loss.item()
epoch_bce += batch_losses.get("bce", 0)
epoch_l2 += batch_losses.get("l2", 0)
epoch_l1 += batch_losses.get("l1", 0)
n_batches = len(train_loader)
self.history["train_loss"].append(epoch_loss / n_batches)
self.history["train_bce_loss"].append(epoch_bce / n_batches)
self.history["train_l2_reg"].append(epoch_l2 / n_batches)
self.history["train_l1_reg"].append(epoch_l1 / n_batches)
if has_validation:
self.model.eval()
with torch.no_grad():
X_v = torch.FloatTensor(X_val).to(self.device)
y_v = torch.FloatTensor(y_val).to(self.device)
logits = self.model(X_v).squeeze(-1)
val_bce = loss_fn(logits, y_v).item()
val_proba = torch.sigmoid(logits).cpu().numpy()
val_auc = roc_auc_score(y_val, val_proba)
self.history["val_auc"].append(val_auc)
self.model.train()
if val_auc > best_val_auc:
best_val_auc, best_model_state = val_auc, {
k: v.cpu().clone() for k, v in self.model.state_dict().items()
}
patience_counter = 0
self.best_val_auc_, self.best_epoch_ = best_val_auc, epoch + 1
else:
patience_counter += 1
if self.verbose:
print(
f"Epoch {epoch+1}/{self.epochs} - Val AUC: {val_auc:.4f} {'*' if val_auc == best_val_auc else ''}"
)
if (
self.early_stopping_patience is not None
and patience_counter >= self.early_stopping_patience
):
break
if has_validation and best_model_state is not None:
self.model.load_state_dict(best_model_state)
return self
def predict_proba(self, X):
if self.model is None:
raise ValueError("Model not fitted yet. Call fit() first.")
X = np.asarray(X.cpu().numpy() if isinstance(X, torch.Tensor) else X)
if self.normalize and self.scaler is not None:
X = self.scaler.transform(X)
self.model.eval()
with torch.no_grad():
logits = self.model(torch.FloatTensor(X).to(self.device)).squeeze(-1)
proba_pos = torch.sigmoid(logits).cpu().numpy()
return np.column_stack([1 - proba_pos, proba_pos])
def predict(self, X):
return (self.predict_proba(X)[:, 1] >= 0.5).astype(int)
def _save_pretrained(self, save_directory):
os.makedirs(save_directory, exist_ok=True)
torch.save(
self.model.state_dict(), os.path.join(save_directory, "pytorch_model.bin")
)
config = {
"hidden_dim": self.hidden_dim,
"l1_lambda": self.l1_lambda,
"l2_lambda": self.l2_lambda,
"normalize": self.normalize,
}
with open(os.path.join(save_directory, "config.json"), "w") as f:
json.dump(config, f)
if self.scaler is not None:
with open(os.path.join(save_directory, "scaler.pkl"), "wb") as f:
pickle.dump(self.scaler, f)
if self.classes_ is not None:
np.save(os.path.join(save_directory, "classes.npy"), self.classes_)
@classmethod
def _from_pretrained(
cls,
model_id,
*args,
config=None,
cache_dir=None,
force_download=False,
**kwargs,
):
from huggingface_hub import hf_hub_download
import pickle
weights_path = hf_hub_download(model_id, "pytorch_model.bin")
config_path = hf_hub_download(model_id, "config.json")
with open(config_path) as f:
cfg = json.load(f)
model = cls(**cfg)
state_dict = torch.load(weights_path, map_location="cpu")
input_dim = state_dict["1.weight"].shape[1]
model.model = model._create_model(input_dim)
model.model.load_state_dict(state_dict)
try:
scaler_path = hf_hub_download(model_id, "scaler.pkl")
with open(scaler_path, "rb") as f:
model.scaler = pickle.load(f)
except:
pass
try:
classes_path = hf_hub_download(model_id, "classes.npy")
model.classes_ = np.load(classes_path)
except:
pass
model.model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.device = device
model.model = model.model.to(device)
return model