Joblib
PeptiVerse / training_classifiers /.ipynb_checkpoints /finetune_boost-checkpoint.py
ynuozhang
update code
baf3373
#!/usr/bin/env python3
# finetune_xgb_halflife_cv_optuna.py
import os
import json
import math
import hashlib
from dataclasses import dataclass
from typing import Dict, Any, Optional, Tuple, List
import numpy as np
import pandas as pd
import optuna
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy.stats import spearmanr
import torch
from transformers import AutoTokenizer, AutoModel
import xgboost as xgb
# -----------------------------
# Repro
# -----------------------------
SEED = 1986
np.random.seed(SEED)
torch.manual_seed(SEED)
# -----------------------------
# Metrics (mirrors your stability script style)
# -----------------------------
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
rho = spearmanr(y_true, y_pred).correlation
if rho is None or np.isnan(rho):
return 0.0
return float(rho)
def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
mae = float(mean_absolute_error(y_true, y_pred))
r2 = float(r2_score(y_true, y_pred))
rho = float(safe_spearmanr(y_true, y_pred))
return {"rmse": rmse, "mae": mae, "r2": r2, "spearman_rho": rho}
# -----------------------------
# ESM-2 embeddings (cached)
# -----------------------------
@dataclass
class ESMEmbedderConfig:
model_name: str = "facebook/esm2_t33_650M_UR50D"
batch_size: int = 8
max_length: int = 1024 # truncate very long proteins
fp16: bool = True
class ESM2Embedder:
"""
Mean-pooled last hidden state (excluding special tokens) -> (H,) per sequence.
"""
def __init__(self, cfg: ESMEmbedderConfig, device: str = "cuda"):
self.cfg = cfg
self.device = device if (device == "cuda" and torch.cuda.is_available()) else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, do_lower_case=False)
self.model = AutoModel.from_pretrained(cfg.model_name)
self.model.eval()
self.model.to(self.device)
# Turn off gradients
for p in self.model.parameters():
p.requires_grad = False
@torch.inference_mode()
def embed(self, seqs: List[str]) -> np.ndarray:
out = []
bs = self.cfg.batch_size
use_amp = (self.cfg.fp16 and self.device == "cuda")
autocast = torch.cuda.amp.autocast if use_amp else torch.cpu.amp.autocast # safe fallback
for i in range(0, len(seqs), bs):
batch = [s.strip().upper() for s in seqs[i:i+bs]]
toks = self.tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.cfg.max_length,
add_special_tokens=True,
)
toks = {k: v.to(self.device) for k, v in toks.items()}
attn = toks["attention_mask"] # (B, L)
with autocast(enabled=use_amp):
h = self.model(**toks).last_hidden_state # (B, L, H)
# mask out special tokens: first token is <cls>; last non-pad token is usually <eos>
mask = attn.clone()
mask[:, 0] = 0
lengths = attn.sum(dim=1) # includes special tokens
# zero out last real token position per sequence
eos_pos = (lengths - 1).clamp(min=0)
mask[torch.arange(mask.size(0), device=mask.device), eos_pos] = 0
denom = mask.sum(dim=1).clamp(min=1).unsqueeze(-1) # (B,1)
pooled = (h * mask.unsqueeze(-1)).sum(dim=1) / denom # (B,H)
out.append(pooled.float().detach().cpu().numpy())
return np.concatenate(out, axis=0).astype(np.float32)
def dataset_fingerprint(seqs: List[str], y: np.ndarray, extra: str = "") -> str:
h = hashlib.sha256()
for s in seqs:
h.update(s.encode("utf-8"))
h.update(b"\n")
h.update(np.asarray(y, dtype=np.float32).tobytes())
h.update(extra.encode("utf-8"))
return h.hexdigest()[:16]
def load_or_compute_embeddings(
df: pd.DataFrame,
out_dir: str,
embed_cfg: ESMEmbedderConfig,
device: str,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
os.makedirs(out_dir, exist_ok=True)
seqs = df["sequence"].astype(str).tolist()
y = df["half_life_hours"].astype(float).to_numpy(dtype=np.float32)
fp = dataset_fingerprint(seqs, y, extra=f"{embed_cfg.model_name}|{embed_cfg.max_length}")
emb_path = os.path.join(out_dir, f"esm2_embeddings_{fp}.npy")
meta_path = os.path.join(out_dir, f"esm2_embeddings_{fp}.json")
if os.path.exists(emb_path) and os.path.exists(meta_path):
X = np.load(emb_path).astype(np.float32)
return X, y, np.asarray(seqs)
embedder = ESM2Embedder(embed_cfg, device=device)
X = embedder.embed(seqs) # (N,H)
np.save(emb_path, X)
with open(meta_path, "w") as f:
json.dump(
{
"fingerprint": fp,
"model_name": embed_cfg.model_name,
"max_length": embed_cfg.max_length,
"n": len(seqs),
"dim": int(X.shape[1]),
},
f,
indent=2,
)
return X, y, np.asarray(seqs)
# -----------------------------
# XGBoost training (supports "finetune" via xgb_model)
# -----------------------------
def train_xgb_reg(
X_train: np.ndarray,
y_train: np.ndarray,
X_val: np.ndarray,
y_val: np.ndarray,
params: Dict[str, Any],
base_model_json: Optional[str] = None,
) -> Tuple[xgb.Booster, np.ndarray, np.ndarray, int]:
dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)
num_boost_round = int(params.pop("num_boost_round"))
early_stopping_rounds = int(params.pop("early_stopping_rounds"))
# Important: load a fresh base model each fold (avoid leakage)
xgb_model = None
if base_model_json is not None:
booster0 = xgb.Booster()
booster0.load_model(base_model_json)
xgb_model = booster0
booster = xgb.train(
params=params,
dtrain=dtrain,
num_boost_round=num_boost_round,
evals=[(dval, "val")],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=False,
xgb_model=xgb_model, # <-- "finetune": continue boosting from base model
)
p_train = booster.predict(dtrain)
p_val = booster.predict(dval)
best_iter = int(getattr(booster, "best_iteration", num_boost_round - 1))
return booster, p_train, p_val, best_iter
# -----------------------------
# Optuna objective: 5-fold mean Spearman rho
# -----------------------------
def make_cv_objective(
X: np.ndarray,
y: np.ndarray,
n_splits: int,
device: str,
base_model_json: Optional[str],
target_transform: str,
):
kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
# Optional target transform (sometimes helps with heavy-tailed half-life)
if target_transform == "log1p":
y_used = np.log1p(np.clip(y, a_min=0.0, a_max=None)).astype(np.float32)
elif target_transform == "none":
y_used = y.astype(np.float32)
else:
raise ValueError(f"Unknown target_transform: {target_transform}")
def objective(trial: optuna.Trial) -> float:
# Hyperparam ranges patterned after your stability script :contentReference[oaicite:1]{index=1}
params = {
"objective": "reg:squarederror",
"eval_metric": "rmse",
"lambda": trial.suggest_float("lambda", 1e-10, 100.0, log=True),
"alpha": trial.suggest_float("alpha", 1e-10, 100.0, log=True),
"gamma": trial.suggest_float("gamma", 0.0, 10.0),
"max_depth": trial.suggest_int("max_depth", 2, 12),
"min_child_weight": trial.suggest_float("min_child_weight", 1e-3, 200.0, log=True),
"subsample": trial.suggest_float("subsample", 0.5, 1.0),
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.2, log=True),
"tree_method": "hist",
"device": "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu",
}
params["num_boost_round"] = trial.suggest_int("num_boost_round", 30, 1500)
params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 10, 150)
fold_metrics = []
fold_best_iters = []
for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1):
Xtr, ytr = X[tr_idx], y_used[tr_idx]
Xva, yva = X[va_idx], y_used[va_idx]
_, _, p_va, best_iter = train_xgb_reg(
Xtr, ytr, Xva, yva, params.copy(),
base_model_json=base_model_json,
)
m = eval_regression(yva, p_va)
fold_metrics.append(m)
fold_best_iters.append(best_iter)
mean_rho = float(np.mean([m["spearman_rho"] for m in fold_metrics]))
mean_rmse = float(np.mean([m["rmse"] for m in fold_metrics]))
mean_mae = float(np.mean([m["mae"] for m in fold_metrics]))
mean_r2 = float(np.mean([m["r2"] for m in fold_metrics]))
mean_best_iter = float(np.mean(fold_best_iters))
trial.set_user_attr("cv_spearman_rho", mean_rho)
trial.set_user_attr("cv_rmse", mean_rmse)
trial.set_user_attr("cv_mae", mean_mae)
trial.set_user_attr("cv_r2", mean_r2)
trial.set_user_attr("cv_mean_best_iter", mean_best_iter)
# maximize Spearman rho (same as your stability workflow :contentReference[oaicite:2]{index=2})
return mean_rho
return objective
def refit_and_save(
X: np.ndarray,
y: np.ndarray,
seqs: np.ndarray,
out_dir: str,
best_params: Dict[str, Any],
n_splits: int,
device: str,
base_model_json: Optional[str],
target_transform: str,
):
os.makedirs(out_dir, exist_ok=True)
# Transform target consistently
if target_transform == "log1p":
y_used = np.log1p(np.clip(y, a_min=0.0, a_max=None)).astype(np.float32)
else:
y_used = y.astype(np.float32)
kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
# 1) get OOF preds + average best_iteration
oof_pred = np.zeros_like(y_used, dtype=np.float32)
best_iters = []
fold_rows = []
for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1):
Xtr, ytr = X[tr_idx], y_used[tr_idx]
Xva, yva = X[va_idx], y_used[va_idx]
_, _, p_va, best_iter = train_xgb_reg(
Xtr, ytr, Xva, yva, best_params.copy(),
base_model_json=base_model_json,
)
oof_pred[va_idx] = p_va.astype(np.float32)
best_iters.append(best_iter)
m = eval_regression(yva, p_va)
fold_rows.append({"fold": fold, **m, "best_iter": int(best_iter)})
fold_df = pd.DataFrame(fold_rows)
fold_df.to_csv(os.path.join(out_dir, "cv_fold_metrics.csv"), index=False)
cv_metrics = eval_regression(y_used, oof_pred)
with open(os.path.join(out_dir, "cv_oof_summary.json"), "w") as f:
json.dump(cv_metrics, f, indent=2)
oof_df = pd.DataFrame({
"sequence": seqs,
"y_true_used": y_used.astype(float),
"y_pred_oof": oof_pred.astype(float),
"residual": (y_used - oof_pred).astype(float),
})
oof_df.to_csv(os.path.join(out_dir, "cv_oof_predictions.csv"), index=False)
mean_best_iter = int(round(float(np.mean(best_iters))))
final_rounds = max(mean_best_iter + 1, 10)
# 2) train final model on ALL data (no early stopping here; use final_rounds)
dtrain_all = xgb.DMatrix(X, label=y_used)
xgb_model = None
if base_model_json is not None:
booster0 = xgb.Booster()
booster0.load_model(base_model_json)
xgb_model = booster0
final_params = best_params.copy()
final_params.pop("early_stopping_rounds", None)
final_params["device"] = "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu"
booster = xgb.train(
params=final_params,
dtrain=dtrain_all,
num_boost_round=int(final_params.pop("num_boost_round", final_rounds)),
evals=[],
verbose_eval=False,
xgb_model=xgb_model,
)
model_path = os.path.join(out_dir, "best_model_finetuned.json")
booster.save_model(model_path)
with open(os.path.join(out_dir, "final_training_notes.json"), "w") as f:
json.dump(
{
"target_transform": target_transform,
"final_rounds_used": int(final_rounds),
"cv_oof_metrics_on_used_target": cv_metrics,
"model_path": model_path,
},
f,
indent=2,
)
print("=" * 72)
print("[Final] CV OOF metrics (on transformed target if enabled):")
print(json.dumps(cv_metrics, indent=2))
print(f"[Final] Saved finetuned model -> {model_path}")
print("=" * 72)
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--csv_path", type=str, default="/scratch/pranamlab/tong/data/halflife/wt_halflife_merged_dedup.csv")
parser.add_argument("--out_dir", type=str, default="/scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_xgb")
# If provided, we will "finetune" by continuing boosting from this model
parser.add_argument("--base_model_json", type=str, default='/scratch/pranamlab/tong/PeptiVerse/src/stability/xgboost/best_model.json', help="Path to an existing XGBoost .json model to continue training from")
# ESM embedding config
parser.add_argument("--esm_model", type=str, default="facebook/esm2_t33_650M_UR50D")
parser.add_argument("--esm_batch_size", type=int, default=8)
parser.add_argument("--esm_max_length", type=int, default=1024)
parser.add_argument("--no_fp16", action="store_true")
# Training config
parser.add_argument("--n_trials", type=int, default=200)
parser.add_argument("--n_splits", type=int, default=5)
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
parser.add_argument("--target_transform", type=str, default="none", choices=["none", "log1p"])
args = parser.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
# Load data
df = pd.read_csv(args.csv_path)
if "sequence" not in df.columns or "half_life_hours" not in df.columns:
raise ValueError("CSV must contain columns: sequence, half_life_hours")
df = df.dropna(subset=["sequence", "half_life_hours"]).copy()
df["sequence"] = df["sequence"].astype(str).str.strip()
df = df[df["sequence"].str.len() > 0]
df = df.drop_duplicates(subset=["sequence"], keep="first").reset_index(drop=True)
print(f"[Data] N={len(df)} from {args.csv_path}")
# Embeddings (cached)
embed_cfg = ESMEmbedderConfig(
model_name=args.esm_model,
batch_size=args.esm_batch_size,
max_length=args.esm_max_length,
fp16=(not args.no_fp16),
)
X, y, seqs = load_or_compute_embeddings(df, args.out_dir, embed_cfg, device=args.device)
print(f"[Embeddings] X={X.shape} (float32)")
# Optuna study
sampler = optuna.samplers.TPESampler(seed=SEED)
study = optuna.create_study(
direction="maximize", # like your stability script :contentReference[oaicite:3]{index=3}
sampler=sampler,
pruner=optuna.pruners.MedianPruner(),
)
objective = make_cv_objective(
X=X,
y=y,
n_splits=args.n_splits,
device=args.device,
base_model_json=args.base_model_json,
target_transform=args.target_transform,
)
study.optimize(objective, n_trials=args.n_trials)
# Save trials
trials_df = study.trials_dataframe()
trials_df.to_csv(os.path.join(args.out_dir, "study_trials.csv"), index=False)
best = study.best_trial
best_params = dict(best.params)
# Build full param dict for refit
best_xgb_params = {
"objective": "reg:squarederror",
"eval_metric": "rmse",
"lambda": best_params["lambda"],
"alpha": best_params["alpha"],
"gamma": best_params["gamma"],
"max_depth": best_params["max_depth"],
"min_child_weight": best_params["min_child_weight"],
"subsample": best_params["subsample"],
"colsample_bytree": best_params["colsample_bytree"],
"learning_rate": best_params["learning_rate"],
"tree_method": "hist",
"device": "cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu",
"num_boost_round": best_params["num_boost_round"],
"early_stopping_rounds": best_params["early_stopping_rounds"],
}
# Summary
summary = {
"best_trial_number": int(best.number),
"best_value_cv_spearman_rho": float(best.value),
"best_user_attrs": best.user_attrs,
"best_params": best_params,
"best_xgb_params_full": best_xgb_params,
"base_model_json": args.base_model_json,
"target_transform": args.target_transform,
"esm_model": args.esm_model,
"esm_max_length": args.esm_max_length,
}
with open(os.path.join(args.out_dir, "optimization_summary.json"), "w") as f:
json.dump(summary, f, indent=2)
print("=" * 72)
print("[Optuna] Best CV Spearman rho:", float(best.value))
print("[Optuna] Best params:\n", json.dumps(best_params, indent=2))
print("=" * 72)
# Refit + save final finetuned model + OOF predictions
refit_and_save(
X=X,
y=y,
seqs=seqs,
out_dir=args.out_dir,
best_params=best_xgb_params,
n_splits=args.n_splits,
device=args.device,
base_model_json=args.base_model_json,
target_transform=args.target_transform,
)
if __name__ == "__main__":
main()