import os, json, time import numpy as np import pandas as pd import matplotlib.pyplot as plt import torch import torch.nn as nn from torch.utils.data import DataLoader from datasets import load_from_disk, DatasetDict import optuna from dataclasses import dataclass from typing import Dict, Any, Tuple, Optional from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score from scipy.stats import spearmanr from torch.cuda.amp import autocast from torch.cuda.amp import autocast, GradScaler scaler = GradScaler(enabled=torch.cuda.is_available()) from lightning.pytorch import seed_everything seed_everything(1986) def load_split(dataset_path): ds = load_from_disk(dataset_path) if isinstance(ds, DatasetDict): return ds["train"], ds["val"] raise ValueError("Expected DatasetDict with 'train' and 'val' splits") def collate_unpooled_reg(batch): lengths = [int(x["length"]) for x in batch] Lmax = max(lengths) H = len(batch[0]["embedding"][0]) X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32) M = torch.zeros(len(batch), Lmax, dtype=torch.bool) y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) for i, x in enumerate(batch): emb = torch.tensor(x["embedding"], dtype=torch.float32) # (L,H) L = emb.shape[0] X[i, :L] = emb if "attention_mask" in x: m = torch.tensor(x["attention_mask"], dtype=torch.bool) M[i, :L] = m[:L] else: M[i, :L] = True return X, M, y def infer_in_dim(ds) -> int: ex = ds[0] return int(len(ex["embedding"][0])) # ============================ # Metrics # ============================ def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float: rho = spearmanr(y_true, y_pred).correlation if rho is None or np.isnan(rho): return 0.0 return float(rho) def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]: # ---- RMSE ---- try: from sklearn.metrics import root_mean_squared_error rmse = root_mean_squared_error(y_true, y_pred) except Exception: mse = mean_squared_error(y_true, y_pred) rmse = float(np.sqrt(mse)) mae = float(mean_absolute_error(y_true, y_pred)) r2 = float(r2_score(y_true, y_pred)) rho = float(safe_spearmanr(y_true, y_pred)) return {"rmse": float(rmse), "mae": mae, "r2": r2, "spearman_rho": rho} # ============================ # Models # ============================ class MaskedMeanPool(nn.Module): def forward(self, X, M): Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) return (X * Mf).sum(dim=1) / denom class MLPRegressor(nn.Module): def __init__(self, in_dim, hidden=512, dropout=0.1): super().__init__() self.pool = MaskedMeanPool() self.net = nn.Sequential( nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1), ) def forward(self, X, M): z = self.pool(X, M) return self.net(z).squeeze(-1) # y_pred class CNNRegressor(nn.Module): def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): super().__init__() blocks = [] ch = in_ch for _ in range(layers): blocks += [ nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout), ] ch = c self.conv = nn.Sequential(*blocks) self.head = nn.Linear(c, 1) def forward(self, X, M): Xc = X.transpose(1, 2) # (B,H,L) Y = self.conv(Xc).transpose(1, 2) # (B,L,C) Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) pooled = (Y * Mf).sum(dim=1) / denom # (B,C) return self.head(pooled).squeeze(-1) class TransformerRegressor(nn.Module): def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): super().__init__() self.proj = nn.Linear(in_dim, d_model) enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=ff, dropout=dropout, batch_first=True, activation="gelu" ) self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers) self.head = nn.Linear(d_model, 1) def forward(self, X, M): pad_mask = ~M Z = self.proj(X) Z = self.enc(Z, src_key_padding_mask=pad_mask) Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) pooled = (Z * Mf).sum(dim=1) / denom return self.head(pooled).squeeze(-1) # ============================ # Train / eval # ============================ @torch.no_grad() def eval_preds(model, loader, device): model.eval() ys, ps = [], [] for X, M, y in loader: X, M = X.to(device), M.to(device) pred = model(X, M).detach().cpu().numpy() ys.append(y.numpy()) ps.append(pred) return np.concatenate(ys), np.concatenate(ps) def train_one_epoch_reg(model, loader, optim, criterion, device): model.train() for X, M, y in loader: X, M, y = X.to(device), M.to(device), y.to(device) optim.zero_grad(set_to_none=True) with autocast(enabled=torch.cuda.is_available()): pred = model(X, M) loss = criterion(pred, y) scaler.scale(loss).backward() scaler.step(optim) scaler.update() # ============================ # Saving + plots # ============================ def save_predictions_csv(out_dir, split_name, y_true, y_pred, sequences=None): os.makedirs(out_dir, exist_ok=True) df = pd.DataFrame({ "y_true": y_true.astype(float), "y_pred": y_pred.astype(float), "residual": (y_true - y_pred).astype(float), }) if sequences is not None: df.insert(0, "sequence", sequences) df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False) def plot_regression_diagnostics(out_dir, y_true, y_pred): os.makedirs(out_dir, exist_ok=True) plt.figure() plt.scatter(y_true, y_pred, s=8, alpha=0.5) plt.xlabel("y_true"); plt.ylabel("y_pred") plt.title("Predicted vs True") plt.tight_layout() plt.savefig(os.path.join(out_dir, "pred_vs_true.png")) plt.close() resid = y_true - y_pred plt.figure() plt.hist(resid, bins=50) plt.xlabel("residual (y_true - y_pred)"); plt.ylabel("count") plt.title("Residual Histogram") plt.tight_layout() plt.savefig(os.path.join(out_dir, "residual_hist.png")) plt.close() plt.figure() plt.scatter(y_pred, resid, s=8, alpha=0.5) plt.xlabel("y_pred"); plt.ylabel("residual") plt.title("Residuals vs Prediction") plt.tight_layout() plt.savefig(os.path.join(out_dir, "residual_vs_pred.png")) plt.close() # ============================ # Optuna objective # ============================ def score_from_metrics(metrics: Dict[str, float], objective: str) -> float: if objective == "spearman": return metrics["spearman_rho"] if objective == "r2": return metrics["r2"] if objective == "neg_rmse": return -metrics["rmse"] raise ValueError(f"Unknown objective={objective}") def objective_nn_reg(trial, model_name, train_ds, val_ds, device="cuda:0", objective="spearman"): lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True) wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True) dropout = trial.suggest_float("dropout", 0.0, 0.5) batch_size = trial.suggest_categorical("batch_size", [16, 32, 64]) in_dim = infer_in_dim(train_ds) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True) if model_name == "mlp": hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048]) model = MLPRegressor(in_dim=in_dim, hidden=hidden, dropout=dropout) elif model_name == "cnn": c = trial.suggest_categorical("channels", [128, 256, 512]) k = trial.suggest_categorical("kernel", [3, 5, 7]) layers = trial.suggest_int("layers", 1, 4) model = CNNRegressor(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout) elif model_name == "transformer": d = trial.suggest_categorical("d_model", [128, 256, 384]) nhead = trial.suggest_categorical("nhead", [4, 8]) layers = trial.suggest_int("layers", 1, 4) ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536]) model = TransformerRegressor(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout) else: raise ValueError(model_name) model = model.to(device) loss_name = trial.suggest_categorical("loss", ["mse", "huber"]) if loss_name == "mse": criterion = nn.MSELoss() else: delta = trial.suggest_float("huber_delta", 0.5, 5.0, log=True) criterion = nn.HuberLoss(delta=delta) optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) best_score = -1e18 patience = 10 bad = 0 for epoch in range(1, 61): train_one_epoch_reg(model, train_loader, optim, criterion, device) y_true, y_pred = eval_preds(model, val_loader, device) metrics = eval_regression(y_true, y_pred) score = score_from_metrics(metrics, objective) # log attrs for k, v in metrics.items(): trial.set_user_attr(f"val_{k}", float(v)) trial.report(score, epoch) if trial.should_prune(): raise optuna.TrialPruned() if score > best_score + 1e-6: best_score = score bad = 0 else: bad += 1 if bad >= patience: break return float(best_score) # ============================ # Main runner # ============================ def run_optuna_and_refit_nn_reg(dataset_path, out_dir, model_name, n_trials=80, device="cuda:0", objective="spearman"): os.makedirs(out_dir, exist_ok=True) train_ds, val_ds = load_split(dataset_path) print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}") study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) study.optimize(lambda t: objective_nn_reg(t, model_name, train_ds, val_ds, device=device, objective=objective), n_trials=n_trials) trials_df = study.trials_dataframe() trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False) best = study.best_trial best_params = dict(best.params) # rebuild model from best params in_dim = infer_in_dim(train_ds) dropout = float(best_params.get("dropout", 0.1)) if model_name == "mlp": model = MLPRegressor(in_dim=in_dim, hidden=int(best_params["hidden"]), dropout=dropout) elif model_name == "cnn": model = CNNRegressor(in_ch=in_dim, c=int(best_params["channels"]), k=int(best_params["kernel"]), layers=int(best_params["layers"]), dropout=dropout) elif model_name == "transformer": model = TransformerRegressor(in_dim=in_dim, d_model=int(best_params["d_model"]), nhead=int(best_params["nhead"]), layers=int(best_params["layers"]), ff=int(best_params["ff"]), dropout=dropout) else: raise ValueError(model_name) model = model.to(device) batch_size = int(best_params.get("batch_size", 32)) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True) # loss if best_params.get("loss", "mse") == "mse": criterion = nn.MSELoss() else: criterion = nn.HuberLoss(delta=float(best_params["huber_delta"])) optim = torch.optim.AdamW(model.parameters(), lr=float(best_params["lr"]), weight_decay=float(best_params["weight_decay"])) # refit longer with early stopping on the SAME objective best_score, bad, patience = -1e18, 0, 15 best_state = None for epoch in range(1, 201): train_one_epoch_reg(model, train_loader, optim, criterion, device) y_true, y_pred = eval_preds(model, val_loader, device) metrics = eval_regression(y_true, y_pred) score = score_from_metrics(metrics, objective) if score > best_score + 1e-6: best_score = score bad = 0 best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} best_metrics = metrics else: bad += 1 if bad >= patience: break if best_state is not None: model.load_state_dict(best_state) # preds y_true_tr, y_pred_tr = eval_preds(model, DataLoader(train_ds, batch_size=64, shuffle=False, collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True), device) y_true_va, y_pred_va = eval_preds(model, val_loader, device) seq_train = np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None seq_val = np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None save_predictions_csv(out_dir, "train", y_true_tr, y_pred_tr, seq_train) save_predictions_csv(out_dir, "val", y_true_va, y_pred_va, seq_val) plot_regression_diagnostics(out_dir, y_true_va, y_pred_va) # save model model_path = os.path.join(out_dir, "best_model.pt") torch.save({"state_dict": model.state_dict(), "best_params": best_params, "in_dim": in_dim}, model_path) summary = [ "=" * 72, f"MODEL: {model_name}", f"OPTUNA objective: {objective} (direction=maximize)", f"Best trial: {best.number}", "Best val metrics:", json.dumps({k: float(v) for k, v in best_metrics.items()}, indent=2), f"Saved model: {model_path}", "Best params:", json.dumps(best_params, indent=2), "=" * 72, ] with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f: f.write("\n".join(summary)) print("\n".join(summary)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--dataset_path", type=str, required=True) parser.add_argument("--out_dir", type=str, required=True) parser.add_argument("--model", type=str, choices=["mlp","cnn","transformer"], required=True) parser.add_argument("--n_trials", type=int, default=80) parser.add_argument("--objective", type=str, default="spearman", choices=["spearman","neg_rmse","r2"]) parser.add_argument("--device", type=str, default="cuda:0") args = parser.parse_args() run_optuna_and_refit_nn_reg( dataset_path=args.dataset_path, out_dir=args.out_dir, model_name=args.model, n_trials=args.n_trials, device=args.device, objective=args.objective, )