File size: 4,254 Bytes
58b3e34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | # navigation_scripts/pose_classifier.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms as T
from pathlib import Path
import numpy as np
from PIL import Image
# Must match train_pose_classifier.py
POSE_CLASSES = ['front', 'front-left', 'front-right', 'left', 'right', 'back-left', 'back-right', 'back']
NUM_CLASSES = len(POSE_CLASSES)
DINO_MODELS = {
'small': ('dinov2_vits14', 384),
'base': ('dinov2_vitb14', 768),
'large': ('dinov2_vitl14', 1024),
}
class _PoseClassifierModel(nn.Module):
"""DINOv2 + MLP head for 8-class pose classification (mirrors train_pose_classifier.PoseClassifier)."""
def __init__(self, model_size='small', dropout=0.3):
super().__init__()
model_name, feat_dim = DINO_MODELS[model_size]
self.backbone = torch.hub.load('facebookresearch/dinov2', model_name)
for param in self.backbone.parameters():
param.requires_grad = False
self.backbone.eval()
self.head = nn.Sequential(
nn.LayerNorm(feat_dim),
nn.Linear(feat_dim, 256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(128, NUM_CLASSES),
)
def forward(self, x):
with torch.no_grad():
features = self.backbone(x)
return self.head(features)
class ViewPointClassifier:
"""
Predicts one of 8 canonical zebra viewpoints:
front, front-left, front-right, left, right, back-left, back-right, back
Uses a DINOv2-small backbone (frozen) with a trained MLP head.
__call__(crops) β list[str]
Each crop is a PIL.Image (RGB). Returns the predicted pose label.
"""
LABELS = POSE_CLASSES
def _to_pil(self, img):
"""Accept PIL.Image | np.ndarray | torch.Tensor -> PIL.Image (RGB)."""
if isinstance(img, Image.Image):
return img.convert("RGB")
if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] == 3:
img = img[..., ::-1] # BGR β RGB
return Image.fromarray(img)
if torch.is_tensor(img):
return T.ToPILImage()(img.cpu())
raise TypeError(f"Unsupported crop type {type(img)}")
def __init__(
self,
weight_path="checkpoints/best_pose_model.pth",
model_size: str = "small",
device: str = "cpu",
):
self.device = torch.device(device)
# Build the same architecture used in training
self.model = _PoseClassifierModel(model_size=model_size)
# Load checkpoint (saved by train_pose_classifier.py)
ckpt = torch.load(weight_path, map_location=self.device)
self.model.load_state_dict(ckpt['model_state_dict'])
self.model.eval().to(self.device)
# Match the validation transforms from training
self.tf = T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@torch.inference_mode()
def __call__(self, crops):
"""
Parameters
----------
crops : list[PIL.Image]
One crop per detection.
Returns
-------
list[str]
Predicted pose label for each crop, e.g. 'front', 'back-left'.
"""
if not crops:
return []
pil_crops = [self._to_pil(c) for c in crops]
batch = torch.stack([self.tf(c) for c in pil_crops]).to(self.device)
logits = self.model(batch) # shape [N, 8]
preds = torch.argmax(logits, dim=-1).cpu() # single-label
return [self.LABELS[i] for i in preds]
# βββββββββ quick sanity check βββββββββ
if __name__ == "__main__":
from PIL import Image
import random
img_dir = Path("some/test/crops") # directory of zebra chip .jpgs
samples = [Image.open(p) for p in random.sample(list(img_dir.glob("*.jpg")), 4)]
clf = ViewPointClassifier(device="cpu")
print(clf(samples))
|