shahabdaiani's picture
updated the model path
9396996 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import gradio as gr
from transformers import AutoModel
# === Path to model file (in same folder) ===
path_to_model = "dino2_classifier_cropped_body_v1.pth"
# === Padding function ===
def pad_to_square(image, fill=0):
w, h = image.size
max_dim = max(w, h)
padded = Image.new(image.mode, (max_dim, max_dim), fill)
padded.paste(image, ((max_dim - w) // 2, (max_dim - h) // 2))
return padded
# === Transform setup ===
def setup_transform(use_padding=True, use_augmentation=False):
base_transforms = []
if use_padding:
base_transforms.append(lambda img: pad_to_square(img))
base_transforms.append(transforms.Resize(224))
else:
base_transforms.extend([
transforms.Resize(256),
transforms.CenterCrop(224)
])
augmentation_transforms = []
if use_augmentation:
augmentation_transforms.extend([
transforms.RandomRotation(degrees=10),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5)),
])
final_transforms = [
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
]
return transforms.Compose(base_transforms + augmentation_transforms + final_transforms)
# === Custom DINOv2 classifier ===
class DINOv2ArcFace(nn.Module):
def __init__(self, usage='classifier', num_classes=33, embedding_dim=512, margin=0.5, scale=64.0):
super().__init__()
self.usage = usage
self.num_classes = num_classes
self.embedding_dim = embedding_dim
self.margin = margin
self.scale = scale
self.dropout = nn.Dropout(p=0.5)
self.backbone = AutoModel.from_pretrained("facebook/dinov2-base")
if self.usage == 'finetune':
self.embedding = nn.Linear(self.backbone.config.hidden_size, self.embedding_dim)
elif self.usage == 'classifier':
self.backbone.requires_grad_(False)
in_features = self.backbone.config.hidden_size
self.classifier = nn.Linear(in_features, self.num_classes)
elif self.usage == 'embeddings':
self.embedding = nn.Linear(self.backbone.config.hidden_size, self.embedding_dim)
def forward(self, x, labels=None):
features = self.backbone(x).last_hidden_state[:, 0, :] # CLS token
if self.usage == 'classifier':
features = self.dropout(features)
logits = self.classifier(features)
return logits
elif self.usage == 'embeddings':
embeddings = F.normalize(self.embedding(features), p=2, dim=1)
return embeddings
else:
raise ValueError("Use mode 'classifier' or 'embeddings' for inference")
# === Load model ===
NUM_CLASSES = 33
model = DINOv2ArcFace(usage="classifier", num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(path_to_model, map_location="cpu"))
model.eval()
# === Class mapping ===
class_names = {
0: 'Abril', 1: 'Akaloi', 2: 'Alira', 3: 'Apeiara', 4: 'Ariely', 5: 'Bagua', 6: 'Benita', 7: 'Bernard', 8: 'Bororo',
9: 'Estella', 10: 'Guaraci', 11: 'Ipepo', 12: 'Jaju', 13: 'Kamaikua', 14: 'Kasimir', 15: 'Katniss', 16: 'Kwang',
17: 'Lua', 18: 'Madalena', 19: 'Marcela', 20: 'Medrosa', 21: 'Ousado', 22: 'Overa', 23: 'Oxum', 24: 'Patricia',
25: 'Pixana', 26: 'Pollyanna', 27: 'Pyte', 28: 'Saseka', 29: 'Solar', 30: 'Ti', 31: 'Tomas', 32: 'unknown'
}
# === Apply your transform ===
transform = setup_transform(use_padding=True, use_augmentation=False)
# === Gradio prediction function ===
def predict(image):
image = image.convert("RGB")
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
logits = model(img_tensor)
probs = torch.nn.functional.softmax(logits[0], dim=0)
return {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
# === Gradio UI ===
gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="🐆 Jaguar Identifier (DINOv2 + ArcFace)",
description="Upload an image of a jaguar. The model will classify it among 33 known individuals."
).launch()