| """Apply the Stage 2 head mask to an Argus backbone. |
| |
| Loads Argus, reads head_importance.json to get the ranked-most-prunable |
| head list, and zeroes the output-projection columns for the top-K heads. |
| F1 at K=10 reaches 0.9159 on COCO val, improving the Stage 0 baseline |
| of 0.8939. |
| |
| Usage: |
| model = load_pruned_argus(K=10) |
| score, pred = model(image_tensor) # with the Stage 1 head on top |
| """ |
| import os, json |
| import torch |
| import torch.nn as nn |
| from transformers import AutoModel |
|
|
| HEAD_DIM = 64 |
|
|
|
|
| def load_pruned_argus(repo_or_path='phanerozoic/argus', K=10, |
| importance_json=None): |
| """Load Argus and zero out the top-K most-prunable attention heads. |
| |
| Returns the patched model. All non-attention params are untouched. |
| Backbone output is unchanged on person-classification tasks. |
| """ |
| if importance_json is None: |
| importance_json = os.path.join(os.path.dirname(os.path.abspath(__file__)), |
| 'head_importance.json') |
| with open(importance_json) as f: |
| imp = json.load(f) |
| ranked = imp['ranked_most_prunable_first'] |
| heads_to_prune = ranked[:K] |
|
|
| model = AutoModel.from_pretrained(repo_or_path, trust_remote_code=True) |
| with torch.no_grad(): |
| for block, head, _drop in heads_to_prune: |
| proj = model.backbone.blocks[block].attn.proj |
| proj.weight.data[:, head * HEAD_DIM:(head + 1) * HEAD_DIM] = 0.0 |
| return model |
|
|
|
|
| if __name__ == '__main__': |
| model = load_pruned_argus('/mnt/d/Argus', K=10) |
| nz = sum(p.numel() for p in model.parameters()) |
| nonzero = sum((p != 0).sum().item() for p in model.parameters()) |
| print(f'Argus loaded and masked. total params: {nz:,} nonzero: {nonzero:,}') |
| print(f'Effective param reduction: {(nz - nonzero):,}') |
|
|