#!/usr/bin/env python3 """Standalone DINOv3 script classifier inference (copied to Hub model repo as ``inference.py``).""" from __future__ import annotations import argparse from pathlib import Path import torch import torch.nn as nn from PIL import Image from transformers import AutoImageProcessor, AutoModel DINOV3_MODEL_ID = "facebook/dinov3-vits16-pretrain-lvd1689m" class DINOv3Classifier(nn.Module): def __init__(self, model_id: str, num_classes: int, dropout: float = 0.1): super().__init__() self.backbone = AutoModel.from_pretrained(model_id) hidden = self.backbone.config.hidden_size self.head = nn.Sequential( nn.LayerNorm(hidden), nn.Dropout(dropout), nn.Linear(hidden, 128), nn.GELU(), nn.Dropout(dropout), nn.Linear(128, num_classes), ) def forward(self, pixel_values): out = self.backbone(pixel_values=pixel_values) cls = out.last_hidden_state[:, 0, :] return self.head(cls) def _resize_short_edge(img: Image.Image, target: int) -> Image.Image: w, h = img.size if h <= w: new_h = target new_w = max(1, int(w * target / h)) else: new_w = target new_h = max(1, int(h * target / w)) return img.resize((new_w, new_h), Image.BICUBIC) def _center_crop(img: Image.Image, size: int = 224) -> Image.Image: img = _resize_short_edge(img, size) w, h = img.size left = max(0, (w - size) // 2) top = max(0, (h - size) // 2) crop = img.crop((left, top, left + size, top + size)) if crop.size != (size, size): padded = Image.new("RGB", (size, size), (255, 255, 255)) padded.paste(crop, (0, 0)) return padded return crop def apply_preprocess(img: Image.Image, mode: str | None, *, size: int = 224) -> Image.Image: if not mode or mode == "none": return img if mode in ("center_crop", "center_crop_whole_page"): return _center_crop(img, size) raise ValueError(f"Unknown preprocess mode: {mode!r}") def label_order(ckpt: dict) -> list[str]: idx = ckpt.get("idx_to_label") or {} if idx: return [str(idx[k]) for k in sorted(idx.keys(), key=lambda x: int(x))] raw = ckpt.get("label_to_idx") or {} if raw: return sorted(raw.keys(), key=lambda k: raw[k]) raise KeyError("checkpoint missing idx_to_label / label_to_idx") @torch.no_grad() def predict(model, processor, image_path: Path, device, *, preprocess: str | None, size: int): img = Image.open(image_path).convert("RGB") img = apply_preprocess(img, preprocess, size=size) pv = processor(images=img, return_tensors="pt")["pixel_values"].to(device) logits = model(pv) probs = torch.softmax(logits, dim=1).squeeze(0).cpu() pred = int(probs.argmax()) return pred, probs.tolist() def main() -> None: ap = argparse.ArgumentParser(description="DINOv3 Tibetan script page classifier") ap.add_argument( "--checkpoint", type=Path, default=Path("final_model.pt"), help="Weights file (default: final_model.pt in cwd)", ) ap.add_argument("--image", type=Path, nargs="+", required=True) ap.add_argument( "--preprocess", default="none", help="none (full page, DINO processor resize) | center_crop (224 short edge)", ) ap.add_argument("--preprocess-size", type=int, default=224) ap.add_argument("--model-id", default=DINOV3_MODEL_ID) args = ap.parse_args() ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False) classes = label_order(ckpt) idx_to_label = {i: lab for i, lab in enumerate(classes)} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = DINOv3Classifier(args.model_id, num_classes=len(classes)).to(device) model.load_state_dict(ckpt["model_state_dict"]) model.eval() processor = AutoImageProcessor.from_pretrained(args.model_id) prep = None if args.preprocess in ("none", "") else args.preprocess for path in args.image: pred, probs = predict( model, processor, path, device, preprocess=prep, size=args.preprocess_size ) name = idx_to_label[pred] conf = probs[pred] print(f"{path.name}: {name} ({conf:.3f})") top3 = sorted(enumerate(probs), key=lambda x: -x[1])[:3] for i, p in top3: print(f" {idx_to_label[i]:16s} {p:.3f}") if __name__ == "__main__": main()