#!/usr/bin/env python3 """Apply atomic HEAPr masks and save zero-pruned BF16 checkpoints.""" from __future__ import annotations import argparse import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) import numpy as np from heapr.constants import DEFAULT_PRUNE_MODEL from heapr.model_utils import load_causal_lm, load_tokenizer from heapr.prune import apply_atomic_mask_to_model, atomic_mask_from_scores, save_pruned_model def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model-id", default=DEFAULT_PRUNE_MODEL) parser.add_argument("--revision") parser.add_argument("--scores-path", required=True) parser.add_argument("--output-root", required=True) parser.add_argument("--ratios", nargs="+", type=float, default=[0.10, 0.20, 0.40]) parser.add_argument("--dtype", default="bfloat16") return parser.parse_args() def main() -> None: args = parse_args() scores = np.load(args.scores_path) tokenizer = load_tokenizer(args.model_id, revision=args.revision) for ratio in args.ratios: model = load_causal_lm(args.model_id, revision=args.revision, dtype=args.dtype, use_cache=False) keep_mask = atomic_mask_from_scores(scores, ratio) apply_atomic_mask_to_model(model, keep_mask) output_dir = Path(args.output_root) / f"atomic_pruned_{int(ratio * 100):02d}pct" output_dir.mkdir(parents=True, exist_ok=True) np.save(output_dir / "atomic_keep_mask.npy", keep_mask) save_pruned_model( model, tokenizer, output_dir, mask_summary={ "ratio": ratio, "mask_shape": list(keep_mask.shape), "kept": int(keep_mask.sum()), "pruned": int((~keep_mask).sum()), }, ) print(output_dir) if __name__ == "__main__": main()