File size: 6,690 Bytes
3729ac4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """Stage 3: depth reduction via block ablation.
For each of the 12 transformer blocks, zero both the attention proj and the
MLP fc2 output projections. Because each block is x + attn(x) + mlp(x), this
degenerates the block to an identity (residual pass-through). Measure F1 on
the Stage 0 classifier. Rank blocks by smallest F1 drop, sweep cumulative
skipping, identify how many blocks can be dropped without collapsing.
Output:
block_importance.json
block_pruning_curve.json
"""
import os, sys, json, time
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
from transformers import AutoModel
COCO_ROOT = '/home/zootest/datasets/coco'
STAGE0_CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json'
N_CALIBRATION = 1000
N_BLOCKS = 12
RES = 768
D = 768
OUT_DIR = '/mnt/d/_tmp/1pc_repo/stage_3'
DEVICE = 'cuda'
def load_classifier():
with open(STAGE0_CLASSIFIER) as f:
c = json.load(f)
pos = torch.tensor(c['pos_dims'], dtype=torch.long, device=DEVICE)
neg = torch.tensor(c['neg_dims'], dtype=torch.long, device=DEVICE)
return pos, neg, float(c['threshold'])
@torch.inference_mode()
def score_images(argus, imgs, pos, neg):
scores = []
for x in imgs:
with torch.autocast('cuda', dtype=torch.bfloat16):
out = argus.backbone.forward_features(x)
patches = out['x_norm_patchtokens'].float().squeeze(0)
ln = F.layer_norm(patches, [D])
pooled = ln.max(dim=0).values
scores.append((pooled[pos].sum() - pooled[neg].sum()).item())
return torch.tensor(scores, device=DEVICE)
def f1_of(scores, labels, thr):
pred = scores > thr
tp = (pred & labels).sum().float()
fp = (pred & ~labels).sum().float()
fn = (~pred & labels).sum().float()
prec = tp / (tp + fp).clamp(min=1)
rec = tp / (tp + fn).clamp(min=1)
f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9)
return float(f1), float(prec), float(rec)
def ablate_block(model, block_idx, zero=True):
"""Zero attn.proj and mlp.fc2 of the given block so the block degenerates
to an identity via residual. Returns (orig_proj, orig_fc2) for restoring."""
block = model.backbone.blocks[block_idx]
orig_proj = block.attn.proj.weight.detach().clone()
orig_fc2 = block.mlp.fc2.weight.detach().clone()
if zero:
with torch.no_grad():
block.attn.proj.weight.data.zero_()
block.mlp.fc2.weight.data.zero_()
return orig_proj, orig_fc2
def restore_block(model, block_idx, orig_proj, orig_fc2):
block = model.backbone.blocks[block_idx]
block.attn.proj.weight.data.copy_(orig_proj)
block.mlp.fc2.weight.data.copy_(orig_fc2)
def load_calibration(coco, n, MEAN, STD):
img_ids = sorted(coco.getImgIds())[:n]
tensors, labels = [], []
for img_id in img_ids:
info = coco.loadImgs(img_id)[0]
path = f"{COCO_ROOT}/val2017/{info['file_name']}"
img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR)
arr = np.asarray(img, dtype=np.uint8).copy()
x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0
tensors.append((x - MEAN) / STD)
labels.append(any(a['category_id'] == 1
for a in coco.loadAnns(coco.getAnnIds(imgIds=img_id, iscrowd=False))))
return tensors, torch.tensor(labels, dtype=torch.bool, device=DEVICE)
def main():
os.makedirs(OUT_DIR, exist_ok=True)
print('[init] loading Argus', flush=True)
model = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True).to(DEVICE).eval()
pos, neg, thr = load_classifier()
MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
print(f'[calib] loading {N_CALIBRATION} COCO val images', flush=True)
coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json')
imgs, labels = load_calibration(coco, N_CALIBRATION, MEAN, STD)
print('[baseline]', flush=True)
base_scores = score_images(model, imgs, pos, neg)
base_f1, base_p, base_r = f1_of(base_scores, labels, thr)
print(f' baseline F1={base_f1:.4f} P={base_p:.4f} R={base_r:.4f}', flush=True)
# Per-block individual ablation
per_block = []
t0 = time.time()
for b in range(N_BLOCKS):
op, of = ablate_block(model, b)
scores = score_images(model, imgs, pos, neg)
restore_block(model, b, op, of)
f1, p, r = f1_of(scores, labels, thr)
drop = base_f1 - f1
per_block.append({'block': b, 'F1': f1, 'precision': p, 'recall': r,
'F1_drop': drop})
print(f' block {b:>2} F1={f1:.4f} drop={drop:+.4f} '
f'{(time.time()-t0):.1f}s', flush=True)
ranked = sorted(per_block, key=lambda x: x['F1_drop'])
# Cumulative ablation curve
print('[curve] cumulative block ablation', flush=True)
curve = []
backups = {b: ablate_block(model, b) for b in range(N_BLOCKS)}
for b, (op, of) in backups.items():
restore_block(model, b, op, of) # ensure clean start
for K in [1, 2, 3, 4, 5, 6, 8, 10, 12]:
# Restore all
for b in range(N_BLOCKS):
op, of = backups[b]
restore_block(model, b, op, of)
# Ablate top-K most-prunable
for r in ranked[:K]:
b = r['block']
with torch.no_grad():
model.backbone.blocks[b].attn.proj.weight.data.zero_()
model.backbone.blocks[b].mlp.fc2.weight.data.zero_()
scores = score_images(model, imgs, pos, neg)
f1, p, rr = f1_of(scores, labels, thr)
curve.append({'blocks_pruned': K, 'F1': f1, 'F1_drop': base_f1 - f1,
'precision': p, 'recall': rr,
'pruned_list': [r['block'] for r in ranked[:K]]})
print(f' K={K:>2} F1={f1:.4f} drop={base_f1-f1:+.4f} '
f'blocks pruned={[r["block"] for r in ranked[:K]]}', flush=True)
# Restore
for b in range(N_BLOCKS):
op, of = backups[b]
restore_block(model, b, op, of)
with open(f'{OUT_DIR}/block_importance.json', 'w') as f:
json.dump({'baseline_F1': base_f1, 'per_block': per_block,
'ranked_most_prunable_first': [(r['block'], r['F1_drop'])
for r in ranked]},
f, indent=2)
with open(f'{OUT_DIR}/block_pruning_curve.json', 'w') as f:
json.dump({'baseline_F1': base_f1, 'curve': curve}, f, indent=2)
print(f'[done] -> {OUT_DIR}', flush=True)
if __name__ == '__main__':
main()
|