Add README, training/inference code, and trained DINOv2-small pose-classifier checkpoint
Browse filesUploads:
- README.md (Imageomics model-card template, populated from the code)
- pose_classifier.py - inference wrapper (ViewPointClassifier)
- train_pose_classifier.py - training script
- POSE_CLASSIFIER_GUIDE.md - user guide with pose-class reference
- checkpoints/best_pose_model.pth - trained DINOv2-small + MLP head weights (~88 MB, LFS)
Note: README YAML frontmatter still needs `license:` and (optionally) `datasets:` filled in.
- POSE_CLASSIFIER_GUIDE.md +287 -0
- README.md +351 -1
- checkpoints/best_pose_model.pth +3 -0
- pose_classifier.py +134 -0
- train_pose_classifier.py +420 -0
POSE_CLASSIFIER_GUIDE.md
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pose Classifier Guide
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
The pose classifier predicts the orientation of animals (zebras, giraffes, etc.) relative to the camera position from aerial drone footage. This is critical for navigation and behavior analysis.
|
| 6 |
+
|
| 7 |
+
## 8-Class Pose Classification System
|
| 8 |
+
|
| 9 |
+
### Pose Classes
|
| 10 |
+
|
| 11 |
+
The classifier identifies **8 discrete pose orientations** arranged in a circle around the animal:
|
| 12 |
+
|
| 13 |
+
1. **front** - Animal facing directly toward camera
|
| 14 |
+
2. **front-left** - Animal facing camera, angled to the left (~45°)
|
| 15 |
+
3. **left** - Animal's left side visible, perpendicular to camera
|
| 16 |
+
4. **back-left** - Animal facing away, angled to the left (~45°)
|
| 17 |
+
5. **back** - Animal facing directly away from camera
|
| 18 |
+
6. **back-right** - Animal facing away, angled to the right (~45°)
|
| 19 |
+
7. **right** - Animal's right side visible, perpendicular to camera
|
| 20 |
+
8. **front-right** - Animal facing camera, angled to the right (~45°)
|
| 21 |
+
|
| 22 |
+
### Visual Reference
|
| 23 |
+
|
| 24 |
+

|
| 25 |
+
|
| 26 |
+
The diagram shows the 8 pose classes arranged in a circle. The camera is positioned at the bottom, and the animal (zebra) is in the center. Each orange dot represents one of the 8 possible pose classifications.
|
| 27 |
+
|
| 28 |
+
## Example Poses
|
| 29 |
+
|
| 30 |
+
### Front Pose
|
| 31 |
+
**Label:** `front`
|
| 32 |
+
|
| 33 |
+
The animal is facing directly toward the camera, with the head and front body visible.
|
| 34 |
+
|
| 35 |
+

|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
### Front-Left Pose
|
| 40 |
+
**Label:** `front-left`
|
| 41 |
+
|
| 42 |
+
The animal is facing toward the camera but angled to its left (camera's right), showing both the front and left side.
|
| 43 |
+
|
| 44 |
+

|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
### Front-Right Pose
|
| 49 |
+
**Label:** `front-right`
|
| 50 |
+
|
| 51 |
+
The animal is facing toward the camera but angled to its right (camera's left), showing both the front and right side.
|
| 52 |
+
|
| 53 |
+

|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
### Left Pose
|
| 58 |
+
**Label:** `left`
|
| 59 |
+
|
| 60 |
+
The animal's left side is visible, perpendicular to the camera. This is a pure profile view.
|
| 61 |
+
|
| 62 |
+

|
| 63 |
+
|
| 64 |
+
---
|
| 65 |
+
|
| 66 |
+
### Right Pose
|
| 67 |
+
**Label:** `right`
|
| 68 |
+
|
| 69 |
+
The animal's right side is visible, perpendicular to the camera. This is a pure profile view from the opposite side.
|
| 70 |
+
|
| 71 |
+

|
| 72 |
+
|
| 73 |
+
---
|
| 74 |
+
|
| 75 |
+
### Back-Left Pose
|
| 76 |
+
**Label:** `back-left`
|
| 77 |
+
|
| 78 |
+
The animal is facing away from the camera but angled to its left, showing the rear-left quarter.
|
| 79 |
+
|
| 80 |
+

|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
### Back-Right Pose
|
| 85 |
+
**Label:** `back-right`
|
| 86 |
+
|
| 87 |
+
The animal is facing away from the camera but angled to its right, showing the rear-right quarter.
|
| 88 |
+
|
| 89 |
+

|
| 90 |
+
|
| 91 |
+
---
|
| 92 |
+
|
| 93 |
+
### Back Pose
|
| 94 |
+
**Label:** `back`
|
| 95 |
+
|
| 96 |
+
The animal is facing directly away from the camera, with the rear and back visible.
|
| 97 |
+
|
| 98 |
+

|
| 99 |
+
|
| 100 |
+
---
|
| 101 |
+
|
| 102 |
+
## Model Architecture
|
| 103 |
+
|
| 104 |
+
### DINOv2 + MLP Head
|
| 105 |
+
|
| 106 |
+
The pose classifier uses a **frozen DINOv2 backbone** with a **trainable MLP classification head**:
|
| 107 |
+
|
| 108 |
+
```
|
| 109 |
+
Input Image (224×224)
|
| 110 |
+
↓
|
| 111 |
+
DINOv2 Vision Transformer (frozen)
|
| 112 |
+
- Small: 384-dim features
|
| 113 |
+
- Base: 768-dim features
|
| 114 |
+
- Large: 1024-dim features
|
| 115 |
+
↓
|
| 116 |
+
MLP Head (trainable)
|
| 117 |
+
- LayerNorm
|
| 118 |
+
- Linear(feat_dim -> 256) + GELU + Dropout(0.3)
|
| 119 |
+
- Linear(256 -> 128) + GELU + Dropout(0.3)
|
| 120 |
+
- Linear(128 -> 8)
|
| 121 |
+
↓
|
| 122 |
+
Output Logits (8 classes)
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
### Why DINOv2?
|
| 126 |
+
|
| 127 |
+
- **Self-supervised learning** on diverse images provides strong visual features
|
| 128 |
+
- **Frozen backbone** reduces training time and prevents overfitting
|
| 129 |
+
- **Small memory footprint** suitable for deployment
|
| 130 |
+
- **Robust to varying image quality** from aerial footage
|
| 131 |
+
|
| 132 |
+
## Training Pipeline
|
| 133 |
+
|
| 134 |
+
### Data Organization
|
| 135 |
+
|
| 136 |
+
Training data is organized in folder structure:
|
| 137 |
+
```
|
| 138 |
+
pose_labels/
|
| 139 |
+
_reference.png # Visual guide
|
| 140 |
+
front/ # Front-facing animals
|
| 141 |
+
front-left/ # Front-left quarter
|
| 142 |
+
left/ # Left profile
|
| 143 |
+
back-left/ # Back-left quarter
|
| 144 |
+
back/ # Back-facing animals
|
| 145 |
+
back-right/ # Back-right quarter
|
| 146 |
+
right/ # Right profile
|
| 147 |
+
front-right/ # Front-right quarter
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
Or via CSV files with columns: `image_path, pose`
|
| 151 |
+
|
| 152 |
+
### Data Augmentation
|
| 153 |
+
|
| 154 |
+
**Geometric Augmentation with Label Swapping:**
|
| 155 |
+
- Horizontal flip applied with 50% probability
|
| 156 |
+
- When flipped, pose labels are swapped according to symmetry:
|
| 157 |
+
- `left` <-> `right`
|
| 158 |
+
- `front-left` <-> `front-right`
|
| 159 |
+
- `back-left` <-> `back-right`
|
| 160 |
+
- `front` and `back` remain unchanged
|
| 161 |
+
|
| 162 |
+
**Color/Transform Augmentation:**
|
| 163 |
+
- Random crop (256px -> 224px)
|
| 164 |
+
- Color jitter: brightness (±30%), contrast (±30%), saturation (±20%)
|
| 165 |
+
- Random rotation (±15°)
|
| 166 |
+
|
| 167 |
+
**Class Balancing:**
|
| 168 |
+
- Weighted random sampler ensures equal representation of all 8 classes during training
|
| 169 |
+
|
| 170 |
+
### Training Configuration
|
| 171 |
+
|
| 172 |
+
```bash
|
| 173 |
+
python train_pose_classifier.py \
|
| 174 |
+
--data_dir ./pose_labels \
|
| 175 |
+
--model_size small \
|
| 176 |
+
--epochs 30 \
|
| 177 |
+
--batch_size 32 \
|
| 178 |
+
--lr 1e-3
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
**Key Parameters:**
|
| 182 |
+
- **Model size**: `small`, `base`, or `large` (DINOv2 variant)
|
| 183 |
+
- **Optimizer**: AdamW with weight decay 0.01
|
| 184 |
+
- **Loss**: CrossEntropyLoss with label smoothing (0.1)
|
| 185 |
+
- **Scheduler**: CosineAnnealingLR
|
| 186 |
+
- **Mixed precision**: Automatic on GPU
|
| 187 |
+
|
| 188 |
+
**Training Output:**
|
| 189 |
+
- Best model saved to `checkpoints/best_pose_model.pth`
|
| 190 |
+
- Includes confusion matrix and per-class accuracy
|
| 191 |
+
- Optional ONNX export for deployment
|
| 192 |
+
|
| 193 |
+
## Usage in Navigation
|
| 194 |
+
|
| 195 |
+
### Integration with Detection Pipeline
|
| 196 |
+
|
| 197 |
+
The pose classifier is used in the navigation system after animal detection:
|
| 198 |
+
|
| 199 |
+
```python
|
| 200 |
+
from navigation.policy.pose_classifier import ViewPointClassifier
|
| 201 |
+
from PIL import Image
|
| 202 |
+
|
| 203 |
+
# Initialize classifier
|
| 204 |
+
classifier = ViewPointClassifier(
|
| 205 |
+
weight_path="model_weights/best_june_24_2025_IA_classifier_016.pth",
|
| 206 |
+
device="cpu",
|
| 207 |
+
threshold=0.5
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Process detected animal crops
|
| 211 |
+
crops = [Image.open(path) for path in detection_crops]
|
| 212 |
+
poses = classifier(crops) # Returns list of pose strings
|
| 213 |
+
|
| 214 |
+
# Use poses for navigation decisions
|
| 215 |
+
for pose in poses:
|
| 216 |
+
if "front" in pose:
|
| 217 |
+
print("Animal is facing camera - approach with caution")
|
| 218 |
+
elif "back" in pose:
|
| 219 |
+
print("Animal is facing away - good for following")
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
### Multi-Label Pose System (Alternative)
|
| 223 |
+
|
| 224 |
+
The `ViewPointClassifier` in `pose_classifier.py` uses a different approach:
|
| 225 |
+
|
| 226 |
+
- **5 multi-label classes**: `up, front, back, right, left`
|
| 227 |
+
- **EfficientNet-B4** backbone trained on zebra crops
|
| 228 |
+
- **Input size**: 512×512 pixels
|
| 229 |
+
- **Output**: Concatenated string (e.g., `"upfrontright"`)
|
| 230 |
+
- **Threshold**: 0.5 (configurable)
|
| 231 |
+
|
| 232 |
+
This allows detecting compound poses like "animal is facing front-right while looking up."
|
| 233 |
+
|
| 234 |
+
## Performance Considerations
|
| 235 |
+
|
| 236 |
+
### Inference Speed
|
| 237 |
+
- **DINOv2-small**: ~15-20ms per image (CPU)
|
| 238 |
+
- **DINOv2-base**: ~30-40ms per image (CPU)
|
| 239 |
+
- **GPU acceleration**: 5-10x faster
|
| 240 |
+
|
| 241 |
+
### Accuracy Targets
|
| 242 |
+
- **Overall accuracy**: >85% on validation set
|
| 243 |
+
- **Critical classes** (front/back): >90% accuracy
|
| 244 |
+
- **Confusion**: Most errors occur between adjacent classes (e.g., front vs. front-left)
|
| 245 |
+
|
| 246 |
+
### Deployment Notes
|
| 247 |
+
- Model checkpoint: ~150MB (small), ~350MB (base)
|
| 248 |
+
- ONNX export available for optimized inference
|
| 249 |
+
- Batch processing recommended for multiple detections
|
| 250 |
+
|
| 251 |
+
## Common Issues & Tips
|
| 252 |
+
|
| 253 |
+
### Issue: Poor performance on occluded animals
|
| 254 |
+
**Solution**: Train with more occluded examples or use confidence thresholding
|
| 255 |
+
|
| 256 |
+
### Issue: Confusion between adjacent poses
|
| 257 |
+
**Solution**: This is expected due to continuous nature of orientations; consider using pose groups (front-facing vs. side-facing vs. back-facing)
|
| 258 |
+
|
| 259 |
+
### Issue: Inconsistent predictions across frames
|
| 260 |
+
**Solution**: Apply temporal smoothing or majority voting across consecutive frames
|
| 261 |
+
|
| 262 |
+
### Issue: Different performance on zebras vs. other species
|
| 263 |
+
**Solution**: Retrain with balanced dataset across species, or train species-specific models
|
| 264 |
+
|
| 265 |
+
## Dataset Statistics
|
| 266 |
+
|
| 267 |
+
Current training data distribution (from folder structure):
|
| 268 |
+
- Folders: `front`, `front-left`, `front-right`, `left`, `right`, `back-left`, `back-right`, `back`
|
| 269 |
+
- Images per class: Variable (check with `train_pose_classifier.py --data_dir pose_labels`)
|
| 270 |
+
- Species: Primarily zebras and giraffes
|
| 271 |
+
- Source: Aerial drone footage from Mpala and OPC sessions
|
| 272 |
+
|
| 273 |
+
## References
|
| 274 |
+
|
| 275 |
+
- DINOv2 Paper: [https://arxiv.org/abs/2304.07193](https://arxiv.org/abs/2304.07193)
|
| 276 |
+
- VARe-ID (ViewPoint Classifier): [https://github.com/ziesski/VARe-ID](https://github.com/ziesski/VARe-ID)
|
| 277 |
+
- Individual identification of wildlife: [https://doi.org/10.1007/s10344-021-01549-4](Review on methods used for wildlife species and individual identification)
|
| 278 |
+
- Training script: [train_pose_classifier.py](train_pose_classifier.py)
|
| 279 |
+
- Navigation integration: [navigation/policy/pose_classifier.py](../navigation/policy/pose_classifier.py)
|
| 280 |
+
|
| 281 |
+
## Quick Start
|
| 282 |
+
|
| 283 |
+
1. **Prepare data**: Organize images in `pose_labels/` folders by class
|
| 284 |
+
2. **Train model**: `python train_pose_classifier.py --data_dir ./pose_labels --epochs 30`
|
| 285 |
+
3. **Evaluate**: Check confusion matrix and per-class accuracy in output
|
| 286 |
+
4. **Export**: Use `--export_onnx` flag for optimized deployment
|
| 287 |
+
5. **Integrate**: Load checkpoint and use for inference on detection crops
|
README.md
CHANGED
|
@@ -1,3 +1,353 @@
|
|
| 1 |
---
|
| 2 |
-
license:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
# TODO: pick an OSI-compatible license tag (see note below) and add it as `license:` here.
|
| 3 |
+
# TODO: if/when a HF dataset is published, add it as `datasets: <org>/<name>` (string, not list).
|
| 4 |
+
language:
|
| 5 |
+
- en
|
| 6 |
+
library_name: pytorch
|
| 7 |
+
tags:
|
| 8 |
+
- biology
|
| 9 |
+
- CV
|
| 10 |
+
- images
|
| 11 |
+
- animals
|
| 12 |
+
- zebra
|
| 13 |
+
- giraffe
|
| 14 |
+
- pose-estimation
|
| 15 |
+
- viewpoint-classification
|
| 16 |
+
- dinov2
|
| 17 |
+
- aerial-imagery
|
| 18 |
+
- drone
|
| 19 |
+
metrics:
|
| 20 |
+
- accuracy
|
| 21 |
+
model_description: An 8-class viewpoint/pose classifier for aerial drone imagery of wildlife (primarily zebras and giraffes). Uses a frozen DINOv2 vision-transformer backbone with a trainable MLP head to predict one of eight canonical orientations of the animal relative to the camera.
|
| 22 |
---
|
| 23 |
+
|
| 24 |
+
<!--
|
| 25 |
+
|
| 26 |
+
NOTE: Add more tags (your particular animal, type of model and use-case, etc.).
|
| 27 |
+
|
| 28 |
+
As with your GitHub Project repo, it is important to choose an appropriate license for your model. Alongside the appropriate stakeholders (e.g., your PI, co-authors), select a license that is [Open Source Initiative](https://opensource.org/licenses) (OSI) compliant. You may also wish to consider adding a [RAIL license](https://www.licenses.ai/ai-licenses), which addresses responsible use.
|
| 29 |
+
For more information on how to choose a license and why it matters, see [Choose A License](https://choosealicense.com) and [A Quick Guide to Software Licensing for the Scientist-Programmer](https://doi.org/10.1371/journal.pcbi.1002598) by A. Morin, et al.
|
| 30 |
+
See the [Imageomics policy for licensing](https://imageomics.github.io/Imageomics-guide/wiki-guide/Digital-products-release-licensing-policy/) for more information.
|
| 31 |
+
|
| 32 |
+
License tags (for the `yaml` above) can be found [here](https://hf.co/docs/hub/repositories-licenses).
|
| 33 |
+
-->
|
| 34 |
+
|
| 35 |
+
# Model Card for DINOv2 8-Class Animal Pose Classifier
|
| 36 |
+
|
| 37 |
+
A lightweight viewpoint/pose classifier that predicts one of **8 canonical orientations** (front, front-left, front-right, left, right, back-left, back-right, back) for an animal crop extracted from aerial drone imagery. It pairs a **frozen DINOv2 vision-transformer backbone** with a small **trainable MLP head**, and is intended for use as a downstream module in a drone-based wildlife detection-and-navigation pipeline.
|
| 38 |
+
|
| 39 |
+
## Model Details
|
| 40 |
+
|
| 41 |
+
### Model Description
|
| 42 |
+
|
| 43 |
+
This model takes a 224×224 RGB image crop of a single animal (typically produced by an upstream detector) and outputs a categorical prediction over 8 viewpoint classes arranged around the animal. The 8 classes form a discretization of the animal's heading relative to the camera, with adjacent classes separated by ~45°.
|
| 44 |
+
|
| 45 |
+
The DINOv2 backbone is loaded via `torch.hub` from `facebookresearch/dinov2` and is kept frozen during training; only the MLP head is updated. This keeps the number of trainable parameters low (well under 1M for the `small` variant), reduces overfitting on small labeled pose datasets, and allows the same self-supervised representation to be reused for related downstream tasks.
|
| 46 |
+
|
| 47 |
+
- **Developed by:** Imageomics Institute — Individual Identification of Zebras project (Claire Sun, et al.)
|
| 48 |
+
- **Model type:** Image classifier (Vision Transformer feature extractor + MLP head)
|
| 49 |
+
- **Language(s) (NLP):** N/A (vision model)
|
| 50 |
+
- **License:** [More Information Needed — choose a license (see above notes)]
|
| 51 |
+
- **Fine-tuned from model:** [facebookresearch/dinov2](https://github.com/facebookresearch/dinov2) (`dinov2_vits14`, `dinov2_vitb14`, or `dinov2_vitl14`)
|
| 52 |
+
|
| 53 |
+
### Model Sources
|
| 54 |
+
|
| 55 |
+
- **Repository:** [individual_id_zebras / Claire / pose_model](.)
|
| 56 |
+
- **Training script:** [train_pose_classifier.py](train_pose_classifier.py)
|
| 57 |
+
- **Inference wrapper:** [pose_classifier.py](pose_classifier.py)
|
| 58 |
+
- **User guide:** [POSE_CLASSIFIER_GUIDE.md](POSE_CLASSIFIER_GUIDE.md)
|
| 59 |
+
- **Paper:** [More Information Needed — optional]
|
| 60 |
+
- **Demo:** [More Information Needed — encouraged]
|
| 61 |
+
|
| 62 |
+
## Uses
|
| 63 |
+
|
| 64 |
+
### Direct Use
|
| 65 |
+
|
| 66 |
+
The model is intended to be applied to **tight, single-animal crops** (e.g., the output of a wildlife detector run on aerial drone frames). For each crop it returns the most likely of 8 viewpoint labels:
|
| 67 |
+
|
| 68 |
+
```
|
| 69 |
+
front, front-left, front-right, left, right, back-left, back-right, back
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
These labels are useful for:
|
| 73 |
+
|
| 74 |
+
- Selecting frames in which an individual is best observed (e.g., side profiles for stripe-based re-identification).
|
| 75 |
+
- Filtering training data for downstream identity models that are viewpoint-sensitive.
|
| 76 |
+
- Behavioral analysis (e.g., orientation of herd members relative to the camera/drone).
|
| 77 |
+
|
| 78 |
+
### Downstream Use
|
| 79 |
+
|
| 80 |
+
This pose classifier is a component of a larger **drone navigation and individual-identification pipeline** for zebras and giraffes. Downstream uses include:
|
| 81 |
+
|
| 82 |
+
- Conditioning a re-identification model on viewpoint.
|
| 83 |
+
- Informing autonomous drone-positioning policies (e.g., maneuver to obtain a side-profile view).
|
| 84 |
+
- Producing per-track viewpoint histograms used for sighting quality scoring.
|
| 85 |
+
|
| 86 |
+
### Out-of-Scope Use
|
| 87 |
+
|
| 88 |
+
- **Non-aerial / ground-level imagery.** The model is trained on top-down/oblique drone footage; predictions on eye-level photos are unlikely to be reliable.
|
| 89 |
+
- **Species the model was not trained on.** Performance has only been characterized for zebras and giraffes. Application to unrelated species is out of scope without retraining.
|
| 90 |
+
- **Continuous heading regression.** The model predicts 1-of-8 discrete classes, not a continuous angle. Adjacent classes (e.g., `front` vs `front-left`) are frequently confused and should not be treated as fully independent.
|
| 91 |
+
- **Identity, species, or behavior inference.** The model does not predict the identity, species, or activity of the animal.
|
| 92 |
+
|
| 93 |
+
## Bias, Risks, and Limitations
|
| 94 |
+
|
| 95 |
+
- **Domain shift:** Training data is drawn primarily from aerial drone footage at two field sites (Mpala and OPC). Performance may degrade on imagery captured at other altitudes, lighting conditions, or camera angles.
|
| 96 |
+
- **Class adjacency confusion:** Because viewpoint is fundamentally continuous, errors are concentrated between neighboring classes (e.g., `front` ↔ `front-left`). The 8-class discretization is a modeling choice, not a property of the underlying phenomenon.
|
| 97 |
+
- **Species imbalance:** Most training samples are zebras; giraffe coverage is smaller and per-class performance has not been independently broken out.
|
| 98 |
+
- **Occlusion sensitivity:** Heavily occluded or truncated crops (animals partially out of frame, overlapping individuals) are not well represented and tend to produce less reliable predictions.
|
| 99 |
+
- **Tight-crop dependence:** The model expects detector-style crops centered on a single animal. Wide-scene images will not produce meaningful predictions.
|
| 100 |
+
|
| 101 |
+
### Recommendations
|
| 102 |
+
|
| 103 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. In particular:
|
| 104 |
+
|
| 105 |
+
- Treat adjacent-class confusion (e.g., `front`/`front-left`) as expected and consider collapsing to coarser bins (front/side/back) for decisions that don't need fine resolution.
|
| 106 |
+
- Apply temporal smoothing or majority voting across consecutive frames when classifying tracked individuals.
|
| 107 |
+
- Confidence-threshold or hold out predictions on visibly occluded crops.
|
| 108 |
+
- Re-evaluate (or retrain) before deploying on a new site, species, or sensor.
|
| 109 |
+
|
| 110 |
+
## How to Get Started with the Model
|
| 111 |
+
|
| 112 |
+
The inference wrapper [`ViewPointClassifier`](pose_classifier.py) provides a one-line interface that takes a list of PIL crops and returns a list of pose labels.
|
| 113 |
+
|
| 114 |
+
```python
|
| 115 |
+
from PIL import Image
|
| 116 |
+
from pose_classifier import ViewPointClassifier
|
| 117 |
+
|
| 118 |
+
classifier = ViewPointClassifier(
|
| 119 |
+
weight_path="checkpoints/best_pose_model.pth",
|
| 120 |
+
model_size="small", # must match the trained checkpoint
|
| 121 |
+
device="cpu", # or "cuda"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
crops = [Image.open(p).convert("RGB") for p in ["zebra1.jpg", "zebra2.jpg"]]
|
| 125 |
+
poses = classifier(crops)
|
| 126 |
+
# e.g. ['front-left', 'back']
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
The wrapper handles preprocessing (resize to 256, center-crop to 224, ImageNet normalization) and accepts PIL images, NumPy arrays, or torch tensors as input.
|
| 130 |
+
|
| 131 |
+
To **train from scratch** on a new pose-labeled dataset:
|
| 132 |
+
|
| 133 |
+
```bash
|
| 134 |
+
python train_pose_classifier.py \
|
| 135 |
+
--data_dir ./pose_labels \
|
| 136 |
+
--model_size small \
|
| 137 |
+
--epochs 30 \
|
| 138 |
+
--batch_size 32 \
|
| 139 |
+
--lr 1e-3
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
See [POSE_CLASSIFIER_GUIDE.md](POSE_CLASSIFIER_GUIDE.md) for the full guide, including the visual reference diagram for each pose class.
|
| 143 |
+
|
| 144 |
+
## Training Details
|
| 145 |
+
|
| 146 |
+
### Training Data
|
| 147 |
+
|
| 148 |
+
Pose-labeled crops of zebras and giraffes extracted from aerial drone footage at Mpala (Kenya) and OPC field sites. Data is organized either as a per-class folder hierarchy:
|
| 149 |
+
|
| 150 |
+
```
|
| 151 |
+
pose_labels/
|
| 152 |
+
front/ front-left/ front-right/ left/ right/ back-left/ back-right/ back/
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
or as a CSV with `image_path, pose` columns. Class counts are inherently imbalanced and are handled at the sampler level (see below).
|
| 156 |
+
|
| 157 |
+
[More Information Needed — exact per-class sample counts, splits, and dataset card link]
|
| 158 |
+
|
| 159 |
+
### Training Procedure
|
| 160 |
+
|
| 161 |
+
#### Preprocessing
|
| 162 |
+
|
| 163 |
+
Training-time transforms (applied per image):
|
| 164 |
+
|
| 165 |
+
- Resize shorter side to 256
|
| 166 |
+
- `RandomCrop(224)`
|
| 167 |
+
- `ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2)`
|
| 168 |
+
- `RandomRotation(±15°)`
|
| 169 |
+
- ToTensor + ImageNet normalization (`mean=[0.485, 0.456, 0.406]`, `std=[0.229, 0.224, 0.225]`)
|
| 170 |
+
|
| 171 |
+
Validation-time transforms: `Resize(256) → CenterCrop(224) → ToTensor → Normalize`.
|
| 172 |
+
|
| 173 |
+
**Symmetry-aware horizontal flip:** with p=0.5 the crop is horizontally flipped and the label is swapped according to the canonical symmetry of the 8-class scheme:
|
| 174 |
+
|
| 175 |
+
```
|
| 176 |
+
left ↔ right
|
| 177 |
+
front-left ↔ front-right
|
| 178 |
+
back-left ↔ back-right
|
| 179 |
+
front, back unchanged
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
This effectively doubles training data without breaking label semantics.
|
| 183 |
+
|
| 184 |
+
**Class balancing:** a `WeightedRandomSampler` with weights inversely proportional to per-class frequency ensures all 8 classes are sampled at equal rates during training.
|
| 185 |
+
|
| 186 |
+
#### Training Hyperparameters
|
| 187 |
+
|
| 188 |
+
- **Training regime:** fp16 mixed precision when running on CUDA (via `torch.cuda.amp`); fp32 on CPU.
|
| 189 |
+
- **Optimizer:** AdamW, `lr=1e-3`, `weight_decay=0.01` (head parameters only — backbone is frozen).
|
| 190 |
+
- **Loss:** `CrossEntropyLoss(label_smoothing=0.1)`.
|
| 191 |
+
- **LR schedule:** `CosineAnnealingLR(T_max=epochs)`.
|
| 192 |
+
- **Default epochs / batch size:** 30 / 32.
|
| 193 |
+
- **Backbone:** frozen DINOv2 (`small` = ViT-S/14, 384-dim; `base` = ViT-B/14, 768-dim; `large` = ViT-L/14, 1024-dim).
|
| 194 |
+
- **Head:** `LayerNorm → Linear(feat_dim, 256) → GELU → Dropout(0.3) → Linear(256, 128) → GELU → Dropout(0.3) → Linear(128, 8)`.
|
| 195 |
+
|
| 196 |
+
Only the MLP head is trained — for the `small` variant this is well under 1M trainable parameters.
|
| 197 |
+
|
| 198 |
+
#### Speeds, Sizes, Times
|
| 199 |
+
|
| 200 |
+
- **Checkpoint size:** ~88 MB for the `small` variant (`best_pose_model.pth`), ~350 MB for `base`.
|
| 201 |
+
- **Inference (CPU):** ~15–20 ms/image (`small`), ~30–40 ms/image (`base`).
|
| 202 |
+
- **Inference (GPU):** roughly 5–10× faster than CPU.
|
| 203 |
+
|
| 204 |
+
[More Information Needed — wall-clock training time, throughput per epoch]
|
| 205 |
+
|
| 206 |
+
## Evaluation
|
| 207 |
+
|
| 208 |
+
### Testing Data, Factors & Metrics
|
| 209 |
+
|
| 210 |
+
#### Testing Data
|
| 211 |
+
|
| 212 |
+
When training from a single `--data_dir`, the script performs an 80/20 random split into train/val. When `--train_csv` and `--val_csv` are supplied, those are used directly.
|
| 213 |
+
|
| 214 |
+
[More Information Needed — held-out test set details, if any beyond the val split]
|
| 215 |
+
|
| 216 |
+
#### Factors
|
| 217 |
+
|
| 218 |
+
The natural disaggregations of interest are:
|
| 219 |
+
|
| 220 |
+
- **Pose class** (8 categories) — adjacent-class confusion is the dominant error mode.
|
| 221 |
+
- **Species** (zebra vs giraffe) — coverage and accuracy may differ.
|
| 222 |
+
- **Site / session** (e.g., Mpala vs OPC sessions) — proxies for altitude, lighting, and habitat.
|
| 223 |
+
|
| 224 |
+
[More Information Needed — disaggregated numbers]
|
| 225 |
+
|
| 226 |
+
#### Metrics
|
| 227 |
+
|
| 228 |
+
- **Top-1 accuracy** (overall and per-class).
|
| 229 |
+
- **8×8 confusion matrix** (printed by [train_pose_classifier.py](train_pose_classifier.py) at the end of training).
|
| 230 |
+
|
| 231 |
+
### Results
|
| 232 |
+
|
| 233 |
+
Target performance reported in the user guide:
|
| 234 |
+
|
| 235 |
+
- Overall validation accuracy: **>85%**
|
| 236 |
+
- Critical front/back classes: **>90%**
|
| 237 |
+
|
| 238 |
+
[More Information Needed — actual measured numbers for the released checkpoint, ideally as a confusion matrix figure]
|
| 239 |
+
|
| 240 |
+
#### Summary
|
| 241 |
+
|
| 242 |
+
The `small` DINOv2 backbone with the MLP head described above is the released configuration and offers a favorable accuracy/latency trade-off for the drone-navigation use case. The `base` and `large` variants are supported by the same training script for users with more compute and labeled data.
|
| 243 |
+
|
| 244 |
+
## Model Examination
|
| 245 |
+
|
| 246 |
+
[More Information Needed — saliency/feature-attribution analysis, if any]
|
| 247 |
+
|
| 248 |
+
## Environmental Impact
|
| 249 |
+
|
| 250 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://doi.org/10.48550/arXiv.1910.09700).
|
| 251 |
+
|
| 252 |
+
- **Hardware Type:** [More Information Needed — GPU model used for training]
|
| 253 |
+
- **Hours used:** [More Information Needed]
|
| 254 |
+
- **Cloud Provider:** Ohio Supercomputer Center (OSC)
|
| 255 |
+
- **Compute Region:** Ohio, USA
|
| 256 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 257 |
+
|
| 258 |
+
## Technical Specifications
|
| 259 |
+
|
| 260 |
+
### Model Architecture and Objective
|
| 261 |
+
|
| 262 |
+
```
|
| 263 |
+
Input Image (224×224, RGB, ImageNet-normalized)
|
| 264 |
+
│
|
| 265 |
+
▼
|
| 266 |
+
DINOv2 ViT (frozen)
|
| 267 |
+
- small : ViT-S/14 → 384-d feature vector
|
| 268 |
+
- base : ViT-B/14 → 768-d feature vector
|
| 269 |
+
- large : ViT-L/14 → 1024-d feature vector
|
| 270 |
+
│
|
| 271 |
+
▼
|
| 272 |
+
MLP head (trainable)
|
| 273 |
+
LayerNorm(feat_dim)
|
| 274 |
+
Linear(feat_dim → 256) + GELU + Dropout(0.3)
|
| 275 |
+
Linear(256 → 128) + GELU + Dropout(0.3)
|
| 276 |
+
Linear(128 → 8)
|
| 277 |
+
│
|
| 278 |
+
▼
|
| 279 |
+
Logits over {front, front-left, front-right, left, right,
|
| 280 |
+
back-left, back-right, back}
|
| 281 |
+
```
|
| 282 |
+
|
| 283 |
+
Training objective: cross-entropy with label smoothing (0.1), optimized only over the MLP head parameters.
|
| 284 |
+
|
| 285 |
+
### Compute Infrastructure
|
| 286 |
+
|
| 287 |
+
#### Hardware
|
| 288 |
+
|
| 289 |
+
- **Training:** a single CUDA-capable GPU is sufficient for the `small` variant; mixed precision is enabled automatically. Larger DINOv2 variants benefit from more GPU memory.
|
| 290 |
+
- **Inference:** runs on CPU or a single GPU. CPU is viable for low-throughput on-board use; GPU is recommended for batched offline processing.
|
| 291 |
+
|
| 292 |
+
#### Software
|
| 293 |
+
|
| 294 |
+
- Python 3.x
|
| 295 |
+
- PyTorch (with `torch.hub` access to `facebookresearch/dinov2`)
|
| 296 |
+
- torchvision
|
| 297 |
+
- pandas, numpy, Pillow, tqdm
|
| 298 |
+
|
| 299 |
+
## Citation
|
| 300 |
+
|
| 301 |
+
[More Information Needed]
|
| 302 |
+
|
| 303 |
+
<!--
|
| 304 |
+
If you use our model in your work, please cite the model and any associated paper.
|
| 305 |
+
|
| 306 |
+
**Model**
|
| 307 |
+
```
|
| 308 |
+
@software{<ref_code>,
|
| 309 |
+
author = {<author1 and author2>},
|
| 310 |
+
doi = {<doi once generated>},
|
| 311 |
+
title = {DINOv2 8-Class Animal Pose Classifier},
|
| 312 |
+
version = {<version#>},
|
| 313 |
+
year = {<year>},
|
| 314 |
+
url = {https://huggingface.co/imageomics/<model_name>}
|
| 315 |
+
}
|
| 316 |
+
```
|
| 317 |
+
-->
|
| 318 |
+
|
| 319 |
+
Underlying backbone:
|
| 320 |
+
|
| 321 |
+
```
|
| 322 |
+
@article{oquab2023dinov2,
|
| 323 |
+
title = {DINOv2: Learning Robust Visual Features without Supervision},
|
| 324 |
+
author = {Oquab, Maxime and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and others},
|
| 325 |
+
journal = {arXiv preprint arXiv:2304.07193},
|
| 326 |
+
year = {2023},
|
| 327 |
+
url = {https://arxiv.org/abs/2304.07193}
|
| 328 |
+
}
|
| 329 |
+
```
|
| 330 |
+
|
| 331 |
+
## Acknowledgements
|
| 332 |
+
|
| 333 |
+
This work was supported by the [Imageomics Institute](https://imageomics.org), which is funded by the US National Science Foundation's Harnessing the Data Revolution (HDR) program under [Award #2118240](https://www.nsf.gov/awardsearch/showAward?AWD_ID=2118240) (Imageomics: A New Frontier of Biological Information Powered by Knowledge-Guided Machine Learning). Any opinions, findings and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation.
|
| 334 |
+
|
| 335 |
+
Compute was provided by the [Ohio Supercomputer Center](https://www.osc.edu/). The backbone model is DINOv2 by Meta AI Research.
|
| 336 |
+
|
| 337 |
+
## Glossary
|
| 338 |
+
|
| 339 |
+
- **Pose / viewpoint:** the orientation of the animal relative to the camera, discretized here into 8 bins of ~45° each.
|
| 340 |
+
- **Frozen backbone:** the DINOv2 weights are fixed during training; gradients flow only through the MLP head.
|
| 341 |
+
- **Symmetry-aware flip:** horizontal-flip augmentation paired with a label swap (`left↔right`, `front-left↔front-right`, `back-left↔back-right`) so that flipped images carry geometrically correct labels.
|
| 342 |
+
|
| 343 |
+
## More Information
|
| 344 |
+
|
| 345 |
+
See [POSE_CLASSIFIER_GUIDE.md](POSE_CLASSIFIER_GUIDE.md) for visual references of each pose class, training tips, and integration notes for the navigation pipeline.
|
| 346 |
+
|
| 347 |
+
## Model Card Authors
|
| 348 |
+
|
| 349 |
+
Jenna Kline
|
| 350 |
+
|
| 351 |
+
## Model Card Contact
|
| 352 |
+
|
| 353 |
+
Elizabeth Campolongo, campolongo.4@osu.edu
|
checkpoints/best_pose_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2956a457eabb17260e7da687898042f2376aea178d2c22a13a16f1c12c48d21d
|
| 3 |
+
size 88828281
|
pose_classifier.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# navigation_scripts/pose_classifier.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torchvision import transforms as T
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Must match train_pose_classifier.py
|
| 12 |
+
POSE_CLASSES = ['front', 'front-left', 'front-right', 'left', 'right', 'back-left', 'back-right', 'back']
|
| 13 |
+
NUM_CLASSES = len(POSE_CLASSES)
|
| 14 |
+
|
| 15 |
+
DINO_MODELS = {
|
| 16 |
+
'small': ('dinov2_vits14', 384),
|
| 17 |
+
'base': ('dinov2_vitb14', 768),
|
| 18 |
+
'large': ('dinov2_vitl14', 1024),
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class _PoseClassifierModel(nn.Module):
|
| 23 |
+
"""DINOv2 + MLP head for 8-class pose classification (mirrors train_pose_classifier.PoseClassifier)."""
|
| 24 |
+
|
| 25 |
+
def __init__(self, model_size='small', dropout=0.3):
|
| 26 |
+
super().__init__()
|
| 27 |
+
model_name, feat_dim = DINO_MODELS[model_size]
|
| 28 |
+
|
| 29 |
+
self.backbone = torch.hub.load('facebookresearch/dinov2', model_name)
|
| 30 |
+
for param in self.backbone.parameters():
|
| 31 |
+
param.requires_grad = False
|
| 32 |
+
self.backbone.eval()
|
| 33 |
+
|
| 34 |
+
self.head = nn.Sequential(
|
| 35 |
+
nn.LayerNorm(feat_dim),
|
| 36 |
+
nn.Linear(feat_dim, 256),
|
| 37 |
+
nn.GELU(),
|
| 38 |
+
nn.Dropout(dropout),
|
| 39 |
+
nn.Linear(256, 128),
|
| 40 |
+
nn.GELU(),
|
| 41 |
+
nn.Dropout(dropout),
|
| 42 |
+
nn.Linear(128, NUM_CLASSES),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
features = self.backbone(x)
|
| 48 |
+
return self.head(features)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ViewPointClassifier:
|
| 52 |
+
"""
|
| 53 |
+
Predicts one of 8 canonical zebra viewpoints:
|
| 54 |
+
front, front-left, front-right, left, right, back-left, back-right, back
|
| 55 |
+
|
| 56 |
+
Uses a DINOv2-small backbone (frozen) with a trained MLP head.
|
| 57 |
+
|
| 58 |
+
__call__(crops) → list[str]
|
| 59 |
+
Each crop is a PIL.Image (RGB). Returns the predicted pose label.
|
| 60 |
+
"""
|
| 61 |
+
LABELS = POSE_CLASSES
|
| 62 |
+
|
| 63 |
+
def _to_pil(self, img):
|
| 64 |
+
"""Accept PIL.Image | np.ndarray | torch.Tensor -> PIL.Image (RGB)."""
|
| 65 |
+
if isinstance(img, Image.Image):
|
| 66 |
+
return img.convert("RGB")
|
| 67 |
+
|
| 68 |
+
if isinstance(img, np.ndarray):
|
| 69 |
+
if img.ndim == 3 and img.shape[2] == 3:
|
| 70 |
+
img = img[..., ::-1] # BGR → RGB
|
| 71 |
+
return Image.fromarray(img)
|
| 72 |
+
|
| 73 |
+
if torch.is_tensor(img):
|
| 74 |
+
return T.ToPILImage()(img.cpu())
|
| 75 |
+
|
| 76 |
+
raise TypeError(f"Unsupported crop type {type(img)}")
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
weight_path="checkpoints/best_pose_model.pth",
|
| 81 |
+
model_size: str = "small",
|
| 82 |
+
device: str = "cpu",
|
| 83 |
+
):
|
| 84 |
+
self.device = torch.device(device)
|
| 85 |
+
|
| 86 |
+
# Build the same architecture used in training
|
| 87 |
+
self.model = _PoseClassifierModel(model_size=model_size)
|
| 88 |
+
|
| 89 |
+
# Load checkpoint (saved by train_pose_classifier.py)
|
| 90 |
+
ckpt = torch.load(weight_path, map_location=self.device)
|
| 91 |
+
self.model.load_state_dict(ckpt['model_state_dict'])
|
| 92 |
+
self.model.eval().to(self.device)
|
| 93 |
+
|
| 94 |
+
# Match the validation transforms from training
|
| 95 |
+
self.tf = T.Compose(
|
| 96 |
+
[
|
| 97 |
+
T.Resize(256),
|
| 98 |
+
T.CenterCrop(224),
|
| 99 |
+
T.ToTensor(),
|
| 100 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 101 |
+
]
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
@torch.inference_mode()
|
| 105 |
+
def __call__(self, crops):
|
| 106 |
+
"""
|
| 107 |
+
Parameters
|
| 108 |
+
----------
|
| 109 |
+
crops : list[PIL.Image]
|
| 110 |
+
One crop per detection.
|
| 111 |
+
|
| 112 |
+
Returns
|
| 113 |
+
-------
|
| 114 |
+
list[str]
|
| 115 |
+
Predicted pose label for each crop, e.g. 'front', 'back-left'.
|
| 116 |
+
"""
|
| 117 |
+
if not crops:
|
| 118 |
+
return []
|
| 119 |
+
pil_crops = [self._to_pil(c) for c in crops]
|
| 120 |
+
batch = torch.stack([self.tf(c) for c in pil_crops]).to(self.device)
|
| 121 |
+
logits = self.model(batch) # shape [N, 8]
|
| 122 |
+
preds = torch.argmax(logits, dim=-1).cpu() # single-label
|
| 123 |
+
return [self.LABELS[i] for i in preds]
|
| 124 |
+
|
| 125 |
+
# ───────── quick sanity check ─────────
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
from PIL import Image
|
| 128 |
+
import random
|
| 129 |
+
|
| 130 |
+
img_dir = Path("some/test/crops") # directory of zebra chip .jpgs
|
| 131 |
+
samples = [Image.open(p) for p in random.sample(list(img_dir.glob("*.jpg")), 4)]
|
| 132 |
+
|
| 133 |
+
clf = ViewPointClassifier(device="cpu")
|
| 134 |
+
print(clf(samples))
|
train_pose_classifier.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
8-Class Pose Classifier Training
|
| 4 |
+
================================
|
| 5 |
+
Train a classifier for animal pose relative to camera.
|
| 6 |
+
|
| 7 |
+
Classes:
|
| 8 |
+
front, front-left, front-right, left, right, back-left, back-right, back
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python train_pose_classifier.py --data_dir ./pose_labels --epochs 30
|
| 12 |
+
python train_pose_classifier.py --train_csv train.csv --val_csv val.csv --epochs 30
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import os
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import numpy as np
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
| 26 |
+
from torchvision import transforms
|
| 27 |
+
import pandas as pd
|
| 28 |
+
|
| 29 |
+
# ============================================================
|
| 30 |
+
# Configuration
|
| 31 |
+
# ============================================================
|
| 32 |
+
|
| 33 |
+
POSE_CLASSES = ['front', 'front-left', 'front-right', 'left', 'right', 'back-left', 'back-right', 'back']
|
| 34 |
+
CLASS_TO_IDX = {c: i for i, c in enumerate(POSE_CLASSES)}
|
| 35 |
+
IDX_TO_CLASS = {i: c for c, i in CLASS_TO_IDX.items()}
|
| 36 |
+
NUM_CLASSES = len(POSE_CLASSES)
|
| 37 |
+
|
| 38 |
+
# Horizontal flip swaps these pairs
|
| 39 |
+
FLIP_PAIRS = {
|
| 40 |
+
'front-left': 'front-right',
|
| 41 |
+
'front-right': 'front-left',
|
| 42 |
+
'left': 'right',
|
| 43 |
+
'right': 'left',
|
| 44 |
+
'back-left': 'back-right',
|
| 45 |
+
'back-right': 'back-left',
|
| 46 |
+
'front': 'front',
|
| 47 |
+
'back': 'back',
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
# DINOv2 model sizes
|
| 51 |
+
DINO_MODELS = {
|
| 52 |
+
'small': ('dinov2_vits14', 384),
|
| 53 |
+
'base': ('dinov2_vitb14', 768),
|
| 54 |
+
'large': ('dinov2_vitl14', 1024),
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ============================================================
|
| 59 |
+
# Dataset
|
| 60 |
+
# ============================================================
|
| 61 |
+
|
| 62 |
+
class PoseDataset(Dataset):
|
| 63 |
+
"""Dataset that supports both folder structure and CSV"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, data_source, transform=None, augment_flip=True):
|
| 66 |
+
"""
|
| 67 |
+
Args:
|
| 68 |
+
data_source: Either a directory path (folder structure) or CSV path
|
| 69 |
+
transform: Image transforms
|
| 70 |
+
augment_flip: Whether to apply horizontal flip with label swap
|
| 71 |
+
"""
|
| 72 |
+
self.transform = transform
|
| 73 |
+
self.augment_flip = augment_flip
|
| 74 |
+
self.samples = []
|
| 75 |
+
|
| 76 |
+
data_path = Path(data_source)
|
| 77 |
+
|
| 78 |
+
if data_path.is_dir():
|
| 79 |
+
# Load from folder structure
|
| 80 |
+
for cls in POSE_CLASSES:
|
| 81 |
+
cls_dir = data_path / cls
|
| 82 |
+
if cls_dir.exists():
|
| 83 |
+
for img_path in cls_dir.glob('*'):
|
| 84 |
+
if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
|
| 85 |
+
self.samples.append((str(img_path), cls))
|
| 86 |
+
else:
|
| 87 |
+
# Load from CSV
|
| 88 |
+
df = pd.read_csv(data_path)
|
| 89 |
+
img_col = 'image_path' if 'image_path' in df.columns else df.columns[0]
|
| 90 |
+
label_col = 'pose' if 'pose' in df.columns else df.columns[1]
|
| 91 |
+
|
| 92 |
+
for _, row in df.iterrows():
|
| 93 |
+
if row[label_col] in POSE_CLASSES:
|
| 94 |
+
self.samples.append((row[img_col], row[label_col]))
|
| 95 |
+
|
| 96 |
+
print(f"Loaded {len(self.samples)} samples")
|
| 97 |
+
self._print_distribution()
|
| 98 |
+
|
| 99 |
+
def _print_distribution(self):
|
| 100 |
+
from collections import Counter
|
| 101 |
+
counts = Counter(s[1] for s in self.samples)
|
| 102 |
+
print("Class distribution:")
|
| 103 |
+
for cls in POSE_CLASSES:
|
| 104 |
+
print(f" {cls}: {counts.get(cls, 0)}")
|
| 105 |
+
|
| 106 |
+
def __len__(self):
|
| 107 |
+
return len(self.samples)
|
| 108 |
+
|
| 109 |
+
def __getitem__(self, idx):
|
| 110 |
+
img_path, label = self.samples[idx]
|
| 111 |
+
image = Image.open(img_path).convert('RGB')
|
| 112 |
+
|
| 113 |
+
# Horizontal flip augmentation with label swap
|
| 114 |
+
do_flip = self.augment_flip and torch.rand(1) < 0.5
|
| 115 |
+
if do_flip:
|
| 116 |
+
image = transforms.functional.hflip(image)
|
| 117 |
+
label = FLIP_PAIRS[label]
|
| 118 |
+
|
| 119 |
+
if self.transform:
|
| 120 |
+
image = self.transform(image)
|
| 121 |
+
|
| 122 |
+
return image, CLASS_TO_IDX[label]
|
| 123 |
+
|
| 124 |
+
def get_sample_weights(self):
|
| 125 |
+
"""Weights for balanced sampling"""
|
| 126 |
+
from collections import Counter
|
| 127 |
+
counts = Counter(s[1] for s in self.samples)
|
| 128 |
+
weights = [1.0 / counts[s[1]] for s in self.samples]
|
| 129 |
+
return torch.DoubleTensor(weights)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# ============================================================
|
| 133 |
+
# Model
|
| 134 |
+
# ============================================================
|
| 135 |
+
|
| 136 |
+
class PoseClassifier(nn.Module):
|
| 137 |
+
"""DINOv2 + MLP head for 8-class pose classification"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, model_size='small', dropout=0.3):
|
| 140 |
+
super().__init__()
|
| 141 |
+
|
| 142 |
+
model_name, feat_dim = DINO_MODELS[model_size]
|
| 143 |
+
|
| 144 |
+
# Load frozen DINOv2 backbone
|
| 145 |
+
self.backbone = torch.hub.load('facebookresearch/dinov2', model_name)
|
| 146 |
+
for param in self.backbone.parameters():
|
| 147 |
+
param.requires_grad = False
|
| 148 |
+
self.backbone.eval()
|
| 149 |
+
|
| 150 |
+
# Trainable MLP head
|
| 151 |
+
self.head = nn.Sequential(
|
| 152 |
+
nn.LayerNorm(feat_dim),
|
| 153 |
+
nn.Linear(feat_dim, 256),
|
| 154 |
+
nn.GELU(),
|
| 155 |
+
nn.Dropout(dropout),
|
| 156 |
+
nn.Linear(256, 128),
|
| 157 |
+
nn.GELU(),
|
| 158 |
+
nn.Dropout(dropout),
|
| 159 |
+
nn.Linear(128, NUM_CLASSES)
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
features = self.backbone(x)
|
| 165 |
+
return self.head(features)
|
| 166 |
+
|
| 167 |
+
def predict_proba(self, x):
|
| 168 |
+
logits = self.forward(x)
|
| 169 |
+
return F.softmax(logits, dim=-1)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ============================================================
|
| 173 |
+
# Training
|
| 174 |
+
# ============================================================
|
| 175 |
+
|
| 176 |
+
def get_transforms(train=True):
|
| 177 |
+
normalize = transforms.Normalize(
|
| 178 |
+
mean=[0.485, 0.456, 0.406],
|
| 179 |
+
std=[0.229, 0.224, 0.225]
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if train:
|
| 183 |
+
return transforms.Compose([
|
| 184 |
+
transforms.Resize(256),
|
| 185 |
+
transforms.RandomCrop(224),
|
| 186 |
+
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
|
| 187 |
+
transforms.RandomRotation(15),
|
| 188 |
+
transforms.ToTensor(),
|
| 189 |
+
normalize,
|
| 190 |
+
])
|
| 191 |
+
else:
|
| 192 |
+
return transforms.Compose([
|
| 193 |
+
transforms.Resize(256),
|
| 194 |
+
transforms.CenterCrop(224),
|
| 195 |
+
transforms.ToTensor(),
|
| 196 |
+
normalize,
|
| 197 |
+
])
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def train_epoch(model, dataloader, optimizer, criterion, device, scaler=None):
|
| 201 |
+
model.train()
|
| 202 |
+
model.backbone.eval() # Keep backbone frozen
|
| 203 |
+
|
| 204 |
+
total_loss = 0
|
| 205 |
+
correct = 0
|
| 206 |
+
total = 0
|
| 207 |
+
|
| 208 |
+
pbar = tqdm(dataloader, desc='Training')
|
| 209 |
+
for images, labels in pbar:
|
| 210 |
+
images, labels = images.to(device), labels.to(device)
|
| 211 |
+
|
| 212 |
+
optimizer.zero_grad()
|
| 213 |
+
|
| 214 |
+
if scaler:
|
| 215 |
+
with torch.cuda.amp.autocast():
|
| 216 |
+
outputs = model(images)
|
| 217 |
+
loss = criterion(outputs, labels)
|
| 218 |
+
scaler.scale(loss).backward()
|
| 219 |
+
scaler.step(optimizer)
|
| 220 |
+
scaler.update()
|
| 221 |
+
else:
|
| 222 |
+
outputs = model(images)
|
| 223 |
+
loss = criterion(outputs, labels)
|
| 224 |
+
loss.backward()
|
| 225 |
+
optimizer.step()
|
| 226 |
+
|
| 227 |
+
total_loss += loss.item()
|
| 228 |
+
_, predicted = outputs.max(1)
|
| 229 |
+
total += labels.size(0)
|
| 230 |
+
correct += predicted.eq(labels).sum().item()
|
| 231 |
+
|
| 232 |
+
pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*correct/total:.1f}%'})
|
| 233 |
+
|
| 234 |
+
return total_loss / len(dataloader), correct / total
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@torch.no_grad()
|
| 238 |
+
def evaluate(model, dataloader, criterion, device):
|
| 239 |
+
model.eval()
|
| 240 |
+
|
| 241 |
+
total_loss = 0
|
| 242 |
+
correct = 0
|
| 243 |
+
total = 0
|
| 244 |
+
all_preds, all_labels = [], []
|
| 245 |
+
|
| 246 |
+
for images, labels in tqdm(dataloader, desc='Evaluating'):
|
| 247 |
+
images, labels = images.to(device), labels.to(device)
|
| 248 |
+
|
| 249 |
+
outputs = model(images)
|
| 250 |
+
loss = criterion(outputs, labels)
|
| 251 |
+
|
| 252 |
+
total_loss += loss.item()
|
| 253 |
+
_, predicted = outputs.max(1)
|
| 254 |
+
total += labels.size(0)
|
| 255 |
+
correct += predicted.eq(labels).sum().item()
|
| 256 |
+
|
| 257 |
+
all_preds.extend(predicted.cpu().numpy())
|
| 258 |
+
all_labels.extend(labels.cpu().numpy())
|
| 259 |
+
|
| 260 |
+
return total_loss / len(dataloader), correct / total, all_preds, all_labels
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def print_confusion_matrix(preds, labels):
|
| 264 |
+
"""Print confusion matrix"""
|
| 265 |
+
from collections import defaultdict
|
| 266 |
+
|
| 267 |
+
matrix = defaultdict(lambda: defaultdict(int))
|
| 268 |
+
for p, l in zip(preds, labels):
|
| 269 |
+
matrix[IDX_TO_CLASS[l]][IDX_TO_CLASS[p]] += 1
|
| 270 |
+
|
| 271 |
+
print("\nConfusion Matrix (rows=true, cols=pred):")
|
| 272 |
+
|
| 273 |
+
# Header
|
| 274 |
+
header = f"{'':>12}" + "".join(f"{c[:6]:>8}" for c in POSE_CLASSES)
|
| 275 |
+
print(header)
|
| 276 |
+
|
| 277 |
+
for true_class in POSE_CLASSES:
|
| 278 |
+
row = f"{true_class:>12}"
|
| 279 |
+
for pred_class in POSE_CLASSES:
|
| 280 |
+
count = matrix[true_class][pred_class]
|
| 281 |
+
row += f"{count:>8}"
|
| 282 |
+
print(row)
|
| 283 |
+
|
| 284 |
+
# Per-class accuracy
|
| 285 |
+
print("\nPer-class accuracy:")
|
| 286 |
+
for cls in POSE_CLASSES:
|
| 287 |
+
correct = matrix[cls][cls]
|
| 288 |
+
total = sum(matrix[cls].values())
|
| 289 |
+
acc = correct / total * 100 if total > 0 else 0
|
| 290 |
+
print(f" {cls:>12}: {acc:5.1f}% ({correct}/{total})")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def export_onnx(model, output_path, device='cpu'):
|
| 294 |
+
"""Export to ONNX"""
|
| 295 |
+
model.eval()
|
| 296 |
+
model.to(device)
|
| 297 |
+
|
| 298 |
+
dummy = torch.randn(1, 3, 224, 224).to(device)
|
| 299 |
+
|
| 300 |
+
torch.onnx.export(
|
| 301 |
+
model, dummy, output_path,
|
| 302 |
+
export_params=True,
|
| 303 |
+
opset_version=14,
|
| 304 |
+
input_names=['image'],
|
| 305 |
+
output_names=['logits'],
|
| 306 |
+
dynamic_axes={'image': {0: 'batch'}, 'logits': {0: 'batch'}}
|
| 307 |
+
)
|
| 308 |
+
print(f"Exported to {output_path}")
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def main():
|
| 312 |
+
parser = argparse.ArgumentParser()
|
| 313 |
+
parser.add_argument('--data_dir', type=str, help='Directory with class folders')
|
| 314 |
+
parser.add_argument('--train_csv', type=str, help='Training CSV')
|
| 315 |
+
parser.add_argument('--val_csv', type=str, help='Validation CSV')
|
| 316 |
+
parser.add_argument('--model_size', type=str, default='small', choices=['small', 'base', 'large'])
|
| 317 |
+
parser.add_argument('--epochs', type=int, default=30)
|
| 318 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 319 |
+
parser.add_argument('--lr', type=float, default=1e-3)
|
| 320 |
+
parser.add_argument('--output_dir', type=str, default='./checkpoints')
|
| 321 |
+
parser.add_argument('--export_onnx', action='store_true')
|
| 322 |
+
args = parser.parse_args()
|
| 323 |
+
|
| 324 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 325 |
+
print(f"Device: {device}")
|
| 326 |
+
|
| 327 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 328 |
+
|
| 329 |
+
# Load data
|
| 330 |
+
train_transform = get_transforms(train=True)
|
| 331 |
+
val_transform = get_transforms(train=False)
|
| 332 |
+
|
| 333 |
+
if args.train_csv:
|
| 334 |
+
train_dataset = PoseDataset(args.train_csv, train_transform, augment_flip=True)
|
| 335 |
+
val_dataset = PoseDataset(args.val_csv, val_transform, augment_flip=False) if args.val_csv else None
|
| 336 |
+
elif args.data_dir:
|
| 337 |
+
full_dataset = PoseDataset(args.data_dir, train_transform, augment_flip=True)
|
| 338 |
+
# Split 80/20
|
| 339 |
+
n_val = int(0.2 * len(full_dataset))
|
| 340 |
+
n_train = len(full_dataset) - n_val
|
| 341 |
+
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [n_train, n_val])
|
| 342 |
+
# Wrap val with no augmentation
|
| 343 |
+
val_dataset.dataset.augment_flip = False
|
| 344 |
+
val_dataset.dataset.transform = val_transform
|
| 345 |
+
else:
|
| 346 |
+
print("Provide --data_dir or --train_csv")
|
| 347 |
+
return
|
| 348 |
+
|
| 349 |
+
# Weighted sampler for class balance
|
| 350 |
+
if hasattr(train_dataset, 'get_sample_weights'):
|
| 351 |
+
weights = train_dataset.get_sample_weights()
|
| 352 |
+
sampler = WeightedRandomSampler(weights, len(weights))
|
| 353 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4)
|
| 354 |
+
else:
|
| 355 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
| 356 |
+
|
| 357 |
+
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) if val_dataset else None
|
| 358 |
+
|
| 359 |
+
# Model
|
| 360 |
+
print(f"\nLoading DINOv2-{args.model_size}...")
|
| 361 |
+
model = PoseClassifier(model_size=args.model_size).to(device)
|
| 362 |
+
|
| 363 |
+
trainable = sum(p.numel() for p in model.head.parameters())
|
| 364 |
+
print(f"Trainable parameters: {trainable:,}")
|
| 365 |
+
|
| 366 |
+
# Training
|
| 367 |
+
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
| 368 |
+
optimizer = torch.optim.AdamW(model.head.parameters(), lr=args.lr, weight_decay=0.01)
|
| 369 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
| 370 |
+
scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
|
| 371 |
+
|
| 372 |
+
best_acc = 0
|
| 373 |
+
|
| 374 |
+
for epoch in range(args.epochs):
|
| 375 |
+
print(f"\nEpoch {epoch+1}/{args.epochs}")
|
| 376 |
+
|
| 377 |
+
train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, scaler)
|
| 378 |
+
|
| 379 |
+
if val_loader:
|
| 380 |
+
val_loss, val_acc, preds, labels = evaluate(model, val_loader, criterion, device)
|
| 381 |
+
print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.1f}%")
|
| 382 |
+
print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc*100:.1f}%")
|
| 383 |
+
|
| 384 |
+
if val_acc > best_acc:
|
| 385 |
+
best_acc = val_acc
|
| 386 |
+
torch.save({
|
| 387 |
+
'epoch': epoch,
|
| 388 |
+
'model_state_dict': model.state_dict(),
|
| 389 |
+
'head_state_dict': model.head.state_dict(),
|
| 390 |
+
'val_acc': val_acc,
|
| 391 |
+
'classes': POSE_CLASSES,
|
| 392 |
+
}, f'{args.output_dir}/best_pose_model.pth')
|
| 393 |
+
print(f" → Saved (acc: {val_acc*100:.1f}%)")
|
| 394 |
+
else:
|
| 395 |
+
print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.1f}%")
|
| 396 |
+
|
| 397 |
+
scheduler.step()
|
| 398 |
+
|
| 399 |
+
# Final evaluation
|
| 400 |
+
if val_loader:
|
| 401 |
+
print("\n" + "="*60)
|
| 402 |
+
print("Final Evaluation")
|
| 403 |
+
print("="*60)
|
| 404 |
+
|
| 405 |
+
ckpt = torch.load(f'{args.output_dir}/best_pose_model.pth')
|
| 406 |
+
model.load_state_dict(ckpt['model_state_dict'])
|
| 407 |
+
|
| 408 |
+
_, acc, preds, labels = evaluate(model, val_loader, criterion, device)
|
| 409 |
+
print(f"Best Accuracy: {acc*100:.1f}%")
|
| 410 |
+
print_confusion_matrix(preds, labels)
|
| 411 |
+
|
| 412 |
+
# Export
|
| 413 |
+
if args.export_onnx:
|
| 414 |
+
export_onnx(model, f'{args.output_dir}/pose_classifier.onnx')
|
| 415 |
+
|
| 416 |
+
print("\nDone!")
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
if __name__ == '__main__':
|
| 420 |
+
main()
|