File size: 3,630 Bytes
63089c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence

import torch
from PIL import Image

from .config import ProtoMorphConfig
from .model import ProtoMorphDINOv3


def load_image(path: str | Path) -> Image.Image:
    return Image.open(path).convert("RGB")


def load_labels(path: Optional[str | Path], num_classes: int) -> List[str]:
    if path is None:
        return [f"class_{i}" for i in range(num_classes)]
    p = Path(path)
    if p.suffix.lower() == ".json":
        data = json.loads(p.read_text())
        if isinstance(data, dict):
            return [data.get(str(i), data.get(i, f"class_{i}")) for i in range(num_classes)]
        return list(data)
    labels = [line.strip() for line in p.read_text().splitlines() if line.strip()]
    if len(labels) < num_classes:
        labels += [f"class_{i}" for i in range(len(labels), num_classes)]
    return labels[:num_classes]


def build_model(
    config_path: str | Path,
    checkpoint_path: Optional[str | Path],
    device: str = "cuda",
    local_files_only: bool = False,
    allow_random_head: bool = False,
) -> ProtoMorphDINOv3:
    cfg = ProtoMorphConfig.from_json(config_path)
    device_obj = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
    model = ProtoMorphDINOv3(cfg, local_files_only=local_files_only).to(device_obj).eval()
    if checkpoint_path is not None and Path(checkpoint_path).exists():
        model.load_custom_head(checkpoint_path)
    elif not allow_random_head:
        raise FileNotFoundError(
            f"Missing custom-head checkpoint: {checkpoint_path}. "
            "Pass --allow-random-head only for smoke tests; random logits are not meaningful."
        )
    return model


@torch.no_grad()
def predict_paths(
    model: ProtoMorphDINOv3,
    image_paths: Sequence[str | Path],
    labels: List[str],
    topk: int = 5,
    device: str = "cuda",
    force_hard: bool = False,
) -> List[Dict]:
    images = [load_image(p) for p in image_paths]
    out = model(images, device=device, force_hard=force_hard)
    probs = out["logits"].softmax(dim=-1).float().cpu()
    main_probs = out["main_logits"].softmax(dim=-1).float().cpu()
    hard_mask = out["hard_mask"].cpu().tolist()
    gate_pmax = out["gate_pmax"].float().cpu().tolist()
    gate_margin = out["gate_margin"].float().cpu().tolist()
    gate_entropy = out["gate_entropy"].float().cpu().tolist()

    results: List[Dict] = []
    for i, path in enumerate(image_paths):
        k = min(topk, probs.shape[-1])
        values, indices = probs[i].topk(k)
        main_values, main_indices = main_probs[i].topk(k)
        results.append(
            {
                "image": str(path),
                "hard_case": bool(hard_mask[i]),
                "gate": {
                    "pmax": float(gate_pmax[i]),
                    "margin": float(gate_margin[i]),
                    "entropy": float(gate_entropy[i]),
                },
                "topk": [
                    {"rank": r + 1, "class_id": int(idx), "label": labels[int(idx)], "prob": float(val)}
                    for r, (idx, val) in enumerate(zip(indices.tolist(), values.tolist()))
                ],
                "main_topk": [
                    {"rank": r + 1, "class_id": int(idx), "label": labels[int(idx)], "prob": float(val)}
                    for r, (idx, val) in enumerate(zip(main_indices.tolist(), main_values.tolist()))
                ],
                "patch_hw": out["patch_hw"],
                "pixel_hw": out["pixel_hw"],
            }
        )
    return results