# analyze_results.py import os, sys, csv, argparse, numpy as np, matplotlib.pyplot as plt from PIL import Image import torch, torch.nn as nn, torchvision.transforms as T import timm, torchvision.models as tvmodels from sklearn.metrics import precision_recall_fscore_support, confusion_matrix import cv2 # Add parent directory to path for imports sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) from src.utils import get_device, get_model, get_transforms def load_csv(path): with open(path) as f: reader = csv.DictReader(f) return [r for r in reader] def save_confusion(cm, labels, out_path): fig, ax = plt.subplots(figsize=(8,8)) im = ax.imshow(cm, interpolation='nearest', cmap='Blues') ax.set_xticks(range(len(labels))); ax.set_yticks(range(len(labels))) ax.set_xticklabels(labels, rotation=45, ha='right'); ax.set_yticklabels(labels) for i in range(len(labels)): for j in range(len(labels)): ax.text(j,i, str(cm[i,j]), ha='center', va='center', color='black') plt.colorbar(im) plt.tight_layout(); plt.savefig(out_path); plt.close(fig) def main(): parser = argparse.ArgumentParser() parser.add_argument('--checkpoint') parser.add_argument('--test-csv') parser.add_argument('--img-root', default='.') parser.add_argument('--model', default='swin') parser.add_argument('--img-size', default=224) parser.add_argument('--class-names') parser.add_argument('--out-dir', default='outputs/analysis') args = parser.parse_args() os.makedirs(args.out_dir, exist_ok=True) class_names = [s.strip() for s in args.class_names.split(',')] num_classes = len(class_names) device = get_device() model = get_model(args.model, num_classes, pretrained=False) ck = torch.load(args.checkpoint, map_location='cpu') model.load_state_dict(ck['model_state_dict']) model.to(device); model.eval() rows = load_csv(args.test_csv) tf = get_transforms('val', args.img_size) preds, trues, paths, probs = [], [], [], [] os.makedirs(os.path.join(args.out_dir,'examples'), exist_ok=True) for r in rows: img_path = r['image_path'] if os.path.isabs(r['image_path']) else os.path.join(args.img_root, r['image_path']) img = Image.open(img_path).convert('RGB') t = tf(img).unsqueeze(0).to(device) with torch.no_grad(): out = model(t) p = torch.softmax(out, dim=1).cpu().numpy()[0] pred = int(p.argmax()) preds.append(pred); trues.append(int(r['label'])); paths.append(img_path); probs.append(p) cm = confusion_matrix(trues, preds) p, r, f1, _ = precision_recall_fscore_support(trues, preds, average=None, labels=list(range(num_classes)), zero_division=0) # print per-class metrics for i,name in enumerate(class_names): print(f'{i} {name}: support={(cm[i].sum())}, prec={p[i]:.3f}, rec={r[i]:.3f}, f1={f1[i]:.3f}') print('macro-f1:', np.mean(f1)) # save confusion matrix image save_confusion(cm, class_names, os.path.join(args.out_dir,'confusion_matrix.png')) # write misclassified csv miscsv = os.path.join(args.out_dir,'misclassified.csv') with open(miscsv,'w') as f: writer = csv.writer(f); writer.writerow(['image_path','true','pred','top1','top2']) for path, t, pr, prob in zip(paths,trues,preds,probs): if t!=pr: top2 = np.argsort(prob)[-2:][::-1].tolist() writer.writerow([path, t, pr, int(np.argmax(prob)), int(top2[0])]) # Save example images for top confused pairs # find the biggest off-diagonal cells cm_off = cm.copy(); np.fill_diagonal(cm_off, 0) flat = [(cm_off[i,j],i,j) for i in range(num_classes) for j in range(num_classes)] flat = sorted(flat, reverse=True) for count,i,j in flat[:6]: # top 6 confusion pairs if count==0: continue pair_dir = os.path.join(args.out_dir, 'examples', f'{i}_to_{j}') os.makedirs(pair_dir, exist_ok=True) saved=0 for path,t,pred,prob in zip(paths,trues,preds,probs): if t==i and pred==j and saved<10: img = Image.open(path).convert('RGB') img.save(os.path.join(pair_dir, os.path.basename(path))) saved+=1 print('Saved misclassified list and example images in', args.out_dir) if __name__=='__main__': main()