File size: 2,488 Bytes
63089c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from __future__ import annotations

import argparse
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))


def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser(description="Create initial ProtoMorph custom-head safetensors checkpoint")
    ap.add_argument("--out-dir", default="checkpoints")
    ap.add_argument("--dino-model-name", default="facebook/dinov3-vits16-pretrain-lvd1689m")
    ap.add_argument("--num-classes", type=int, default=10)
    ap.add_argument("--embed-dim", type=int, default=None)
    ap.add_argument("--image-size", type=int, default=512)
    ap.add_argument("--proto-count", type=int, default=64)
    ap.add_argument("--memory-tokens", type=int, default=16)
    ap.add_argument("--rbf-count", type=int, default=128)
    ap.add_argument("--num-heads", type=int, default=8)
    ap.add_argument("--force", action="store_true", help="Overwrite existing config/checkpoint/labels")
    return ap.parse_args()


def main() -> None:
    args = parse_args()
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    cfg_path = out_dir / "config.json"
    ckpt_path = out_dir / "protomorph_head.safetensors"
    labels_path = out_dir / "labels.txt"

    if not args.force and cfg_path.exists() and ckpt_path.exists() and labels_path.exists():
        print(f"Existing checkpoint bundle found in {out_dir}; not overwriting. Pass --force to recreate it.")
        return

    from safetensors.torch import save_file
    from src.protomorph.config import ProtoMorphConfig
    from src.protomorph.model import ProtoMorphHead, infer_embed_dim_from_model_name

    embed_dim = args.embed_dim or infer_embed_dim_from_model_name(args.dino_model_name)
    cfg = ProtoMorphConfig(
        dino_model_name=args.dino_model_name,
        num_classes=args.num_classes,
        embed_dim=embed_dim,
        image_size=args.image_size,
        proto_count=args.proto_count,
        memory_tokens=args.memory_tokens,
        rbf_count=args.rbf_count,
        num_heads=args.num_heads,
    )
    head = ProtoMorphHead(cfg)
    cfg.to_json(cfg_path)
    save_file(head.state_dict(), str(ckpt_path))
    labels_path.write_text("\n".join([f"class_{i}" for i in range(args.num_classes)]) + "\n")
    print(f"Wrote {cfg_path}")
    print(f"Wrote {ckpt_path}")
    print("Important: this is a random head for plumbing/smoke tests. Train it before real predictions.")


if __name__ == "__main__":
    main()