Spaces:
Runtime error
Runtime error
| # analyze_results.py | |
| import os, sys, csv, argparse, numpy as np, matplotlib.pyplot as plt | |
| from PIL import Image | |
| import torch, torch.nn as nn, torchvision.transforms as T | |
| import timm, torchvision.models as tvmodels | |
| from sklearn.metrics import precision_recall_fscore_support, confusion_matrix | |
| import cv2 | |
| # Add parent directory to path for imports | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) | |
| from src.utils import get_device, get_model, get_transforms | |
| def load_csv(path): | |
| with open(path) as f: | |
| reader = csv.DictReader(f) | |
| return [r for r in reader] | |
| def save_confusion(cm, labels, out_path): | |
| fig, ax = plt.subplots(figsize=(8,8)) | |
| im = ax.imshow(cm, interpolation='nearest', cmap='Blues') | |
| ax.set_xticks(range(len(labels))); ax.set_yticks(range(len(labels))) | |
| ax.set_xticklabels(labels, rotation=45, ha='right'); ax.set_yticklabels(labels) | |
| for i in range(len(labels)): | |
| for j in range(len(labels)): | |
| ax.text(j,i, str(cm[i,j]), ha='center', va='center', color='black') | |
| plt.colorbar(im) | |
| plt.tight_layout(); plt.savefig(out_path); plt.close(fig) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--checkpoint') | |
| parser.add_argument('--test-csv') | |
| parser.add_argument('--img-root', default='.') | |
| parser.add_argument('--model', default='swin') | |
| parser.add_argument('--img-size', default=224) | |
| parser.add_argument('--class-names') | |
| parser.add_argument('--out-dir', default='outputs/analysis') | |
| args = parser.parse_args() | |
| os.makedirs(args.out_dir, exist_ok=True) | |
| class_names = [s.strip() for s in args.class_names.split(',')] | |
| num_classes = len(class_names) | |
| device = get_device() | |
| model = get_model(args.model, num_classes, pretrained=False) | |
| ck = torch.load(args.checkpoint, map_location='cpu') | |
| model.load_state_dict(ck['model_state_dict']) | |
| model.to(device); model.eval() | |
| rows = load_csv(args.test_csv) | |
| tf = get_transforms('val', args.img_size) | |
| preds, trues, paths, probs = [], [], [], [] | |
| os.makedirs(os.path.join(args.out_dir,'examples'), exist_ok=True) | |
| for r in rows: | |
| img_path = r['image_path'] if os.path.isabs(r['image_path']) else os.path.join(args.img_root, r['image_path']) | |
| img = Image.open(img_path).convert('RGB') | |
| t = tf(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| out = model(t) | |
| p = torch.softmax(out, dim=1).cpu().numpy()[0] | |
| pred = int(p.argmax()) | |
| preds.append(pred); trues.append(int(r['label'])); paths.append(img_path); probs.append(p) | |
| cm = confusion_matrix(trues, preds) | |
| p, r, f1, _ = precision_recall_fscore_support(trues, preds, average=None, labels=list(range(num_classes)), zero_division=0) | |
| # print per-class metrics | |
| for i,name in enumerate(class_names): | |
| print(f'{i} {name}: support={(cm[i].sum())}, prec={p[i]:.3f}, rec={r[i]:.3f}, f1={f1[i]:.3f}') | |
| print('macro-f1:', np.mean(f1)) | |
| # save confusion matrix image | |
| save_confusion(cm, class_names, os.path.join(args.out_dir,'confusion_matrix.png')) | |
| # write misclassified csv | |
| miscsv = os.path.join(args.out_dir,'misclassified.csv') | |
| with open(miscsv,'w') as f: | |
| writer = csv.writer(f); writer.writerow(['image_path','true','pred','top1','top2']) | |
| for path, t, pr, prob in zip(paths,trues,preds,probs): | |
| if t!=pr: | |
| top2 = np.argsort(prob)[-2:][::-1].tolist() | |
| writer.writerow([path, t, pr, int(np.argmax(prob)), int(top2[0])]) | |
| # Save example images for top confused pairs | |
| # find the biggest off-diagonal cells | |
| cm_off = cm.copy(); np.fill_diagonal(cm_off, 0) | |
| flat = [(cm_off[i,j],i,j) for i in range(num_classes) for j in range(num_classes)] | |
| flat = sorted(flat, reverse=True) | |
| for count,i,j in flat[:6]: # top 6 confusion pairs | |
| if count==0: continue | |
| pair_dir = os.path.join(args.out_dir, 'examples', f'{i}_to_{j}') | |
| os.makedirs(pair_dir, exist_ok=True) | |
| saved=0 | |
| for path,t,pred,prob in zip(paths,trues,preds,probs): | |
| if t==i and pred==j and saved<10: | |
| img = Image.open(path).convert('RGB') | |
| img.save(os.path.join(pair_dir, os.path.basename(path))) | |
| saved+=1 | |
| print('Saved misclassified list and example images in', args.out_dir) | |
| if __name__=='__main__': | |
| main() | |