File size: 4,439 Bytes
bf07f10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# analyze_results.py
import os, sys, csv, argparse, numpy as np, matplotlib.pyplot as plt
from PIL import Image
import torch, torch.nn as nn, torchvision.transforms as T
import timm, torchvision.models as tvmodels
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import cv2

# Add parent directory to path for imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

from src.utils import get_device, get_model, get_transforms

def load_csv(path):
    with open(path) as f:
        reader = csv.DictReader(f)
        return [r for r in reader]

def save_confusion(cm, labels, out_path):
    fig, ax = plt.subplots(figsize=(8,8))
    im = ax.imshow(cm, interpolation='nearest', cmap='Blues')
    ax.set_xticks(range(len(labels))); ax.set_yticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=45, ha='right'); ax.set_yticklabels(labels)
    for i in range(len(labels)):
        for j in range(len(labels)):
            ax.text(j,i, str(cm[i,j]), ha='center', va='center', color='black')
    plt.colorbar(im)
    plt.tight_layout(); plt.savefig(out_path); plt.close(fig)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint')
    parser.add_argument('--test-csv')
    parser.add_argument('--img-root', default='.')
    parser.add_argument('--model', default='swin')
    parser.add_argument('--img-size', default=224)
    parser.add_argument('--class-names')
    parser.add_argument('--out-dir', default='outputs/analysis')
    args = parser.parse_args()
    os.makedirs(args.out_dir, exist_ok=True)
    class_names = [s.strip() for s in args.class_names.split(',')]
    num_classes = len(class_names)
    device = get_device()

    model = get_model(args.model, num_classes, pretrained=False)
    ck = torch.load(args.checkpoint, map_location='cpu')
    model.load_state_dict(ck['model_state_dict'])
    model.to(device); model.eval()

    rows = load_csv(args.test_csv)
    tf = get_transforms('val', args.img_size)
    preds, trues, paths, probs = [], [], [], []
    os.makedirs(os.path.join(args.out_dir,'examples'), exist_ok=True)

    for r in rows:
        img_path = r['image_path'] if os.path.isabs(r['image_path']) else os.path.join(args.img_root, r['image_path'])
        img = Image.open(img_path).convert('RGB')
        t = tf(img).unsqueeze(0).to(device)
        with torch.no_grad():
            out = model(t)
            p = torch.softmax(out, dim=1).cpu().numpy()[0]
            pred = int(p.argmax())
        preds.append(pred); trues.append(int(r['label'])); paths.append(img_path); probs.append(p)

    cm = confusion_matrix(trues, preds)
    p, r, f1, _ = precision_recall_fscore_support(trues, preds, average=None, labels=list(range(num_classes)), zero_division=0)

    # print per-class metrics
    for i,name in enumerate(class_names):
        print(f'{i} {name}: support={(cm[i].sum())}, prec={p[i]:.3f}, rec={r[i]:.3f}, f1={f1[i]:.3f}')
    print('macro-f1:', np.mean(f1))

    # save confusion matrix image
    save_confusion(cm, class_names, os.path.join(args.out_dir,'confusion_matrix.png'))

    # write misclassified csv
    miscsv = os.path.join(args.out_dir,'misclassified.csv')
    with open(miscsv,'w') as f:
        writer = csv.writer(f); writer.writerow(['image_path','true','pred','top1','top2'])
        for path, t, pr, prob in zip(paths,trues,preds,probs):
            if t!=pr:
                top2 = np.argsort(prob)[-2:][::-1].tolist()
                writer.writerow([path, t, pr, int(np.argmax(prob)), int(top2[0])])

    # Save example images for top confused pairs
    # find the biggest off-diagonal cells
    cm_off = cm.copy(); np.fill_diagonal(cm_off, 0)
    flat = [(cm_off[i,j],i,j) for i in range(num_classes) for j in range(num_classes)]
    flat = sorted(flat, reverse=True)
    for count,i,j in flat[:6]:  # top 6 confusion pairs
        if count==0: continue
        pair_dir = os.path.join(args.out_dir, 'examples', f'{i}_to_{j}')
        os.makedirs(pair_dir, exist_ok=True)
        saved=0
        for path,t,pred,prob in zip(paths,trues,preds,probs):
            if t==i and pred==j and saved<10:
                img = Image.open(path).convert('RGB')
                img.save(os.path.join(pair_dir, os.path.basename(path)))
                saved+=1

    print('Saved misclassified list and example images in', args.out_dir)

if __name__=='__main__':
    main()