DINO-Protomorph / src /protomorph /inference.py
shiowo's picture
Upload ProtoMorph-DINO scaffold and random head checkpoint
63089c1 verified
from __future__ import annotations
import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence
import torch
from PIL import Image
from .config import ProtoMorphConfig
from .model import ProtoMorphDINOv3
def load_image(path: str | Path) -> Image.Image:
return Image.open(path).convert("RGB")
def load_labels(path: Optional[str | Path], num_classes: int) -> List[str]:
if path is None:
return [f"class_{i}" for i in range(num_classes)]
p = Path(path)
if p.suffix.lower() == ".json":
data = json.loads(p.read_text())
if isinstance(data, dict):
return [data.get(str(i), data.get(i, f"class_{i}")) for i in range(num_classes)]
return list(data)
labels = [line.strip() for line in p.read_text().splitlines() if line.strip()]
if len(labels) < num_classes:
labels += [f"class_{i}" for i in range(len(labels), num_classes)]
return labels[:num_classes]
def build_model(
config_path: str | Path,
checkpoint_path: Optional[str | Path],
device: str = "cuda",
local_files_only: bool = False,
allow_random_head: bool = False,
) -> ProtoMorphDINOv3:
cfg = ProtoMorphConfig.from_json(config_path)
device_obj = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
model = ProtoMorphDINOv3(cfg, local_files_only=local_files_only).to(device_obj).eval()
if checkpoint_path is not None and Path(checkpoint_path).exists():
model.load_custom_head(checkpoint_path)
elif not allow_random_head:
raise FileNotFoundError(
f"Missing custom-head checkpoint: {checkpoint_path}. "
"Pass --allow-random-head only for smoke tests; random logits are not meaningful."
)
return model
@torch.no_grad()
def predict_paths(
model: ProtoMorphDINOv3,
image_paths: Sequence[str | Path],
labels: List[str],
topk: int = 5,
device: str = "cuda",
force_hard: bool = False,
) -> List[Dict]:
images = [load_image(p) for p in image_paths]
out = model(images, device=device, force_hard=force_hard)
probs = out["logits"].softmax(dim=-1).float().cpu()
main_probs = out["main_logits"].softmax(dim=-1).float().cpu()
hard_mask = out["hard_mask"].cpu().tolist()
gate_pmax = out["gate_pmax"].float().cpu().tolist()
gate_margin = out["gate_margin"].float().cpu().tolist()
gate_entropy = out["gate_entropy"].float().cpu().tolist()
results: List[Dict] = []
for i, path in enumerate(image_paths):
k = min(topk, probs.shape[-1])
values, indices = probs[i].topk(k)
main_values, main_indices = main_probs[i].topk(k)
results.append(
{
"image": str(path),
"hard_case": bool(hard_mask[i]),
"gate": {
"pmax": float(gate_pmax[i]),
"margin": float(gate_margin[i]),
"entropy": float(gate_entropy[i]),
},
"topk": [
{"rank": r + 1, "class_id": int(idx), "label": labels[int(idx)], "prob": float(val)}
for r, (idx, val) in enumerate(zip(indices.tolist(), values.tolist()))
],
"main_topk": [
{"rank": r + 1, "class_id": int(idx), "label": labels[int(idx)], "prob": float(val)}
for r, (idx, val) in enumerate(zip(main_indices.tolist(), main_values.tolist()))
],
"patch_hw": out["patch_hw"],
"pixel_hw": out["pixel_hw"],
}
)
return results