import os import yaml import torch import numpy as np import pandas as pd from tqdm import tqdm from sklearn.metrics import ( average_precision_score, roc_auc_score, f1_score, precision_score, recall_score, accuracy_score ) # 引入你的模型和数据加载器 from models.transmil_q2l import TransMIL_Query2Label_E2E from thyroid_dataset import create_dataloaders, TARGET_CLASSES ''' # 18类标签定义 (与训练时保持一致) TARGET_CLASSES = [ "TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级", "TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级", "钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留", "切除术后", "弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性" ] ''' def get_best_checkpoint_path(save_dir): """自动寻找 best checkpoint""" best_path = os.path.join(save_dir, 'checkpoint_best.pth') if os.path.exists(best_path): return best_path # 如果没找到 best,找 latest latest_path = os.path.join(save_dir, 'checkpoint_latest.pth') if os.path.exists(latest_path): print(f"Warning: 'checkpoint_best.pth' not found. Using '{latest_path}' instead.") return latest_path raise FileNotFoundError(f"No checkpoints found in {save_dir}") def compute_metrics(y_true, y_pred_probs, threshold=0.5): """ 计算全面的多标签指标 y_true: [N, num_classes] (0 or 1) y_pred_probs: [N, num_classes] (0.0 ~ 1.0) """ metrics = {} # 1. 二值化预测 y_pred_binary = (y_pred_probs >= threshold).astype(int) # 2. 全局指标 (Global Metrics) # mAP (mean Average Precision) - 最重要的多标签指标 metrics['mAP'] = average_precision_score(y_true, y_pred_probs, average='macro') metrics['weighted_mAP'] = average_precision_score(y_true, y_pred_probs, average='weighted') # AUROC (Macro & Micro) try: metrics['macro_auroc'] = roc_auc_score(y_true, y_pred_probs, average='macro') metrics['micro_auroc'] = roc_auc_score(y_true, y_pred_probs, average='micro') except ValueError: metrics['macro_auroc'] = 0.0 metrics['micro_auroc'] = 0.0 # F1 Score metrics['micro_f1'] = f1_score(y_true, y_pred_binary, average='micro') metrics['macro_f1'] = f1_score(y_true, y_pred_binary, average='macro') # Exact Match Ratio (Subset Accuracy) - 全对才算对 metrics['subset_accuracy'] = accuracy_score(y_true, y_pred_binary) # 3. 每类详细指标 (Per-class Metrics) class_metrics = [] for i, class_name in enumerate(TARGET_CLASSES): # 提取当前类的真实标签和预测概率 yt = y_true[:, i] yp = y_pred_probs[:, i] yb = y_pred_binary[:, i] # 样本数 support = int(yt.sum()) # 如果该类没有正样本,部分指标无法计算 if support > 0: ap = average_precision_score(yt, yp) try: auroc = roc_auc_score(yt, yp) except ValueError: auroc = 0.5 # 只有一个类别存在时无法计算AUC f1 = f1_score(yt, yb) rec = recall_score(yt, yb) prec = precision_score(yt, yb, zero_division=0) else: ap, auroc, f1, rec, prec = 0.0, 0.5, 0.0, 0.0, 0.0 class_metrics.append({ "Class": class_name, "Support": support, "AP": ap, "AUROC": auroc, "F1": f1, "Precision": prec, "Recall": rec }) return metrics, pd.DataFrame(class_metrics) def main(): # 1. 加载配置 config_path = 'config.yaml' # 确保这里路径正确 with open(config_path, 'r') as f: config = yaml.safe_load(f) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Evaluating on {device}") # 2. 准备数据加载器 print("Loading Test Data...") _, _, test_loader = create_dataloaders(config) # 3. 初始化模型 print("Initializing Model...") model = TransMIL_Query2Label_E2E( num_class=config['model']['num_class'], hidden_dim=config['model']['hidden_dim'], nheads=config['model']['nheads'], num_decoder_layers=config['model']['num_decoder_layers'], pretrained_resnet=False, # 推理时不需要下载预训练权重,直接加载我们自己的权重 use_checkpointing=False, # 推理时不需要 checkpointing use_ppeg=config['model'].get('use_ppeg', False) ) # 4. 加载权重 ckpt_path = get_best_checkpoint_path(config['training']['save_dir']) print(f"Loading checkpoint from: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) # 处理 state_dict 键名可能不匹配的问题 (如 module. 前缀) state_dict = checkpoint['model_state_dict'] new_state_dict = {} for k, v in state_dict.items(): name = k.replace("module.", "") new_state_dict[name] = v model.load_state_dict(new_state_dict) model.to(device) model.eval() # 5. 推理循环 print("Running Inference...") all_preds = [] all_targets = [] with torch.no_grad(): for batch in tqdm(test_loader): images = batch['images'].to(device) num_instances = batch['num_instances_per_case'] labels = batch['labels'].numpy() # CPU numpy # Forward logits = model(images, num_instances) probs = torch.sigmoid(logits).cpu().numpy() all_preds.append(probs) all_targets.append(labels) # 拼接 y_pred_probs = np.concatenate(all_preds, axis=0) y_true = np.concatenate(all_targets, axis=0) # 6. 计算指标 print("\nComputing Metrics...") global_metrics, class_df = compute_metrics(y_true, y_pred_probs) # 7. 打印结果 print("\n" + "="*60) print(" GLOBAL PERFORMANCE SUMMARY ") print("="*60) print(f" mAP (Macro) : {global_metrics['mAP']:.4f}") print(f" mAP (Weighted): {global_metrics['weighted_mAP']:.4f}") print(f" AUROC (Macro) : {global_metrics['macro_auroc']:.4f}") print(f" AUROC (Micro) : {global_metrics['micro_auroc']:.4f}") print(f" F1 (Micro) : {global_metrics['micro_f1']:.4f}") print(f" F1 (Macro) : {global_metrics['macro_f1']:.4f}") print(f" Subset Acc : {global_metrics['subset_accuracy']:.4f}") print("-" * 60) print("\n" + "="*100) print(" PER-CLASS PERFORMANCE DETAILS (Sorted by Support) ") print("="*100) # 按样本数量排序 class_df = class_df.sort_values(by='Support', ascending=False) # --- 开始修改:手动格式化打印 --- # 定义表头 # 中文字符宽度处理技巧:给 Class 列预留足够大的空间 (比如30) # {:N} 右对齐 headers = ["Class", "Support", "AP", "AUROC", "F1", "Precision", "Recall"] # 打印表头 # {0:<24} 表示第一列左对齐占24格 head_fmt = "{:<24} {:>8} {:>10} {:>10} {:>10} {:>12} {:>10}" print(head_fmt.format(*headers)) print("-" * 100) # 打印每一行 row_fmt = "{:<24} {:>8d} {:>10.4f} {:>10.4f} {:>10.4f} {:>12.4f} {:>10.4f}" for _, row in class_df.iterrows(): cls_name = row['Class'] display_width = len(cls_name.encode('gbk')) # 计算需要填充的空格数 # 目标宽度 24 - 实际显示宽度 target_width = 24 padding = target_width - display_width # 构造对齐后的字符串 aligned_name = cls_name + " " * padding print(f"{aligned_name} {int(row['Support']):>8d} {row['AP']:>10.4f} {row['AUROC']:>10.4f} {row['F1']:>10.4f} {row['Precision']:>12.4f} {row['Recall']:>10.4f}") print("="*100) # 保存结果到 CSV result_csv = os.path.join(config['training']['save_dir'], 'evaluation_report.csv') class_df.to_csv(result_csv, index=False, encoding='utf-8-sig') print(f"\nDetailed report saved to: {result_csv}") if __name__ == "__main__": main()