File size: 2,488 Bytes
63089c1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | from __future__ import annotations
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
def parse_args() -> argparse.Namespace:
ap = argparse.ArgumentParser(description="Create initial ProtoMorph custom-head safetensors checkpoint")
ap.add_argument("--out-dir", default="checkpoints")
ap.add_argument("--dino-model-name", default="facebook/dinov3-vits16-pretrain-lvd1689m")
ap.add_argument("--num-classes", type=int, default=10)
ap.add_argument("--embed-dim", type=int, default=None)
ap.add_argument("--image-size", type=int, default=512)
ap.add_argument("--proto-count", type=int, default=64)
ap.add_argument("--memory-tokens", type=int, default=16)
ap.add_argument("--rbf-count", type=int, default=128)
ap.add_argument("--num-heads", type=int, default=8)
ap.add_argument("--force", action="store_true", help="Overwrite existing config/checkpoint/labels")
return ap.parse_args()
def main() -> None:
args = parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
cfg_path = out_dir / "config.json"
ckpt_path = out_dir / "protomorph_head.safetensors"
labels_path = out_dir / "labels.txt"
if not args.force and cfg_path.exists() and ckpt_path.exists() and labels_path.exists():
print(f"Existing checkpoint bundle found in {out_dir}; not overwriting. Pass --force to recreate it.")
return
from safetensors.torch import save_file
from src.protomorph.config import ProtoMorphConfig
from src.protomorph.model import ProtoMorphHead, infer_embed_dim_from_model_name
embed_dim = args.embed_dim or infer_embed_dim_from_model_name(args.dino_model_name)
cfg = ProtoMorphConfig(
dino_model_name=args.dino_model_name,
num_classes=args.num_classes,
embed_dim=embed_dim,
image_size=args.image_size,
proto_count=args.proto_count,
memory_tokens=args.memory_tokens,
rbf_count=args.rbf_count,
num_heads=args.num_heads,
)
head = ProtoMorphHead(cfg)
cfg.to_json(cfg_path)
save_file(head.state_dict(), str(ckpt_path))
labels_path.write_text("\n".join([f"class_{i}" for i in range(args.num_classes)]) + "\n")
print(f"Wrote {cfg_path}")
print(f"Wrote {ckpt_path}")
print("Important: this is a random head for plumbing/smoke tests. Train it before real predictions.")
if __name__ == "__main__":
main()
|