phanerozoic's picture
Stage 2: attention-head pruning results + mask + apply_mask.py
a7e09b2 verified
"""Apply the Stage 2 head mask to an Argus backbone.
Loads Argus, reads head_importance.json to get the ranked-most-prunable
head list, and zeroes the output-projection columns for the top-K heads.
F1 at K=10 reaches 0.9159 on COCO val, improving the Stage 0 baseline
of 0.8939.
Usage:
model = load_pruned_argus(K=10)
score, pred = model(image_tensor) # with the Stage 1 head on top
"""
import os, json
import torch
import torch.nn as nn
from transformers import AutoModel
HEAD_DIM = 64
def load_pruned_argus(repo_or_path='phanerozoic/argus', K=10,
importance_json=None):
"""Load Argus and zero out the top-K most-prunable attention heads.
Returns the patched model. All non-attention params are untouched.
Backbone output is unchanged on person-classification tasks.
"""
if importance_json is None:
importance_json = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'head_importance.json')
with open(importance_json) as f:
imp = json.load(f)
ranked = imp['ranked_most_prunable_first'] # list of (block, head, F1_drop)
heads_to_prune = ranked[:K]
model = AutoModel.from_pretrained(repo_or_path, trust_remote_code=True)
with torch.no_grad():
for block, head, _drop in heads_to_prune:
proj = model.backbone.blocks[block].attn.proj
proj.weight.data[:, head * HEAD_DIM:(head + 1) * HEAD_DIM] = 0.0
return model
if __name__ == '__main__':
model = load_pruned_argus('/mnt/d/Argus', K=10)
nz = sum(p.numel() for p in model.parameters())
nonzero = sum((p != 0).sum().item() for p in model.parameters())
print(f'Argus loaded and masked. total params: {nz:,} nonzero: {nonzero:,}')
print(f'Effective param reduction: {(nz - nonzero):,}')