DINO-Protomorph / infer.py
shiowo's picture
Upload ProtoMorph-DINO scaffold and random head checkpoint
63089c1 verified
from __future__ import annotations
import argparse
import json
from pathlib import Path
import torch
from src.protomorph.inference import build_model, load_labels, predict_paths
def parse_args() -> argparse.Namespace:
ap = argparse.ArgumentParser(description="ProtoMorph-DINOv3 inference CLI")
ap.add_argument("--config", default="checkpoints/config.json")
ap.add_argument("--checkpoint", default="checkpoints/protomorph_head.safetensors")
ap.add_argument("--labels", default=None, help="txt/json labels. Defaults to class_0..class_N")
ap.add_argument("--image", action="append", required=True, help="Image path. Repeat for batch inference.")
ap.add_argument("--topk", type=int, default=5)
ap.add_argument("--device", default="cuda")
ap.add_argument("--force-hard", action="store_true", help="Always run/fuse hard expert branch.")
ap.add_argument("--local-files-only", action="store_true")
ap.add_argument("--allow-random-head", action="store_true", help="Smoke test only; logits are random.")
ap.add_argument("--output", default=None, help="Optional JSON output path")
return ap.parse_args()
def main() -> None:
args = parse_args()
device = args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu"
model = build_model(
args.config,
args.checkpoint,
device=device,
local_files_only=args.local_files_only,
allow_random_head=args.allow_random_head,
)
labels = load_labels(args.labels, model.cfg.num_classes)
results = predict_paths(model, args.image, labels, topk=args.topk, device=device, force_hard=args.force_hard)
text = json.dumps(results, indent=2)
print(text)
if args.output:
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
Path(args.output).write_text(text + "\n")
if __name__ == "__main__":
main()