Joblib
PeptiVerse / training_classifiers /train_nn_regression.py
ynuozhang
update code
baf3373
import os, json, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_from_disk, DatasetDict
import optuna
from dataclasses import dataclass
from typing import Dict, Any, Tuple, Optional
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy.stats import spearmanr
from torch.cuda.amp import autocast
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler(enabled=torch.cuda.is_available())
from lightning.pytorch import seed_everything
seed_everything(1986)
def load_split(dataset_path):
ds = load_from_disk(dataset_path)
if isinstance(ds, DatasetDict):
return ds["train"], ds["val"]
raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
def collate_unpooled_reg(batch):
lengths = [int(x["length"]) for x in batch]
Lmax = max(lengths)
H = len(batch[0]["embedding"][0])
X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
for i, x in enumerate(batch):
emb = torch.tensor(x["embedding"], dtype=torch.float32) # (L,H)
L = emb.shape[0]
X[i, :L] = emb
if "attention_mask" in x:
m = torch.tensor(x["attention_mask"], dtype=torch.bool)
M[i, :L] = m[:L]
else:
M[i, :L] = True
return X, M, y
def infer_in_dim(ds) -> int:
ex = ds[0]
return int(len(ex["embedding"][0]))
# ============================
# Metrics
# ============================
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
rho = spearmanr(y_true, y_pred).correlation
if rho is None or np.isnan(rho):
return 0.0
return float(rho)
def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
# ---- RMSE ----
try:
from sklearn.metrics import root_mean_squared_error
rmse = root_mean_squared_error(y_true, y_pred)
except Exception:
mse = mean_squared_error(y_true, y_pred)
rmse = float(np.sqrt(mse))
mae = float(mean_absolute_error(y_true, y_pred))
r2 = float(r2_score(y_true, y_pred))
rho = float(safe_spearmanr(y_true, y_pred))
return {"rmse": float(rmse), "mae": mae, "r2": r2, "spearman_rho": rho}
# ============================
# Models
# ============================
class MaskedMeanPool(nn.Module):
def forward(self, X, M):
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
return (X * Mf).sum(dim=1) / denom
class MLPRegressor(nn.Module):
def __init__(self, in_dim, hidden=512, dropout=0.1):
super().__init__()
self.pool = MaskedMeanPool()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, 1),
)
def forward(self, X, M):
z = self.pool(X, M)
return self.net(z).squeeze(-1) # y_pred
class CNNRegressor(nn.Module):
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
super().__init__()
blocks = []
ch = in_ch
for _ in range(layers):
blocks += [
nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
nn.GELU(),
nn.Dropout(dropout),
]
ch = c
self.conv = nn.Sequential(*blocks)
self.head = nn.Linear(c, 1)
def forward(self, X, M):
Xc = X.transpose(1, 2) # (B,H,L)
Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
pooled = (Y * Mf).sum(dim=1) / denom # (B,C)
return self.head(pooled).squeeze(-1)
class TransformerRegressor(nn.Module):
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
super().__init__()
self.proj = nn.Linear(in_dim, d_model)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=ff,
dropout=dropout, batch_first=True, activation="gelu"
)
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
self.head = nn.Linear(d_model, 1)
def forward(self, X, M):
pad_mask = ~M
Z = self.proj(X)
Z = self.enc(Z, src_key_padding_mask=pad_mask)
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
pooled = (Z * Mf).sum(dim=1) / denom
return self.head(pooled).squeeze(-1)
# ============================
# Train / eval
# ============================
@torch.no_grad()
def eval_preds(model, loader, device):
model.eval()
ys, ps = [], []
for X, M, y in loader:
X, M = X.to(device), M.to(device)
pred = model(X, M).detach().cpu().numpy()
ys.append(y.numpy())
ps.append(pred)
return np.concatenate(ys), np.concatenate(ps)
def train_one_epoch_reg(model, loader, optim, criterion, device):
model.train()
for X, M, y in loader:
X, M, y = X.to(device), M.to(device), y.to(device)
optim.zero_grad(set_to_none=True)
with autocast(enabled=torch.cuda.is_available()):
pred = model(X, M)
loss = criterion(pred, y)
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
# ============================
# Saving + plots
# ============================
def save_predictions_csv(out_dir, split_name, y_true, y_pred, sequences=None):
os.makedirs(out_dir, exist_ok=True)
df = pd.DataFrame({
"y_true": y_true.astype(float),
"y_pred": y_pred.astype(float),
"residual": (y_true - y_pred).astype(float),
})
if sequences is not None:
df.insert(0, "sequence", sequences)
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
def plot_regression_diagnostics(out_dir, y_true, y_pred):
os.makedirs(out_dir, exist_ok=True)
plt.figure()
plt.scatter(y_true, y_pred, s=8, alpha=0.5)
plt.xlabel("y_true"); plt.ylabel("y_pred")
plt.title("Predicted vs True")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "pred_vs_true.png"))
plt.close()
resid = y_true - y_pred
plt.figure()
plt.hist(resid, bins=50)
plt.xlabel("residual (y_true - y_pred)"); plt.ylabel("count")
plt.title("Residual Histogram")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "residual_hist.png"))
plt.close()
plt.figure()
plt.scatter(y_pred, resid, s=8, alpha=0.5)
plt.xlabel("y_pred"); plt.ylabel("residual")
plt.title("Residuals vs Prediction")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "residual_vs_pred.png"))
plt.close()
# ============================
# Optuna objective
# ============================
def score_from_metrics(metrics: Dict[str, float], objective: str) -> float:
if objective == "spearman":
return metrics["spearman_rho"]
if objective == "r2":
return metrics["r2"]
if objective == "neg_rmse":
return -metrics["rmse"]
raise ValueError(f"Unknown objective={objective}")
def objective_nn_reg(trial, model_name, train_ds, val_ds, device="cuda:0", objective="spearman"):
lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True)
dropout = trial.suggest_float("dropout", 0.0, 0.5)
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
in_dim = infer_in_dim(train_ds)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
if model_name == "mlp":
hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048])
model = MLPRegressor(in_dim=in_dim, hidden=hidden, dropout=dropout)
elif model_name == "cnn":
c = trial.suggest_categorical("channels", [128, 256, 512])
k = trial.suggest_categorical("kernel", [3, 5, 7])
layers = trial.suggest_int("layers", 1, 4)
model = CNNRegressor(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout)
elif model_name == "transformer":
d = trial.suggest_categorical("d_model", [128, 256, 384])
nhead = trial.suggest_categorical("nhead", [4, 8])
layers = trial.suggest_int("layers", 1, 4)
ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536])
model = TransformerRegressor(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout)
else:
raise ValueError(model_name)
model = model.to(device)
loss_name = trial.suggest_categorical("loss", ["mse", "huber"])
if loss_name == "mse":
criterion = nn.MSELoss()
else:
delta = trial.suggest_float("huber_delta", 0.5, 5.0, log=True)
criterion = nn.HuberLoss(delta=delta)
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
best_score = -1e18
patience = 10
bad = 0
for epoch in range(1, 61):
train_one_epoch_reg(model, train_loader, optim, criterion, device)
y_true, y_pred = eval_preds(model, val_loader, device)
metrics = eval_regression(y_true, y_pred)
score = score_from_metrics(metrics, objective)
# log attrs
for k, v in metrics.items():
trial.set_user_attr(f"val_{k}", float(v))
trial.report(score, epoch)
if trial.should_prune():
raise optuna.TrialPruned()
if score > best_score + 1e-6:
best_score = score
bad = 0
else:
bad += 1
if bad >= patience:
break
return float(best_score)
# ============================
# Main runner
# ============================
def run_optuna_and_refit_nn_reg(dataset_path, out_dir, model_name, n_trials=80, device="cuda:0",
objective="spearman"):
os.makedirs(out_dir, exist_ok=True)
train_ds, val_ds = load_split(dataset_path)
print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
study.optimize(lambda t: objective_nn_reg(t, model_name, train_ds, val_ds, device=device, objective=objective),
n_trials=n_trials)
trials_df = study.trials_dataframe()
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
best = study.best_trial
best_params = dict(best.params)
# rebuild model from best params
in_dim = infer_in_dim(train_ds)
dropout = float(best_params.get("dropout", 0.1))
if model_name == "mlp":
model = MLPRegressor(in_dim=in_dim, hidden=int(best_params["hidden"]), dropout=dropout)
elif model_name == "cnn":
model = CNNRegressor(in_ch=in_dim, c=int(best_params["channels"]),
k=int(best_params["kernel"]), layers=int(best_params["layers"]),
dropout=dropout)
elif model_name == "transformer":
model = TransformerRegressor(in_dim=in_dim, d_model=int(best_params["d_model"]),
nhead=int(best_params["nhead"]), layers=int(best_params["layers"]),
ff=int(best_params["ff"]), dropout=dropout)
else:
raise ValueError(model_name)
model = model.to(device)
batch_size = int(best_params.get("batch_size", 32))
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
# loss
if best_params.get("loss", "mse") == "mse":
criterion = nn.MSELoss()
else:
criterion = nn.HuberLoss(delta=float(best_params["huber_delta"]))
optim = torch.optim.AdamW(model.parameters(), lr=float(best_params["lr"]),
weight_decay=float(best_params["weight_decay"]))
# refit longer with early stopping on the SAME objective
best_score, bad, patience = -1e18, 0, 15
best_state = None
for epoch in range(1, 201):
train_one_epoch_reg(model, train_loader, optim, criterion, device)
y_true, y_pred = eval_preds(model, val_loader, device)
metrics = eval_regression(y_true, y_pred)
score = score_from_metrics(metrics, objective)
if score > best_score + 1e-6:
best_score = score
bad = 0
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
best_metrics = metrics
else:
bad += 1
if bad >= patience:
break
if best_state is not None:
model.load_state_dict(best_state)
# preds
y_true_tr, y_pred_tr = eval_preds(model, DataLoader(train_ds, batch_size=64, shuffle=False,
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True), device)
y_true_va, y_pred_va = eval_preds(model, val_loader, device)
seq_train = np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None
seq_val = np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None
save_predictions_csv(out_dir, "train", y_true_tr, y_pred_tr, seq_train)
save_predictions_csv(out_dir, "val", y_true_va, y_pred_va, seq_val)
plot_regression_diagnostics(out_dir, y_true_va, y_pred_va)
# save model
model_path = os.path.join(out_dir, "best_model.pt")
torch.save({"state_dict": model.state_dict(), "best_params": best_params, "in_dim": in_dim}, model_path)
summary = [
"=" * 72,
f"MODEL: {model_name}",
f"OPTUNA objective: {objective} (direction=maximize)",
f"Best trial: {best.number}",
"Best val metrics:",
json.dumps({k: float(v) for k, v in best_metrics.items()}, indent=2),
f"Saved model: {model_path}",
"Best params:",
json.dumps(best_params, indent=2),
"=" * 72,
]
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
f.write("\n".join(summary))
print("\n".join(summary))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, required=True)
parser.add_argument("--out_dir", type=str, required=True)
parser.add_argument("--model", type=str, choices=["mlp","cnn","transformer"], required=True)
parser.add_argument("--n_trials", type=int, default=80)
parser.add_argument("--objective", type=str, default="spearman",
choices=["spearman","neg_rmse","r2"])
parser.add_argument("--device", type=str, default="cuda:0")
args = parser.parse_args()
run_optuna_and_refit_nn_reg(
dataset_path=args.dataset_path,
out_dir=args.out_dir,
model_name=args.model,
n_trials=args.n_trials,
device=args.device,
objective=args.objective,
)