"""Stage 2: attention-head pruning. Single-head ablation sweep on EUPE-ViT-B. For each of the 144 (block, head) pairs, zero the columns of that block's attention output projection that correspond to that head, run a calibration batch through the full pipeline end-to-end, and record: - L2 deviation of the 100 target output dims vs unablated baseline - F1 on COCO val 2017 for the Stage 0 person classifier Sort heads by F1 impact. Sweep the cumulative pruning curve: how many heads can be zeroed before F1 drops by 0.01 / 0.02 / 0.05. Output: head_importance.json, pruning_curve.json. """ import os, sys, json, time import copy import numpy as np import torch import torch.nn.functional as F from PIL import Image from pycocotools.coco import COCO from transformers import AutoModel sys.path.insert(0, '/mnt/d/Argus') COCO_ROOT = '/home/zootest/datasets/coco' VAL_CACHE = f'{COCO_ROOT}/val_feature_cache_768/val.pt' STAGE0_CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json' RES = 768 D = 768 N_BLOCKS = 12 N_HEADS = 12 HEAD_DIM = D // N_HEADS # 64 N_CALIBRATION = 1000 # COCO val images used for the sweep OUT_DIR = '/mnt/d/_tmp/1pc_repo/stage_2' DEVICE = 'cuda' def load_classifier(): with open(STAGE0_CLASSIFIER) as f: c = json.load(f) pos = torch.tensor(c['pos_dims'], dtype=torch.long, device=DEVICE) neg = torch.tensor(c['neg_dims'], dtype=torch.long, device=DEVICE) thr = float(c['threshold']) target_dims = torch.cat([pos, neg]).unique() return pos, neg, thr, target_dims @torch.inference_mode() def score_images(argus, img_tensors, pos, neg): """Return (N,) classifier scores for a batch of pre-normalized images.""" scores = [] for x in img_tensors: with torch.autocast('cuda', dtype=torch.bfloat16): out = argus.backbone.forward_features(x) patches = out['x_norm_patchtokens'].float().squeeze(0) ln = F.layer_norm(patches, [D]) pooled = ln.max(dim=0).values scores.append((pooled[pos].sum() - pooled[neg].sum()).item()) return torch.tensor(scores) @torch.inference_mode() def pooled_targets(argus, img_tensors, target_dims): """Return (N, |target_dims|) pooled layer-normed features at the target dims.""" outs = [] for x in img_tensors: with torch.autocast('cuda', dtype=torch.bfloat16): out = argus.backbone.forward_features(x) patches = out['x_norm_patchtokens'].float().squeeze(0) ln = F.layer_norm(patches, [D]) pooled = ln.max(dim=0).values outs.append(pooled[target_dims]) return torch.stack(outs) def load_calibration(coco, n, MEAN, STD): img_ids = sorted(coco.getImgIds())[:n] labels = [] tensors = [] for img_id in img_ids: info = coco.loadImgs(img_id)[0] path = f"{COCO_ROOT}/val2017/{info['file_name']}" img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR) arr = np.asarray(img, dtype=np.uint8).copy() x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0 tensors.append((x - MEAN) / STD) ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False) labels.append(any(a['category_id'] == 1 for a in coco.loadAnns(ann_ids))) return tensors, torch.tensor(labels, dtype=torch.bool, device=DEVICE) def compute_f1(scores, labels, thr): pred = scores > thr tp = (pred & labels).sum().float() fp = (pred & ~labels).sum().float() fn = (~pred & labels).sum().float() prec = tp / (tp + fp).clamp(min=1) rec = tp / (tp + fn).clamp(min=1) f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9) return float(f1), float(prec), float(rec) def main(): os.makedirs(OUT_DIR, exist_ok=True) print('[init] loading Argus', flush=True) argus = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True).to(DEVICE).eval() pos, neg, thr, target_dims = load_classifier() print(f' |target_dims|={len(target_dims)} threshold={thr:.3f}', flush=True) MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() print(f'[calib] loading {N_CALIBRATION} COCO val images', flush=True) coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json') imgs, labels = load_calibration(coco, N_CALIBRATION, MEAN, STD) pos_rate = labels.float().mean().item() print(f' loaded. person_rate in calib = {pos_rate:.3f}', flush=True) print('[baseline] scoring without ablation', flush=True) t0 = time.time() base_scores = score_images(argus, imgs, pos, neg).to(DEVICE) base_targets = pooled_targets(argus, imgs, target_dims) base_f1, base_p, base_r = compute_f1(base_scores, labels, thr) print(f' baseline F1={base_f1:.4f} P={base_p:.4f} R={base_r:.4f} ' f'({len(imgs)/(time.time()-t0):.1f} img/s)', flush=True) # Store original proj.weight per block for quick restore orig_weights = {} for b in range(N_BLOCKS): w = argus.backbone.blocks[b].attn.proj.weight orig_weights[b] = w.detach().clone() # Per-head ablation sweep print(f'[sweep] 144 head ablations', flush=True) results = [] for b in range(N_BLOCKS): for h in range(N_HEADS): t_h = time.time() w = argus.backbone.blocks[b].attn.proj.weight with torch.no_grad(): w.data[:, h * HEAD_DIM:(h + 1) * HEAD_DIM] = 0.0 scores = score_images(argus, imgs, pos, neg).to(DEVICE) targets = pooled_targets(argus, imgs, target_dims) with torch.no_grad(): w.data.copy_(orig_weights[b]) f1, p, r = compute_f1(scores, labels, thr) l2 = (targets - base_targets).pow(2).sum(dim=1).sqrt().mean().item() drop = base_f1 - f1 results.append({ 'block': b, 'head': h, 'F1': f1, 'precision': p, 'recall': r, 'F1_drop': drop, 'target_L2': l2, }) print(f' B{b:>2}H{h:>2} F1={f1:.4f} drop={drop:+.4f} ' f'L2={l2:.3f} {time.time()-t_h:.1f}s', flush=True) # Rank by F1 drop (smallest drop = most prunable) ranked = sorted(results, key=lambda r: r['F1_drop']) # Cumulative pruning curve: prune the K heads with smallest F1 drop, measure F1 print(f'[curve] cumulative pruning (heads ranked by smallest individual drop)', flush=True) # Backup all proj weights backup = {b: argus.backbone.blocks[b].attn.proj.weight.detach().clone() for b in range(N_BLOCKS)} curve = [] for K in [1, 5, 10, 15, 20, 30, 40, 50, 60, 80, 100, 120, 144]: # Restore for b in range(N_BLOCKS): argus.backbone.blocks[b].attn.proj.weight.data.copy_(backup[b]) # Zero the top-K most prunable heads for r in ranked[:K]: b, h = r['block'], r['head'] with torch.no_grad(): argus.backbone.blocks[b].attn.proj.weight.data[:, h*HEAD_DIM:(h+1)*HEAD_DIM] = 0.0 scores = score_images(argus, imgs, pos, neg).to(DEVICE) f1, p, r_ = compute_f1(scores, labels, thr) curve.append({'heads_pruned': K, 'F1': f1, 'F1_drop': base_f1 - f1, 'precision': p, 'recall': r_}) print(f' K={K:>3} pruned F1={f1:.4f} drop={base_f1-f1:+.4f}', flush=True) # Final restore for b in range(N_BLOCKS): argus.backbone.blocks[b].attn.proj.weight.data.copy_(backup[b]) with open(f'{OUT_DIR}/head_importance.json', 'w') as f: json.dump({'baseline_F1': base_f1, 'baseline_P': base_p, 'baseline_R': base_r, 'n_calibration': N_CALIBRATION, 'per_head': results, 'ranked_most_prunable_first': [(r['block'], r['head'], r['F1_drop']) for r in ranked]}, f, indent=2) with open(f'{OUT_DIR}/pruning_curve.json', 'w') as f: json.dump({'baseline_F1': base_f1, 'curve': curve}, f, indent=2) print(f'[done] results -> {OUT_DIR}', flush=True) if __name__ == '__main__': main()