|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import random |
|
|
from copy import deepcopy |
|
|
from typing import Any, Dict |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from hyperopt import fmin, tpe, hp, Trials, STATUS_OK |
|
|
from hyperopt.pyll.base import scope |
|
|
from sklearn.model_selection import StratifiedKFold |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim |
|
|
from torch import Tensor |
|
|
|
|
|
import tabm |
|
|
import rtdl_num_embeddings |
|
|
|
|
|
def set_seed(seed: int) -> None: |
|
|
random.seed(seed) |
|
|
np.random.seed(seed + 1) |
|
|
torch.manual_seed(seed + 2) |
|
|
|
|
|
def _dump_model_info_sidecar(model_path: str) -> None: |
|
|
try: |
|
|
if not os.path.exists(model_path): |
|
|
return |
|
|
ckpt = torch.load(model_path, map_location='cpu', weights_only=False) |
|
|
sidecar = os.path.splitext(model_path)[0] + ".info.txt" |
|
|
with open(sidecar, "w", encoding="utf-8") as f: |
|
|
def _p(title: str, d): |
|
|
try: |
|
|
f.write(title + "\n") |
|
|
if hasattr(d, "__dict__"): |
|
|
items = sorted(vars(d).items()) |
|
|
elif isinstance(d, dict): |
|
|
items = sorted(d.items()) |
|
|
else: |
|
|
try: |
|
|
items = sorted(d.__dict__.items()) |
|
|
except Exception: |
|
|
items = [] |
|
|
for k, v in items: |
|
|
try: |
|
|
f.write(f"- {k}: {repr(v)}\n") |
|
|
except Exception: |
|
|
f.write(f"- {k}: <unprintable>\n") |
|
|
f.write("=" * len(title) + "\n") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
_p("===== checkpoint['args'] =====", ckpt.get('args')) |
|
|
_p("===== checkpoint['training_args'] =====", ckpt.get('training_args', {})) |
|
|
_p("===== checkpoint['best_params'] =====", ckpt.get('best_params', {})) |
|
|
_p("===== checkpoint['full_args'] =====", ckpt.get('full_args', {})) |
|
|
|
|
|
if ckpt.get("used_feature_idx") is not None: |
|
|
ufi = ckpt["used_feature_idx"] |
|
|
f.write("===== used_feature_idx =====\n") |
|
|
try: |
|
|
f.write(f"- length: {len(ufi)}\n") |
|
|
f.write(f"- head: {list(ufi[:10])}\n") |
|
|
except Exception: |
|
|
f.write("<unprintable>\n") |
|
|
f.write("=" * 25 + "\n") |
|
|
|
|
|
|
|
|
try: |
|
|
f.write("===== Environment =====\n") |
|
|
f.write(f"- torch: {torch.__version__}\n") |
|
|
f.write(f"- cuda available: {torch.cuda.is_available()}\n") |
|
|
if torch.cuda.is_available(): |
|
|
f.write(f"- device: {torch.cuda.get_device_name(0)}\n") |
|
|
f.write(f"- cuda version: {torch.version.cuda}\n") |
|
|
import tabm as _tabm_mod |
|
|
f.write(f"- tabm: {getattr(_tabm_mod, '__version__', 'unknown')}\n") |
|
|
f.write("========================\n") |
|
|
except Exception: |
|
|
pass |
|
|
except Exception: |
|
|
pass |
|
|
def load_training_data(data_file: str) -> tuple[np.ndarray, np.ndarray]: |
|
|
|
|
|
|
|
|
df = pd.read_csv( |
|
|
data_file, |
|
|
sep='\t', |
|
|
header=0, |
|
|
dtype=str, |
|
|
keep_default_na=False, |
|
|
na_filter=False, |
|
|
engine='python', |
|
|
) |
|
|
|
|
|
if df.shape[0] == 0 or df.shape[1] < 2: |
|
|
raise ValueError( |
|
|
f"Incorrect training data format: {data_file}, requires at least 1 label column + 1 feature column, actual shape={df.shape}" |
|
|
) |
|
|
|
|
|
|
|
|
label_col = 'label' if 'label' in df.columns else df.columns[0] |
|
|
|
|
|
|
|
|
y = pd.to_numeric(df[label_col], errors='coerce').fillna(0).astype(np.int64).to_numpy() |
|
|
|
|
|
|
|
|
feature_cols = [c for c in df.columns if c != label_col] |
|
|
if len(feature_cols) == 0: |
|
|
raise ValueError("No feature columns found") |
|
|
|
|
|
X_df = df[feature_cols].apply(pd.to_numeric, errors='coerce').fillna(0.0) |
|
|
X = X_df.to_numpy(dtype=np.float32) |
|
|
|
|
|
return X, y |
|
|
|
|
|
def build_num_embeddings(embedding_type: str, X_fold: np.ndarray) -> tuple[Any, np.ndarray]: |
|
|
used_idx = np.arange(X_fold.shape[1]) |
|
|
if embedding_type == 'piecewise': |
|
|
var = X_fold.var(axis=0) |
|
|
used_idx = np.where(var > 0.0)[0] |
|
|
X_fold = X_fold[:, used_idx] |
|
|
if len(used_idx) < 1: |
|
|
return None, used_idx |
|
|
try: |
|
|
X_tensor = torch.as_tensor(X_fold, dtype=torch.float32) |
|
|
num_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings( |
|
|
rtdl_num_embeddings.compute_bins(X_tensor, n_bins=48), |
|
|
d_embedding=16, |
|
|
activation=False, |
|
|
version='B', |
|
|
) |
|
|
return num_embeddings, used_idx |
|
|
except Exception: |
|
|
return None, used_idx |
|
|
elif embedding_type == 'linear': |
|
|
return rtdl_num_embeddings.LinearReLUEmbeddings(X_fold.shape[1]), used_idx |
|
|
elif embedding_type == 'periodic': |
|
|
return rtdl_num_embeddings.PeriodicEmbeddings(X_fold.shape[1], lite=False), used_idx |
|
|
else: |
|
|
return None, used_idx |
|
|
|
|
|
def make_model(n_features: int, |
|
|
k: int, |
|
|
n_blocks: int, |
|
|
d_block: int, |
|
|
num_embeddings: Any, |
|
|
arch_type: str = 'tabm') -> nn.Module: |
|
|
return tabm.TabM.make( |
|
|
n_num_features=n_features, |
|
|
cat_cardinalities=[], |
|
|
d_out=2, |
|
|
k=k, |
|
|
n_blocks=n_blocks, |
|
|
d_block=d_block, |
|
|
num_embeddings=num_embeddings, |
|
|
arch_type=arch_type, |
|
|
) |
|
|
|
|
|
def train_one_epoch(model: nn.Module, |
|
|
X: torch.Tensor, |
|
|
y: torch.Tensor, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
batch_size: int, |
|
|
device: torch.device) -> float: |
|
|
model.train() |
|
|
indices = torch.randperm(len(X), device=device) |
|
|
batches = indices.split(batch_size) |
|
|
total_loss = 0.0 |
|
|
share_training_batches = True |
|
|
|
|
|
def loss_fn(y_pred: Tensor, y_true: Tensor) -> Tensor: |
|
|
|
|
|
y_pred = y_pred.flatten(0, 1) |
|
|
if share_training_batches: |
|
|
y_true = y_true.repeat_interleave(model.backbone.k) |
|
|
else: |
|
|
y_true = y_true.flatten(0, 1) |
|
|
return nn.functional.cross_entropy(y_pred, y_true) |
|
|
|
|
|
for batch_idx in batches: |
|
|
optimizer.zero_grad() |
|
|
logits = model(X[batch_idx]) |
|
|
loss = loss_fn(logits, y[batch_idx]) |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
total_loss += float(loss.detach().cpu()) |
|
|
return total_loss / max(1, len(batches)) |
|
|
|
|
|
def sum_rank_correct_numpy(y_true: np.ndarray, y_prob: np.ndarray, alpha: float = 0.005) -> float: |
|
|
idx = np.argsort(-y_prob) |
|
|
y_sorted = y_true[idx] |
|
|
r = np.where(y_sorted == 1)[0] |
|
|
return float(np.sum(np.exp(-alpha * r))) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def evaluate_sum_exp_rank(model: nn.Module, X: torch.Tensor, y: torch.Tensor, device: torch.device, alpha: float = 0.005) -> float: |
|
|
model.eval() |
|
|
eval_bs = 8096 |
|
|
logits = torch.cat([ |
|
|
model(X[idx]).mean(1) |
|
|
for idx in torch.arange(len(X), device=device).split(eval_bs) |
|
|
]) |
|
|
probs_pos = torch.softmax(logits, dim=1)[:, 1].cpu().numpy() |
|
|
y_true = y.cpu().numpy() |
|
|
return sum_rank_correct_numpy(y_true, probs_pos, alpha) |
|
|
|
|
|
|
|
|
def objective(params: Dict[str, Any], |
|
|
X: np.ndarray, |
|
|
y: np.ndarray, |
|
|
device: torch.device, |
|
|
seed: int, |
|
|
cv_folds: int, |
|
|
epochs: int, |
|
|
batch_size: int, |
|
|
alpha: float = 0.005) -> Dict[str, Any]: |
|
|
|
|
|
k = int(params.get('k', 32)) |
|
|
n_blocks = int(params['n_blocks']) |
|
|
d_block = int(params['d_block']) |
|
|
lr = float(params['lr']) |
|
|
wd_choice = params['weight_decay_choice'] |
|
|
weight_decay = 0.0 if wd_choice == 0 else float(params['weight_decay_val']) |
|
|
embedding_type = params['embedding_type'] |
|
|
arch_type = params['arch_type'] |
|
|
|
|
|
cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=seed) |
|
|
ap_scores: list[float] = [] |
|
|
|
|
|
for train_idx, val_idx in cv.split(X, y): |
|
|
X_tr = X[train_idx] |
|
|
y_tr = y[train_idx] |
|
|
X_va = X[val_idx] |
|
|
y_va = y[val_idx] |
|
|
|
|
|
num_embeddings, used_idx = build_num_embeddings(embedding_type, X_tr) |
|
|
X_tr_used = X_tr[:, used_idx] if len(used_idx) != X_tr.shape[1] else (X_tr if embedding_type != 'piecewise' else X_tr[:, used_idx]) |
|
|
X_va_used = X_va[:, used_idx] if embedding_type == 'piecewise' else X_va |
|
|
|
|
|
n_features = X_tr_used.shape[1] |
|
|
model = make_model(n_features, k, n_blocks, d_block, num_embeddings, arch_type).to(device) |
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
|
|
|
X_tr_t = torch.as_tensor(X_tr_used, device=device) |
|
|
y_tr_t = torch.as_tensor(y_tr, device=device) |
|
|
X_va_t = torch.as_tensor(X_va_used, device=device) |
|
|
y_va_t = torch.as_tensor(y_va, device=device) |
|
|
|
|
|
for _ in range(epochs): |
|
|
train_one_epoch(model, X_tr_t, y_tr_t, optimizer, batch_size, device) |
|
|
|
|
|
score = evaluate_sum_exp_rank(model, X_va_t, y_va_t, device, alpha) |
|
|
ap_scores.append(score) |
|
|
|
|
|
mean_score = float(np.mean(ap_scores)) |
|
|
return {"loss": -mean_score, "status": STATUS_OK, "score": mean_score} |
|
|
|
|
|
def train_final(X: np.ndarray, |
|
|
y: np.ndarray, |
|
|
best_params: Dict[str, Any], |
|
|
device: torch.device, |
|
|
final_epochs: int, |
|
|
batch_size: int, |
|
|
output_path: str, |
|
|
seed: int, |
|
|
alpha: float = 0.005) -> None: |
|
|
k = int(best_params.get('k', 32)) |
|
|
n_blocks = int(best_params['n_blocks']) |
|
|
d_block = int(best_params['d_block']) |
|
|
lr = float(best_params['lr']) |
|
|
wd_choice = best_params['weight_decay_choice'] |
|
|
weight_decay = 0.0 if wd_choice == 0 else float(best_params['weight_decay_val']) |
|
|
embedding_type = best_params['embedding_type'] |
|
|
arch_type = best_params['arch_type'] |
|
|
|
|
|
num_embeddings, used_idx = build_num_embeddings(embedding_type, X) |
|
|
X_used = X[:, used_idx] if embedding_type == 'piecewise' else X |
|
|
n_features = X_used.shape[1] |
|
|
|
|
|
model = make_model(n_features, k, n_blocks, d_block, num_embeddings, arch_type).to(device) |
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
|
|
|
X_t = torch.as_tensor(X_used, device=device) |
|
|
y_t = torch.as_tensor(y, device=device) |
|
|
|
|
|
for _ in range(final_epochs): |
|
|
train_one_epoch(model, X_t, y_t, optimizer, batch_size, device) |
|
|
|
|
|
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) |
|
|
torch.save({ |
|
|
"model_state_dict": model.state_dict(), |
|
|
"args": argparse.Namespace( |
|
|
k=k, |
|
|
n_blocks=n_blocks, |
|
|
d_block=d_block, |
|
|
use_embeddings=True if embedding_type in ("linear", "periodic", "piecewise") else False, |
|
|
embedding_type=embedding_type, |
|
|
arch_type=arch_type, |
|
|
), |
|
|
"best_params": deepcopy(best_params), |
|
|
"training_args": { |
|
|
"lr": lr, |
|
|
"weight_decay_choice": wd_choice, |
|
|
"weight_decay_val": weight_decay, |
|
|
"batch_size": batch_size, |
|
|
"final_epochs": final_epochs, |
|
|
"seed": seed, |
|
|
"alpha": alpha, |
|
|
"device": str(device), |
|
|
}, |
|
|
"used_feature_idx": used_idx, |
|
|
"full_args": dict( |
|
|
best_params=deepcopy(best_params), |
|
|
final_epochs=final_epochs, batch_size=batch_size, |
|
|
seed=seed, alpha=alpha, device=str(device), |
|
|
), |
|
|
"search_space": "hyperopt space v1", |
|
|
}, output_path) |
|
|
print(f"Final models saved into: {output_path}") |
|
|
_dump_model_info_sidecar(output_path) |
|
|
|
|
|
def hyperopt_search(X: np.ndarray, |
|
|
y: np.ndarray, |
|
|
device: torch.device, |
|
|
seed: int, |
|
|
cv_folds: int, |
|
|
epochs: int, |
|
|
batch_size: int, |
|
|
alpha: float, |
|
|
tune_k: bool, |
|
|
max_evals: int) -> tuple[dict, float]: |
|
|
space = { |
|
|
"n_blocks": scope.int(hp.quniform("n_blocks", 2, 5, 1)), |
|
|
"d_block": scope.int(hp.quniform("d_block", 64, 1024, 16)), |
|
|
"lr": hp.loguniform("lr", np.log(1e-4), np.log(5e-3)), |
|
|
"weight_decay_choice": hp.choice("weight_decay_choice", [0, 1]), |
|
|
"weight_decay_val": hp.loguniform("weight_decay_val", np.log(1e-4), np.log(1e-1)), |
|
|
"embedding_type": hp.choice("embedding_type", ["none", "linear", "periodic", "piecewise"]), |
|
|
"arch_type": hp.choice("arch_type", ["tabm", "tabm-mini"]), |
|
|
} |
|
|
if tune_k: |
|
|
space["k"] = scope.int(hp.quniform("k", 16, 32, 8)) |
|
|
else: |
|
|
space["k"] = 32 |
|
|
|
|
|
def obj_fn(hparams): |
|
|
return objective(hparams, X, y, device, seed, cv_folds, epochs, batch_size, alpha) |
|
|
|
|
|
trials = Trials() |
|
|
best = fmin(fn=obj_fn, space=space, algo=tpe.suggest, max_evals=max_evals, trials=trials) |
|
|
best_trial = min(trials.trials, key=lambda t: t["result"]["loss"]) |
|
|
best_ap = -best_trial["result"]["loss"] |
|
|
best_params = best_trial["misc"]["vals"].copy() |
|
|
|
|
|
emb_choices = ["none", "linear", "periodic", "piecewise"] |
|
|
best_params["embedding_type"] = emb_choices[int(best_params["embedding_type"][0])] if isinstance(best_params["embedding_type"], list) else best_params["embedding_type"] |
|
|
arch_choices = ["tabm", "tabm-mini"] |
|
|
best_params["arch_type"] = arch_choices[int(best_params["arch_type"][0])] if isinstance(best_params["arch_type"], list) else best_params["arch_type"] |
|
|
if isinstance(best_params.get("k", 32), list): |
|
|
best_params["k"] = int(best_params["k"][0]) |
|
|
for k_ in ["n_blocks", "d_block", "weight_decay_choice"]: |
|
|
if isinstance(best_params[k_], list): |
|
|
best_params[k_] = int(best_params[k_][0]) |
|
|
for k_ in ["lr", "weight_decay_val"]: |
|
|
if isinstance(best_params[k_], list): |
|
|
best_params[k_] = float(best_params[k_][0]) |
|
|
|
|
|
return best_params, float(best_ap) |
|
|
|
|
|
def run_one_pipeline(rep_idx: int, |
|
|
X: np.ndarray, |
|
|
y: np.ndarray, |
|
|
device_str: str, |
|
|
args_dict: dict, |
|
|
out_dir: str, |
|
|
base: str, |
|
|
ext: str) -> str: |
|
|
device = torch.device(device_str) |
|
|
rep_seed = int(args_dict["seed"]) + 997 * int(rep_idx) |
|
|
set_seed(rep_seed) |
|
|
|
|
|
print(f"[rep {rep_idx}] π Starting hyperparameter search (max_evals={args_dict['max_evals']}) ...") |
|
|
best_params, best_ap = hyperopt_search( |
|
|
X, y, device, |
|
|
seed=rep_seed, |
|
|
cv_folds=args_dict["cv_folds"], |
|
|
epochs=args_dict["epochs"], |
|
|
batch_size=args_dict["batch_size"], |
|
|
alpha=args_dict["alpha"], |
|
|
tune_k=args_dict["tune_k"], |
|
|
max_evals=args_dict["max_evals"], |
|
|
) |
|
|
print(f"[rep {rep_idx}] π― Best sum_exp_rank={best_ap:.6f}") |
|
|
print(f"[rep {rep_idx}] π― Best parameters={best_params}") |
|
|
|
|
|
out_path = os.path.join(out_dir, f"{base}_rep{rep_idx}{ext}") |
|
|
print(f"[rep {rep_idx}] ποΈ Starting final training and saving to: {out_path}") |
|
|
train_final( |
|
|
X, y, best_params, device, |
|
|
final_epochs=args_dict["final_epochs"], |
|
|
batch_size=args_dict["batch_size"], |
|
|
output_path=out_path, |
|
|
seed=rep_seed, |
|
|
alpha=args_dict["alpha"], |
|
|
) |
|
|
return out_path |
|
|
|
|
|
def main(): |
|
|
|
|
|
ap = argparse.ArgumentParser(description="TabM hyperparameter search (Hyperopt) with internal cross-validation, target=AUPRC; training set only, no external validation/test") |
|
|
ap.add_argument("--data_file", type=str, default="Neopep_ml_with_labels.txt", help="Training data TSV") |
|
|
ap.add_argument("--model_out", type=str, default="tabm_results/tabm_hyperopt_best.pth", help="Final model save path (or base name within directory)") |
|
|
ap.add_argument("--max_evals", type=int, default=30, help="Number of Hyperopt evaluations per parallel repetition") |
|
|
ap.add_argument("--cv_folds", type=int, default=5, help="Number of cross-validation folds") |
|
|
ap.add_argument("--epochs", type=int, default=40, help="Training epochs per fold") |
|
|
ap.add_argument("--final_epochs", type=int, default=120, help="Final model training epochs") |
|
|
ap.add_argument("--batch_size", type=int, default=256, help="Batch size") |
|
|
ap.add_argument("--seed", type=int, default=42, help="Random seed (each repetition will be offset when running in parallel)") |
|
|
ap.add_argument("--alpha", type=float, default=0.005, help="Alpha for sum_exp_rank") |
|
|
ap.add_argument("--tune_k", action="store_true", help="Whether to search for k together (default fixed at 32)") |
|
|
ap.add_argument("--device", type=str, default="auto", help="Device selection: auto/cuda/cpu") |
|
|
ap.add_argument("--nr_hyperopt_rep", type=int, default=1, help="Parallel repetition count: each independent hyperparameter search + final training") |
|
|
args = ap.parse_args() |
|
|
|
|
|
set_seed(args.seed) |
|
|
|
|
|
|
|
|
if args.device == "auto": |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device('cuda:0') |
|
|
print(f"π Detected GPU: {torch.cuda.get_device_name(0)}") |
|
|
print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") |
|
|
print(f" CUDA Version: {torch.version.cuda}") |
|
|
else: |
|
|
device = torch.device('cpu') |
|
|
print("β οΈ No GPU detected, using CPU") |
|
|
elif args.device == "cuda": |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device('cuda:0') |
|
|
print(f"π Forcing GPU usage: {torch.cuda.get_device_name(0)}") |
|
|
else: |
|
|
raise RuntimeError("CUDA specified but no GPU detected") |
|
|
else: |
|
|
device = torch.device('cpu') |
|
|
print("π₯οΈ Using CPU") |
|
|
|
|
|
X, y = load_training_data(args.data_file) |
|
|
print(f"Training data: {X.shape}, Positive sample ratio: {np.mean(y):.5f}") |
|
|
|
|
|
out_dir = os.path.dirname(args.model_out) or '.' |
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
base = os.path.splitext(os.path.basename(args.model_out))[0] |
|
|
ext = os.path.splitext(args.model_out)[1] or '.pth' |
|
|
|
|
|
args_dict = { |
|
|
"seed": int(args.seed), |
|
|
"cv_folds": int(args.cv_folds), |
|
|
"epochs": int(args.epochs), |
|
|
"final_epochs": int(args.final_epochs), |
|
|
"batch_size": int(args.batch_size), |
|
|
"alpha": float(args.alpha), |
|
|
"tune_k": bool(args.tune_k), |
|
|
"max_evals": int(args.max_evals), |
|
|
} |
|
|
|
|
|
from multiprocessing import get_context |
|
|
ctx = get_context('spawn') |
|
|
repeats = int(args.nr_hyperopt_rep) |
|
|
print(f"π§΅ Parallel repetitions: {repeats} (each independent hyperparameter search + final training)") |
|
|
|
|
|
with ctx.Pool(processes=repeats) as pool: |
|
|
paths = pool.starmap( |
|
|
run_one_pipeline, |
|
|
[(i, X, y, str(device), args_dict, out_dir, base, ext) for i in range(repeats)] |
|
|
) |
|
|
print("Saved model files:") |
|
|
for p in sorted(paths): |
|
|
print("-", p) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |