| |
| |
|
|
| import os |
| import json |
| import math |
| import hashlib |
| from dataclasses import dataclass |
| from typing import Dict, Any, Optional, Tuple, List |
|
|
| import numpy as np |
| import pandas as pd |
| import optuna |
|
|
| from sklearn.model_selection import KFold |
| from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score |
| from scipy.stats import spearmanr |
|
|
| import torch |
| from transformers import AutoTokenizer, AutoModel |
|
|
| import xgboost as xgb |
|
|
|
|
| |
| |
| |
| SEED = 1986 |
| np.random.seed(SEED) |
| torch.manual_seed(SEED) |
|
|
|
|
| |
| |
| |
| def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float: |
| rho = spearmanr(y_true, y_pred).correlation |
| if rho is None or np.isnan(rho): |
| return 0.0 |
| return float(rho) |
|
|
| def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]: |
| rmse = float(np.sqrt(mean_squared_error(y_true, y_pred))) |
| mae = float(mean_absolute_error(y_true, y_pred)) |
| r2 = float(r2_score(y_true, y_pred)) |
| rho = float(safe_spearmanr(y_true, y_pred)) |
| return {"rmse": rmse, "mae": mae, "r2": r2, "spearman_rho": rho} |
|
|
|
|
| |
| |
| |
| @dataclass |
| class ESMEmbedderConfig: |
| model_name: str = "facebook/esm2_t33_650M_UR50D" |
| batch_size: int = 8 |
| max_length: int = 1024 |
| fp16: bool = True |
|
|
| class ESM2Embedder: |
| """ |
| Mean-pooled last hidden state (excluding special tokens) -> (H,) per sequence. |
| """ |
| def __init__(self, cfg: ESMEmbedderConfig, device: str = "cuda"): |
| self.cfg = cfg |
| self.device = device if (device == "cuda" and torch.cuda.is_available()) else "cpu" |
| self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, do_lower_case=False) |
| self.model = AutoModel.from_pretrained(cfg.model_name) |
| self.model.eval() |
| self.model.to(self.device) |
|
|
| |
| for p in self.model.parameters(): |
| p.requires_grad = False |
|
|
| @torch.inference_mode() |
| def embed(self, seqs: List[str]) -> np.ndarray: |
| out = [] |
| bs = self.cfg.batch_size |
|
|
| use_amp = (self.cfg.fp16 and self.device == "cuda") |
| autocast = torch.cuda.amp.autocast if use_amp else torch.cpu.amp.autocast |
|
|
| for i in range(0, len(seqs), bs): |
| batch = [s.strip().upper() for s in seqs[i:i+bs]] |
| toks = self.tokenizer( |
| batch, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=self.cfg.max_length, |
| add_special_tokens=True, |
| ) |
| toks = {k: v.to(self.device) for k, v in toks.items()} |
| attn = toks["attention_mask"] |
|
|
| with autocast(enabled=use_amp): |
| h = self.model(**toks).last_hidden_state |
|
|
| |
| mask = attn.clone() |
| mask[:, 0] = 0 |
| lengths = attn.sum(dim=1) |
| |
| eos_pos = (lengths - 1).clamp(min=0) |
| mask[torch.arange(mask.size(0), device=mask.device), eos_pos] = 0 |
|
|
| denom = mask.sum(dim=1).clamp(min=1).unsqueeze(-1) |
| pooled = (h * mask.unsqueeze(-1)).sum(dim=1) / denom |
| out.append(pooled.float().detach().cpu().numpy()) |
|
|
| return np.concatenate(out, axis=0).astype(np.float32) |
|
|
|
|
| def dataset_fingerprint(seqs: List[str], y: np.ndarray, extra: str = "") -> str: |
| h = hashlib.sha256() |
| for s in seqs: |
| h.update(s.encode("utf-8")) |
| h.update(b"\n") |
| h.update(np.asarray(y, dtype=np.float32).tobytes()) |
| h.update(extra.encode("utf-8")) |
| return h.hexdigest()[:16] |
|
|
|
|
| def load_or_compute_embeddings( |
| df: pd.DataFrame, |
| out_dir: str, |
| embed_cfg: ESMEmbedderConfig, |
| device: str, |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| seqs = df["sequence"].astype(str).tolist() |
| y = df["half_life_hours"].astype(float).to_numpy(dtype=np.float32) |
|
|
| fp = dataset_fingerprint(seqs, y, extra=f"{embed_cfg.model_name}|{embed_cfg.max_length}") |
| emb_path = os.path.join(out_dir, f"esm2_embeddings_{fp}.npy") |
| meta_path = os.path.join(out_dir, f"esm2_embeddings_{fp}.json") |
|
|
| if os.path.exists(emb_path) and os.path.exists(meta_path): |
| X = np.load(emb_path).astype(np.float32) |
| return X, y, np.asarray(seqs) |
|
|
| embedder = ESM2Embedder(embed_cfg, device=device) |
| X = embedder.embed(seqs) |
|
|
| np.save(emb_path, X) |
| with open(meta_path, "w") as f: |
| json.dump( |
| { |
| "fingerprint": fp, |
| "model_name": embed_cfg.model_name, |
| "max_length": embed_cfg.max_length, |
| "n": len(seqs), |
| "dim": int(X.shape[1]), |
| }, |
| f, |
| indent=2, |
| ) |
| return X, y, np.asarray(seqs) |
|
|
|
|
| |
| |
| |
| def train_xgb_reg( |
| X_train: np.ndarray, |
| y_train: np.ndarray, |
| X_val: np.ndarray, |
| y_val: np.ndarray, |
| params: Dict[str, Any], |
| base_model_json: Optional[str] = None, |
| ) -> Tuple[xgb.Booster, np.ndarray, np.ndarray, int]: |
| dtrain = xgb.DMatrix(X_train, label=y_train) |
| dval = xgb.DMatrix(X_val, label=y_val) |
|
|
| num_boost_round = int(params.pop("num_boost_round")) |
| early_stopping_rounds = int(params.pop("early_stopping_rounds")) |
|
|
| |
| xgb_model = None |
| if base_model_json is not None: |
| booster0 = xgb.Booster() |
| booster0.load_model(base_model_json) |
| xgb_model = booster0 |
|
|
| booster = xgb.train( |
| params=params, |
| dtrain=dtrain, |
| num_boost_round=num_boost_round, |
| evals=[(dval, "val")], |
| early_stopping_rounds=early_stopping_rounds, |
| verbose_eval=False, |
| xgb_model=xgb_model, |
| ) |
|
|
| p_train = booster.predict(dtrain) |
| p_val = booster.predict(dval) |
| best_iter = int(getattr(booster, "best_iteration", num_boost_round - 1)) |
| return booster, p_train, p_val, best_iter |
|
|
|
|
| |
| |
| |
| def make_cv_objective( |
| X: np.ndarray, |
| y: np.ndarray, |
| n_splits: int, |
| device: str, |
| base_model_json: Optional[str], |
| target_transform: str, |
| ): |
| kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED) |
|
|
| |
| if target_transform == "log1p": |
| y_used = np.log1p(np.clip(y, a_min=0.0, a_max=None)).astype(np.float32) |
| elif target_transform == "none": |
| y_used = y.astype(np.float32) |
| else: |
| raise ValueError(f"Unknown target_transform: {target_transform}") |
|
|
| def objective(trial: optuna.Trial) -> float: |
| |
| params = { |
| "objective": "reg:squarederror", |
| "eval_metric": "rmse", |
|
|
| "lambda": trial.suggest_float("lambda", 1e-10, 100.0, log=True), |
| "alpha": trial.suggest_float("alpha", 1e-10, 100.0, log=True), |
| "gamma": trial.suggest_float("gamma", 0.0, 10.0), |
|
|
| "max_depth": trial.suggest_int("max_depth", 2, 12), |
| "min_child_weight": trial.suggest_float("min_child_weight", 1e-3, 200.0, log=True), |
| "subsample": trial.suggest_float("subsample", 0.5, 1.0), |
| "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0), |
|
|
| "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.2, log=True), |
|
|
| "tree_method": "hist", |
| "device": "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu", |
| } |
| params["num_boost_round"] = trial.suggest_int("num_boost_round", 30, 1500) |
| params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 10, 150) |
|
|
| fold_metrics = [] |
| fold_best_iters = [] |
|
|
| for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1): |
| Xtr, ytr = X[tr_idx], y_used[tr_idx] |
| Xva, yva = X[va_idx], y_used[va_idx] |
|
|
| _, _, p_va, best_iter = train_xgb_reg( |
| Xtr, ytr, Xva, yva, params.copy(), |
| base_model_json=base_model_json, |
| ) |
|
|
| m = eval_regression(yva, p_va) |
| fold_metrics.append(m) |
| fold_best_iters.append(best_iter) |
|
|
| mean_rho = float(np.mean([m["spearman_rho"] for m in fold_metrics])) |
| mean_rmse = float(np.mean([m["rmse"] for m in fold_metrics])) |
| mean_mae = float(np.mean([m["mae"] for m in fold_metrics])) |
| mean_r2 = float(np.mean([m["r2"] for m in fold_metrics])) |
| mean_best_iter = float(np.mean(fold_best_iters)) |
|
|
| trial.set_user_attr("cv_spearman_rho", mean_rho) |
| trial.set_user_attr("cv_rmse", mean_rmse) |
| trial.set_user_attr("cv_mae", mean_mae) |
| trial.set_user_attr("cv_r2", mean_r2) |
| trial.set_user_attr("cv_mean_best_iter", mean_best_iter) |
|
|
| |
| return mean_rho |
|
|
| return objective |
|
|
|
|
| def refit_and_save( |
| X: np.ndarray, |
| y: np.ndarray, |
| seqs: np.ndarray, |
| out_dir: str, |
| best_params: Dict[str, Any], |
| n_splits: int, |
| device: str, |
| base_model_json: Optional[str], |
| target_transform: str, |
| ): |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| |
| if target_transform == "log1p": |
| y_used = np.log1p(np.clip(y, a_min=0.0, a_max=None)).astype(np.float32) |
| else: |
| y_used = y.astype(np.float32) |
|
|
| kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED) |
|
|
| |
| oof_pred = np.zeros_like(y_used, dtype=np.float32) |
| best_iters = [] |
| fold_rows = [] |
|
|
| for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1): |
| Xtr, ytr = X[tr_idx], y_used[tr_idx] |
| Xva, yva = X[va_idx], y_used[va_idx] |
|
|
| _, _, p_va, best_iter = train_xgb_reg( |
| Xtr, ytr, Xva, yva, best_params.copy(), |
| base_model_json=base_model_json, |
| ) |
| oof_pred[va_idx] = p_va.astype(np.float32) |
| best_iters.append(best_iter) |
|
|
| m = eval_regression(yva, p_va) |
| fold_rows.append({"fold": fold, **m, "best_iter": int(best_iter)}) |
|
|
| fold_df = pd.DataFrame(fold_rows) |
| fold_df.to_csv(os.path.join(out_dir, "cv_fold_metrics.csv"), index=False) |
|
|
| cv_metrics = eval_regression(y_used, oof_pred) |
| with open(os.path.join(out_dir, "cv_oof_summary.json"), "w") as f: |
| json.dump(cv_metrics, f, indent=2) |
|
|
| oof_df = pd.DataFrame({ |
| "sequence": seqs, |
| "y_true_used": y_used.astype(float), |
| "y_pred_oof": oof_pred.astype(float), |
| "residual": (y_used - oof_pred).astype(float), |
| }) |
| oof_df.to_csv(os.path.join(out_dir, "cv_oof_predictions.csv"), index=False) |
|
|
| mean_best_iter = int(round(float(np.mean(best_iters)))) |
| final_rounds = max(mean_best_iter + 1, 10) |
|
|
| |
| dtrain_all = xgb.DMatrix(X, label=y_used) |
|
|
| xgb_model = None |
| if base_model_json is not None: |
| booster0 = xgb.Booster() |
| booster0.load_model(base_model_json) |
| xgb_model = booster0 |
|
|
| final_params = best_params.copy() |
| final_params.pop("early_stopping_rounds", None) |
| final_params["device"] = "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu" |
|
|
| booster = xgb.train( |
| params=final_params, |
| dtrain=dtrain_all, |
| num_boost_round=int(final_params.pop("num_boost_round", final_rounds)), |
| evals=[], |
| verbose_eval=False, |
| xgb_model=xgb_model, |
| ) |
|
|
| model_path = os.path.join(out_dir, "best_model_finetuned.json") |
| booster.save_model(model_path) |
|
|
| with open(os.path.join(out_dir, "final_training_notes.json"), "w") as f: |
| json.dump( |
| { |
| "target_transform": target_transform, |
| "final_rounds_used": int(final_rounds), |
| "cv_oof_metrics_on_used_target": cv_metrics, |
| "model_path": model_path, |
| }, |
| f, |
| indent=2, |
| ) |
|
|
| print("=" * 72) |
| print("[Final] CV OOF metrics (on transformed target if enabled):") |
| print(json.dumps(cv_metrics, indent=2)) |
| print(f"[Final] Saved finetuned model -> {model_path}") |
| print("=" * 72) |
|
|
|
|
| def main(): |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--csv_path", type=str, default="/scratch/pranamlab/tong/data/halflife/wt_halflife_merged_dedup.csv") |
| parser.add_argument("--out_dir", type=str, default="/scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_xgb") |
|
|
| |
| parser.add_argument("--base_model_json", type=str, default='/scratch/pranamlab/tong/PeptiVerse/src/stability/xgboost/best_model.json', help="Path to an existing XGBoost .json model to continue training from") |
|
|
| |
| parser.add_argument("--esm_model", type=str, default="facebook/esm2_t33_650M_UR50D") |
| parser.add_argument("--esm_batch_size", type=int, default=8) |
| parser.add_argument("--esm_max_length", type=int, default=1024) |
| parser.add_argument("--no_fp16", action="store_true") |
|
|
| |
| parser.add_argument("--n_trials", type=int, default=200) |
| parser.add_argument("--n_splits", type=int, default=5) |
| parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) |
| parser.add_argument("--target_transform", type=str, default="none", choices=["none", "log1p"]) |
|
|
| args = parser.parse_args() |
| os.makedirs(args.out_dir, exist_ok=True) |
|
|
| |
| df = pd.read_csv(args.csv_path) |
| if "sequence" not in df.columns or "half_life_hours" not in df.columns: |
| raise ValueError("CSV must contain columns: sequence, half_life_hours") |
|
|
| df = df.dropna(subset=["sequence", "half_life_hours"]).copy() |
| df["sequence"] = df["sequence"].astype(str).str.strip() |
| df = df[df["sequence"].str.len() > 0] |
| df = df.drop_duplicates(subset=["sequence"], keep="first").reset_index(drop=True) |
|
|
| print(f"[Data] N={len(df)} from {args.csv_path}") |
|
|
| |
| embed_cfg = ESMEmbedderConfig( |
| model_name=args.esm_model, |
| batch_size=args.esm_batch_size, |
| max_length=args.esm_max_length, |
| fp16=(not args.no_fp16), |
| ) |
| X, y, seqs = load_or_compute_embeddings(df, args.out_dir, embed_cfg, device=args.device) |
| print(f"[Embeddings] X={X.shape} (float32)") |
|
|
| |
| sampler = optuna.samplers.TPESampler(seed=SEED) |
| study = optuna.create_study( |
| direction="maximize", |
| sampler=sampler, |
| pruner=optuna.pruners.MedianPruner(), |
| ) |
|
|
| objective = make_cv_objective( |
| X=X, |
| y=y, |
| n_splits=args.n_splits, |
| device=args.device, |
| base_model_json=args.base_model_json, |
| target_transform=args.target_transform, |
| ) |
| study.optimize(objective, n_trials=args.n_trials) |
|
|
| |
| trials_df = study.trials_dataframe() |
| trials_df.to_csv(os.path.join(args.out_dir, "study_trials.csv"), index=False) |
|
|
| best = study.best_trial |
| best_params = dict(best.params) |
|
|
| |
| best_xgb_params = { |
| "objective": "reg:squarederror", |
| "eval_metric": "rmse", |
| "lambda": best_params["lambda"], |
| "alpha": best_params["alpha"], |
| "gamma": best_params["gamma"], |
| "max_depth": best_params["max_depth"], |
| "min_child_weight": best_params["min_child_weight"], |
| "subsample": best_params["subsample"], |
| "colsample_bytree": best_params["colsample_bytree"], |
| "learning_rate": best_params["learning_rate"], |
| "tree_method": "hist", |
| "device": "cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu", |
| "num_boost_round": best_params["num_boost_round"], |
| "early_stopping_rounds": best_params["early_stopping_rounds"], |
| } |
|
|
| |
| summary = { |
| "best_trial_number": int(best.number), |
| "best_value_cv_spearman_rho": float(best.value), |
| "best_user_attrs": best.user_attrs, |
| "best_params": best_params, |
| "best_xgb_params_full": best_xgb_params, |
| "base_model_json": args.base_model_json, |
| "target_transform": args.target_transform, |
| "esm_model": args.esm_model, |
| "esm_max_length": args.esm_max_length, |
| } |
| with open(os.path.join(args.out_dir, "optimization_summary.json"), "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
| print("=" * 72) |
| print("[Optuna] Best CV Spearman rho:", float(best.value)) |
| print("[Optuna] Best params:\n", json.dumps(best_params, indent=2)) |
| print("=" * 72) |
|
|
| |
| refit_and_save( |
| X=X, |
| y=y, |
| seqs=seqs, |
| out_dir=args.out_dir, |
| best_params=best_xgb_params, |
| n_splits=args.n_splits, |
| device=args.device, |
| base_model_json=args.base_model_json, |
| target_transform=args.target_transform, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|