Image Classification
Transformers
Tibetan
tibetan
script-classification
dinov3
binary
karma689's picture
Add Gyuyig vs Tsugdri binary classifier weights, metrics, and training history
e50372c verified
Raw
History Blame Contribute Delete
4.52 kB
#!/usr/bin/env python3
"""Standalone DINOv3 script classifier inference (copied to Hub model repo as ``inference.py``)."""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
DINOV3_MODEL_ID = "facebook/dinov3-vits16-pretrain-lvd1689m"
class DINOv3Classifier(nn.Module):
def __init__(self, model_id: str, num_classes: int, dropout: float = 0.1):
super().__init__()
self.backbone = AutoModel.from_pretrained(model_id)
hidden = self.backbone.config.hidden_size
self.head = nn.Sequential(
nn.LayerNorm(hidden),
nn.Dropout(dropout),
nn.Linear(hidden, 128),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(128, num_classes),
)
def forward(self, pixel_values):
out = self.backbone(pixel_values=pixel_values)
cls = out.last_hidden_state[:, 0, :]
return self.head(cls)
def _resize_short_edge(img: Image.Image, target: int) -> Image.Image:
w, h = img.size
if h <= w:
new_h = target
new_w = max(1, int(w * target / h))
else:
new_w = target
new_h = max(1, int(h * target / w))
return img.resize((new_w, new_h), Image.BICUBIC)
def _center_crop(img: Image.Image, size: int = 224) -> Image.Image:
img = _resize_short_edge(img, size)
w, h = img.size
left = max(0, (w - size) // 2)
top = max(0, (h - size) // 2)
crop = img.crop((left, top, left + size, top + size))
if crop.size != (size, size):
padded = Image.new("RGB", (size, size), (255, 255, 255))
padded.paste(crop, (0, 0))
return padded
return crop
def apply_preprocess(img: Image.Image, mode: str | None, *, size: int = 224) -> Image.Image:
if not mode or mode == "none":
return img
if mode in ("center_crop", "center_crop_whole_page"):
return _center_crop(img, size)
raise ValueError(f"Unknown preprocess mode: {mode!r}")
def label_order(ckpt: dict) -> list[str]:
idx = ckpt.get("idx_to_label") or {}
if idx:
return [str(idx[k]) for k in sorted(idx.keys(), key=lambda x: int(x))]
raw = ckpt.get("label_to_idx") or {}
if raw:
return sorted(raw.keys(), key=lambda k: raw[k])
raise KeyError("checkpoint missing idx_to_label / label_to_idx")
@torch.no_grad()
def predict(model, processor, image_path: Path, device, *, preprocess: str | None, size: int):
img = Image.open(image_path).convert("RGB")
img = apply_preprocess(img, preprocess, size=size)
pv = processor(images=img, return_tensors="pt")["pixel_values"].to(device)
logits = model(pv)
probs = torch.softmax(logits, dim=1).squeeze(0).cpu()
pred = int(probs.argmax())
return pred, probs.tolist()
def main() -> None:
ap = argparse.ArgumentParser(description="DINOv3 Tibetan script page classifier")
ap.add_argument(
"--checkpoint",
type=Path,
default=Path("final_model.pt"),
help="Weights file (default: final_model.pt in cwd)",
)
ap.add_argument("--image", type=Path, nargs="+", required=True)
ap.add_argument(
"--preprocess",
default="none",
help="none (full page, DINO processor resize) | center_crop (224 short edge)",
)
ap.add_argument("--preprocess-size", type=int, default=224)
ap.add_argument("--model-id", default=DINOV3_MODEL_ID)
args = ap.parse_args()
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
classes = label_order(ckpt)
idx_to_label = {i: lab for i, lab in enumerate(classes)}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DINOv3Classifier(args.model_id, num_classes=len(classes)).to(device)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
processor = AutoImageProcessor.from_pretrained(args.model_id)
prep = None if args.preprocess in ("none", "") else args.preprocess
for path in args.image:
pred, probs = predict(
model, processor, path, device, preprocess=prep, size=args.preprocess_size
)
name = idx_to_label[pred]
conf = probs[pred]
print(f"{path.name}: {name} ({conf:.3f})")
top3 = sorted(enumerate(probs), key=lambda x: -x[1])[:3]
for i, p in top3:
print(f" {idx_to_label[i]:16s} {p:.3f}")
if __name__ == "__main__":
main()