# navigation_scripts/pose_classifier.py import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms as T from pathlib import Path import numpy as np from PIL import Image # Must match train_pose_classifier.py POSE_CLASSES = ['front', 'front-left', 'front-right', 'left', 'right', 'back-left', 'back-right', 'back'] NUM_CLASSES = len(POSE_CLASSES) DINO_MODELS = { 'small': ('dinov2_vits14', 384), 'base': ('dinov2_vitb14', 768), 'large': ('dinov2_vitl14', 1024), } class _PoseClassifierModel(nn.Module): """DINOv2 + MLP head for 8-class pose classification (mirrors train_pose_classifier.PoseClassifier).""" def __init__(self, model_size='small', dropout=0.3): super().__init__() model_name, feat_dim = DINO_MODELS[model_size] self.backbone = torch.hub.load('facebookresearch/dinov2', model_name) for param in self.backbone.parameters(): param.requires_grad = False self.backbone.eval() self.head = nn.Sequential( nn.LayerNorm(feat_dim), nn.Linear(feat_dim, 256), nn.GELU(), nn.Dropout(dropout), nn.Linear(256, 128), nn.GELU(), nn.Dropout(dropout), nn.Linear(128, NUM_CLASSES), ) def forward(self, x): with torch.no_grad(): features = self.backbone(x) return self.head(features) class ViewPointClassifier: """ Predicts one of 8 canonical zebra viewpoints: front, front-left, front-right, left, right, back-left, back-right, back Uses a DINOv2-small backbone (frozen) with a trained MLP head. __call__(crops) → list[str] Each crop is a PIL.Image (RGB). Returns the predicted pose label. """ LABELS = POSE_CLASSES def _to_pil(self, img): """Accept PIL.Image | np.ndarray | torch.Tensor -> PIL.Image (RGB).""" if isinstance(img, Image.Image): return img.convert("RGB") if isinstance(img, np.ndarray): if img.ndim == 3 and img.shape[2] == 3: img = img[..., ::-1] # BGR → RGB return Image.fromarray(img) if torch.is_tensor(img): return T.ToPILImage()(img.cpu()) raise TypeError(f"Unsupported crop type {type(img)}") def __init__( self, weight_path="checkpoints/best_pose_model.pth", model_size: str = "small", device: str = "cpu", ): self.device = torch.device(device) # Build the same architecture used in training self.model = _PoseClassifierModel(model_size=model_size) # Load checkpoint (saved by train_pose_classifier.py) ckpt = torch.load(weight_path, map_location=self.device) self.model.load_state_dict(ckpt['model_state_dict']) self.model.eval().to(self.device) # Match the validation transforms from training self.tf = T.Compose( [ T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) @torch.inference_mode() def __call__(self, crops): """ Parameters ---------- crops : list[PIL.Image] One crop per detection. Returns ------- list[str] Predicted pose label for each crop, e.g. 'front', 'back-left'. """ if not crops: return [] pil_crops = [self._to_pil(c) for c in crops] batch = torch.stack([self.tf(c) for c in pil_crops]).to(self.device) logits = self.model(batch) # shape [N, 8] preds = torch.argmax(logits, dim=-1).cpu() # single-label return [self.LABELS[i] for i in preds] # ───────── quick sanity check ───────── if __name__ == "__main__": from PIL import Image import random img_dir = Path("some/test/crops") # directory of zebra chip .jpgs samples = [Image.open(p) for p in random.sample(list(img_dir.glob("*.jpg")), 4)] clf = ViewPointClassifier(device="cpu") print(clf(samples))