File size: 6,690 Bytes
3729ac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""Stage 3: depth reduction via block ablation.

For each of the 12 transformer blocks, zero both the attention proj and the
MLP fc2 output projections. Because each block is x + attn(x) + mlp(x), this
degenerates the block to an identity (residual pass-through). Measure F1 on
the Stage 0 classifier. Rank blocks by smallest F1 drop, sweep cumulative
skipping, identify how many blocks can be dropped without collapsing.

Output:
  block_importance.json
  block_pruning_curve.json
"""
import os, sys, json, time
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
from transformers import AutoModel

COCO_ROOT = '/home/zootest/datasets/coco'
STAGE0_CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json'
N_CALIBRATION = 1000
N_BLOCKS = 12
RES = 768
D = 768
OUT_DIR = '/mnt/d/_tmp/1pc_repo/stage_3'
DEVICE = 'cuda'


def load_classifier():
    with open(STAGE0_CLASSIFIER) as f:
        c = json.load(f)
    pos = torch.tensor(c['pos_dims'], dtype=torch.long, device=DEVICE)
    neg = torch.tensor(c['neg_dims'], dtype=torch.long, device=DEVICE)
    return pos, neg, float(c['threshold'])


@torch.inference_mode()
def score_images(argus, imgs, pos, neg):
    scores = []
    for x in imgs:
        with torch.autocast('cuda', dtype=torch.bfloat16):
            out = argus.backbone.forward_features(x)
        patches = out['x_norm_patchtokens'].float().squeeze(0)
        ln = F.layer_norm(patches, [D])
        pooled = ln.max(dim=0).values
        scores.append((pooled[pos].sum() - pooled[neg].sum()).item())
    return torch.tensor(scores, device=DEVICE)


def f1_of(scores, labels, thr):
    pred = scores > thr
    tp = (pred & labels).sum().float()
    fp = (pred & ~labels).sum().float()
    fn = (~pred & labels).sum().float()
    prec = tp / (tp + fp).clamp(min=1)
    rec = tp / (tp + fn).clamp(min=1)
    f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9)
    return float(f1), float(prec), float(rec)


def ablate_block(model, block_idx, zero=True):
    """Zero attn.proj and mlp.fc2 of the given block so the block degenerates
    to an identity via residual. Returns (orig_proj, orig_fc2) for restoring."""
    block = model.backbone.blocks[block_idx]
    orig_proj = block.attn.proj.weight.detach().clone()
    orig_fc2 = block.mlp.fc2.weight.detach().clone()
    if zero:
        with torch.no_grad():
            block.attn.proj.weight.data.zero_()
            block.mlp.fc2.weight.data.zero_()
    return orig_proj, orig_fc2


def restore_block(model, block_idx, orig_proj, orig_fc2):
    block = model.backbone.blocks[block_idx]
    block.attn.proj.weight.data.copy_(orig_proj)
    block.mlp.fc2.weight.data.copy_(orig_fc2)


def load_calibration(coco, n, MEAN, STD):
    img_ids = sorted(coco.getImgIds())[:n]
    tensors, labels = [], []
    for img_id in img_ids:
        info = coco.loadImgs(img_id)[0]
        path = f"{COCO_ROOT}/val2017/{info['file_name']}"
        img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR)
        arr = np.asarray(img, dtype=np.uint8).copy()
        x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0
        tensors.append((x - MEAN) / STD)
        labels.append(any(a['category_id'] == 1
                           for a in coco.loadAnns(coco.getAnnIds(imgIds=img_id, iscrowd=False))))
    return tensors, torch.tensor(labels, dtype=torch.bool, device=DEVICE)


def main():
    os.makedirs(OUT_DIR, exist_ok=True)
    print('[init] loading Argus', flush=True)
    model = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True).to(DEVICE).eval()
    pos, neg, thr = load_classifier()

    MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
    STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()

    print(f'[calib] loading {N_CALIBRATION} COCO val images', flush=True)
    coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json')
    imgs, labels = load_calibration(coco, N_CALIBRATION, MEAN, STD)

    print('[baseline]', flush=True)
    base_scores = score_images(model, imgs, pos, neg)
    base_f1, base_p, base_r = f1_of(base_scores, labels, thr)
    print(f'  baseline F1={base_f1:.4f}  P={base_p:.4f}  R={base_r:.4f}', flush=True)

    # Per-block individual ablation
    per_block = []
    t0 = time.time()
    for b in range(N_BLOCKS):
        op, of = ablate_block(model, b)
        scores = score_images(model, imgs, pos, neg)
        restore_block(model, b, op, of)
        f1, p, r = f1_of(scores, labels, thr)
        drop = base_f1 - f1
        per_block.append({'block': b, 'F1': f1, 'precision': p, 'recall': r,
                          'F1_drop': drop})
        print(f'  block {b:>2}  F1={f1:.4f}  drop={drop:+.4f}  '
              f'{(time.time()-t0):.1f}s', flush=True)

    ranked = sorted(per_block, key=lambda x: x['F1_drop'])

    # Cumulative ablation curve
    print('[curve] cumulative block ablation', flush=True)
    curve = []
    backups = {b: ablate_block(model, b) for b in range(N_BLOCKS)}
    for b, (op, of) in backups.items():
        restore_block(model, b, op, of)  # ensure clean start
    for K in [1, 2, 3, 4, 5, 6, 8, 10, 12]:
        # Restore all
        for b in range(N_BLOCKS):
            op, of = backups[b]
            restore_block(model, b, op, of)
        # Ablate top-K most-prunable
        for r in ranked[:K]:
            b = r['block']
            with torch.no_grad():
                model.backbone.blocks[b].attn.proj.weight.data.zero_()
                model.backbone.blocks[b].mlp.fc2.weight.data.zero_()
        scores = score_images(model, imgs, pos, neg)
        f1, p, rr = f1_of(scores, labels, thr)
        curve.append({'blocks_pruned': K, 'F1': f1, 'F1_drop': base_f1 - f1,
                      'precision': p, 'recall': rr,
                      'pruned_list': [r['block'] for r in ranked[:K]]})
        print(f'  K={K:>2}  F1={f1:.4f}  drop={base_f1-f1:+.4f}  '
              f'blocks pruned={[r["block"] for r in ranked[:K]]}', flush=True)
    # Restore
    for b in range(N_BLOCKS):
        op, of = backups[b]
        restore_block(model, b, op, of)

    with open(f'{OUT_DIR}/block_importance.json', 'w') as f:
        json.dump({'baseline_F1': base_f1, 'per_block': per_block,
                   'ranked_most_prunable_first': [(r['block'], r['F1_drop'])
                                                    for r in ranked]},
                  f, indent=2)
    with open(f'{OUT_DIR}/block_pruning_curve.json', 'w') as f:
        json.dump({'baseline_F1': base_f1, 'curve': curve}, f, indent=2)
    print(f'[done] -> {OUT_DIR}', flush=True)


if __name__ == '__main__':
    main()