| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import transforms |
| from PIL import Image |
| import gradio as gr |
| from transformers import AutoModel |
|
|
| |
| path_to_model = "dino2_classifier_cropped_body_v1.pth" |
|
|
| |
| def pad_to_square(image, fill=0): |
| w, h = image.size |
| max_dim = max(w, h) |
| padded = Image.new(image.mode, (max_dim, max_dim), fill) |
| padded.paste(image, ((max_dim - w) // 2, (max_dim - h) // 2)) |
| return padded |
|
|
| |
| def setup_transform(use_padding=True, use_augmentation=False): |
| base_transforms = [] |
| if use_padding: |
| base_transforms.append(lambda img: pad_to_square(img)) |
| base_transforms.append(transforms.Resize(224)) |
| else: |
| base_transforms.extend([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224) |
| ]) |
| augmentation_transforms = [] |
| if use_augmentation: |
| augmentation_transforms.extend([ |
| transforms.RandomRotation(degrees=10), |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1), |
| transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), |
| transforms.RandomResizedCrop(224, scale=(0.9, 1.0)), |
| transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5)), |
| ]) |
| final_transforms = [ |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225] |
| ) |
| ] |
| return transforms.Compose(base_transforms + augmentation_transforms + final_transforms) |
|
|
| |
| class DINOv2ArcFace(nn.Module): |
| def __init__(self, usage='classifier', num_classes=33, embedding_dim=512, margin=0.5, scale=64.0): |
| super().__init__() |
| self.usage = usage |
| self.num_classes = num_classes |
| self.embedding_dim = embedding_dim |
| self.margin = margin |
| self.scale = scale |
| self.dropout = nn.Dropout(p=0.5) |
| self.backbone = AutoModel.from_pretrained("facebook/dinov2-base") |
| if self.usage == 'finetune': |
| self.embedding = nn.Linear(self.backbone.config.hidden_size, self.embedding_dim) |
| elif self.usage == 'classifier': |
| self.backbone.requires_grad_(False) |
| in_features = self.backbone.config.hidden_size |
| self.classifier = nn.Linear(in_features, self.num_classes) |
| elif self.usage == 'embeddings': |
| self.embedding = nn.Linear(self.backbone.config.hidden_size, self.embedding_dim) |
|
|
| def forward(self, x, labels=None): |
| features = self.backbone(x).last_hidden_state[:, 0, :] |
| if self.usage == 'classifier': |
| features = self.dropout(features) |
| logits = self.classifier(features) |
| return logits |
| elif self.usage == 'embeddings': |
| embeddings = F.normalize(self.embedding(features), p=2, dim=1) |
| return embeddings |
| else: |
| raise ValueError("Use mode 'classifier' or 'embeddings' for inference") |
|
|
| |
| NUM_CLASSES = 33 |
| model = DINOv2ArcFace(usage="classifier", num_classes=NUM_CLASSES) |
| model.load_state_dict(torch.load(path_to_model, map_location="cpu")) |
| model.eval() |
|
|
| |
| class_names = { |
| 0: 'Abril', 1: 'Akaloi', 2: 'Alira', 3: 'Apeiara', 4: 'Ariely', 5: 'Bagua', 6: 'Benita', 7: 'Bernard', 8: 'Bororo', |
| 9: 'Estella', 10: 'Guaraci', 11: 'Ipepo', 12: 'Jaju', 13: 'Kamaikua', 14: 'Kasimir', 15: 'Katniss', 16: 'Kwang', |
| 17: 'Lua', 18: 'Madalena', 19: 'Marcela', 20: 'Medrosa', 21: 'Ousado', 22: 'Overa', 23: 'Oxum', 24: 'Patricia', |
| 25: 'Pixana', 26: 'Pollyanna', 27: 'Pyte', 28: 'Saseka', 29: 'Solar', 30: 'Ti', 31: 'Tomas', 32: 'unknown' |
| } |
|
|
| |
| transform = setup_transform(use_padding=True, use_augmentation=False) |
|
|
| |
| def predict(image): |
| image = image.convert("RGB") |
| img_tensor = transform(image).unsqueeze(0) |
| with torch.no_grad(): |
| logits = model(img_tensor) |
| probs = torch.nn.functional.softmax(logits[0], dim=0) |
| return {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)} |
|
|
| |
| gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="pil"), |
| outputs=gr.Label(num_top_classes=5), |
| title="🐆 Jaguar Identifier (DINOv2 + ArcFace)", |
| description="Upload an image of a jaguar. The model will classify it among 33 known individuals." |
| ).launch() |
|
|