| """Stage 3: depth reduction via block ablation. |
| |
| For each of the 12 transformer blocks, zero both the attention proj and the |
| MLP fc2 output projections. Because each block is x + attn(x) + mlp(x), this |
| degenerates the block to an identity (residual pass-through). Measure F1 on |
| the Stage 0 classifier. Rank blocks by smallest F1 drop, sweep cumulative |
| skipping, identify how many blocks can be dropped without collapsing. |
| |
| Output: |
| block_importance.json |
| block_pruning_curve.json |
| """ |
| import os, sys, json, time |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
| from pycocotools.coco import COCO |
| from transformers import AutoModel |
|
|
| COCO_ROOT = '/home/zootest/datasets/coco' |
| STAGE0_CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json' |
| N_CALIBRATION = 1000 |
| N_BLOCKS = 12 |
| RES = 768 |
| D = 768 |
| OUT_DIR = '/mnt/d/_tmp/1pc_repo/stage_3' |
| DEVICE = 'cuda' |
|
|
|
|
| def load_classifier(): |
| with open(STAGE0_CLASSIFIER) as f: |
| c = json.load(f) |
| pos = torch.tensor(c['pos_dims'], dtype=torch.long, device=DEVICE) |
| neg = torch.tensor(c['neg_dims'], dtype=torch.long, device=DEVICE) |
| return pos, neg, float(c['threshold']) |
|
|
|
|
| @torch.inference_mode() |
| def score_images(argus, imgs, pos, neg): |
| scores = [] |
| for x in imgs: |
| with torch.autocast('cuda', dtype=torch.bfloat16): |
| out = argus.backbone.forward_features(x) |
| patches = out['x_norm_patchtokens'].float().squeeze(0) |
| ln = F.layer_norm(patches, [D]) |
| pooled = ln.max(dim=0).values |
| scores.append((pooled[pos].sum() - pooled[neg].sum()).item()) |
| return torch.tensor(scores, device=DEVICE) |
|
|
|
|
| def f1_of(scores, labels, thr): |
| pred = scores > thr |
| tp = (pred & labels).sum().float() |
| fp = (pred & ~labels).sum().float() |
| fn = (~pred & labels).sum().float() |
| prec = tp / (tp + fp).clamp(min=1) |
| rec = tp / (tp + fn).clamp(min=1) |
| f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9) |
| return float(f1), float(prec), float(rec) |
|
|
|
|
| def ablate_block(model, block_idx, zero=True): |
| """Zero attn.proj and mlp.fc2 of the given block so the block degenerates |
| to an identity via residual. Returns (orig_proj, orig_fc2) for restoring.""" |
| block = model.backbone.blocks[block_idx] |
| orig_proj = block.attn.proj.weight.detach().clone() |
| orig_fc2 = block.mlp.fc2.weight.detach().clone() |
| if zero: |
| with torch.no_grad(): |
| block.attn.proj.weight.data.zero_() |
| block.mlp.fc2.weight.data.zero_() |
| return orig_proj, orig_fc2 |
|
|
|
|
| def restore_block(model, block_idx, orig_proj, orig_fc2): |
| block = model.backbone.blocks[block_idx] |
| block.attn.proj.weight.data.copy_(orig_proj) |
| block.mlp.fc2.weight.data.copy_(orig_fc2) |
|
|
|
|
| def load_calibration(coco, n, MEAN, STD): |
| img_ids = sorted(coco.getImgIds())[:n] |
| tensors, labels = [], [] |
| for img_id in img_ids: |
| info = coco.loadImgs(img_id)[0] |
| path = f"{COCO_ROOT}/val2017/{info['file_name']}" |
| img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR) |
| arr = np.asarray(img, dtype=np.uint8).copy() |
| x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0 |
| tensors.append((x - MEAN) / STD) |
| labels.append(any(a['category_id'] == 1 |
| for a in coco.loadAnns(coco.getAnnIds(imgIds=img_id, iscrowd=False)))) |
| return tensors, torch.tensor(labels, dtype=torch.bool, device=DEVICE) |
|
|
|
|
| def main(): |
| os.makedirs(OUT_DIR, exist_ok=True) |
| print('[init] loading Argus', flush=True) |
| model = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True).to(DEVICE).eval() |
| pos, neg, thr = load_classifier() |
|
|
| MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() |
| STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() |
|
|
| print(f'[calib] loading {N_CALIBRATION} COCO val images', flush=True) |
| coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json') |
| imgs, labels = load_calibration(coco, N_CALIBRATION, MEAN, STD) |
|
|
| print('[baseline]', flush=True) |
| base_scores = score_images(model, imgs, pos, neg) |
| base_f1, base_p, base_r = f1_of(base_scores, labels, thr) |
| print(f' baseline F1={base_f1:.4f} P={base_p:.4f} R={base_r:.4f}', flush=True) |
|
|
| |
| per_block = [] |
| t0 = time.time() |
| for b in range(N_BLOCKS): |
| op, of = ablate_block(model, b) |
| scores = score_images(model, imgs, pos, neg) |
| restore_block(model, b, op, of) |
| f1, p, r = f1_of(scores, labels, thr) |
| drop = base_f1 - f1 |
| per_block.append({'block': b, 'F1': f1, 'precision': p, 'recall': r, |
| 'F1_drop': drop}) |
| print(f' block {b:>2} F1={f1:.4f} drop={drop:+.4f} ' |
| f'{(time.time()-t0):.1f}s', flush=True) |
|
|
| ranked = sorted(per_block, key=lambda x: x['F1_drop']) |
|
|
| |
| print('[curve] cumulative block ablation', flush=True) |
| curve = [] |
| backups = {b: ablate_block(model, b) for b in range(N_BLOCKS)} |
| for b, (op, of) in backups.items(): |
| restore_block(model, b, op, of) |
| for K in [1, 2, 3, 4, 5, 6, 8, 10, 12]: |
| |
| for b in range(N_BLOCKS): |
| op, of = backups[b] |
| restore_block(model, b, op, of) |
| |
| for r in ranked[:K]: |
| b = r['block'] |
| with torch.no_grad(): |
| model.backbone.blocks[b].attn.proj.weight.data.zero_() |
| model.backbone.blocks[b].mlp.fc2.weight.data.zero_() |
| scores = score_images(model, imgs, pos, neg) |
| f1, p, rr = f1_of(scores, labels, thr) |
| curve.append({'blocks_pruned': K, 'F1': f1, 'F1_drop': base_f1 - f1, |
| 'precision': p, 'recall': rr, |
| 'pruned_list': [r['block'] for r in ranked[:K]]}) |
| print(f' K={K:>2} F1={f1:.4f} drop={base_f1-f1:+.4f} ' |
| f'blocks pruned={[r["block"] for r in ranked[:K]]}', flush=True) |
| |
| for b in range(N_BLOCKS): |
| op, of = backups[b] |
| restore_block(model, b, op, of) |
|
|
| with open(f'{OUT_DIR}/block_importance.json', 'w') as f: |
| json.dump({'baseline_F1': base_f1, 'per_block': per_block, |
| 'ranked_most_prunable_first': [(r['block'], r['F1_drop']) |
| for r in ranked]}, |
| f, indent=2) |
| with open(f'{OUT_DIR}/block_pruning_curve.json', 'w') as f: |
| json.dump({'baseline_F1': base_f1, 'curve': curve}, f, indent=2) |
| print(f'[done] -> {OUT_DIR}', flush=True) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|