File size: 1,355 Bytes
b308b74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#!/usr/bin/env python
import argparse, os
import torch
import torch.nn.utils.prune as prune
from transformers import AutoModelForMaskedLM


def global_magnitude_prune(model, target_density: float):
    """Prune to target density by L1 magnitude across Linear layers."""
    assert 0 < target_density <= 1.0
    modules = []
    for name, m in model.named_modules():
        if isinstance(m, torch.nn.Linear):
            modules.append((m, 'weight'))
    # Calculate total weights
    total = sum(getattr(m, name).numel() for m, name in modules)
    keep = int(total * target_density)
    prune.global_unstructured(
        modules,
        pruning_method=prune.L1Unstructured,
        amount=max(0, total - keep)
    )
    # Remove reparametrization
    for m, _ in modules:
        prune.remove(m, 'weight')
    return model


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--in', dest='inp', required=True)
    ap.add_argument('--target_density', type=float, required=True)
    ap.add_argument('--out', required=True)
    args = ap.parse_args()

    model = AutoModelForMaskedLM.from_pretrained(args.inp)
    pruned = global_magnitude_prune(model, args.target_density)
    os.makedirs(args.out, exist_ok=True)
    pruned.save_pretrained(args.out)
    print(f"Pruned and saved → {args.out}")

if __name__ == '__main__':
    main()