1-parameter-classifier / stage_3 /block_ablation.py
phanerozoic's picture
Stage 3: depth reduction results (only 1 block cleanly prunable)
3729ac4 verified
"""Stage 3: depth reduction via block ablation.
For each of the 12 transformer blocks, zero both the attention proj and the
MLP fc2 output projections. Because each block is x + attn(x) + mlp(x), this
degenerates the block to an identity (residual pass-through). Measure F1 on
the Stage 0 classifier. Rank blocks by smallest F1 drop, sweep cumulative
skipping, identify how many blocks can be dropped without collapsing.
Output:
block_importance.json
block_pruning_curve.json
"""
import os, sys, json, time
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
from transformers import AutoModel
COCO_ROOT = '/home/zootest/datasets/coco'
STAGE0_CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json'
N_CALIBRATION = 1000
N_BLOCKS = 12
RES = 768
D = 768
OUT_DIR = '/mnt/d/_tmp/1pc_repo/stage_3'
DEVICE = 'cuda'
def load_classifier():
with open(STAGE0_CLASSIFIER) as f:
c = json.load(f)
pos = torch.tensor(c['pos_dims'], dtype=torch.long, device=DEVICE)
neg = torch.tensor(c['neg_dims'], dtype=torch.long, device=DEVICE)
return pos, neg, float(c['threshold'])
@torch.inference_mode()
def score_images(argus, imgs, pos, neg):
scores = []
for x in imgs:
with torch.autocast('cuda', dtype=torch.bfloat16):
out = argus.backbone.forward_features(x)
patches = out['x_norm_patchtokens'].float().squeeze(0)
ln = F.layer_norm(patches, [D])
pooled = ln.max(dim=0).values
scores.append((pooled[pos].sum() - pooled[neg].sum()).item())
return torch.tensor(scores, device=DEVICE)
def f1_of(scores, labels, thr):
pred = scores > thr
tp = (pred & labels).sum().float()
fp = (pred & ~labels).sum().float()
fn = (~pred & labels).sum().float()
prec = tp / (tp + fp).clamp(min=1)
rec = tp / (tp + fn).clamp(min=1)
f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9)
return float(f1), float(prec), float(rec)
def ablate_block(model, block_idx, zero=True):
"""Zero attn.proj and mlp.fc2 of the given block so the block degenerates
to an identity via residual. Returns (orig_proj, orig_fc2) for restoring."""
block = model.backbone.blocks[block_idx]
orig_proj = block.attn.proj.weight.detach().clone()
orig_fc2 = block.mlp.fc2.weight.detach().clone()
if zero:
with torch.no_grad():
block.attn.proj.weight.data.zero_()
block.mlp.fc2.weight.data.zero_()
return orig_proj, orig_fc2
def restore_block(model, block_idx, orig_proj, orig_fc2):
block = model.backbone.blocks[block_idx]
block.attn.proj.weight.data.copy_(orig_proj)
block.mlp.fc2.weight.data.copy_(orig_fc2)
def load_calibration(coco, n, MEAN, STD):
img_ids = sorted(coco.getImgIds())[:n]
tensors, labels = [], []
for img_id in img_ids:
info = coco.loadImgs(img_id)[0]
path = f"{COCO_ROOT}/val2017/{info['file_name']}"
img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR)
arr = np.asarray(img, dtype=np.uint8).copy()
x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0
tensors.append((x - MEAN) / STD)
labels.append(any(a['category_id'] == 1
for a in coco.loadAnns(coco.getAnnIds(imgIds=img_id, iscrowd=False))))
return tensors, torch.tensor(labels, dtype=torch.bool, device=DEVICE)
def main():
os.makedirs(OUT_DIR, exist_ok=True)
print('[init] loading Argus', flush=True)
model = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True).to(DEVICE).eval()
pos, neg, thr = load_classifier()
MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
print(f'[calib] loading {N_CALIBRATION} COCO val images', flush=True)
coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json')
imgs, labels = load_calibration(coco, N_CALIBRATION, MEAN, STD)
print('[baseline]', flush=True)
base_scores = score_images(model, imgs, pos, neg)
base_f1, base_p, base_r = f1_of(base_scores, labels, thr)
print(f' baseline F1={base_f1:.4f} P={base_p:.4f} R={base_r:.4f}', flush=True)
# Per-block individual ablation
per_block = []
t0 = time.time()
for b in range(N_BLOCKS):
op, of = ablate_block(model, b)
scores = score_images(model, imgs, pos, neg)
restore_block(model, b, op, of)
f1, p, r = f1_of(scores, labels, thr)
drop = base_f1 - f1
per_block.append({'block': b, 'F1': f1, 'precision': p, 'recall': r,
'F1_drop': drop})
print(f' block {b:>2} F1={f1:.4f} drop={drop:+.4f} '
f'{(time.time()-t0):.1f}s', flush=True)
ranked = sorted(per_block, key=lambda x: x['F1_drop'])
# Cumulative ablation curve
print('[curve] cumulative block ablation', flush=True)
curve = []
backups = {b: ablate_block(model, b) for b in range(N_BLOCKS)}
for b, (op, of) in backups.items():
restore_block(model, b, op, of) # ensure clean start
for K in [1, 2, 3, 4, 5, 6, 8, 10, 12]:
# Restore all
for b in range(N_BLOCKS):
op, of = backups[b]
restore_block(model, b, op, of)
# Ablate top-K most-prunable
for r in ranked[:K]:
b = r['block']
with torch.no_grad():
model.backbone.blocks[b].attn.proj.weight.data.zero_()
model.backbone.blocks[b].mlp.fc2.weight.data.zero_()
scores = score_images(model, imgs, pos, neg)
f1, p, rr = f1_of(scores, labels, thr)
curve.append({'blocks_pruned': K, 'F1': f1, 'F1_drop': base_f1 - f1,
'precision': p, 'recall': rr,
'pruned_list': [r['block'] for r in ranked[:K]]})
print(f' K={K:>2} F1={f1:.4f} drop={base_f1-f1:+.4f} '
f'blocks pruned={[r["block"] for r in ranked[:K]]}', flush=True)
# Restore
for b in range(N_BLOCKS):
op, of = backups[b]
restore_block(model, b, op, of)
with open(f'{OUT_DIR}/block_importance.json', 'w') as f:
json.dump({'baseline_F1': base_f1, 'per_block': per_block,
'ranked_most_prunable_first': [(r['block'], r['F1_drop'])
for r in ranked]},
f, indent=2)
with open(f'{OUT_DIR}/block_pruning_curve.json', 'w') as f:
json.dump({'baseline_F1': base_f1, 'curve': curve}, f, indent=2)
print(f'[done] -> {OUT_DIR}', flush=True)
if __name__ == '__main__':
main()