neuro / scripts /prune_backbone.py
Evogoatml's picture
initial build
b308b74
#!/usr/bin/env python
import argparse, os
import torch
import torch.nn.utils.prune as prune
from transformers import AutoModelForMaskedLM
def global_magnitude_prune(model, target_density: float):
"""Prune to target density by L1 magnitude across Linear layers."""
assert 0 < target_density <= 1.0
modules = []
for name, m in model.named_modules():
if isinstance(m, torch.nn.Linear):
modules.append((m, 'weight'))
# Calculate total weights
total = sum(getattr(m, name).numel() for m, name in modules)
keep = int(total * target_density)
prune.global_unstructured(
modules,
pruning_method=prune.L1Unstructured,
amount=max(0, total - keep)
)
# Remove reparametrization
for m, _ in modules:
prune.remove(m, 'weight')
return model
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--in', dest='inp', required=True)
ap.add_argument('--target_density', type=float, required=True)
ap.add_argument('--out', required=True)
args = ap.parse_args()
model = AutoModelForMaskedLM.from_pretrained(args.inp)
pruned = global_magnitude_prune(model, args.target_density)
os.makedirs(args.out, exist_ok=True)
pruned.save_pretrained(args.out)
print(f"Pruned and saved → {args.out}")
if __name__ == '__main__':
main()