laguna-martini / scripts /apply_pruning.py
nikgeo's picture
Publish Laguna Martini grouped-pruning model card and reproducibility artifacts
6f11713 verified
Raw
History Blame Contribute Delete
1.92 kB
#!/usr/bin/env python3
"""Apply atomic HEAPr masks and save zero-pruned BF16 checkpoints."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import numpy as np
from heapr.constants import DEFAULT_PRUNE_MODEL
from heapr.model_utils import load_causal_lm, load_tokenizer
from heapr.prune import apply_atomic_mask_to_model, atomic_mask_from_scores, save_pruned_model
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", default=DEFAULT_PRUNE_MODEL)
parser.add_argument("--revision")
parser.add_argument("--scores-path", required=True)
parser.add_argument("--output-root", required=True)
parser.add_argument("--ratios", nargs="+", type=float, default=[0.10, 0.20, 0.40])
parser.add_argument("--dtype", default="bfloat16")
return parser.parse_args()
def main() -> None:
args = parse_args()
scores = np.load(args.scores_path)
tokenizer = load_tokenizer(args.model_id, revision=args.revision)
for ratio in args.ratios:
model = load_causal_lm(args.model_id, revision=args.revision, dtype=args.dtype, use_cache=False)
keep_mask = atomic_mask_from_scores(scores, ratio)
apply_atomic_mask_to_model(model, keep_mask)
output_dir = Path(args.output_root) / f"atomic_pruned_{int(ratio * 100):02d}pct"
output_dir.mkdir(parents=True, exist_ok=True)
np.save(output_dir / "atomic_keep_mask.npy", keep_mask)
save_pruned_model(
model,
tokenizer,
output_dir,
mask_summary={
"ratio": ratio,
"mask_shape": list(keep_mask.shape),
"kept": int(keep_mask.sum()),
"pruned": int((~keep_mask).sum()),
},
)
print(output_dir)
if __name__ == "__main__":
main()