HintsPredictionModel / evaluate.py
Doul0414's picture
Initial upload: HintsPrediction
343e05c verified
import os
import yaml
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import (
average_precision_score,
roc_auc_score,
f1_score,
precision_score,
recall_score,
accuracy_score
)
# 引入你的模型和数据加载器
from models.transmil_q2l import TransMIL_Query2Label_E2E
from thyroid_dataset import create_dataloaders, TARGET_CLASSES
'''
# 18类标签定义 (与训练时保持一致)
TARGET_CLASSES = [
"TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
"TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
"钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留", "切除术后",
"弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
]
'''
def get_best_checkpoint_path(save_dir):
"""自动寻找 best checkpoint"""
best_path = os.path.join(save_dir, 'checkpoint_best.pth')
if os.path.exists(best_path):
return best_path
# 如果没找到 best,找 latest
latest_path = os.path.join(save_dir, 'checkpoint_latest.pth')
if os.path.exists(latest_path):
print(f"Warning: 'checkpoint_best.pth' not found. Using '{latest_path}' instead.")
return latest_path
raise FileNotFoundError(f"No checkpoints found in {save_dir}")
def compute_metrics(y_true, y_pred_probs, threshold=0.5):
"""
计算全面的多标签指标
y_true: [N, num_classes] (0 or 1)
y_pred_probs: [N, num_classes] (0.0 ~ 1.0)
"""
metrics = {}
# 1. 二值化预测
y_pred_binary = (y_pred_probs >= threshold).astype(int)
# 2. 全局指标 (Global Metrics)
# mAP (mean Average Precision) - 最重要的多标签指标
metrics['mAP'] = average_precision_score(y_true, y_pred_probs, average='macro')
metrics['weighted_mAP'] = average_precision_score(y_true, y_pred_probs, average='weighted')
# AUROC (Macro & Micro)
try:
metrics['macro_auroc'] = roc_auc_score(y_true, y_pred_probs, average='macro')
metrics['micro_auroc'] = roc_auc_score(y_true, y_pred_probs, average='micro')
except ValueError:
metrics['macro_auroc'] = 0.0
metrics['micro_auroc'] = 0.0
# F1 Score
metrics['micro_f1'] = f1_score(y_true, y_pred_binary, average='micro')
metrics['macro_f1'] = f1_score(y_true, y_pred_binary, average='macro')
# Exact Match Ratio (Subset Accuracy) - 全对才算对
metrics['subset_accuracy'] = accuracy_score(y_true, y_pred_binary)
# 3. 每类详细指标 (Per-class Metrics)
class_metrics = []
for i, class_name in enumerate(TARGET_CLASSES):
# 提取当前类的真实标签和预测概率
yt = y_true[:, i]
yp = y_pred_probs[:, i]
yb = y_pred_binary[:, i]
# 样本数
support = int(yt.sum())
# 如果该类没有正样本,部分指标无法计算
if support > 0:
ap = average_precision_score(yt, yp)
try:
auroc = roc_auc_score(yt, yp)
except ValueError:
auroc = 0.5 # 只有一个类别存在时无法计算AUC
f1 = f1_score(yt, yb)
rec = recall_score(yt, yb)
prec = precision_score(yt, yb, zero_division=0)
else:
ap, auroc, f1, rec, prec = 0.0, 0.5, 0.0, 0.0, 0.0
class_metrics.append({
"Class": class_name,
"Support": support,
"AP": ap,
"AUROC": auroc,
"F1": f1,
"Precision": prec,
"Recall": rec
})
return metrics, pd.DataFrame(class_metrics)
def main():
# 1. 加载配置
config_path = 'config.yaml' # 确保这里路径正确
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Evaluating on {device}")
# 2. 准备数据加载器
print("Loading Test Data...")
_, _, test_loader = create_dataloaders(config)
# 3. 初始化模型
print("Initializing Model...")
model = TransMIL_Query2Label_E2E(
num_class=config['model']['num_class'],
hidden_dim=config['model']['hidden_dim'],
nheads=config['model']['nheads'],
num_decoder_layers=config['model']['num_decoder_layers'],
pretrained_resnet=False, # 推理时不需要下载预训练权重,直接加载我们自己的权重
use_checkpointing=False, # 推理时不需要 checkpointing
use_ppeg=config['model'].get('use_ppeg', False)
)
# 4. 加载权重
ckpt_path = get_best_checkpoint_path(config['training']['save_dir'])
print(f"Loading checkpoint from: {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
# 处理 state_dict 键名可能不匹配的问题 (如 module. 前缀)
state_dict = checkpoint['model_state_dict']
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace("module.", "")
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
# 5. 推理循环
print("Running Inference...")
all_preds = []
all_targets = []
with torch.no_grad():
for batch in tqdm(test_loader):
images = batch['images'].to(device)
num_instances = batch['num_instances_per_case']
labels = batch['labels'].numpy() # CPU numpy
# Forward
logits = model(images, num_instances)
probs = torch.sigmoid(logits).cpu().numpy()
all_preds.append(probs)
all_targets.append(labels)
# 拼接
y_pred_probs = np.concatenate(all_preds, axis=0)
y_true = np.concatenate(all_targets, axis=0)
# 6. 计算指标
print("\nComputing Metrics...")
global_metrics, class_df = compute_metrics(y_true, y_pred_probs)
# 7. 打印结果
print("\n" + "="*60)
print(" GLOBAL PERFORMANCE SUMMARY ")
print("="*60)
print(f" mAP (Macro) : {global_metrics['mAP']:.4f}")
print(f" mAP (Weighted): {global_metrics['weighted_mAP']:.4f}")
print(f" AUROC (Macro) : {global_metrics['macro_auroc']:.4f}")
print(f" AUROC (Micro) : {global_metrics['micro_auroc']:.4f}")
print(f" F1 (Micro) : {global_metrics['micro_f1']:.4f}")
print(f" F1 (Macro) : {global_metrics['macro_f1']:.4f}")
print(f" Subset Acc : {global_metrics['subset_accuracy']:.4f}")
print("-" * 60)
print("\n" + "="*100)
print(" PER-CLASS PERFORMANCE DETAILS (Sorted by Support) ")
print("="*100)
# 按样本数量排序
class_df = class_df.sort_values(by='Support', ascending=False)
# --- 开始修改:手动格式化打印 ---
# 定义表头
# 中文字符宽度处理技巧:给 Class 列预留足够大的空间 (比如30)
# {:<N} 左对齐, {:>N} 右对齐
headers = ["Class", "Support", "AP", "AUROC", "F1", "Precision", "Recall"]
# 打印表头
# {0:<24} 表示第一列左对齐占24格
head_fmt = "{:<24} {:>8} {:>10} {:>10} {:>10} {:>12} {:>10}"
print(head_fmt.format(*headers))
print("-" * 100)
# 打印每一行
row_fmt = "{:<24} {:>8d} {:>10.4f} {:>10.4f} {:>10.4f} {:>12.4f} {:>10.4f}"
for _, row in class_df.iterrows():
cls_name = row['Class']
display_width = len(cls_name.encode('gbk'))
# 计算需要填充的空格数
# 目标宽度 24 - 实际显示宽度
target_width = 24
padding = target_width - display_width
# 构造对齐后的字符串
aligned_name = cls_name + " " * padding
print(f"{aligned_name} {int(row['Support']):>8d} {row['AP']:>10.4f} {row['AUROC']:>10.4f} {row['F1']:>10.4f} {row['Precision']:>12.4f} {row['Recall']:>10.4f}")
print("="*100)
# 保存结果到 CSV
result_csv = os.path.join(config['training']['save_dir'], 'evaluation_report.csv')
class_df.to_csv(result_csv, index=False, encoding='utf-8-sig')
print(f"\nDetailed report saved to: {result_csv}")
if __name__ == "__main__":
main()