| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| ap = argparse.ArgumentParser(description="Create initial ProtoMorph custom-head safetensors checkpoint") |
| ap.add_argument("--out-dir", default="checkpoints") |
| ap.add_argument("--dino-model-name", default="facebook/dinov3-vits16-pretrain-lvd1689m") |
| ap.add_argument("--num-classes", type=int, default=10) |
| ap.add_argument("--embed-dim", type=int, default=None) |
| ap.add_argument("--image-size", type=int, default=512) |
| ap.add_argument("--proto-count", type=int, default=64) |
| ap.add_argument("--memory-tokens", type=int, default=16) |
| ap.add_argument("--rbf-count", type=int, default=128) |
| ap.add_argument("--num-heads", type=int, default=8) |
| ap.add_argument("--force", action="store_true", help="Overwrite existing config/checkpoint/labels") |
| return ap.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| out_dir = Path(args.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| cfg_path = out_dir / "config.json" |
| ckpt_path = out_dir / "protomorph_head.safetensors" |
| labels_path = out_dir / "labels.txt" |
|
|
| if not args.force and cfg_path.exists() and ckpt_path.exists() and labels_path.exists(): |
| print(f"Existing checkpoint bundle found in {out_dir}; not overwriting. Pass --force to recreate it.") |
| return |
|
|
| from safetensors.torch import save_file |
| from src.protomorph.config import ProtoMorphConfig |
| from src.protomorph.model import ProtoMorphHead, infer_embed_dim_from_model_name |
|
|
| embed_dim = args.embed_dim or infer_embed_dim_from_model_name(args.dino_model_name) |
| cfg = ProtoMorphConfig( |
| dino_model_name=args.dino_model_name, |
| num_classes=args.num_classes, |
| embed_dim=embed_dim, |
| image_size=args.image_size, |
| proto_count=args.proto_count, |
| memory_tokens=args.memory_tokens, |
| rbf_count=args.rbf_count, |
| num_heads=args.num_heads, |
| ) |
| head = ProtoMorphHead(cfg) |
| cfg.to_json(cfg_path) |
| save_file(head.state_dict(), str(ckpt_path)) |
| labels_path.write_text("\n".join([f"class_{i}" for i in range(args.num_classes)]) + "\n") |
| print(f"Wrote {cfg_path}") |
| print(f"Wrote {ckpt_path}") |
| print("Important: this is a random head for plumbing/smoke tests. Train it before real predictions.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|