from __future__ import annotations import json from pathlib import Path from typing import Dict, List, Optional, Sequence import torch from PIL import Image from .config import ProtoMorphConfig from .model import ProtoMorphDINOv3 def load_image(path: str | Path) -> Image.Image: return Image.open(path).convert("RGB") def load_labels(path: Optional[str | Path], num_classes: int) -> List[str]: if path is None: return [f"class_{i}" for i in range(num_classes)] p = Path(path) if p.suffix.lower() == ".json": data = json.loads(p.read_text()) if isinstance(data, dict): return [data.get(str(i), data.get(i, f"class_{i}")) for i in range(num_classes)] return list(data) labels = [line.strip() for line in p.read_text().splitlines() if line.strip()] if len(labels) < num_classes: labels += [f"class_{i}" for i in range(len(labels), num_classes)] return labels[:num_classes] def build_model( config_path: str | Path, checkpoint_path: Optional[str | Path], device: str = "cuda", local_files_only: bool = False, allow_random_head: bool = False, ) -> ProtoMorphDINOv3: cfg = ProtoMorphConfig.from_json(config_path) device_obj = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu") model = ProtoMorphDINOv3(cfg, local_files_only=local_files_only).to(device_obj).eval() if checkpoint_path is not None and Path(checkpoint_path).exists(): model.load_custom_head(checkpoint_path) elif not allow_random_head: raise FileNotFoundError( f"Missing custom-head checkpoint: {checkpoint_path}. " "Pass --allow-random-head only for smoke tests; random logits are not meaningful." ) return model @torch.no_grad() def predict_paths( model: ProtoMorphDINOv3, image_paths: Sequence[str | Path], labels: List[str], topk: int = 5, device: str = "cuda", force_hard: bool = False, ) -> List[Dict]: images = [load_image(p) for p in image_paths] out = model(images, device=device, force_hard=force_hard) probs = out["logits"].softmax(dim=-1).float().cpu() main_probs = out["main_logits"].softmax(dim=-1).float().cpu() hard_mask = out["hard_mask"].cpu().tolist() gate_pmax = out["gate_pmax"].float().cpu().tolist() gate_margin = out["gate_margin"].float().cpu().tolist() gate_entropy = out["gate_entropy"].float().cpu().tolist() results: List[Dict] = [] for i, path in enumerate(image_paths): k = min(topk, probs.shape[-1]) values, indices = probs[i].topk(k) main_values, main_indices = main_probs[i].topk(k) results.append( { "image": str(path), "hard_case": bool(hard_mask[i]), "gate": { "pmax": float(gate_pmax[i]), "margin": float(gate_margin[i]), "entropy": float(gate_entropy[i]), }, "topk": [ {"rank": r + 1, "class_id": int(idx), "label": labels[int(idx)], "prob": float(val)} for r, (idx, val) in enumerate(zip(indices.tolist(), values.tolist())) ], "main_topk": [ {"rank": r + 1, "class_id": int(idx), "label": labels[int(idx)], "prob": float(val)} for r, (idx, val) in enumerate(zip(main_indices.tolist(), main_values.tolist())) ], "patch_hw": out["patch_hw"], "pixel_hw": out["pixel_hw"], } ) return results