TabM / src /tabm_eval.py
NeoDiscoveryAdmin's picture
add test model and the train, test files
20eb53e
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
import numpy as np
import pandas as pd
import torch
import tabm
from sklearn.metrics import precision_recall_curve, auc
def normalize_rt(s: pd.Series) -> pd.Series:
return s.astype(str).str.strip().str.upper()
def compute_patient_metrics(df_p: pd.DataFrame, y_prob: np.ndarray) -> tuple:
X_r = df_p.copy()
X_r['ML_pred'] = y_prob
X_r['response'] = (normalize_rt(X_r['response_type']) == 'CD8').astype(int)
X_r = X_r.sort_values(by=['ML_pred'], ascending=False).reset_index(drop=True)
idx_pos = np.where(X_r['response'].to_numpy() == 1)[0]
idx_tested = np.where(normalize_rt(X_r['response_type']) == 'NEGATIVE')[0]
def topk_counts(k: int):
k_eff = min(k, len(X_r))
nr_correct = int(np.sum(idx_pos < k_eff))
nr_tested = nr_correct + int(np.sum(idx_tested < k_eff))
return nr_correct, nr_tested
nr_correct20, nr_tested20 = topk_counts(20)
nr_correct50, nr_tested50 = topk_counts(50)
nr_correct100, nr_tested100 = topk_counts(100)
nr_immuno = int(np.sum(X_r['response'] == 1))
y_true = X_r['response'].to_numpy()
y_pred = X_r['ML_pred'].to_numpy()
alpha = 0.005
score = float(np.sum(np.exp(-alpha * idx_pos)))
if nr_immuno > 0:
sort_idx = np.argsort(idx_pos)
ranks_str = ",".join([f"{int(r+1)}" for r in idx_pos[sort_idx]])
mut_seqs = X_r.loc[X_r['response'] == 1, 'mutant_seq'].to_numpy()
mut_seqs_str = ",".join([str(s) for s in mut_seqs[sort_idx]])
genes = X_r.loc[X_r['response'] == 1, 'gene'].to_numpy()
genes_str = ",".join([str(g) for g in genes[sort_idx]])
else:
ranks_str = ""
mut_seqs_str = ""
genes_str = ""
return (X_r['ML_pred'].to_numpy(), X_r,
nr_correct20, nr_tested20,
nr_correct50, nr_tested50,
nr_correct100, nr_tested100,
nr_immuno, idx_pos, score,
ranks_str, mut_seqs_str, genes_str)
def predict_in_batches(model, X_all, device, batch_size=1024):
model.eval()
y_prob_all = []
with torch.inference_mode():
for i in range(0, len(X_all), batch_size):
batch_end = min(i + batch_size, len(X_all))
batch_X = X_all[i:batch_end].to(device)
batch_pred = model(batch_X).mean(1)
batch_pred = torch.softmax(batch_pred, dim=1)[:, 1]
y_prob_all.append(batch_pred.cpu())
del batch_X, batch_pred
if torch.cuda.is_available():
torch.cuda.empty_cache()
return torch.cat(y_prob_all, dim=0).numpy()
def main():
ap = argparse.ArgumentParser(description="TabM model evaluation, output format consistent with TestVotingClassifier")
ap.add_argument("--model_file", type=str, required=False, help="TabM model file, e.g. tabm_results/tabm_model.pth (mutually exclusive with --model_files/--model_glob, choose one of three)")
ap.add_argument("--model_files", type=str, nargs='*', default=None, help="Multiple model files for equal-weighted average prediction")
ap.add_argument("--model_glob", type=str, default=None, help="Use wildcards to match multiple model files (e.g. tabm_results/tabm_hyperopt_best_rep*.pth)")
ap.add_argument("--data_file", type=str, required=True, help="Input TSV: TestVoting_selection_neopep.tsv")
ap.add_argument("--output_file", type=str, required=True, help="Main result output file (header consistent with original)")
ap.add_argument("--tesla_file", type=str, default=None, help="TESLA score output file (for neopep task)")
ap.add_argument("--output_xlsx", type=str, default=None, help="Main result Excel output path (optional)")
ap.add_argument("--tesla_xlsx", type=str, default=None, help="TESLA result Excel output path (optional)")
ap.add_argument("--dataset_name", type=str, default=None, help="If no dataset column exists, use this value as Dataset column in TESLA")
ap.add_argument("--skip_no_cd8", action="store_true", help="Skip patients without CD8")
ap.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"],
help="Device selection: auto/cuda/cpu")
ap.add_argument("--batch_size", type=int, default=1024,
help="Batch size to avoid GPU memory overflow (default 1024)")
args = ap.parse_args()
# device selection
if args.device == "auto":
if torch.cuda.is_available():
device = torch.device('cuda:0')
print(f"🚀 Auto-selected GPU: {torch.cuda.get_device_name(0)}")
print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
device = torch.device('cpu')
print("⚠️ No GPU detected, using CPU")
elif args.device == "cuda":
if torch.cuda.is_available():
device = torch.device('cuda:0')
print(f"🚀 Force using GPU: {torch.cuda.get_device_name(0)}")
else:
raise RuntimeError("CUDA specified but no GPU detected")
else:
device = torch.device('cpu')
print("️ Using CPU")
print(f" Batch size: {args.batch_size}")
# Read data
df = pd.read_csv(args.data_file, sep="\t", header=0, low_memory=False)
print(f"📈 Data shape: {df.shape}")
# Required columns check
required_cols = ["patient", "response_type", "gene", "mutant_seq"]
for c in required_cols:
if c not in df.columns:
raise KeyError(f"Missing required column: {c}")
# Feature columns = all columns except metadata columns
feature_cols = [c for c in df.columns if c not in required_cols]
# Dynamically read numeric features (no fixed column count processing)
X_all = df[feature_cols].apply(pd.to_numeric, errors='coerce').fillna(0.0).to_numpy()
print(f" Number of features: {X_all.shape[1]}")
# model files parsing
import glob as _glob
model_paths: list[str] = []
if args.model_files:
model_paths.extend(list(args.model_files))
if args.model_glob:
model_paths.extend(sorted(_glob.glob(args.model_glob)))
if not model_paths and args.model_file:
model_paths = [args.model_file]
if not model_paths:
raise FileNotFoundError("No model files found, please check!")
first_ckpt = torch.load(model_paths[0], map_location='cpu', weights_only=False)
model_args = first_ckpt['args']
def _predict_with_model(model_path: str, X_all_np: np.ndarray) -> np.ndarray:
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not existed: {model_path}")
ckpt = torch.load(model_path, map_location='cpu', weights_only=False)
m_args = ckpt['args']
X_np = X_all_np
if ckpt.get("used_feature_idx") is not None:
try:
ufi = ckpt["used_feature_idx"]
import numpy as _np
ufi_arr = _np.array(ufi, dtype=int)
max_idx = X_np.shape[1] - 1
ufi_arr = ufi_arr[(ufi_arr >= 0) & (ufi_arr <= max_idx)]
if len(ufi_arr) > 0:
X_np = X_np[:, ufi_arr]
except Exception:
pass
X_tensor_cpu = torch.as_tensor(X_np, dtype=torch.float32)
num_embeddings = None
if getattr(m_args, 'use_embeddings', False):
if m_args.embedding_type == 'linear':
import rtdl_num_embeddings
num_embeddings = rtdl_num_embeddings.LinearReLUEmbeddings(X_tensor_cpu.shape[1])
elif m_args.embedding_type == 'periodic':
import rtdl_num_embeddings
num_embeddings = rtdl_num_embeddings.PeriodicEmbeddings(X_tensor_cpu.shape[1], lite=False)
elif m_args.embedding_type == 'piecewise':
import rtdl_num_embeddings
num_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
rtdl_num_embeddings.compute_bins(X_tensor_cpu, n_bins=48),
d_embedding=16,
activation=False,
version='B',
)
model = tabm.TabM.make(
n_num_features=X_tensor_cpu.shape[1],
cat_cardinalities=[],
d_out=2,
k=m_args.k,
n_blocks=m_args.n_blocks,
d_block=m_args.d_block,
num_embeddings=num_embeddings,
arch_type=getattr(m_args, 'arch_type', 'tabm'),
)
model.load_state_dict(ckpt['model_state_dict'])
model.to(device)
model.eval()
bs = max(256, args.batch_size)
probs_list = []
n = len(X_tensor_cpu)
with torch.inference_mode():
for i in range(0, n, bs):
j = min(i + bs, n)
xb = X_tensor_cpu[i:j].to(device)
logits = model(xb).mean(1)
probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()
probs_list.append(probs)
del xb, logits
if torch.cuda.is_available() and device.type == 'cuda':
torch.cuda.empty_cache()
if (i // bs) % 50 == 0:
print(f" batch {i//bs}/{(n+bs-1)//bs}")
return np.concatenate(probs_list, axis=0)
def _stringify(v):
try:
return repr(v)
except Exception:
try:
return str(v)
except Exception:
return "<unprintable>"
print("===== Saved Hyperparameters from checkpoint['args'] =====")
if hasattr(model_args, "__dict__"):
hp_items = sorted(vars(model_args).items())
elif isinstance(model_args, dict):
hp_items = sorted(model_args.items())
else:
try:
hp_items = sorted(model_args.__dict__.items())
except Exception:
hp_items = []
print("⚠️ Unable to enumerate contents of model_args")
for key, val in hp_items:
print(f"- {key}: {_stringify(val)}")
print("=========================================================")
def _p_dict(title, d):
try:
print(title)
for k in sorted(d.keys()):
try:
print(f"- {k}: {repr(d[k])}")
except Exception:
print(f"- {k}: <unprintable>")
print("=" * len(title))
except Exception:
pass
if isinstance(first_ckpt.get("training_args"), dict):
_p_dict("===== checkpoint['training_args'] =====", first_ckpt["training_args"])
if isinstance(first_ckpt.get("best_params"), dict):
_p_dict("===== checkpoint['best_params'] =====", first_ckpt["best_params"])
if isinstance(first_ckpt.get("full_args"), dict):
_p_dict("===== checkpoint['full_args'] =====", first_ckpt["full_args"])
if first_ckpt.get("used_feature_idx") is not None:
try:
ufi = first_ckpt["used_feature_idx"]
print("===== used_feature_idx =====")
print(f"- length: {len(ufi)}")
print(f"- head: {list(ufi[:10])}")
print("=" * 25)
except Exception:
print("===== used_feature_idx =====\n<unprintable>\n============================")
try:
print("===== Environment =====")
print(f"- torch: {torch.__version__}")
print(f"- cuda available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"- device: {torch.cuda.get_device_name(0)}")
print(f"- cuda version: {torch.version.cuda}")
import tabm as _tabm_mod
print(f"- tabm: {getattr(_tabm_mod, '__version__', 'unknown')}")
print("========================")
except Exception:
pass
n_models = len(model_paths)
print(f"🔗 Loading {n_models} models for equal-weighted average prediction...")
y_prob_all = None
for mp in model_paths:
print(f" -> {mp}")
probs = _predict_with_model(mp, X_all)
if y_prob_all is None:
y_prob_all = probs.astype(np.float64)
else:
y_prob_all += probs
y_prob_all = (y_prob_all / float(n_models)).astype(np.float64)
print(f"✅ Prediction completed, total {len(y_prob_all)} samples; number of models={n_models}")
rows_main = []
rows_tesla = []
need_header = (not os.path.exists(args.output_file)) or (os.path.getsize(args.output_file) == 0)
with open(args.output_file, "a", encoding="utf-8") as f:
if need_header:
f.write("Patient\tNr_correct_top20\tNr_tested_top20\tNr_correct_top50\tNr_tested_top50\t"
"Nr_correct_top100\tNr_tested_top100\tNr_immunogenic\tNr_peptides\tClf_score\t"
"CD8_ranks\tCD8_mut_seqs\tCD8_genes\n")
for patient, df_p in df.groupby("patient", sort=False):
has_cd8 = (normalize_rt(df_p["response_type"]) == "CD8").any()
if args.skip_no_cd8 and not has_cd8:
continue
idx = df_p.index.to_numpy()
y_prob = y_prob_all[idx]
(y_pred_sorted, X_sorted,
nr_correct20, nr_tested20,
nr_correct50, nr_tested50,
nr_correct100, nr_tested100,
nr_immuno, r, score,
ranks_str, mut_seqs_str, genes_str) = compute_patient_metrics(df_p, y_prob)
f.write(f"{patient}\t{nr_correct20}\t{nr_tested20}\t{nr_correct50}\t{nr_tested50}\t"
f"{nr_correct100}\t{nr_tested100}\t{nr_immuno}\t{len(df_p)}\t{score:.6f}\t"
f"{ranks_str}\t{mut_seqs_str}\t{genes_str}\n")
rows_main.append({
"Patient": patient,
"Nr_correct_top20": nr_correct20,
"Nr_tested_top20": nr_tested20,
"Nr_correct_top50": nr_correct50,
"Nr_tested_top50": nr_tested50,
"Nr_correct_top100": nr_correct100,
"Nr_tested_top100": nr_tested100,
"Nr_immunogenic": nr_immuno,
"Nr_peptides": len(df_p),
"Clf_score": score,
"CD8_ranks": ranks_str,
"CD8_mut_seqs": mut_seqs_str,
"CD8_genes": genes_str,
})
if args.tesla_file or args.tesla_xlsx:
if "dataset" in df_p.columns:
dataset_val = str(df_p["dataset"].iloc[0])
else:
dataset_val = args.dataset_name if args.dataset_name is not None else ""
idx_nt = X_sorted['response_type'].astype(str) != 'not_tested'
y_pred_tesla = pd.Series(y_pred_sorted)[idx_nt].to_numpy()
y_tesla = X_sorted.loc[idx_nt, 'response'].to_numpy()
ttif = (nr_correct20 / nr_tested20) if nr_tested20 > 0 else 0.0
fr = (nr_correct100 / nr_immuno) if nr_immuno > 0 else 0.0
precision, recall, _ = precision_recall_curve(y_tesla, y_pred_tesla)
auprc = auc(recall, precision)
if args.tesla_file:
new_tesla = (not os.path.exists(args.tesla_file)) or (os.path.getsize(args.tesla_file) == 0)
with open(args.tesla_file, "a", encoding="utf-8") as tf:
if new_tesla:
tf.write("Dataset\tPatient\tTTIF\tFR\tAUPRC\n")
tf.write(f"{dataset_val}\t{patient}\t{ttif:.3f}\t{fr:.3f}\t{auprc:.3f}\n")
rows_tesla.append({
"Dataset": dataset_val,
"Patient": patient,
"TTIF": ttif,
"FR": fr,
"AUPRC": auprc,
})
if args.output_xlsx and rows_main:
os.makedirs(os.path.dirname(args.output_xlsx) or '.', exist_ok=True)
pd.DataFrame(rows_main).to_excel(args.output_xlsx, index=False)
if args.tesla_xlsx and rows_tesla:
os.makedirs(os.path.dirname(args.tesla_xlsx) or '.', exist_ok=True)
pd.DataFrame(rows_tesla).to_excel(args.tesla_xlsx, index=False)
print(f" Evaluation completed! Processed {len(rows_main)} patients")
if __name__ == "__main__":
main()