|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
import tabm |
|
|
from sklearn.metrics import precision_recall_curve, auc |
|
|
|
|
|
def normalize_rt(s: pd.Series) -> pd.Series: |
|
|
return s.astype(str).str.strip().str.upper() |
|
|
|
|
|
def compute_patient_metrics(df_p: pd.DataFrame, y_prob: np.ndarray) -> tuple: |
|
|
X_r = df_p.copy() |
|
|
X_r['ML_pred'] = y_prob |
|
|
X_r['response'] = (normalize_rt(X_r['response_type']) == 'CD8').astype(int) |
|
|
|
|
|
X_r = X_r.sort_values(by=['ML_pred'], ascending=False).reset_index(drop=True) |
|
|
|
|
|
idx_pos = np.where(X_r['response'].to_numpy() == 1)[0] |
|
|
idx_tested = np.where(normalize_rt(X_r['response_type']) == 'NEGATIVE')[0] |
|
|
|
|
|
def topk_counts(k: int): |
|
|
k_eff = min(k, len(X_r)) |
|
|
nr_correct = int(np.sum(idx_pos < k_eff)) |
|
|
nr_tested = nr_correct + int(np.sum(idx_tested < k_eff)) |
|
|
return nr_correct, nr_tested |
|
|
|
|
|
nr_correct20, nr_tested20 = topk_counts(20) |
|
|
nr_correct50, nr_tested50 = topk_counts(50) |
|
|
nr_correct100, nr_tested100 = topk_counts(100) |
|
|
|
|
|
nr_immuno = int(np.sum(X_r['response'] == 1)) |
|
|
y_true = X_r['response'].to_numpy() |
|
|
y_pred = X_r['ML_pred'].to_numpy() |
|
|
|
|
|
alpha = 0.005 |
|
|
score = float(np.sum(np.exp(-alpha * idx_pos))) |
|
|
|
|
|
if nr_immuno > 0: |
|
|
sort_idx = np.argsort(idx_pos) |
|
|
ranks_str = ",".join([f"{int(r+1)}" for r in idx_pos[sort_idx]]) |
|
|
mut_seqs = X_r.loc[X_r['response'] == 1, 'mutant_seq'].to_numpy() |
|
|
mut_seqs_str = ",".join([str(s) for s in mut_seqs[sort_idx]]) |
|
|
genes = X_r.loc[X_r['response'] == 1, 'gene'].to_numpy() |
|
|
genes_str = ",".join([str(g) for g in genes[sort_idx]]) |
|
|
else: |
|
|
ranks_str = "" |
|
|
mut_seqs_str = "" |
|
|
genes_str = "" |
|
|
|
|
|
return (X_r['ML_pred'].to_numpy(), X_r, |
|
|
nr_correct20, nr_tested20, |
|
|
nr_correct50, nr_tested50, |
|
|
nr_correct100, nr_tested100, |
|
|
nr_immuno, idx_pos, score, |
|
|
ranks_str, mut_seqs_str, genes_str) |
|
|
|
|
|
|
|
|
def predict_in_batches(model, X_all, device, batch_size=1024): |
|
|
model.eval() |
|
|
y_prob_all = [] |
|
|
|
|
|
with torch.inference_mode(): |
|
|
for i in range(0, len(X_all), batch_size): |
|
|
batch_end = min(i + batch_size, len(X_all)) |
|
|
batch_X = X_all[i:batch_end].to(device) |
|
|
|
|
|
batch_pred = model(batch_X).mean(1) |
|
|
batch_pred = torch.softmax(batch_pred, dim=1)[:, 1] |
|
|
|
|
|
y_prob_all.append(batch_pred.cpu()) |
|
|
|
|
|
del batch_X, batch_pred |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return torch.cat(y_prob_all, dim=0).numpy() |
|
|
|
|
|
def main(): |
|
|
|
|
|
ap = argparse.ArgumentParser(description="TabM model evaluation, output format consistent with TestVotingClassifier") |
|
|
ap.add_argument("--model_file", type=str, required=False, help="TabM model file, e.g. tabm_results/tabm_model.pth (mutually exclusive with --model_files/--model_glob, choose one of three)") |
|
|
ap.add_argument("--model_files", type=str, nargs='*', default=None, help="Multiple model files for equal-weighted average prediction") |
|
|
ap.add_argument("--model_glob", type=str, default=None, help="Use wildcards to match multiple model files (e.g. tabm_results/tabm_hyperopt_best_rep*.pth)") |
|
|
ap.add_argument("--data_file", type=str, required=True, help="Input TSV: TestVoting_selection_neopep.tsv") |
|
|
ap.add_argument("--output_file", type=str, required=True, help="Main result output file (header consistent with original)") |
|
|
ap.add_argument("--tesla_file", type=str, default=None, help="TESLA score output file (for neopep task)") |
|
|
ap.add_argument("--output_xlsx", type=str, default=None, help="Main result Excel output path (optional)") |
|
|
ap.add_argument("--tesla_xlsx", type=str, default=None, help="TESLA result Excel output path (optional)") |
|
|
ap.add_argument("--dataset_name", type=str, default=None, help="If no dataset column exists, use this value as Dataset column in TESLA") |
|
|
ap.add_argument("--skip_no_cd8", action="store_true", help="Skip patients without CD8") |
|
|
ap.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], |
|
|
help="Device selection: auto/cuda/cpu") |
|
|
ap.add_argument("--batch_size", type=int, default=1024, |
|
|
help="Batch size to avoid GPU memory overflow (default 1024)") |
|
|
args = ap.parse_args() |
|
|
|
|
|
|
|
|
if args.device == "auto": |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device('cuda:0') |
|
|
print(f"🚀 Auto-selected GPU: {torch.cuda.get_device_name(0)}") |
|
|
print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") |
|
|
else: |
|
|
device = torch.device('cpu') |
|
|
print("⚠️ No GPU detected, using CPU") |
|
|
elif args.device == "cuda": |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device('cuda:0') |
|
|
print(f"🚀 Force using GPU: {torch.cuda.get_device_name(0)}") |
|
|
else: |
|
|
raise RuntimeError("CUDA specified but no GPU detected") |
|
|
else: |
|
|
device = torch.device('cpu') |
|
|
print("️ Using CPU") |
|
|
|
|
|
print(f" Batch size: {args.batch_size}") |
|
|
|
|
|
|
|
|
df = pd.read_csv(args.data_file, sep="\t", header=0, low_memory=False) |
|
|
print(f"📈 Data shape: {df.shape}") |
|
|
|
|
|
|
|
|
required_cols = ["patient", "response_type", "gene", "mutant_seq"] |
|
|
for c in required_cols: |
|
|
if c not in df.columns: |
|
|
raise KeyError(f"Missing required column: {c}") |
|
|
|
|
|
|
|
|
feature_cols = [c for c in df.columns if c not in required_cols] |
|
|
|
|
|
X_all = df[feature_cols].apply(pd.to_numeric, errors='coerce').fillna(0.0).to_numpy() |
|
|
print(f" Number of features: {X_all.shape[1]}") |
|
|
|
|
|
|
|
|
import glob as _glob |
|
|
model_paths: list[str] = [] |
|
|
if args.model_files: |
|
|
model_paths.extend(list(args.model_files)) |
|
|
if args.model_glob: |
|
|
model_paths.extend(sorted(_glob.glob(args.model_glob))) |
|
|
if not model_paths and args.model_file: |
|
|
model_paths = [args.model_file] |
|
|
if not model_paths: |
|
|
raise FileNotFoundError("No model files found, please check!") |
|
|
|
|
|
first_ckpt = torch.load(model_paths[0], map_location='cpu', weights_only=False) |
|
|
model_args = first_ckpt['args'] |
|
|
|
|
|
def _predict_with_model(model_path: str, X_all_np: np.ndarray) -> np.ndarray: |
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"Model file not existed: {model_path}") |
|
|
ckpt = torch.load(model_path, map_location='cpu', weights_only=False) |
|
|
m_args = ckpt['args'] |
|
|
X_np = X_all_np |
|
|
if ckpt.get("used_feature_idx") is not None: |
|
|
try: |
|
|
ufi = ckpt["used_feature_idx"] |
|
|
import numpy as _np |
|
|
ufi_arr = _np.array(ufi, dtype=int) |
|
|
max_idx = X_np.shape[1] - 1 |
|
|
ufi_arr = ufi_arr[(ufi_arr >= 0) & (ufi_arr <= max_idx)] |
|
|
if len(ufi_arr) > 0: |
|
|
X_np = X_np[:, ufi_arr] |
|
|
except Exception: |
|
|
pass |
|
|
X_tensor_cpu = torch.as_tensor(X_np, dtype=torch.float32) |
|
|
num_embeddings = None |
|
|
if getattr(m_args, 'use_embeddings', False): |
|
|
if m_args.embedding_type == 'linear': |
|
|
import rtdl_num_embeddings |
|
|
num_embeddings = rtdl_num_embeddings.LinearReLUEmbeddings(X_tensor_cpu.shape[1]) |
|
|
elif m_args.embedding_type == 'periodic': |
|
|
import rtdl_num_embeddings |
|
|
num_embeddings = rtdl_num_embeddings.PeriodicEmbeddings(X_tensor_cpu.shape[1], lite=False) |
|
|
elif m_args.embedding_type == 'piecewise': |
|
|
import rtdl_num_embeddings |
|
|
num_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings( |
|
|
rtdl_num_embeddings.compute_bins(X_tensor_cpu, n_bins=48), |
|
|
d_embedding=16, |
|
|
activation=False, |
|
|
version='B', |
|
|
) |
|
|
model = tabm.TabM.make( |
|
|
n_num_features=X_tensor_cpu.shape[1], |
|
|
cat_cardinalities=[], |
|
|
d_out=2, |
|
|
k=m_args.k, |
|
|
n_blocks=m_args.n_blocks, |
|
|
d_block=m_args.d_block, |
|
|
num_embeddings=num_embeddings, |
|
|
arch_type=getattr(m_args, 'arch_type', 'tabm'), |
|
|
) |
|
|
model.load_state_dict(ckpt['model_state_dict']) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
bs = max(256, args.batch_size) |
|
|
probs_list = [] |
|
|
n = len(X_tensor_cpu) |
|
|
with torch.inference_mode(): |
|
|
for i in range(0, n, bs): |
|
|
j = min(i + bs, n) |
|
|
xb = X_tensor_cpu[i:j].to(device) |
|
|
logits = model(xb).mean(1) |
|
|
probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() |
|
|
probs_list.append(probs) |
|
|
del xb, logits |
|
|
if torch.cuda.is_available() and device.type == 'cuda': |
|
|
torch.cuda.empty_cache() |
|
|
if (i // bs) % 50 == 0: |
|
|
print(f" batch {i//bs}/{(n+bs-1)//bs}") |
|
|
return np.concatenate(probs_list, axis=0) |
|
|
|
|
|
def _stringify(v): |
|
|
try: |
|
|
return repr(v) |
|
|
except Exception: |
|
|
try: |
|
|
return str(v) |
|
|
except Exception: |
|
|
return "<unprintable>" |
|
|
|
|
|
print("===== Saved Hyperparameters from checkpoint['args'] =====") |
|
|
if hasattr(model_args, "__dict__"): |
|
|
hp_items = sorted(vars(model_args).items()) |
|
|
elif isinstance(model_args, dict): |
|
|
hp_items = sorted(model_args.items()) |
|
|
else: |
|
|
try: |
|
|
hp_items = sorted(model_args.__dict__.items()) |
|
|
except Exception: |
|
|
hp_items = [] |
|
|
print("⚠️ Unable to enumerate contents of model_args") |
|
|
for key, val in hp_items: |
|
|
print(f"- {key}: {_stringify(val)}") |
|
|
print("=========================================================") |
|
|
|
|
|
def _p_dict(title, d): |
|
|
try: |
|
|
print(title) |
|
|
for k in sorted(d.keys()): |
|
|
try: |
|
|
print(f"- {k}: {repr(d[k])}") |
|
|
except Exception: |
|
|
print(f"- {k}: <unprintable>") |
|
|
print("=" * len(title)) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if isinstance(first_ckpt.get("training_args"), dict): |
|
|
_p_dict("===== checkpoint['training_args'] =====", first_ckpt["training_args"]) |
|
|
|
|
|
if isinstance(first_ckpt.get("best_params"), dict): |
|
|
_p_dict("===== checkpoint['best_params'] =====", first_ckpt["best_params"]) |
|
|
|
|
|
if isinstance(first_ckpt.get("full_args"), dict): |
|
|
_p_dict("===== checkpoint['full_args'] =====", first_ckpt["full_args"]) |
|
|
|
|
|
if first_ckpt.get("used_feature_idx") is not None: |
|
|
try: |
|
|
ufi = first_ckpt["used_feature_idx"] |
|
|
print("===== used_feature_idx =====") |
|
|
print(f"- length: {len(ufi)}") |
|
|
print(f"- head: {list(ufi[:10])}") |
|
|
print("=" * 25) |
|
|
except Exception: |
|
|
print("===== used_feature_idx =====\n<unprintable>\n============================") |
|
|
|
|
|
try: |
|
|
print("===== Environment =====") |
|
|
print(f"- torch: {torch.__version__}") |
|
|
print(f"- cuda available: {torch.cuda.is_available()}") |
|
|
if torch.cuda.is_available(): |
|
|
print(f"- device: {torch.cuda.get_device_name(0)}") |
|
|
print(f"- cuda version: {torch.version.cuda}") |
|
|
import tabm as _tabm_mod |
|
|
print(f"- tabm: {getattr(_tabm_mod, '__version__', 'unknown')}") |
|
|
print("========================") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
n_models = len(model_paths) |
|
|
print(f"🔗 Loading {n_models} models for equal-weighted average prediction...") |
|
|
y_prob_all = None |
|
|
for mp in model_paths: |
|
|
print(f" -> {mp}") |
|
|
probs = _predict_with_model(mp, X_all) |
|
|
if y_prob_all is None: |
|
|
y_prob_all = probs.astype(np.float64) |
|
|
else: |
|
|
y_prob_all += probs |
|
|
y_prob_all = (y_prob_all / float(n_models)).astype(np.float64) |
|
|
|
|
|
print(f"✅ Prediction completed, total {len(y_prob_all)} samples; number of models={n_models}") |
|
|
|
|
|
rows_main = [] |
|
|
rows_tesla = [] |
|
|
|
|
|
need_header = (not os.path.exists(args.output_file)) or (os.path.getsize(args.output_file) == 0) |
|
|
with open(args.output_file, "a", encoding="utf-8") as f: |
|
|
if need_header: |
|
|
f.write("Patient\tNr_correct_top20\tNr_tested_top20\tNr_correct_top50\tNr_tested_top50\t" |
|
|
"Nr_correct_top100\tNr_tested_top100\tNr_immunogenic\tNr_peptides\tClf_score\t" |
|
|
"CD8_ranks\tCD8_mut_seqs\tCD8_genes\n") |
|
|
|
|
|
for patient, df_p in df.groupby("patient", sort=False): |
|
|
has_cd8 = (normalize_rt(df_p["response_type"]) == "CD8").any() |
|
|
if args.skip_no_cd8 and not has_cd8: |
|
|
continue |
|
|
|
|
|
idx = df_p.index.to_numpy() |
|
|
y_prob = y_prob_all[idx] |
|
|
|
|
|
(y_pred_sorted, X_sorted, |
|
|
nr_correct20, nr_tested20, |
|
|
nr_correct50, nr_tested50, |
|
|
nr_correct100, nr_tested100, |
|
|
nr_immuno, r, score, |
|
|
ranks_str, mut_seqs_str, genes_str) = compute_patient_metrics(df_p, y_prob) |
|
|
|
|
|
f.write(f"{patient}\t{nr_correct20}\t{nr_tested20}\t{nr_correct50}\t{nr_tested50}\t" |
|
|
f"{nr_correct100}\t{nr_tested100}\t{nr_immuno}\t{len(df_p)}\t{score:.6f}\t" |
|
|
f"{ranks_str}\t{mut_seqs_str}\t{genes_str}\n") |
|
|
|
|
|
rows_main.append({ |
|
|
"Patient": patient, |
|
|
"Nr_correct_top20": nr_correct20, |
|
|
"Nr_tested_top20": nr_tested20, |
|
|
"Nr_correct_top50": nr_correct50, |
|
|
"Nr_tested_top50": nr_tested50, |
|
|
"Nr_correct_top100": nr_correct100, |
|
|
"Nr_tested_top100": nr_tested100, |
|
|
"Nr_immunogenic": nr_immuno, |
|
|
"Nr_peptides": len(df_p), |
|
|
"Clf_score": score, |
|
|
"CD8_ranks": ranks_str, |
|
|
"CD8_mut_seqs": mut_seqs_str, |
|
|
"CD8_genes": genes_str, |
|
|
}) |
|
|
|
|
|
if args.tesla_file or args.tesla_xlsx: |
|
|
if "dataset" in df_p.columns: |
|
|
dataset_val = str(df_p["dataset"].iloc[0]) |
|
|
else: |
|
|
dataset_val = args.dataset_name if args.dataset_name is not None else "" |
|
|
idx_nt = X_sorted['response_type'].astype(str) != 'not_tested' |
|
|
y_pred_tesla = pd.Series(y_pred_sorted)[idx_nt].to_numpy() |
|
|
y_tesla = X_sorted.loc[idx_nt, 'response'].to_numpy() |
|
|
ttif = (nr_correct20 / nr_tested20) if nr_tested20 > 0 else 0.0 |
|
|
fr = (nr_correct100 / nr_immuno) if nr_immuno > 0 else 0.0 |
|
|
precision, recall, _ = precision_recall_curve(y_tesla, y_pred_tesla) |
|
|
auprc = auc(recall, precision) |
|
|
|
|
|
if args.tesla_file: |
|
|
new_tesla = (not os.path.exists(args.tesla_file)) or (os.path.getsize(args.tesla_file) == 0) |
|
|
with open(args.tesla_file, "a", encoding="utf-8") as tf: |
|
|
if new_tesla: |
|
|
tf.write("Dataset\tPatient\tTTIF\tFR\tAUPRC\n") |
|
|
tf.write(f"{dataset_val}\t{patient}\t{ttif:.3f}\t{fr:.3f}\t{auprc:.3f}\n") |
|
|
|
|
|
rows_tesla.append({ |
|
|
"Dataset": dataset_val, |
|
|
"Patient": patient, |
|
|
"TTIF": ttif, |
|
|
"FR": fr, |
|
|
"AUPRC": auprc, |
|
|
}) |
|
|
|
|
|
if args.output_xlsx and rows_main: |
|
|
os.makedirs(os.path.dirname(args.output_xlsx) or '.', exist_ok=True) |
|
|
pd.DataFrame(rows_main).to_excel(args.output_xlsx, index=False) |
|
|
if args.tesla_xlsx and rows_tesla: |
|
|
os.makedirs(os.path.dirname(args.tesla_xlsx) or '.', exist_ok=True) |
|
|
pd.DataFrame(rows_tesla).to_excel(args.tesla_xlsx, index=False) |
|
|
|
|
|
print(f" Evaluation completed! Processed {len(rows_main)} patients") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |