""" shared_utils.py --------------- Shared utilities for PersonViT and YOLO26 ablation notebooks. Both notebooks share identical or near-identical logic for: - JSON serialization of numpy types - Results persistence (save/load) - Reproducibility seed setting - Matplotlib palette and style - Radar chart plotting - Grouped bar comparison chart - Model profiling (params, FLOPs, latency) """ import os import json import math import time import random from pathlib import Path import numpy as np import pandas as pd import matplotlib.pyplot as plt from matplotlib import rcParams from tqdm import tqdm import torch import torch.nn.functional as F # --------------------------------------------------------------------------- # 1. PALETTE & PLOT STYLE # --------------------------------------------------------------------------- PALETTE = [ '#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#6A994E', '#BC4B51', '#5F4B8B', '#E08D79', '#1B998B', '#E84855' ] def setup_plot_style(): """Apply the shared Matplotlib style used in all ablation notebooks.""" plt.style.use('seaborn-v0_8-darkgrid') rcParams.update({'font.size': 10, 'axes.titlesize': 12, 'figure.titlesize': 13}) # --------------------------------------------------------------------------- # 2. REPRODUCIBILITY # --------------------------------------------------------------------------- def set_seed(seed: int = 42): """Set all random seeds for full reproducibility across numpy, random, and PyTorch.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) print(f'[seed] All seeds set to {seed}') # --------------------------------------------------------------------------- # 3. JSON SERIALISATION # --------------------------------------------------------------------------- def json_serialisable(v): """ Recursively convert numpy scalars / arrays to Python-native types so that the results dict can be passed to json.dump without errors. This function unifies _json_serialisable() (PersonViT) and _json_safe() (YOLO26), which are functionally identical. """ if isinstance(v, (float, np.floating)): return round(float(v), 6) if isinstance(v, (int, np.integer)): return int(v) if isinstance(v, dict): return {kk: json_serialisable(vv) for kk, vv in v.items()} if isinstance(v, list): return [json_serialisable(x) for x in v] return v # --------------------------------------------------------------------------- # 4. RESULTS REGISTRY PERSISTENCE # --------------------------------------------------------------------------- def save_results(results: dict, results_path: Path, exclude_keys=('history',)): """ Persist a results dict to disk as JSON. Parameters ---------- results : dict — the RESULTS registry to serialise. results_path : Path — destination .json file. exclude_keys : tuple — per-run keys to strip before saving (e.g. 'history' in PersonViT, which is large and not needed for the summary JSON). """ serialisable = { k: json_serialisable({kk: vv for kk, vv in v.items() if kk not in exclude_keys}) for k, v in results.items() } with open(results_path, 'w') as f: json.dump(serialisable, f, indent=2) print(f'[checkpoint] Results saved → {results_path} ({len(results)} runs)') def load_results(results: dict, results_path: Path): """ Reload a previously saved results dict from disk into *results* in-place. Parameters ---------- results : dict — the RESULTS registry to populate. results_path : Path — source .json file. """ if Path(results_path).exists(): with open(results_path) as f: results.update(json.load(f)) print(f'[checkpoint] Loaded {len(results)} runs from {results_path}') else: print('[checkpoint] No previous results found — starting fresh.') # --------------------------------------------------------------------------- # 5. MODEL PROFILING # --------------------------------------------------------------------------- def profile_model(model, img_height: int, img_width: int, device, n_warmup: int = 10, n_runs: int = 100): """ Measure total params, GFLOPs, per-image latency (ms), and throughput (img/s). Unifies profile_model() from PersonViT (uses thop + CUDA events) and profile_model() from YOLO26 (uses thop + perf_counter). Here we use CUDA events when available for higher timing accuracy. Parameters ---------- model : nn.Module — must already be on the correct device. img_height : int img_width : int device : torch.device n_warmup : int — warm-up forward passes before timing. n_runs : int — timed forward passes. Returns ------- dict with keys: total_params, flops_giga, inference_ms, throughput """ try: from thop import profile as thop_profile, clever_format dummy = torch.randn(1, 3, img_height, img_width).to(device) macs, params = thop_profile(model, inputs=(dummy,), verbose=False) _, pfmt = clever_format([macs, params], '%.3f') flops_giga = macs * 2 / 1e9 except Exception: params = sum(p.numel() for p in model.parameters()) flops_giga = 0.0 pfmt = f'{params/1e6:.1f}M' total_params = int(sum(p.numel() for p in model.parameters())) dummy = torch.randn(1, 3, img_height, img_width).to(device) model.eval() with torch.no_grad(): for _ in range(n_warmup): model(dummy) if device.type == 'cuda': torch.cuda.synchronize() start_ev = torch.cuda.Event(enable_timing=True) end_ev = torch.cuda.Event(enable_timing=True) times = [] with torch.no_grad(): for _ in range(n_runs): start_ev.record() model(dummy) end_ev.record() torch.cuda.synchronize() times.append(start_ev.elapsed_time(end_ev)) else: times = [] with torch.no_grad(): for _ in range(n_runs): t0 = time.perf_counter() model(dummy) times.append((time.perf_counter() - t0) * 1000) ms = float(np.mean(times)) print(f' Params : {pfmt} ({total_params:,})') print(f' GFLOPs : {flops_giga:.2f}') print(f' Latency : {ms:.2f} ms/img | {1000/ms:.1f} img/s') return { 'total_params': total_params, 'flops_giga': flops_giga, 'inference_ms': ms, 'throughput': 1000.0 / ms, } # --------------------------------------------------------------------------- # 6. RADAR CHART # --------------------------------------------------------------------------- def radar_chart(results_dict: dict, title: str, metric_labels: list, metric_keys: list, save_path=None): """ Generic 4-axis radar chart shared by both notebooks. Differences between the two notebooks are parameterised: - PersonViT : metric_labels=['mAP','Rank-1','Rank-5','Rank-10'], metric_keys =['mAP','rank1','rank5','rank10'] - YOLO26 : metric_labels=['mAP@50','mAP@50-95','Precision','Recall'], metric_keys =['mAP50','mAP50_95','precision','recall'] Parameters ---------- results_dict : {run_key: metrics_dict} title : chart title string metric_labels : list[str] — axis labels (length N) metric_keys : list[str] — dict keys (length N, values in [0,1]) save_path : str | None """ N = len(metric_labels) angles = [n / float(N) * 2 * np.pi for n in range(N)] angles += angles[:1] # close the polygon fig, ax = plt.subplots(figsize=(7, 7), subplot_kw=dict(polar=True)) ax.set_theta_offset(np.pi / 2) ax.set_theta_direction(-1) ax.set_xticks(angles[:-1]) ax.set_xticklabels(metric_labels, fontsize=11, fontweight='bold') ax.set_ylim(0, 100) ax.set_yticks([20, 40, 60, 80, 100]) ax.set_yticklabels(['20', '40', '60', '80', '100'], fontsize=7, color='grey') for (name, v), color in zip(results_dict.items(), PALETTE): vals = [v.get(mk, 0) * 100 for mk in metric_keys] vals += vals[:1] label = v.get('display_name', name) ax.plot(angles, vals, '-o', lw=2, color=color, ms=6, label=label) ax.fill(angles, vals, alpha=0.08, color=color) ax.set_title(title, fontsize=13, fontweight='bold', pad=20) ax.legend(loc='upper right', bbox_to_anchor=(1.35, 1.15), fontsize=9) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f'Saved → {save_path}') plt.show() # --------------------------------------------------------------------------- # 7. GROUPED BAR COMPARISON CHART # --------------------------------------------------------------------------- def plot_bar_comparison(results_dict: dict, metric_key_a: str, metric_label_a: str, metric_key_b: str, metric_label_b: str, title: str, save_path=None, color_a='#2E86AB', color_b='#F18F01'): """ Generic side-by-side bar chart comparing two metrics across runs. PersonViT → metric_key_a='mAP', metric_label_a='mAP', metric_key_b='rank1', metric_label_b='Rank-1' YOLO26 → metric_key_a='mAP50', metric_label_a='mAP@50', metric_key_b='mAP50_95', metric_label_b='mAP@50-95' Both notebooks multiply raw values by 100 to express as percentages. """ names = list(results_dict.keys()) vals_a = np.array([results_dict[n].get(metric_key_a, 0) * 100 for n in names]) vals_b = np.array([results_dict[n].get(metric_key_b, 0) * 100 for n in names]) fig, ax = plt.subplots(figsize=(max(10, len(names) * 1.6), 6)) x, w = np.arange(len(names)), 0.35 b1 = ax.bar(x - w/2, vals_a, w, label=metric_label_a, color=color_a, alpha=.85, edgecolor='k', lw=.5) b2 = ax.bar(x + w/2, vals_b, w, label=metric_label_b, color=color_b, alpha=.85, edgecolor='k', lw=.5) for b in list(b1) + list(b2): ax.text(b.get_x() + b.get_width() / 2, b.get_height() + .3, f'{b.get_height():.1f}', ha='center', va='bottom', fontsize=7.5) display_names = [results_dict[n].get('display_name', n) for n in names] ax.set(xticks=x, xticklabels=display_names, ylabel='Score (%)', title=f'{title} — {metric_label_a} & {metric_label_b}') plt.xticks(rotation=30, ha='right') ax.legend() plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f'Saved → {save_path}') plt.show()