|
|
import os
|
|
|
import yaml
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
from tqdm import tqdm
|
|
|
from sklearn.metrics import (
|
|
|
average_precision_score,
|
|
|
roc_auc_score,
|
|
|
f1_score,
|
|
|
precision_score,
|
|
|
recall_score,
|
|
|
accuracy_score
|
|
|
)
|
|
|
|
|
|
|
|
|
from models.transmil_q2l import TransMIL_Query2Label_E2E
|
|
|
from thyroid_dataset import create_dataloaders, TARGET_CLASSES
|
|
|
'''
|
|
|
# 18类标签定义 (与训练时保持一致)
|
|
|
TARGET_CLASSES = [
|
|
|
"TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
|
|
|
"TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
|
|
|
"钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留", "切除术后",
|
|
|
"弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
|
|
|
]
|
|
|
'''
|
|
|
|
|
|
def get_best_checkpoint_path(save_dir):
|
|
|
"""自动寻找 best checkpoint"""
|
|
|
best_path = os.path.join(save_dir, 'checkpoint_best.pth')
|
|
|
if os.path.exists(best_path):
|
|
|
return best_path
|
|
|
|
|
|
latest_path = os.path.join(save_dir, 'checkpoint_latest.pth')
|
|
|
if os.path.exists(latest_path):
|
|
|
print(f"Warning: 'checkpoint_best.pth' not found. Using '{latest_path}' instead.")
|
|
|
return latest_path
|
|
|
raise FileNotFoundError(f"No checkpoints found in {save_dir}")
|
|
|
|
|
|
def compute_metrics(y_true, y_pred_probs, threshold=0.5):
|
|
|
"""
|
|
|
计算全面的多标签指标
|
|
|
y_true: [N, num_classes] (0 or 1)
|
|
|
y_pred_probs: [N, num_classes] (0.0 ~ 1.0)
|
|
|
"""
|
|
|
metrics = {}
|
|
|
|
|
|
|
|
|
y_pred_binary = (y_pred_probs >= threshold).astype(int)
|
|
|
|
|
|
|
|
|
|
|
|
metrics['mAP'] = average_precision_score(y_true, y_pred_probs, average='macro')
|
|
|
metrics['weighted_mAP'] = average_precision_score(y_true, y_pred_probs, average='weighted')
|
|
|
|
|
|
|
|
|
try:
|
|
|
metrics['macro_auroc'] = roc_auc_score(y_true, y_pred_probs, average='macro')
|
|
|
metrics['micro_auroc'] = roc_auc_score(y_true, y_pred_probs, average='micro')
|
|
|
except ValueError:
|
|
|
metrics['macro_auroc'] = 0.0
|
|
|
metrics['micro_auroc'] = 0.0
|
|
|
|
|
|
|
|
|
metrics['micro_f1'] = f1_score(y_true, y_pred_binary, average='micro')
|
|
|
metrics['macro_f1'] = f1_score(y_true, y_pred_binary, average='macro')
|
|
|
|
|
|
|
|
|
metrics['subset_accuracy'] = accuracy_score(y_true, y_pred_binary)
|
|
|
|
|
|
|
|
|
class_metrics = []
|
|
|
for i, class_name in enumerate(TARGET_CLASSES):
|
|
|
|
|
|
yt = y_true[:, i]
|
|
|
yp = y_pred_probs[:, i]
|
|
|
yb = y_pred_binary[:, i]
|
|
|
|
|
|
|
|
|
support = int(yt.sum())
|
|
|
|
|
|
|
|
|
if support > 0:
|
|
|
ap = average_precision_score(yt, yp)
|
|
|
try:
|
|
|
auroc = roc_auc_score(yt, yp)
|
|
|
except ValueError:
|
|
|
auroc = 0.5
|
|
|
|
|
|
f1 = f1_score(yt, yb)
|
|
|
rec = recall_score(yt, yb)
|
|
|
prec = precision_score(yt, yb, zero_division=0)
|
|
|
else:
|
|
|
ap, auroc, f1, rec, prec = 0.0, 0.5, 0.0, 0.0, 0.0
|
|
|
|
|
|
class_metrics.append({
|
|
|
"Class": class_name,
|
|
|
"Support": support,
|
|
|
"AP": ap,
|
|
|
"AUROC": auroc,
|
|
|
"F1": f1,
|
|
|
"Precision": prec,
|
|
|
"Recall": rec
|
|
|
})
|
|
|
|
|
|
return metrics, pd.DataFrame(class_metrics)
|
|
|
|
|
|
def main():
|
|
|
|
|
|
config_path = 'config.yaml'
|
|
|
with open(config_path, 'r') as f:
|
|
|
config = yaml.safe_load(f)
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
print(f"Evaluating on {device}")
|
|
|
|
|
|
|
|
|
print("Loading Test Data...")
|
|
|
_, _, test_loader = create_dataloaders(config)
|
|
|
|
|
|
|
|
|
print("Initializing Model...")
|
|
|
model = TransMIL_Query2Label_E2E(
|
|
|
num_class=config['model']['num_class'],
|
|
|
hidden_dim=config['model']['hidden_dim'],
|
|
|
nheads=config['model']['nheads'],
|
|
|
num_decoder_layers=config['model']['num_decoder_layers'],
|
|
|
pretrained_resnet=False,
|
|
|
use_checkpointing=False,
|
|
|
use_ppeg=config['model'].get('use_ppeg', False)
|
|
|
)
|
|
|
|
|
|
|
|
|
ckpt_path = get_best_checkpoint_path(config['training']['save_dir'])
|
|
|
print(f"Loading checkpoint from: {ckpt_path}")
|
|
|
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
|
|
|
|
|
|
|
|
state_dict = checkpoint['model_state_dict']
|
|
|
new_state_dict = {}
|
|
|
for k, v in state_dict.items():
|
|
|
name = k.replace("module.", "")
|
|
|
new_state_dict[name] = v
|
|
|
model.load_state_dict(new_state_dict)
|
|
|
|
|
|
model.to(device)
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
print("Running Inference...")
|
|
|
all_preds = []
|
|
|
all_targets = []
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for batch in tqdm(test_loader):
|
|
|
images = batch['images'].to(device)
|
|
|
num_instances = batch['num_instances_per_case']
|
|
|
labels = batch['labels'].numpy()
|
|
|
|
|
|
|
|
|
logits = model(images, num_instances)
|
|
|
probs = torch.sigmoid(logits).cpu().numpy()
|
|
|
|
|
|
all_preds.append(probs)
|
|
|
all_targets.append(labels)
|
|
|
|
|
|
|
|
|
y_pred_probs = np.concatenate(all_preds, axis=0)
|
|
|
y_true = np.concatenate(all_targets, axis=0)
|
|
|
|
|
|
|
|
|
print("\nComputing Metrics...")
|
|
|
global_metrics, class_df = compute_metrics(y_true, y_pred_probs)
|
|
|
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
print(" GLOBAL PERFORMANCE SUMMARY ")
|
|
|
print("="*60)
|
|
|
print(f" mAP (Macro) : {global_metrics['mAP']:.4f}")
|
|
|
print(f" mAP (Weighted): {global_metrics['weighted_mAP']:.4f}")
|
|
|
print(f" AUROC (Macro) : {global_metrics['macro_auroc']:.4f}")
|
|
|
print(f" AUROC (Micro) : {global_metrics['micro_auroc']:.4f}")
|
|
|
print(f" F1 (Micro) : {global_metrics['micro_f1']:.4f}")
|
|
|
print(f" F1 (Macro) : {global_metrics['macro_f1']:.4f}")
|
|
|
print(f" Subset Acc : {global_metrics['subset_accuracy']:.4f}")
|
|
|
print("-" * 60)
|
|
|
|
|
|
print("\n" + "="*100)
|
|
|
print(" PER-CLASS PERFORMANCE DETAILS (Sorted by Support) ")
|
|
|
print("="*100)
|
|
|
|
|
|
|
|
|
class_df = class_df.sort_values(by='Support', ascending=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
headers = ["Class", "Support", "AP", "AUROC", "F1", "Precision", "Recall"]
|
|
|
|
|
|
|
|
|
|
|
|
head_fmt = "{:<24} {:>8} {:>10} {:>10} {:>10} {:>12} {:>10}"
|
|
|
print(head_fmt.format(*headers))
|
|
|
print("-" * 100)
|
|
|
|
|
|
|
|
|
row_fmt = "{:<24} {:>8d} {:>10.4f} {:>10.4f} {:>10.4f} {:>12.4f} {:>10.4f}"
|
|
|
|
|
|
for _, row in class_df.iterrows():
|
|
|
cls_name = row['Class']
|
|
|
|
|
|
|
|
|
display_width = len(cls_name.encode('gbk'))
|
|
|
|
|
|
|
|
|
|
|
|
target_width = 24
|
|
|
padding = target_width - display_width
|
|
|
|
|
|
|
|
|
aligned_name = cls_name + " " * padding
|
|
|
|
|
|
print(f"{aligned_name} {int(row['Support']):>8d} {row['AP']:>10.4f} {row['AUROC']:>10.4f} {row['F1']:>10.4f} {row['Precision']:>12.4f} {row['Recall']:>10.4f}")
|
|
|
|
|
|
print("="*100)
|
|
|
|
|
|
|
|
|
result_csv = os.path.join(config['training']['save_dir'], 'evaluation_report.csv')
|
|
|
class_df.to_csv(result_csv, index=False, encoding='utf-8-sig')
|
|
|
print(f"\nDetailed report saved to: {result_csv}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |