File size: 1,283 Bytes
70520f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, f1_score


def evaluate(model, loader: DataLoader, device: torch.device) -> dict:
    model.eval()
    all_labels, all_scores = [], []
    total_iou, n = 0.0, 0

    with torch.no_grad():
        for imgs, masks, labels in loader:
            imgs, masks = imgs.to(device), masks.to(device)
            pred_mask, pred_logit = model(imgs)

            probs = torch.sigmoid(pred_logit).view(-1).cpu().numpy()

            all_scores.extend(probs.tolist())
            all_labels.extend(labels.numpy().tolist())

            pred_bin = (pred_mask > 0.5).float()
            gt_bin   = (masks > 0.5).float()
            intersection = (pred_bin * gt_bin).sum()
            union        = pred_bin.sum() + gt_bin.sum() - intersection
            total_iou   += (intersection / (union + 1e-6)).item()
            n           += 1

    preds = [1 if s > 0.5 else 0 for s in all_scores]
    auc   = roc_auc_score(all_labels, all_scores) if len(set(all_labels)) > 1 else 0.0

    return {
        'auc':       round(float(auc), 4),
        'f1':        round(float(f1_score(all_labels, preds, zero_division=0)), 4),
        'pixel_iou': round(total_iou / max(n, 1), 4),
    }