"""Apply the Stage 2 head mask to an Argus backbone. Loads Argus, reads head_importance.json to get the ranked-most-prunable head list, and zeroes the output-projection columns for the top-K heads. F1 at K=10 reaches 0.9159 on COCO val, improving the Stage 0 baseline of 0.8939. Usage: model = load_pruned_argus(K=10) score, pred = model(image_tensor) # with the Stage 1 head on top """ import os, json import torch import torch.nn as nn from transformers import AutoModel HEAD_DIM = 64 def load_pruned_argus(repo_or_path='phanerozoic/argus', K=10, importance_json=None): """Load Argus and zero out the top-K most-prunable attention heads. Returns the patched model. All non-attention params are untouched. Backbone output is unchanged on person-classification tasks. """ if importance_json is None: importance_json = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'head_importance.json') with open(importance_json) as f: imp = json.load(f) ranked = imp['ranked_most_prunable_first'] # list of (block, head, F1_drop) heads_to_prune = ranked[:K] model = AutoModel.from_pretrained(repo_or_path, trust_remote_code=True) with torch.no_grad(): for block, head, _drop in heads_to_prune: proj = model.backbone.blocks[block].attn.proj proj.weight.data[:, head * HEAD_DIM:(head + 1) * HEAD_DIM] = 0.0 return model if __name__ == '__main__': model = load_pruned_argus('/mnt/d/Argus', K=10) nz = sum(p.numel() for p in model.parameters()) nonzero = sum((p != 0).sum().item() for p in model.parameters()) print(f'Argus loaded and masked. total params: {nz:,} nonzero: {nonzero:,}') print(f'Effective param reduction: {(nz - nonzero):,}')