"""Stage 3: depth reduction via block ablation. For each of the 12 transformer blocks, zero both the attention proj and the MLP fc2 output projections. Because each block is x + attn(x) + mlp(x), this degenerates the block to an identity (residual pass-through). Measure F1 on the Stage 0 classifier. Rank blocks by smallest F1 drop, sweep cumulative skipping, identify how many blocks can be dropped without collapsing. Output: block_importance.json block_pruning_curve.json """ import os, sys, json, time import torch import torch.nn.functional as F import numpy as np from PIL import Image from pycocotools.coco import COCO from transformers import AutoModel COCO_ROOT = '/home/zootest/datasets/coco' STAGE0_CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json' N_CALIBRATION = 1000 N_BLOCKS = 12 RES = 768 D = 768 OUT_DIR = '/mnt/d/_tmp/1pc_repo/stage_3' DEVICE = 'cuda' def load_classifier(): with open(STAGE0_CLASSIFIER) as f: c = json.load(f) pos = torch.tensor(c['pos_dims'], dtype=torch.long, device=DEVICE) neg = torch.tensor(c['neg_dims'], dtype=torch.long, device=DEVICE) return pos, neg, float(c['threshold']) @torch.inference_mode() def score_images(argus, imgs, pos, neg): scores = [] for x in imgs: with torch.autocast('cuda', dtype=torch.bfloat16): out = argus.backbone.forward_features(x) patches = out['x_norm_patchtokens'].float().squeeze(0) ln = F.layer_norm(patches, [D]) pooled = ln.max(dim=0).values scores.append((pooled[pos].sum() - pooled[neg].sum()).item()) return torch.tensor(scores, device=DEVICE) def f1_of(scores, labels, thr): pred = scores > thr tp = (pred & labels).sum().float() fp = (pred & ~labels).sum().float() fn = (~pred & labels).sum().float() prec = tp / (tp + fp).clamp(min=1) rec = tp / (tp + fn).clamp(min=1) f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9) return float(f1), float(prec), float(rec) def ablate_block(model, block_idx, zero=True): """Zero attn.proj and mlp.fc2 of the given block so the block degenerates to an identity via residual. Returns (orig_proj, orig_fc2) for restoring.""" block = model.backbone.blocks[block_idx] orig_proj = block.attn.proj.weight.detach().clone() orig_fc2 = block.mlp.fc2.weight.detach().clone() if zero: with torch.no_grad(): block.attn.proj.weight.data.zero_() block.mlp.fc2.weight.data.zero_() return orig_proj, orig_fc2 def restore_block(model, block_idx, orig_proj, orig_fc2): block = model.backbone.blocks[block_idx] block.attn.proj.weight.data.copy_(orig_proj) block.mlp.fc2.weight.data.copy_(orig_fc2) def load_calibration(coco, n, MEAN, STD): img_ids = sorted(coco.getImgIds())[:n] tensors, labels = [], [] for img_id in img_ids: info = coco.loadImgs(img_id)[0] path = f"{COCO_ROOT}/val2017/{info['file_name']}" img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR) arr = np.asarray(img, dtype=np.uint8).copy() x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0 tensors.append((x - MEAN) / STD) labels.append(any(a['category_id'] == 1 for a in coco.loadAnns(coco.getAnnIds(imgIds=img_id, iscrowd=False)))) return tensors, torch.tensor(labels, dtype=torch.bool, device=DEVICE) def main(): os.makedirs(OUT_DIR, exist_ok=True) print('[init] loading Argus', flush=True) model = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True).to(DEVICE).eval() pos, neg, thr = load_classifier() MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() print(f'[calib] loading {N_CALIBRATION} COCO val images', flush=True) coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json') imgs, labels = load_calibration(coco, N_CALIBRATION, MEAN, STD) print('[baseline]', flush=True) base_scores = score_images(model, imgs, pos, neg) base_f1, base_p, base_r = f1_of(base_scores, labels, thr) print(f' baseline F1={base_f1:.4f} P={base_p:.4f} R={base_r:.4f}', flush=True) # Per-block individual ablation per_block = [] t0 = time.time() for b in range(N_BLOCKS): op, of = ablate_block(model, b) scores = score_images(model, imgs, pos, neg) restore_block(model, b, op, of) f1, p, r = f1_of(scores, labels, thr) drop = base_f1 - f1 per_block.append({'block': b, 'F1': f1, 'precision': p, 'recall': r, 'F1_drop': drop}) print(f' block {b:>2} F1={f1:.4f} drop={drop:+.4f} ' f'{(time.time()-t0):.1f}s', flush=True) ranked = sorted(per_block, key=lambda x: x['F1_drop']) # Cumulative ablation curve print('[curve] cumulative block ablation', flush=True) curve = [] backups = {b: ablate_block(model, b) for b in range(N_BLOCKS)} for b, (op, of) in backups.items(): restore_block(model, b, op, of) # ensure clean start for K in [1, 2, 3, 4, 5, 6, 8, 10, 12]: # Restore all for b in range(N_BLOCKS): op, of = backups[b] restore_block(model, b, op, of) # Ablate top-K most-prunable for r in ranked[:K]: b = r['block'] with torch.no_grad(): model.backbone.blocks[b].attn.proj.weight.data.zero_() model.backbone.blocks[b].mlp.fc2.weight.data.zero_() scores = score_images(model, imgs, pos, neg) f1, p, rr = f1_of(scores, labels, thr) curve.append({'blocks_pruned': K, 'F1': f1, 'F1_drop': base_f1 - f1, 'precision': p, 'recall': rr, 'pruned_list': [r['block'] for r in ranked[:K]]}) print(f' K={K:>2} F1={f1:.4f} drop={base_f1-f1:+.4f} ' f'blocks pruned={[r["block"] for r in ranked[:K]]}', flush=True) # Restore for b in range(N_BLOCKS): op, of = backups[b] restore_block(model, b, op, of) with open(f'{OUT_DIR}/block_importance.json', 'w') as f: json.dump({'baseline_F1': base_f1, 'per_block': per_block, 'ranked_most_prunable_first': [(r['block'], r['F1_drop']) for r in ranked]}, f, indent=2) with open(f'{OUT_DIR}/block_pruning_curve.json', 'w') as f: json.dump({'baseline_F1': base_f1, 'curve': curve}, f, indent=2) print(f'[done] -> {OUT_DIR}', flush=True) if __name__ == '__main__': main()