| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import transforms as T |
| from pathlib import Path |
| import numpy as np |
| from PIL import Image |
|
|
|
|
| |
| POSE_CLASSES = ['front', 'front-left', 'front-right', 'left', 'right', 'back-left', 'back-right', 'back'] |
| NUM_CLASSES = len(POSE_CLASSES) |
|
|
| DINO_MODELS = { |
| 'small': ('dinov2_vits14', 384), |
| 'base': ('dinov2_vitb14', 768), |
| 'large': ('dinov2_vitl14', 1024), |
| } |
|
|
|
|
| class _PoseClassifierModel(nn.Module): |
| """DINOv2 + MLP head for 8-class pose classification (mirrors train_pose_classifier.PoseClassifier).""" |
|
|
| def __init__(self, model_size='small', dropout=0.3): |
| super().__init__() |
| model_name, feat_dim = DINO_MODELS[model_size] |
|
|
| self.backbone = torch.hub.load('facebookresearch/dinov2', model_name) |
| for param in self.backbone.parameters(): |
| param.requires_grad = False |
| self.backbone.eval() |
|
|
| self.head = nn.Sequential( |
| nn.LayerNorm(feat_dim), |
| nn.Linear(feat_dim, 256), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(256, 128), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(128, NUM_CLASSES), |
| ) |
|
|
| def forward(self, x): |
| with torch.no_grad(): |
| features = self.backbone(x) |
| return self.head(features) |
|
|
|
|
| class ViewPointClassifier: |
| """ |
| Predicts one of 8 canonical zebra viewpoints: |
| front, front-left, front-right, left, right, back-left, back-right, back |
| |
| Uses a DINOv2-small backbone (frozen) with a trained MLP head. |
| |
| __call__(crops) β list[str] |
| Each crop is a PIL.Image (RGB). Returns the predicted pose label. |
| """ |
| LABELS = POSE_CLASSES |
|
|
| def _to_pil(self, img): |
| """Accept PIL.Image | np.ndarray | torch.Tensor -> PIL.Image (RGB).""" |
| if isinstance(img, Image.Image): |
| return img.convert("RGB") |
|
|
| if isinstance(img, np.ndarray): |
| if img.ndim == 3 and img.shape[2] == 3: |
| img = img[..., ::-1] |
| return Image.fromarray(img) |
|
|
| if torch.is_tensor(img): |
| return T.ToPILImage()(img.cpu()) |
|
|
| raise TypeError(f"Unsupported crop type {type(img)}") |
|
|
| def __init__( |
| self, |
| weight_path="checkpoints/best_pose_model.pth", |
| model_size: str = "small", |
| device: str = "cpu", |
| ): |
| self.device = torch.device(device) |
|
|
| |
| self.model = _PoseClassifierModel(model_size=model_size) |
|
|
| |
| ckpt = torch.load(weight_path, map_location=self.device) |
| self.model.load_state_dict(ckpt['model_state_dict']) |
| self.model.eval().to(self.device) |
|
|
| |
| self.tf = T.Compose( |
| [ |
| T.Resize(256), |
| T.CenterCrop(224), |
| T.ToTensor(), |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ] |
| ) |
|
|
| @torch.inference_mode() |
| def __call__(self, crops): |
| """ |
| Parameters |
| ---------- |
| crops : list[PIL.Image] |
| One crop per detection. |
| |
| Returns |
| ------- |
| list[str] |
| Predicted pose label for each crop, e.g. 'front', 'back-left'. |
| """ |
| if not crops: |
| return [] |
| pil_crops = [self._to_pil(c) for c in crops] |
| batch = torch.stack([self.tf(c) for c in pil_crops]).to(self.device) |
| logits = self.model(batch) |
| preds = torch.argmax(logits, dim=-1).cpu() |
| return [self.LABELS[i] for i in preds] |
|
|
| |
| if __name__ == "__main__": |
| from PIL import Image |
| import random |
|
|
| img_dir = Path("some/test/crops") |
| samples = [Image.open(p) for p in random.sample(list(img_dir.glob("*.jpg")), 4)] |
|
|
| clf = ViewPointClassifier(device="cpu") |
| print(clf(samples)) |
|
|