#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse import os import numpy as np import pandas as pd import torch import tabm from sklearn.metrics import precision_recall_curve, auc def normalize_rt(s: pd.Series) -> pd.Series: return s.astype(str).str.strip().str.upper() def compute_patient_metrics(df_p: pd.DataFrame, y_prob: np.ndarray) -> tuple: X_r = df_p.copy() X_r['ML_pred'] = y_prob X_r['response'] = (normalize_rt(X_r['response_type']) == 'CD8').astype(int) X_r = X_r.sort_values(by=['ML_pred'], ascending=False).reset_index(drop=True) idx_pos = np.where(X_r['response'].to_numpy() == 1)[0] idx_tested = np.where(normalize_rt(X_r['response_type']) == 'NEGATIVE')[0] def topk_counts(k: int): k_eff = min(k, len(X_r)) nr_correct = int(np.sum(idx_pos < k_eff)) nr_tested = nr_correct + int(np.sum(idx_tested < k_eff)) return nr_correct, nr_tested nr_correct20, nr_tested20 = topk_counts(20) nr_correct50, nr_tested50 = topk_counts(50) nr_correct100, nr_tested100 = topk_counts(100) nr_immuno = int(np.sum(X_r['response'] == 1)) y_true = X_r['response'].to_numpy() y_pred = X_r['ML_pred'].to_numpy() alpha = 0.005 score = float(np.sum(np.exp(-alpha * idx_pos))) if nr_immuno > 0: sort_idx = np.argsort(idx_pos) ranks_str = ",".join([f"{int(r+1)}" for r in idx_pos[sort_idx]]) mut_seqs = X_r.loc[X_r['response'] == 1, 'mutant_seq'].to_numpy() mut_seqs_str = ",".join([str(s) for s in mut_seqs[sort_idx]]) genes = X_r.loc[X_r['response'] == 1, 'gene'].to_numpy() genes_str = ",".join([str(g) for g in genes[sort_idx]]) else: ranks_str = "" mut_seqs_str = "" genes_str = "" return (X_r['ML_pred'].to_numpy(), X_r, nr_correct20, nr_tested20, nr_correct50, nr_tested50, nr_correct100, nr_tested100, nr_immuno, idx_pos, score, ranks_str, mut_seqs_str, genes_str) def predict_in_batches(model, X_all, device, batch_size=1024): model.eval() y_prob_all = [] with torch.inference_mode(): for i in range(0, len(X_all), batch_size): batch_end = min(i + batch_size, len(X_all)) batch_X = X_all[i:batch_end].to(device) batch_pred = model(batch_X).mean(1) batch_pred = torch.softmax(batch_pred, dim=1)[:, 1] y_prob_all.append(batch_pred.cpu()) del batch_X, batch_pred if torch.cuda.is_available(): torch.cuda.empty_cache() return torch.cat(y_prob_all, dim=0).numpy() def main(): ap = argparse.ArgumentParser(description="TabM model evaluation, output format consistent with TestVotingClassifier") ap.add_argument("--model_file", type=str, required=False, help="TabM model file, e.g. tabm_results/tabm_model.pth (mutually exclusive with --model_files/--model_glob, choose one of three)") ap.add_argument("--model_files", type=str, nargs='*', default=None, help="Multiple model files for equal-weighted average prediction") ap.add_argument("--model_glob", type=str, default=None, help="Use wildcards to match multiple model files (e.g. tabm_results/tabm_hyperopt_best_rep*.pth)") ap.add_argument("--data_file", type=str, required=True, help="Input TSV: TestVoting_selection_neopep.tsv") ap.add_argument("--output_file", type=str, required=True, help="Main result output file (header consistent with original)") ap.add_argument("--tesla_file", type=str, default=None, help="TESLA score output file (for neopep task)") ap.add_argument("--output_xlsx", type=str, default=None, help="Main result Excel output path (optional)") ap.add_argument("--tesla_xlsx", type=str, default=None, help="TESLA result Excel output path (optional)") ap.add_argument("--dataset_name", type=str, default=None, help="If no dataset column exists, use this value as Dataset column in TESLA") ap.add_argument("--skip_no_cd8", action="store_true", help="Skip patients without CD8") ap.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], help="Device selection: auto/cuda/cpu") ap.add_argument("--batch_size", type=int, default=1024, help="Batch size to avoid GPU memory overflow (default 1024)") args = ap.parse_args() # device selection if args.device == "auto": if torch.cuda.is_available(): device = torch.device('cuda:0') print(f"🚀 Auto-selected GPU: {torch.cuda.get_device_name(0)}") print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") else: device = torch.device('cpu') print("⚠️ No GPU detected, using CPU") elif args.device == "cuda": if torch.cuda.is_available(): device = torch.device('cuda:0') print(f"🚀 Force using GPU: {torch.cuda.get_device_name(0)}") else: raise RuntimeError("CUDA specified but no GPU detected") else: device = torch.device('cpu') print("️ Using CPU") print(f" Batch size: {args.batch_size}") # Read data df = pd.read_csv(args.data_file, sep="\t", header=0, low_memory=False) print(f"📈 Data shape: {df.shape}") # Required columns check required_cols = ["patient", "response_type", "gene", "mutant_seq"] for c in required_cols: if c not in df.columns: raise KeyError(f"Missing required column: {c}") # Feature columns = all columns except metadata columns feature_cols = [c for c in df.columns if c not in required_cols] # Dynamically read numeric features (no fixed column count processing) X_all = df[feature_cols].apply(pd.to_numeric, errors='coerce').fillna(0.0).to_numpy() print(f" Number of features: {X_all.shape[1]}") # model files parsing import glob as _glob model_paths: list[str] = [] if args.model_files: model_paths.extend(list(args.model_files)) if args.model_glob: model_paths.extend(sorted(_glob.glob(args.model_glob))) if not model_paths and args.model_file: model_paths = [args.model_file] if not model_paths: raise FileNotFoundError("No model files found, please check!") first_ckpt = torch.load(model_paths[0], map_location='cpu', weights_only=False) model_args = first_ckpt['args'] def _predict_with_model(model_path: str, X_all_np: np.ndarray) -> np.ndarray: if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not existed: {model_path}") ckpt = torch.load(model_path, map_location='cpu', weights_only=False) m_args = ckpt['args'] X_np = X_all_np if ckpt.get("used_feature_idx") is not None: try: ufi = ckpt["used_feature_idx"] import numpy as _np ufi_arr = _np.array(ufi, dtype=int) max_idx = X_np.shape[1] - 1 ufi_arr = ufi_arr[(ufi_arr >= 0) & (ufi_arr <= max_idx)] if len(ufi_arr) > 0: X_np = X_np[:, ufi_arr] except Exception: pass X_tensor_cpu = torch.as_tensor(X_np, dtype=torch.float32) num_embeddings = None if getattr(m_args, 'use_embeddings', False): if m_args.embedding_type == 'linear': import rtdl_num_embeddings num_embeddings = rtdl_num_embeddings.LinearReLUEmbeddings(X_tensor_cpu.shape[1]) elif m_args.embedding_type == 'periodic': import rtdl_num_embeddings num_embeddings = rtdl_num_embeddings.PeriodicEmbeddings(X_tensor_cpu.shape[1], lite=False) elif m_args.embedding_type == 'piecewise': import rtdl_num_embeddings num_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings( rtdl_num_embeddings.compute_bins(X_tensor_cpu, n_bins=48), d_embedding=16, activation=False, version='B', ) model = tabm.TabM.make( n_num_features=X_tensor_cpu.shape[1], cat_cardinalities=[], d_out=2, k=m_args.k, n_blocks=m_args.n_blocks, d_block=m_args.d_block, num_embeddings=num_embeddings, arch_type=getattr(m_args, 'arch_type', 'tabm'), ) model.load_state_dict(ckpt['model_state_dict']) model.to(device) model.eval() bs = max(256, args.batch_size) probs_list = [] n = len(X_tensor_cpu) with torch.inference_mode(): for i in range(0, n, bs): j = min(i + bs, n) xb = X_tensor_cpu[i:j].to(device) logits = model(xb).mean(1) probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() probs_list.append(probs) del xb, logits if torch.cuda.is_available() and device.type == 'cuda': torch.cuda.empty_cache() if (i // bs) % 50 == 0: print(f" batch {i//bs}/{(n+bs-1)//bs}") return np.concatenate(probs_list, axis=0) def _stringify(v): try: return repr(v) except Exception: try: return str(v) except Exception: return "" print("===== Saved Hyperparameters from checkpoint['args'] =====") if hasattr(model_args, "__dict__"): hp_items = sorted(vars(model_args).items()) elif isinstance(model_args, dict): hp_items = sorted(model_args.items()) else: try: hp_items = sorted(model_args.__dict__.items()) except Exception: hp_items = [] print("⚠️ Unable to enumerate contents of model_args") for key, val in hp_items: print(f"- {key}: {_stringify(val)}") print("=========================================================") def _p_dict(title, d): try: print(title) for k in sorted(d.keys()): try: print(f"- {k}: {repr(d[k])}") except Exception: print(f"- {k}: ") print("=" * len(title)) except Exception: pass if isinstance(first_ckpt.get("training_args"), dict): _p_dict("===== checkpoint['training_args'] =====", first_ckpt["training_args"]) if isinstance(first_ckpt.get("best_params"), dict): _p_dict("===== checkpoint['best_params'] =====", first_ckpt["best_params"]) if isinstance(first_ckpt.get("full_args"), dict): _p_dict("===== checkpoint['full_args'] =====", first_ckpt["full_args"]) if first_ckpt.get("used_feature_idx") is not None: try: ufi = first_ckpt["used_feature_idx"] print("===== used_feature_idx =====") print(f"- length: {len(ufi)}") print(f"- head: {list(ufi[:10])}") print("=" * 25) except Exception: print("===== used_feature_idx =====\n\n============================") try: print("===== Environment =====") print(f"- torch: {torch.__version__}") print(f"- cuda available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"- device: {torch.cuda.get_device_name(0)}") print(f"- cuda version: {torch.version.cuda}") import tabm as _tabm_mod print(f"- tabm: {getattr(_tabm_mod, '__version__', 'unknown')}") print("========================") except Exception: pass n_models = len(model_paths) print(f"🔗 Loading {n_models} models for equal-weighted average prediction...") y_prob_all = None for mp in model_paths: print(f" -> {mp}") probs = _predict_with_model(mp, X_all) if y_prob_all is None: y_prob_all = probs.astype(np.float64) else: y_prob_all += probs y_prob_all = (y_prob_all / float(n_models)).astype(np.float64) print(f"✅ Prediction completed, total {len(y_prob_all)} samples; number of models={n_models}") rows_main = [] rows_tesla = [] need_header = (not os.path.exists(args.output_file)) or (os.path.getsize(args.output_file) == 0) with open(args.output_file, "a", encoding="utf-8") as f: if need_header: f.write("Patient\tNr_correct_top20\tNr_tested_top20\tNr_correct_top50\tNr_tested_top50\t" "Nr_correct_top100\tNr_tested_top100\tNr_immunogenic\tNr_peptides\tClf_score\t" "CD8_ranks\tCD8_mut_seqs\tCD8_genes\n") for patient, df_p in df.groupby("patient", sort=False): has_cd8 = (normalize_rt(df_p["response_type"]) == "CD8").any() if args.skip_no_cd8 and not has_cd8: continue idx = df_p.index.to_numpy() y_prob = y_prob_all[idx] (y_pred_sorted, X_sorted, nr_correct20, nr_tested20, nr_correct50, nr_tested50, nr_correct100, nr_tested100, nr_immuno, r, score, ranks_str, mut_seqs_str, genes_str) = compute_patient_metrics(df_p, y_prob) f.write(f"{patient}\t{nr_correct20}\t{nr_tested20}\t{nr_correct50}\t{nr_tested50}\t" f"{nr_correct100}\t{nr_tested100}\t{nr_immuno}\t{len(df_p)}\t{score:.6f}\t" f"{ranks_str}\t{mut_seqs_str}\t{genes_str}\n") rows_main.append({ "Patient": patient, "Nr_correct_top20": nr_correct20, "Nr_tested_top20": nr_tested20, "Nr_correct_top50": nr_correct50, "Nr_tested_top50": nr_tested50, "Nr_correct_top100": nr_correct100, "Nr_tested_top100": nr_tested100, "Nr_immunogenic": nr_immuno, "Nr_peptides": len(df_p), "Clf_score": score, "CD8_ranks": ranks_str, "CD8_mut_seqs": mut_seqs_str, "CD8_genes": genes_str, }) if args.tesla_file or args.tesla_xlsx: if "dataset" in df_p.columns: dataset_val = str(df_p["dataset"].iloc[0]) else: dataset_val = args.dataset_name if args.dataset_name is not None else "" idx_nt = X_sorted['response_type'].astype(str) != 'not_tested' y_pred_tesla = pd.Series(y_pred_sorted)[idx_nt].to_numpy() y_tesla = X_sorted.loc[idx_nt, 'response'].to_numpy() ttif = (nr_correct20 / nr_tested20) if nr_tested20 > 0 else 0.0 fr = (nr_correct100 / nr_immuno) if nr_immuno > 0 else 0.0 precision, recall, _ = precision_recall_curve(y_tesla, y_pred_tesla) auprc = auc(recall, precision) if args.tesla_file: new_tesla = (not os.path.exists(args.tesla_file)) or (os.path.getsize(args.tesla_file) == 0) with open(args.tesla_file, "a", encoding="utf-8") as tf: if new_tesla: tf.write("Dataset\tPatient\tTTIF\tFR\tAUPRC\n") tf.write(f"{dataset_val}\t{patient}\t{ttif:.3f}\t{fr:.3f}\t{auprc:.3f}\n") rows_tesla.append({ "Dataset": dataset_val, "Patient": patient, "TTIF": ttif, "FR": fr, "AUPRC": auprc, }) if args.output_xlsx and rows_main: os.makedirs(os.path.dirname(args.output_xlsx) or '.', exist_ok=True) pd.DataFrame(rows_main).to_excel(args.output_xlsx, index=False) if args.tesla_xlsx and rows_tesla: os.makedirs(os.path.dirname(args.tesla_xlsx) or '.', exist_ok=True) pd.DataFrame(rows_tesla).to_excel(args.tesla_xlsx, index=False) print(f" Evaluation completed! Processed {len(rows_main)} patients") if __name__ == "__main__": main()