1-parameter-classifier / stage_2 /head_ablation.py
phanerozoic's picture
Stage 2: attention-head pruning results + mask + apply_mask.py
a7e09b2 verified
"""Stage 2: attention-head pruning.
Single-head ablation sweep on EUPE-ViT-B. For each of the 144 (block, head)
pairs, zero the columns of that block's attention output projection that
correspond to that head, run a calibration batch through the full pipeline
end-to-end, and record:
- L2 deviation of the 100 target output dims vs unablated baseline
- F1 on COCO val 2017 for the Stage 0 person classifier
Sort heads by F1 impact. Sweep the cumulative pruning curve: how many
heads can be zeroed before F1 drops by 0.01 / 0.02 / 0.05.
Output: head_importance.json, pruning_curve.json.
"""
import os, sys, json, time
import copy
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from pycocotools.coco import COCO
from transformers import AutoModel
sys.path.insert(0, '/mnt/d/Argus')
COCO_ROOT = '/home/zootest/datasets/coco'
VAL_CACHE = f'{COCO_ROOT}/val_feature_cache_768/val.pt'
STAGE0_CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json'
RES = 768
D = 768
N_BLOCKS = 12
N_HEADS = 12
HEAD_DIM = D // N_HEADS # 64
N_CALIBRATION = 1000 # COCO val images used for the sweep
OUT_DIR = '/mnt/d/_tmp/1pc_repo/stage_2'
DEVICE = 'cuda'
def load_classifier():
with open(STAGE0_CLASSIFIER) as f:
c = json.load(f)
pos = torch.tensor(c['pos_dims'], dtype=torch.long, device=DEVICE)
neg = torch.tensor(c['neg_dims'], dtype=torch.long, device=DEVICE)
thr = float(c['threshold'])
target_dims = torch.cat([pos, neg]).unique()
return pos, neg, thr, target_dims
@torch.inference_mode()
def score_images(argus, img_tensors, pos, neg):
"""Return (N,) classifier scores for a batch of pre-normalized images."""
scores = []
for x in img_tensors:
with torch.autocast('cuda', dtype=torch.bfloat16):
out = argus.backbone.forward_features(x)
patches = out['x_norm_patchtokens'].float().squeeze(0)
ln = F.layer_norm(patches, [D])
pooled = ln.max(dim=0).values
scores.append((pooled[pos].sum() - pooled[neg].sum()).item())
return torch.tensor(scores)
@torch.inference_mode()
def pooled_targets(argus, img_tensors, target_dims):
"""Return (N, |target_dims|) pooled layer-normed features at the target dims."""
outs = []
for x in img_tensors:
with torch.autocast('cuda', dtype=torch.bfloat16):
out = argus.backbone.forward_features(x)
patches = out['x_norm_patchtokens'].float().squeeze(0)
ln = F.layer_norm(patches, [D])
pooled = ln.max(dim=0).values
outs.append(pooled[target_dims])
return torch.stack(outs)
def load_calibration(coco, n, MEAN, STD):
img_ids = sorted(coco.getImgIds())[:n]
labels = []
tensors = []
for img_id in img_ids:
info = coco.loadImgs(img_id)[0]
path = f"{COCO_ROOT}/val2017/{info['file_name']}"
img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR)
arr = np.asarray(img, dtype=np.uint8).copy()
x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0
tensors.append((x - MEAN) / STD)
ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
labels.append(any(a['category_id'] == 1 for a in coco.loadAnns(ann_ids)))
return tensors, torch.tensor(labels, dtype=torch.bool, device=DEVICE)
def compute_f1(scores, labels, thr):
pred = scores > thr
tp = (pred & labels).sum().float()
fp = (pred & ~labels).sum().float()
fn = (~pred & labels).sum().float()
prec = tp / (tp + fp).clamp(min=1)
rec = tp / (tp + fn).clamp(min=1)
f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9)
return float(f1), float(prec), float(rec)
def main():
os.makedirs(OUT_DIR, exist_ok=True)
print('[init] loading Argus', flush=True)
argus = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True).to(DEVICE).eval()
pos, neg, thr, target_dims = load_classifier()
print(f' |target_dims|={len(target_dims)} threshold={thr:.3f}', flush=True)
MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
print(f'[calib] loading {N_CALIBRATION} COCO val images', flush=True)
coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json')
imgs, labels = load_calibration(coco, N_CALIBRATION, MEAN, STD)
pos_rate = labels.float().mean().item()
print(f' loaded. person_rate in calib = {pos_rate:.3f}', flush=True)
print('[baseline] scoring without ablation', flush=True)
t0 = time.time()
base_scores = score_images(argus, imgs, pos, neg).to(DEVICE)
base_targets = pooled_targets(argus, imgs, target_dims)
base_f1, base_p, base_r = compute_f1(base_scores, labels, thr)
print(f' baseline F1={base_f1:.4f} P={base_p:.4f} R={base_r:.4f} '
f'({len(imgs)/(time.time()-t0):.1f} img/s)', flush=True)
# Store original proj.weight per block for quick restore
orig_weights = {}
for b in range(N_BLOCKS):
w = argus.backbone.blocks[b].attn.proj.weight
orig_weights[b] = w.detach().clone()
# Per-head ablation sweep
print(f'[sweep] 144 head ablations', flush=True)
results = []
for b in range(N_BLOCKS):
for h in range(N_HEADS):
t_h = time.time()
w = argus.backbone.blocks[b].attn.proj.weight
with torch.no_grad():
w.data[:, h * HEAD_DIM:(h + 1) * HEAD_DIM] = 0.0
scores = score_images(argus, imgs, pos, neg).to(DEVICE)
targets = pooled_targets(argus, imgs, target_dims)
with torch.no_grad():
w.data.copy_(orig_weights[b])
f1, p, r = compute_f1(scores, labels, thr)
l2 = (targets - base_targets).pow(2).sum(dim=1).sqrt().mean().item()
drop = base_f1 - f1
results.append({
'block': b, 'head': h, 'F1': f1, 'precision': p, 'recall': r,
'F1_drop': drop, 'target_L2': l2,
})
print(f' B{b:>2}H{h:>2} F1={f1:.4f} drop={drop:+.4f} '
f'L2={l2:.3f} {time.time()-t_h:.1f}s', flush=True)
# Rank by F1 drop (smallest drop = most prunable)
ranked = sorted(results, key=lambda r: r['F1_drop'])
# Cumulative pruning curve: prune the K heads with smallest F1 drop, measure F1
print(f'[curve] cumulative pruning (heads ranked by smallest individual drop)', flush=True)
# Backup all proj weights
backup = {b: argus.backbone.blocks[b].attn.proj.weight.detach().clone() for b in range(N_BLOCKS)}
curve = []
for K in [1, 5, 10, 15, 20, 30, 40, 50, 60, 80, 100, 120, 144]:
# Restore
for b in range(N_BLOCKS):
argus.backbone.blocks[b].attn.proj.weight.data.copy_(backup[b])
# Zero the top-K most prunable heads
for r in ranked[:K]:
b, h = r['block'], r['head']
with torch.no_grad():
argus.backbone.blocks[b].attn.proj.weight.data[:, h*HEAD_DIM:(h+1)*HEAD_DIM] = 0.0
scores = score_images(argus, imgs, pos, neg).to(DEVICE)
f1, p, r_ = compute_f1(scores, labels, thr)
curve.append({'heads_pruned': K, 'F1': f1, 'F1_drop': base_f1 - f1,
'precision': p, 'recall': r_})
print(f' K={K:>3} pruned F1={f1:.4f} drop={base_f1-f1:+.4f}', flush=True)
# Final restore
for b in range(N_BLOCKS):
argus.backbone.blocks[b].attn.proj.weight.data.copy_(backup[b])
with open(f'{OUT_DIR}/head_importance.json', 'w') as f:
json.dump({'baseline_F1': base_f1, 'baseline_P': base_p, 'baseline_R': base_r,
'n_calibration': N_CALIBRATION, 'per_head': results,
'ranked_most_prunable_first': [(r['block'], r['head'], r['F1_drop'])
for r in ranked]}, f, indent=2)
with open(f'{OUT_DIR}/pruning_curve.json', 'w') as f:
json.dump({'baseline_F1': base_f1, 'curve': curve}, f, indent=2)
print(f'[done] results -> {OUT_DIR}', flush=True)
if __name__ == '__main__':
main()