mmla-dino-pose / pose_classifier.py
jennamk14's picture
Add README, training/inference code, and trained DINOv2-small pose-classifier checkpoint
2be7251 verified
raw
history blame
4.25 kB
# navigation_scripts/pose_classifier.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms as T
from pathlib import Path
import numpy as np
from PIL import Image
# Must match train_pose_classifier.py
POSE_CLASSES = ['front', 'front-left', 'front-right', 'left', 'right', 'back-left', 'back-right', 'back']
NUM_CLASSES = len(POSE_CLASSES)
DINO_MODELS = {
'small': ('dinov2_vits14', 384),
'base': ('dinov2_vitb14', 768),
'large': ('dinov2_vitl14', 1024),
}
class _PoseClassifierModel(nn.Module):
"""DINOv2 + MLP head for 8-class pose classification (mirrors train_pose_classifier.PoseClassifier)."""
def __init__(self, model_size='small', dropout=0.3):
super().__init__()
model_name, feat_dim = DINO_MODELS[model_size]
self.backbone = torch.hub.load('facebookresearch/dinov2', model_name)
for param in self.backbone.parameters():
param.requires_grad = False
self.backbone.eval()
self.head = nn.Sequential(
nn.LayerNorm(feat_dim),
nn.Linear(feat_dim, 256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(128, NUM_CLASSES),
)
def forward(self, x):
with torch.no_grad():
features = self.backbone(x)
return self.head(features)
class ViewPointClassifier:
"""
Predicts one of 8 canonical zebra viewpoints:
front, front-left, front-right, left, right, back-left, back-right, back
Uses a DINOv2-small backbone (frozen) with a trained MLP head.
__call__(crops) β†’ list[str]
Each crop is a PIL.Image (RGB). Returns the predicted pose label.
"""
LABELS = POSE_CLASSES
def _to_pil(self, img):
"""Accept PIL.Image | np.ndarray | torch.Tensor -> PIL.Image (RGB)."""
if isinstance(img, Image.Image):
return img.convert("RGB")
if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] == 3:
img = img[..., ::-1] # BGR β†’ RGB
return Image.fromarray(img)
if torch.is_tensor(img):
return T.ToPILImage()(img.cpu())
raise TypeError(f"Unsupported crop type {type(img)}")
def __init__(
self,
weight_path="checkpoints/best_pose_model.pth",
model_size: str = "small",
device: str = "cpu",
):
self.device = torch.device(device)
# Build the same architecture used in training
self.model = _PoseClassifierModel(model_size=model_size)
# Load checkpoint (saved by train_pose_classifier.py)
ckpt = torch.load(weight_path, map_location=self.device)
self.model.load_state_dict(ckpt['model_state_dict'])
self.model.eval().to(self.device)
# Match the validation transforms from training
self.tf = T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@torch.inference_mode()
def __call__(self, crops):
"""
Parameters
----------
crops : list[PIL.Image]
One crop per detection.
Returns
-------
list[str]
Predicted pose label for each crop, e.g. 'front', 'back-left'.
"""
if not crops:
return []
pil_crops = [self._to_pil(c) for c in crops]
batch = torch.stack([self.tf(c) for c in pil_crops]).to(self.device)
logits = self.model(batch) # shape [N, 8]
preds = torch.argmax(logits, dim=-1).cpu() # single-label
return [self.LABELS[i] for i in preds]
# ───────── quick sanity check ─────────
if __name__ == "__main__":
from PIL import Image
import random
img_dir = Path("some/test/crops") # directory of zebra chip .jpgs
samples = [Image.open(p) for p in random.sample(list(img_dir.glob("*.jpg")), 4)]
clf = ViewPointClassifier(device="cpu")
print(clf(samples))