Image Classification
Transformers
Tibetan
tibetan
script-classification
dinov3
binary
File size: 4,524 Bytes
e50372c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
"""Standalone DINOv3 script classifier inference (copied to Hub model repo as ``inference.py``)."""

from __future__ import annotations

import argparse
from pathlib import Path

import torch
import torch.nn as nn
from PIL import Image
from transformers import AutoImageProcessor, AutoModel

DINOV3_MODEL_ID = "facebook/dinov3-vits16-pretrain-lvd1689m"


class DINOv3Classifier(nn.Module):
    def __init__(self, model_id: str, num_classes: int, dropout: float = 0.1):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_id)
        hidden = self.backbone.config.hidden_size
        self.head = nn.Sequential(
            nn.LayerNorm(hidden),
            nn.Dropout(dropout),
            nn.Linear(hidden, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes),
        )

    def forward(self, pixel_values):
        out = self.backbone(pixel_values=pixel_values)
        cls = out.last_hidden_state[:, 0, :]
        return self.head(cls)


def _resize_short_edge(img: Image.Image, target: int) -> Image.Image:
    w, h = img.size
    if h <= w:
        new_h = target
        new_w = max(1, int(w * target / h))
    else:
        new_w = target
        new_h = max(1, int(h * target / w))
    return img.resize((new_w, new_h), Image.BICUBIC)


def _center_crop(img: Image.Image, size: int = 224) -> Image.Image:
    img = _resize_short_edge(img, size)
    w, h = img.size
    left = max(0, (w - size) // 2)
    top = max(0, (h - size) // 2)
    crop = img.crop((left, top, left + size, top + size))
    if crop.size != (size, size):
        padded = Image.new("RGB", (size, size), (255, 255, 255))
        padded.paste(crop, (0, 0))
        return padded
    return crop


def apply_preprocess(img: Image.Image, mode: str | None, *, size: int = 224) -> Image.Image:
    if not mode or mode == "none":
        return img
    if mode in ("center_crop", "center_crop_whole_page"):
        return _center_crop(img, size)
    raise ValueError(f"Unknown preprocess mode: {mode!r}")


def label_order(ckpt: dict) -> list[str]:
    idx = ckpt.get("idx_to_label") or {}
    if idx:
        return [str(idx[k]) for k in sorted(idx.keys(), key=lambda x: int(x))]
    raw = ckpt.get("label_to_idx") or {}
    if raw:
        return sorted(raw.keys(), key=lambda k: raw[k])
    raise KeyError("checkpoint missing idx_to_label / label_to_idx")


@torch.no_grad()
def predict(model, processor, image_path: Path, device, *, preprocess: str | None, size: int):
    img = Image.open(image_path).convert("RGB")
    img = apply_preprocess(img, preprocess, size=size)
    pv = processor(images=img, return_tensors="pt")["pixel_values"].to(device)
    logits = model(pv)
    probs = torch.softmax(logits, dim=1).squeeze(0).cpu()
    pred = int(probs.argmax())
    return pred, probs.tolist()


def main() -> None:
    ap = argparse.ArgumentParser(description="DINOv3 Tibetan script page classifier")
    ap.add_argument(
        "--checkpoint",
        type=Path,
        default=Path("final_model.pt"),
        help="Weights file (default: final_model.pt in cwd)",
    )
    ap.add_argument("--image", type=Path, nargs="+", required=True)
    ap.add_argument(
        "--preprocess",
        default="none",
        help="none (full page, DINO processor resize) | center_crop (224 short edge)",
    )
    ap.add_argument("--preprocess-size", type=int, default=224)
    ap.add_argument("--model-id", default=DINOV3_MODEL_ID)
    args = ap.parse_args()

    ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
    classes = label_order(ckpt)
    idx_to_label = {i: lab for i, lab in enumerate(classes)}

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = DINOv3Classifier(args.model_id, num_classes=len(classes)).to(device)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()
    processor = AutoImageProcessor.from_pretrained(args.model_id)
    prep = None if args.preprocess in ("none", "") else args.preprocess

    for path in args.image:
        pred, probs = predict(
            model, processor, path, device, preprocess=prep, size=args.preprocess_size
        )
        name = idx_to_label[pred]
        conf = probs[pred]
        print(f"{path.name}: {name} ({conf:.3f})")
        top3 = sorted(enumerate(probs), key=lambda x: -x[1])[:3]
        for i, p in top3:
            print(f"  {idx_to_label[i]:16s} {p:.3f}")


if __name__ == "__main__":
    main()