File size: 8,137 Bytes
a7e09b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """Stage 2: attention-head pruning.
Single-head ablation sweep on EUPE-ViT-B. For each of the 144 (block, head)
pairs, zero the columns of that block's attention output projection that
correspond to that head, run a calibration batch through the full pipeline
end-to-end, and record:
- L2 deviation of the 100 target output dims vs unablated baseline
- F1 on COCO val 2017 for the Stage 0 person classifier
Sort heads by F1 impact. Sweep the cumulative pruning curve: how many
heads can be zeroed before F1 drops by 0.01 / 0.02 / 0.05.
Output: head_importance.json, pruning_curve.json.
"""
import os, sys, json, time
import copy
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from pycocotools.coco import COCO
from transformers import AutoModel
sys.path.insert(0, '/mnt/d/Argus')
COCO_ROOT = '/home/zootest/datasets/coco'
VAL_CACHE = f'{COCO_ROOT}/val_feature_cache_768/val.pt'
STAGE0_CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json'
RES = 768
D = 768
N_BLOCKS = 12
N_HEADS = 12
HEAD_DIM = D // N_HEADS # 64
N_CALIBRATION = 1000 # COCO val images used for the sweep
OUT_DIR = '/mnt/d/_tmp/1pc_repo/stage_2'
DEVICE = 'cuda'
def load_classifier():
with open(STAGE0_CLASSIFIER) as f:
c = json.load(f)
pos = torch.tensor(c['pos_dims'], dtype=torch.long, device=DEVICE)
neg = torch.tensor(c['neg_dims'], dtype=torch.long, device=DEVICE)
thr = float(c['threshold'])
target_dims = torch.cat([pos, neg]).unique()
return pos, neg, thr, target_dims
@torch.inference_mode()
def score_images(argus, img_tensors, pos, neg):
"""Return (N,) classifier scores for a batch of pre-normalized images."""
scores = []
for x in img_tensors:
with torch.autocast('cuda', dtype=torch.bfloat16):
out = argus.backbone.forward_features(x)
patches = out['x_norm_patchtokens'].float().squeeze(0)
ln = F.layer_norm(patches, [D])
pooled = ln.max(dim=0).values
scores.append((pooled[pos].sum() - pooled[neg].sum()).item())
return torch.tensor(scores)
@torch.inference_mode()
def pooled_targets(argus, img_tensors, target_dims):
"""Return (N, |target_dims|) pooled layer-normed features at the target dims."""
outs = []
for x in img_tensors:
with torch.autocast('cuda', dtype=torch.bfloat16):
out = argus.backbone.forward_features(x)
patches = out['x_norm_patchtokens'].float().squeeze(0)
ln = F.layer_norm(patches, [D])
pooled = ln.max(dim=0).values
outs.append(pooled[target_dims])
return torch.stack(outs)
def load_calibration(coco, n, MEAN, STD):
img_ids = sorted(coco.getImgIds())[:n]
labels = []
tensors = []
for img_id in img_ids:
info = coco.loadImgs(img_id)[0]
path = f"{COCO_ROOT}/val2017/{info['file_name']}"
img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR)
arr = np.asarray(img, dtype=np.uint8).copy()
x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0
tensors.append((x - MEAN) / STD)
ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
labels.append(any(a['category_id'] == 1 for a in coco.loadAnns(ann_ids)))
return tensors, torch.tensor(labels, dtype=torch.bool, device=DEVICE)
def compute_f1(scores, labels, thr):
pred = scores > thr
tp = (pred & labels).sum().float()
fp = (pred & ~labels).sum().float()
fn = (~pred & labels).sum().float()
prec = tp / (tp + fp).clamp(min=1)
rec = tp / (tp + fn).clamp(min=1)
f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9)
return float(f1), float(prec), float(rec)
def main():
os.makedirs(OUT_DIR, exist_ok=True)
print('[init] loading Argus', flush=True)
argus = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True).to(DEVICE).eval()
pos, neg, thr, target_dims = load_classifier()
print(f' |target_dims|={len(target_dims)} threshold={thr:.3f}', flush=True)
MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
print(f'[calib] loading {N_CALIBRATION} COCO val images', flush=True)
coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json')
imgs, labels = load_calibration(coco, N_CALIBRATION, MEAN, STD)
pos_rate = labels.float().mean().item()
print(f' loaded. person_rate in calib = {pos_rate:.3f}', flush=True)
print('[baseline] scoring without ablation', flush=True)
t0 = time.time()
base_scores = score_images(argus, imgs, pos, neg).to(DEVICE)
base_targets = pooled_targets(argus, imgs, target_dims)
base_f1, base_p, base_r = compute_f1(base_scores, labels, thr)
print(f' baseline F1={base_f1:.4f} P={base_p:.4f} R={base_r:.4f} '
f'({len(imgs)/(time.time()-t0):.1f} img/s)', flush=True)
# Store original proj.weight per block for quick restore
orig_weights = {}
for b in range(N_BLOCKS):
w = argus.backbone.blocks[b].attn.proj.weight
orig_weights[b] = w.detach().clone()
# Per-head ablation sweep
print(f'[sweep] 144 head ablations', flush=True)
results = []
for b in range(N_BLOCKS):
for h in range(N_HEADS):
t_h = time.time()
w = argus.backbone.blocks[b].attn.proj.weight
with torch.no_grad():
w.data[:, h * HEAD_DIM:(h + 1) * HEAD_DIM] = 0.0
scores = score_images(argus, imgs, pos, neg).to(DEVICE)
targets = pooled_targets(argus, imgs, target_dims)
with torch.no_grad():
w.data.copy_(orig_weights[b])
f1, p, r = compute_f1(scores, labels, thr)
l2 = (targets - base_targets).pow(2).sum(dim=1).sqrt().mean().item()
drop = base_f1 - f1
results.append({
'block': b, 'head': h, 'F1': f1, 'precision': p, 'recall': r,
'F1_drop': drop, 'target_L2': l2,
})
print(f' B{b:>2}H{h:>2} F1={f1:.4f} drop={drop:+.4f} '
f'L2={l2:.3f} {time.time()-t_h:.1f}s', flush=True)
# Rank by F1 drop (smallest drop = most prunable)
ranked = sorted(results, key=lambda r: r['F1_drop'])
# Cumulative pruning curve: prune the K heads with smallest F1 drop, measure F1
print(f'[curve] cumulative pruning (heads ranked by smallest individual drop)', flush=True)
# Backup all proj weights
backup = {b: argus.backbone.blocks[b].attn.proj.weight.detach().clone() for b in range(N_BLOCKS)}
curve = []
for K in [1, 5, 10, 15, 20, 30, 40, 50, 60, 80, 100, 120, 144]:
# Restore
for b in range(N_BLOCKS):
argus.backbone.blocks[b].attn.proj.weight.data.copy_(backup[b])
# Zero the top-K most prunable heads
for r in ranked[:K]:
b, h = r['block'], r['head']
with torch.no_grad():
argus.backbone.blocks[b].attn.proj.weight.data[:, h*HEAD_DIM:(h+1)*HEAD_DIM] = 0.0
scores = score_images(argus, imgs, pos, neg).to(DEVICE)
f1, p, r_ = compute_f1(scores, labels, thr)
curve.append({'heads_pruned': K, 'F1': f1, 'F1_drop': base_f1 - f1,
'precision': p, 'recall': r_})
print(f' K={K:>3} pruned F1={f1:.4f} drop={base_f1-f1:+.4f}', flush=True)
# Final restore
for b in range(N_BLOCKS):
argus.backbone.blocks[b].attn.proj.weight.data.copy_(backup[b])
with open(f'{OUT_DIR}/head_importance.json', 'w') as f:
json.dump({'baseline_F1': base_f1, 'baseline_P': base_p, 'baseline_R': base_r,
'n_calibration': N_CALIBRATION, 'per_head': results,
'ranked_most_prunable_first': [(r['block'], r['head'], r['F1_drop'])
for r in ranked]}, f, indent=2)
with open(f'{OUT_DIR}/pruning_curve.json', 'w') as f:
json.dump({'baseline_F1': base_f1, 'curve': curve}, f, indent=2)
print(f'[done] results -> {OUT_DIR}', flush=True)
if __name__ == '__main__':
main()
|