| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import torch |
|
|
| from src.protomorph.inference import build_model, load_labels, predict_paths |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| ap = argparse.ArgumentParser(description="ProtoMorph-DINOv3 inference CLI") |
| ap.add_argument("--config", default="checkpoints/config.json") |
| ap.add_argument("--checkpoint", default="checkpoints/protomorph_head.safetensors") |
| ap.add_argument("--labels", default=None, help="txt/json labels. Defaults to class_0..class_N") |
| ap.add_argument("--image", action="append", required=True, help="Image path. Repeat for batch inference.") |
| ap.add_argument("--topk", type=int, default=5) |
| ap.add_argument("--device", default="cuda") |
| ap.add_argument("--force-hard", action="store_true", help="Always run/fuse hard expert branch.") |
| ap.add_argument("--local-files-only", action="store_true") |
| ap.add_argument("--allow-random-head", action="store_true", help="Smoke test only; logits are random.") |
| ap.add_argument("--output", default=None, help="Optional JSON output path") |
| return ap.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| device = args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu" |
| model = build_model( |
| args.config, |
| args.checkpoint, |
| device=device, |
| local_files_only=args.local_files_only, |
| allow_random_head=args.allow_random_head, |
| ) |
| labels = load_labels(args.labels, model.cfg.num_classes) |
| results = predict_paths(model, args.image, labels, topk=args.topk, device=device, force_hard=args.force_hard) |
| text = json.dumps(results, indent=2) |
| print(text) |
| if args.output: |
| Path(args.output).parent.mkdir(parents=True, exist_ok=True) |
| Path(args.output).write_text(text + "\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|