| """Stage 2: attention-head pruning. |
| |
| Single-head ablation sweep on EUPE-ViT-B. For each of the 144 (block, head) |
| pairs, zero the columns of that block's attention output projection that |
| correspond to that head, run a calibration batch through the full pipeline |
| end-to-end, and record: |
| - L2 deviation of the 100 target output dims vs unablated baseline |
| - F1 on COCO val 2017 for the Stage 0 person classifier |
| |
| Sort heads by F1 impact. Sweep the cumulative pruning curve: how many |
| heads can be zeroed before F1 drops by 0.01 / 0.02 / 0.05. |
| |
| Output: head_importance.json, pruning_curve.json. |
| """ |
| import os, sys, json, time |
| import copy |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from PIL import Image |
| from pycocotools.coco import COCO |
| from transformers import AutoModel |
|
|
| sys.path.insert(0, '/mnt/d/Argus') |
|
|
| COCO_ROOT = '/home/zootest/datasets/coco' |
| VAL_CACHE = f'{COCO_ROOT}/val_feature_cache_768/val.pt' |
| STAGE0_CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json' |
| RES = 768 |
| D = 768 |
| N_BLOCKS = 12 |
| N_HEADS = 12 |
| HEAD_DIM = D // N_HEADS |
| N_CALIBRATION = 1000 |
| OUT_DIR = '/mnt/d/_tmp/1pc_repo/stage_2' |
|
|
| DEVICE = 'cuda' |
|
|
|
|
| def load_classifier(): |
| with open(STAGE0_CLASSIFIER) as f: |
| c = json.load(f) |
| pos = torch.tensor(c['pos_dims'], dtype=torch.long, device=DEVICE) |
| neg = torch.tensor(c['neg_dims'], dtype=torch.long, device=DEVICE) |
| thr = float(c['threshold']) |
| target_dims = torch.cat([pos, neg]).unique() |
| return pos, neg, thr, target_dims |
|
|
|
|
| @torch.inference_mode() |
| def score_images(argus, img_tensors, pos, neg): |
| """Return (N,) classifier scores for a batch of pre-normalized images.""" |
| scores = [] |
| for x in img_tensors: |
| with torch.autocast('cuda', dtype=torch.bfloat16): |
| out = argus.backbone.forward_features(x) |
| patches = out['x_norm_patchtokens'].float().squeeze(0) |
| ln = F.layer_norm(patches, [D]) |
| pooled = ln.max(dim=0).values |
| scores.append((pooled[pos].sum() - pooled[neg].sum()).item()) |
| return torch.tensor(scores) |
|
|
|
|
| @torch.inference_mode() |
| def pooled_targets(argus, img_tensors, target_dims): |
| """Return (N, |target_dims|) pooled layer-normed features at the target dims.""" |
| outs = [] |
| for x in img_tensors: |
| with torch.autocast('cuda', dtype=torch.bfloat16): |
| out = argus.backbone.forward_features(x) |
| patches = out['x_norm_patchtokens'].float().squeeze(0) |
| ln = F.layer_norm(patches, [D]) |
| pooled = ln.max(dim=0).values |
| outs.append(pooled[target_dims]) |
| return torch.stack(outs) |
|
|
|
|
| def load_calibration(coco, n, MEAN, STD): |
| img_ids = sorted(coco.getImgIds())[:n] |
| labels = [] |
| tensors = [] |
| for img_id in img_ids: |
| info = coco.loadImgs(img_id)[0] |
| path = f"{COCO_ROOT}/val2017/{info['file_name']}" |
| img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR) |
| arr = np.asarray(img, dtype=np.uint8).copy() |
| x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0 |
| tensors.append((x - MEAN) / STD) |
| ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False) |
| labels.append(any(a['category_id'] == 1 for a in coco.loadAnns(ann_ids))) |
| return tensors, torch.tensor(labels, dtype=torch.bool, device=DEVICE) |
|
|
|
|
| def compute_f1(scores, labels, thr): |
| pred = scores > thr |
| tp = (pred & labels).sum().float() |
| fp = (pred & ~labels).sum().float() |
| fn = (~pred & labels).sum().float() |
| prec = tp / (tp + fp).clamp(min=1) |
| rec = tp / (tp + fn).clamp(min=1) |
| f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9) |
| return float(f1), float(prec), float(rec) |
|
|
|
|
| def main(): |
| os.makedirs(OUT_DIR, exist_ok=True) |
|
|
| print('[init] loading Argus', flush=True) |
| argus = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True).to(DEVICE).eval() |
|
|
| pos, neg, thr, target_dims = load_classifier() |
| print(f' |target_dims|={len(target_dims)} threshold={thr:.3f}', flush=True) |
|
|
| MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() |
| STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() |
|
|
| print(f'[calib] loading {N_CALIBRATION} COCO val images', flush=True) |
| coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json') |
| imgs, labels = load_calibration(coco, N_CALIBRATION, MEAN, STD) |
| pos_rate = labels.float().mean().item() |
| print(f' loaded. person_rate in calib = {pos_rate:.3f}', flush=True) |
|
|
| print('[baseline] scoring without ablation', flush=True) |
| t0 = time.time() |
| base_scores = score_images(argus, imgs, pos, neg).to(DEVICE) |
| base_targets = pooled_targets(argus, imgs, target_dims) |
| base_f1, base_p, base_r = compute_f1(base_scores, labels, thr) |
| print(f' baseline F1={base_f1:.4f} P={base_p:.4f} R={base_r:.4f} ' |
| f'({len(imgs)/(time.time()-t0):.1f} img/s)', flush=True) |
|
|
| |
| orig_weights = {} |
| for b in range(N_BLOCKS): |
| w = argus.backbone.blocks[b].attn.proj.weight |
| orig_weights[b] = w.detach().clone() |
|
|
| |
| print(f'[sweep] 144 head ablations', flush=True) |
| results = [] |
| for b in range(N_BLOCKS): |
| for h in range(N_HEADS): |
| t_h = time.time() |
| w = argus.backbone.blocks[b].attn.proj.weight |
| with torch.no_grad(): |
| w.data[:, h * HEAD_DIM:(h + 1) * HEAD_DIM] = 0.0 |
| scores = score_images(argus, imgs, pos, neg).to(DEVICE) |
| targets = pooled_targets(argus, imgs, target_dims) |
| with torch.no_grad(): |
| w.data.copy_(orig_weights[b]) |
| f1, p, r = compute_f1(scores, labels, thr) |
| l2 = (targets - base_targets).pow(2).sum(dim=1).sqrt().mean().item() |
| drop = base_f1 - f1 |
| results.append({ |
| 'block': b, 'head': h, 'F1': f1, 'precision': p, 'recall': r, |
| 'F1_drop': drop, 'target_L2': l2, |
| }) |
| print(f' B{b:>2}H{h:>2} F1={f1:.4f} drop={drop:+.4f} ' |
| f'L2={l2:.3f} {time.time()-t_h:.1f}s', flush=True) |
|
|
| |
| ranked = sorted(results, key=lambda r: r['F1_drop']) |
|
|
| |
| print(f'[curve] cumulative pruning (heads ranked by smallest individual drop)', flush=True) |
| |
| backup = {b: argus.backbone.blocks[b].attn.proj.weight.detach().clone() for b in range(N_BLOCKS)} |
| curve = [] |
| for K in [1, 5, 10, 15, 20, 30, 40, 50, 60, 80, 100, 120, 144]: |
| |
| for b in range(N_BLOCKS): |
| argus.backbone.blocks[b].attn.proj.weight.data.copy_(backup[b]) |
| |
| for r in ranked[:K]: |
| b, h = r['block'], r['head'] |
| with torch.no_grad(): |
| argus.backbone.blocks[b].attn.proj.weight.data[:, h*HEAD_DIM:(h+1)*HEAD_DIM] = 0.0 |
| scores = score_images(argus, imgs, pos, neg).to(DEVICE) |
| f1, p, r_ = compute_f1(scores, labels, thr) |
| curve.append({'heads_pruned': K, 'F1': f1, 'F1_drop': base_f1 - f1, |
| 'precision': p, 'recall': r_}) |
| print(f' K={K:>3} pruned F1={f1:.4f} drop={base_f1-f1:+.4f}', flush=True) |
|
|
| |
| for b in range(N_BLOCKS): |
| argus.backbone.blocks[b].attn.proj.weight.data.copy_(backup[b]) |
|
|
| with open(f'{OUT_DIR}/head_importance.json', 'w') as f: |
| json.dump({'baseline_F1': base_f1, 'baseline_P': base_p, 'baseline_R': base_r, |
| 'n_calibration': N_CALIBRATION, 'per_head': results, |
| 'ranked_most_prunable_first': [(r['block'], r['head'], r['F1_drop']) |
| for r in ranked]}, f, indent=2) |
| with open(f'{OUT_DIR}/pruning_curve.json', 'w') as f: |
| json.dump({'baseline_F1': base_f1, 'curve': curve}, f, indent=2) |
|
|
| print(f'[done] results -> {OUT_DIR}', flush=True) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|