from __future__ import annotations import argparse import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) def parse_args() -> argparse.Namespace: ap = argparse.ArgumentParser(description="Create initial ProtoMorph custom-head safetensors checkpoint") ap.add_argument("--out-dir", default="checkpoints") ap.add_argument("--dino-model-name", default="facebook/dinov3-vits16-pretrain-lvd1689m") ap.add_argument("--num-classes", type=int, default=10) ap.add_argument("--embed-dim", type=int, default=None) ap.add_argument("--image-size", type=int, default=512) ap.add_argument("--proto-count", type=int, default=64) ap.add_argument("--memory-tokens", type=int, default=16) ap.add_argument("--rbf-count", type=int, default=128) ap.add_argument("--num-heads", type=int, default=8) ap.add_argument("--force", action="store_true", help="Overwrite existing config/checkpoint/labels") return ap.parse_args() def main() -> None: args = parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) cfg_path = out_dir / "config.json" ckpt_path = out_dir / "protomorph_head.safetensors" labels_path = out_dir / "labels.txt" if not args.force and cfg_path.exists() and ckpt_path.exists() and labels_path.exists(): print(f"Existing checkpoint bundle found in {out_dir}; not overwriting. Pass --force to recreate it.") return from safetensors.torch import save_file from src.protomorph.config import ProtoMorphConfig from src.protomorph.model import ProtoMorphHead, infer_embed_dim_from_model_name embed_dim = args.embed_dim or infer_embed_dim_from_model_name(args.dino_model_name) cfg = ProtoMorphConfig( dino_model_name=args.dino_model_name, num_classes=args.num_classes, embed_dim=embed_dim, image_size=args.image_size, proto_count=args.proto_count, memory_tokens=args.memory_tokens, rbf_count=args.rbf_count, num_heads=args.num_heads, ) head = ProtoMorphHead(cfg) cfg.to_json(cfg_path) save_file(head.state_dict(), str(ckpt_path)) labels_path.write_text("\n".join([f"class_{i}" for i in range(args.num_classes)]) + "\n") print(f"Wrote {cfg_path}") print(f"Wrote {ckpt_path}") print("Important: this is a random head for plumbing/smoke tests. Train it before real predictions.") if __name__ == "__main__": main()