jennamk14 commited on
Commit
2be7251
·
verified ·
1 Parent(s): 3237c11

Add README, training/inference code, and trained DINOv2-small pose-classifier checkpoint

Browse files

Uploads:
- README.md (Imageomics model-card template, populated from the code)
- pose_classifier.py - inference wrapper (ViewPointClassifier)
- train_pose_classifier.py - training script
- POSE_CLASSIFIER_GUIDE.md - user guide with pose-class reference
- checkpoints/best_pose_model.pth - trained DINOv2-small + MLP head weights (~88 MB, LFS)

Note: README YAML frontmatter still needs `license:` and (optionally) `datasets:` filled in.

POSE_CLASSIFIER_GUIDE.md ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pose Classifier Guide
2
+
3
+ ## Overview
4
+
5
+ The pose classifier predicts the orientation of animals (zebras, giraffes, etc.) relative to the camera position from aerial drone footage. This is critical for navigation and behavior analysis.
6
+
7
+ ## 8-Class Pose Classification System
8
+
9
+ ### Pose Classes
10
+
11
+ The classifier identifies **8 discrete pose orientations** arranged in a circle around the animal:
12
+
13
+ 1. **front** - Animal facing directly toward camera
14
+ 2. **front-left** - Animal facing camera, angled to the left (~45°)
15
+ 3. **left** - Animal's left side visible, perpendicular to camera
16
+ 4. **back-left** - Animal facing away, angled to the left (~45°)
17
+ 5. **back** - Animal facing directly away from camera
18
+ 6. **back-right** - Animal facing away, angled to the right (~45°)
19
+ 7. **right** - Animal's right side visible, perpendicular to camera
20
+ 8. **front-right** - Animal facing camera, angled to the right (~45°)
21
+
22
+ ### Visual Reference
23
+
24
+ ![Pose Reference Diagram](pose_labels/_reference.png)
25
+
26
+ The diagram shows the 8 pose classes arranged in a circle. The camera is positioned at the bottom, and the animal (zebra) is in the center. Each orange dot represents one of the 8 possible pose classifications.
27
+
28
+ ## Example Poses
29
+
30
+ ### Front Pose
31
+ **Label:** `front`
32
+
33
+ The animal is facing directly toward the camera, with the head and front body visible.
34
+
35
+ ![Front Pose Example](pose_labels/front/mpala_session_1_DJI_0002_partition_1_DJI_0002_000171_c0_004.jpg)
36
+
37
+ ---
38
+
39
+ ### Front-Left Pose
40
+ **Label:** `front-left`
41
+
42
+ The animal is facing toward the camera but angled to its left (camera's right), showing both the front and left side.
43
+
44
+ ![Front-Left Pose Example](pose_labels/front-left/mpala_session_1_DJI_0002_partition_1_DJI_0002_000471_c0_005.jpg)
45
+
46
+ ---
47
+
48
+ ### Front-Right Pose
49
+ **Label:** `front-right`
50
+
51
+ The animal is facing toward the camera but angled to its right (camera's left), showing both the front and right side.
52
+
53
+ ![Front-Right Pose Example](pose_labels/front-right/mpala_session_2_DJI_0006_partition_2_DJI_0006_006552_c0_004.jpg)
54
+
55
+ ---
56
+
57
+ ### Left Pose
58
+ **Label:** `left`
59
+
60
+ The animal's left side is visible, perpendicular to the camera. This is a pure profile view.
61
+
62
+ ![Left Pose Example](pose_labels/left/mpala_session_1_DJI_0002_partition_1_DJI_0002_000321_c1_001.jpg)
63
+
64
+ ---
65
+
66
+ ### Right Pose
67
+ **Label:** `right`
68
+
69
+ The animal's right side is visible, perpendicular to the camera. This is a pure profile view from the opposite side.
70
+
71
+ ![Right Pose Example](pose_labels/right/mpala_session_1_DJI_0002_partition_1_DJI_0002_000171_c0_002.jpg)
72
+
73
+ ---
74
+
75
+ ### Back-Left Pose
76
+ **Label:** `back-left`
77
+
78
+ The animal is facing away from the camera but angled to its left, showing the rear-left quarter.
79
+
80
+ ![Back-Left Pose Example](pose_labels/back-left/mpala_session_1_DJI_0002_partition_1_DJI_0002_000171_c1_001.jpg)
81
+
82
+ ---
83
+
84
+ ### Back-Right Pose
85
+ **Label:** `back-right`
86
+
87
+ The animal is facing away from the camera but angled to its right, showing the rear-right quarter.
88
+
89
+ ![Back-Right Pose Example](pose_labels/back-right/mpala_session_1_DJI_0002_partition_1_DJI_0002_000171_c1_000.jpg)
90
+
91
+ ---
92
+
93
+ ### Back Pose
94
+ **Label:** `back`
95
+
96
+ The animal is facing directly away from the camera, with the rear and back visible.
97
+
98
+ ![Back Pose Example](pose_labels/back/mpala_session_1_DJI_0002_partition_1_DJI_0002_000321_c1_000.jpg)
99
+
100
+ ---
101
+
102
+ ## Model Architecture
103
+
104
+ ### DINOv2 + MLP Head
105
+
106
+ The pose classifier uses a **frozen DINOv2 backbone** with a **trainable MLP classification head**:
107
+
108
+ ```
109
+ Input Image (224×224)
110
+
111
+ DINOv2 Vision Transformer (frozen)
112
+ - Small: 384-dim features
113
+ - Base: 768-dim features
114
+ - Large: 1024-dim features
115
+
116
+ MLP Head (trainable)
117
+ - LayerNorm
118
+ - Linear(feat_dim -> 256) + GELU + Dropout(0.3)
119
+ - Linear(256 -> 128) + GELU + Dropout(0.3)
120
+ - Linear(128 -> 8)
121
+
122
+ Output Logits (8 classes)
123
+ ```
124
+
125
+ ### Why DINOv2?
126
+
127
+ - **Self-supervised learning** on diverse images provides strong visual features
128
+ - **Frozen backbone** reduces training time and prevents overfitting
129
+ - **Small memory footprint** suitable for deployment
130
+ - **Robust to varying image quality** from aerial footage
131
+
132
+ ## Training Pipeline
133
+
134
+ ### Data Organization
135
+
136
+ Training data is organized in folder structure:
137
+ ```
138
+ pose_labels/
139
+ _reference.png # Visual guide
140
+ front/ # Front-facing animals
141
+ front-left/ # Front-left quarter
142
+ left/ # Left profile
143
+ back-left/ # Back-left quarter
144
+ back/ # Back-facing animals
145
+ back-right/ # Back-right quarter
146
+ right/ # Right profile
147
+ front-right/ # Front-right quarter
148
+ ```
149
+
150
+ Or via CSV files with columns: `image_path, pose`
151
+
152
+ ### Data Augmentation
153
+
154
+ **Geometric Augmentation with Label Swapping:**
155
+ - Horizontal flip applied with 50% probability
156
+ - When flipped, pose labels are swapped according to symmetry:
157
+ - `left` <-> `right`
158
+ - `front-left` <-> `front-right`
159
+ - `back-left` <-> `back-right`
160
+ - `front` and `back` remain unchanged
161
+
162
+ **Color/Transform Augmentation:**
163
+ - Random crop (256px -> 224px)
164
+ - Color jitter: brightness (±30%), contrast (±30%), saturation (±20%)
165
+ - Random rotation (±15°)
166
+
167
+ **Class Balancing:**
168
+ - Weighted random sampler ensures equal representation of all 8 classes during training
169
+
170
+ ### Training Configuration
171
+
172
+ ```bash
173
+ python train_pose_classifier.py \
174
+ --data_dir ./pose_labels \
175
+ --model_size small \
176
+ --epochs 30 \
177
+ --batch_size 32 \
178
+ --lr 1e-3
179
+ ```
180
+
181
+ **Key Parameters:**
182
+ - **Model size**: `small`, `base`, or `large` (DINOv2 variant)
183
+ - **Optimizer**: AdamW with weight decay 0.01
184
+ - **Loss**: CrossEntropyLoss with label smoothing (0.1)
185
+ - **Scheduler**: CosineAnnealingLR
186
+ - **Mixed precision**: Automatic on GPU
187
+
188
+ **Training Output:**
189
+ - Best model saved to `checkpoints/best_pose_model.pth`
190
+ - Includes confusion matrix and per-class accuracy
191
+ - Optional ONNX export for deployment
192
+
193
+ ## Usage in Navigation
194
+
195
+ ### Integration with Detection Pipeline
196
+
197
+ The pose classifier is used in the navigation system after animal detection:
198
+
199
+ ```python
200
+ from navigation.policy.pose_classifier import ViewPointClassifier
201
+ from PIL import Image
202
+
203
+ # Initialize classifier
204
+ classifier = ViewPointClassifier(
205
+ weight_path="model_weights/best_june_24_2025_IA_classifier_016.pth",
206
+ device="cpu",
207
+ threshold=0.5
208
+ )
209
+
210
+ # Process detected animal crops
211
+ crops = [Image.open(path) for path in detection_crops]
212
+ poses = classifier(crops) # Returns list of pose strings
213
+
214
+ # Use poses for navigation decisions
215
+ for pose in poses:
216
+ if "front" in pose:
217
+ print("Animal is facing camera - approach with caution")
218
+ elif "back" in pose:
219
+ print("Animal is facing away - good for following")
220
+ ```
221
+
222
+ ### Multi-Label Pose System (Alternative)
223
+
224
+ The `ViewPointClassifier` in `pose_classifier.py` uses a different approach:
225
+
226
+ - **5 multi-label classes**: `up, front, back, right, left`
227
+ - **EfficientNet-B4** backbone trained on zebra crops
228
+ - **Input size**: 512×512 pixels
229
+ - **Output**: Concatenated string (e.g., `"upfrontright"`)
230
+ - **Threshold**: 0.5 (configurable)
231
+
232
+ This allows detecting compound poses like "animal is facing front-right while looking up."
233
+
234
+ ## Performance Considerations
235
+
236
+ ### Inference Speed
237
+ - **DINOv2-small**: ~15-20ms per image (CPU)
238
+ - **DINOv2-base**: ~30-40ms per image (CPU)
239
+ - **GPU acceleration**: 5-10x faster
240
+
241
+ ### Accuracy Targets
242
+ - **Overall accuracy**: >85% on validation set
243
+ - **Critical classes** (front/back): >90% accuracy
244
+ - **Confusion**: Most errors occur between adjacent classes (e.g., front vs. front-left)
245
+
246
+ ### Deployment Notes
247
+ - Model checkpoint: ~150MB (small), ~350MB (base)
248
+ - ONNX export available for optimized inference
249
+ - Batch processing recommended for multiple detections
250
+
251
+ ## Common Issues & Tips
252
+
253
+ ### Issue: Poor performance on occluded animals
254
+ **Solution**: Train with more occluded examples or use confidence thresholding
255
+
256
+ ### Issue: Confusion between adjacent poses
257
+ **Solution**: This is expected due to continuous nature of orientations; consider using pose groups (front-facing vs. side-facing vs. back-facing)
258
+
259
+ ### Issue: Inconsistent predictions across frames
260
+ **Solution**: Apply temporal smoothing or majority voting across consecutive frames
261
+
262
+ ### Issue: Different performance on zebras vs. other species
263
+ **Solution**: Retrain with balanced dataset across species, or train species-specific models
264
+
265
+ ## Dataset Statistics
266
+
267
+ Current training data distribution (from folder structure):
268
+ - Folders: `front`, `front-left`, `front-right`, `left`, `right`, `back-left`, `back-right`, `back`
269
+ - Images per class: Variable (check with `train_pose_classifier.py --data_dir pose_labels`)
270
+ - Species: Primarily zebras and giraffes
271
+ - Source: Aerial drone footage from Mpala and OPC sessions
272
+
273
+ ## References
274
+
275
+ - DINOv2 Paper: [https://arxiv.org/abs/2304.07193](https://arxiv.org/abs/2304.07193)
276
+ - VARe-ID (ViewPoint Classifier): [https://github.com/ziesski/VARe-ID](https://github.com/ziesski/VARe-ID)
277
+ - Individual identification of wildlife: [https://doi.org/10.1007/s10344-021-01549-4](Review on methods used for wildlife species and individual identification)
278
+ - Training script: [train_pose_classifier.py](train_pose_classifier.py)
279
+ - Navigation integration: [navigation/policy/pose_classifier.py](../navigation/policy/pose_classifier.py)
280
+
281
+ ## Quick Start
282
+
283
+ 1. **Prepare data**: Organize images in `pose_labels/` folders by class
284
+ 2. **Train model**: `python train_pose_classifier.py --data_dir ./pose_labels --epochs 30`
285
+ 3. **Evaluate**: Check confusion matrix and per-class accuracy in output
286
+ 4. **Export**: Use `--export_onnx` flag for optimized deployment
287
+ 5. **Integrate**: Load checkpoint and use for inference on detection crops
README.md CHANGED
@@ -1,3 +1,353 @@
1
  ---
2
- license: mit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ # TODO: pick an OSI-compatible license tag (see note below) and add it as `license:` here.
3
+ # TODO: if/when a HF dataset is published, add it as `datasets: <org>/<name>` (string, not list).
4
+ language:
5
+ - en
6
+ library_name: pytorch
7
+ tags:
8
+ - biology
9
+ - CV
10
+ - images
11
+ - animals
12
+ - zebra
13
+ - giraffe
14
+ - pose-estimation
15
+ - viewpoint-classification
16
+ - dinov2
17
+ - aerial-imagery
18
+ - drone
19
+ metrics:
20
+ - accuracy
21
+ model_description: An 8-class viewpoint/pose classifier for aerial drone imagery of wildlife (primarily zebras and giraffes). Uses a frozen DINOv2 vision-transformer backbone with a trainable MLP head to predict one of eight canonical orientations of the animal relative to the camera.
22
  ---
23
+
24
+ <!--
25
+
26
+ NOTE: Add more tags (your particular animal, type of model and use-case, etc.).
27
+
28
+ As with your GitHub Project repo, it is important to choose an appropriate license for your model. Alongside the appropriate stakeholders (e.g., your PI, co-authors), select a license that is [Open Source Initiative](https://opensource.org/licenses) (OSI) compliant. You may also wish to consider adding a [RAIL license](https://www.licenses.ai/ai-licenses), which addresses responsible use.
29
+ For more information on how to choose a license and why it matters, see [Choose A License](https://choosealicense.com) and [A Quick Guide to Software Licensing for the Scientist-Programmer](https://doi.org/10.1371/journal.pcbi.1002598) by A. Morin, et al.
30
+ See the [Imageomics policy for licensing](https://imageomics.github.io/Imageomics-guide/wiki-guide/Digital-products-release-licensing-policy/) for more information.
31
+
32
+ License tags (for the `yaml` above) can be found [here](https://hf.co/docs/hub/repositories-licenses).
33
+ -->
34
+
35
+ # Model Card for DINOv2 8-Class Animal Pose Classifier
36
+
37
+ A lightweight viewpoint/pose classifier that predicts one of **8 canonical orientations** (front, front-left, front-right, left, right, back-left, back-right, back) for an animal crop extracted from aerial drone imagery. It pairs a **frozen DINOv2 vision-transformer backbone** with a small **trainable MLP head**, and is intended for use as a downstream module in a drone-based wildlife detection-and-navigation pipeline.
38
+
39
+ ## Model Details
40
+
41
+ ### Model Description
42
+
43
+ This model takes a 224×224 RGB image crop of a single animal (typically produced by an upstream detector) and outputs a categorical prediction over 8 viewpoint classes arranged around the animal. The 8 classes form a discretization of the animal's heading relative to the camera, with adjacent classes separated by ~45°.
44
+
45
+ The DINOv2 backbone is loaded via `torch.hub` from `facebookresearch/dinov2` and is kept frozen during training; only the MLP head is updated. This keeps the number of trainable parameters low (well under 1M for the `small` variant), reduces overfitting on small labeled pose datasets, and allows the same self-supervised representation to be reused for related downstream tasks.
46
+
47
+ - **Developed by:** Imageomics Institute — Individual Identification of Zebras project (Claire Sun, et al.)
48
+ - **Model type:** Image classifier (Vision Transformer feature extractor + MLP head)
49
+ - **Language(s) (NLP):** N/A (vision model)
50
+ - **License:** [More Information Needed — choose a license (see above notes)]
51
+ - **Fine-tuned from model:** [facebookresearch/dinov2](https://github.com/facebookresearch/dinov2) (`dinov2_vits14`, `dinov2_vitb14`, or `dinov2_vitl14`)
52
+
53
+ ### Model Sources
54
+
55
+ - **Repository:** [individual_id_zebras / Claire / pose_model](.)
56
+ - **Training script:** [train_pose_classifier.py](train_pose_classifier.py)
57
+ - **Inference wrapper:** [pose_classifier.py](pose_classifier.py)
58
+ - **User guide:** [POSE_CLASSIFIER_GUIDE.md](POSE_CLASSIFIER_GUIDE.md)
59
+ - **Paper:** [More Information Needed — optional]
60
+ - **Demo:** [More Information Needed — encouraged]
61
+
62
+ ## Uses
63
+
64
+ ### Direct Use
65
+
66
+ The model is intended to be applied to **tight, single-animal crops** (e.g., the output of a wildlife detector run on aerial drone frames). For each crop it returns the most likely of 8 viewpoint labels:
67
+
68
+ ```
69
+ front, front-left, front-right, left, right, back-left, back-right, back
70
+ ```
71
+
72
+ These labels are useful for:
73
+
74
+ - Selecting frames in which an individual is best observed (e.g., side profiles for stripe-based re-identification).
75
+ - Filtering training data for downstream identity models that are viewpoint-sensitive.
76
+ - Behavioral analysis (e.g., orientation of herd members relative to the camera/drone).
77
+
78
+ ### Downstream Use
79
+
80
+ This pose classifier is a component of a larger **drone navigation and individual-identification pipeline** for zebras and giraffes. Downstream uses include:
81
+
82
+ - Conditioning a re-identification model on viewpoint.
83
+ - Informing autonomous drone-positioning policies (e.g., maneuver to obtain a side-profile view).
84
+ - Producing per-track viewpoint histograms used for sighting quality scoring.
85
+
86
+ ### Out-of-Scope Use
87
+
88
+ - **Non-aerial / ground-level imagery.** The model is trained on top-down/oblique drone footage; predictions on eye-level photos are unlikely to be reliable.
89
+ - **Species the model was not trained on.** Performance has only been characterized for zebras and giraffes. Application to unrelated species is out of scope without retraining.
90
+ - **Continuous heading regression.** The model predicts 1-of-8 discrete classes, not a continuous angle. Adjacent classes (e.g., `front` vs `front-left`) are frequently confused and should not be treated as fully independent.
91
+ - **Identity, species, or behavior inference.** The model does not predict the identity, species, or activity of the animal.
92
+
93
+ ## Bias, Risks, and Limitations
94
+
95
+ - **Domain shift:** Training data is drawn primarily from aerial drone footage at two field sites (Mpala and OPC). Performance may degrade on imagery captured at other altitudes, lighting conditions, or camera angles.
96
+ - **Class adjacency confusion:** Because viewpoint is fundamentally continuous, errors are concentrated between neighboring classes (e.g., `front` ↔ `front-left`). The 8-class discretization is a modeling choice, not a property of the underlying phenomenon.
97
+ - **Species imbalance:** Most training samples are zebras; giraffe coverage is smaller and per-class performance has not been independently broken out.
98
+ - **Occlusion sensitivity:** Heavily occluded or truncated crops (animals partially out of frame, overlapping individuals) are not well represented and tend to produce less reliable predictions.
99
+ - **Tight-crop dependence:** The model expects detector-style crops centered on a single animal. Wide-scene images will not produce meaningful predictions.
100
+
101
+ ### Recommendations
102
+
103
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. In particular:
104
+
105
+ - Treat adjacent-class confusion (e.g., `front`/`front-left`) as expected and consider collapsing to coarser bins (front/side/back) for decisions that don't need fine resolution.
106
+ - Apply temporal smoothing or majority voting across consecutive frames when classifying tracked individuals.
107
+ - Confidence-threshold or hold out predictions on visibly occluded crops.
108
+ - Re-evaluate (or retrain) before deploying on a new site, species, or sensor.
109
+
110
+ ## How to Get Started with the Model
111
+
112
+ The inference wrapper [`ViewPointClassifier`](pose_classifier.py) provides a one-line interface that takes a list of PIL crops and returns a list of pose labels.
113
+
114
+ ```python
115
+ from PIL import Image
116
+ from pose_classifier import ViewPointClassifier
117
+
118
+ classifier = ViewPointClassifier(
119
+ weight_path="checkpoints/best_pose_model.pth",
120
+ model_size="small", # must match the trained checkpoint
121
+ device="cpu", # or "cuda"
122
+ )
123
+
124
+ crops = [Image.open(p).convert("RGB") for p in ["zebra1.jpg", "zebra2.jpg"]]
125
+ poses = classifier(crops)
126
+ # e.g. ['front-left', 'back']
127
+ ```
128
+
129
+ The wrapper handles preprocessing (resize to 256, center-crop to 224, ImageNet normalization) and accepts PIL images, NumPy arrays, or torch tensors as input.
130
+
131
+ To **train from scratch** on a new pose-labeled dataset:
132
+
133
+ ```bash
134
+ python train_pose_classifier.py \
135
+ --data_dir ./pose_labels \
136
+ --model_size small \
137
+ --epochs 30 \
138
+ --batch_size 32 \
139
+ --lr 1e-3
140
+ ```
141
+
142
+ See [POSE_CLASSIFIER_GUIDE.md](POSE_CLASSIFIER_GUIDE.md) for the full guide, including the visual reference diagram for each pose class.
143
+
144
+ ## Training Details
145
+
146
+ ### Training Data
147
+
148
+ Pose-labeled crops of zebras and giraffes extracted from aerial drone footage at Mpala (Kenya) and OPC field sites. Data is organized either as a per-class folder hierarchy:
149
+
150
+ ```
151
+ pose_labels/
152
+ front/ front-left/ front-right/ left/ right/ back-left/ back-right/ back/
153
+ ```
154
+
155
+ or as a CSV with `image_path, pose` columns. Class counts are inherently imbalanced and are handled at the sampler level (see below).
156
+
157
+ [More Information Needed — exact per-class sample counts, splits, and dataset card link]
158
+
159
+ ### Training Procedure
160
+
161
+ #### Preprocessing
162
+
163
+ Training-time transforms (applied per image):
164
+
165
+ - Resize shorter side to 256
166
+ - `RandomCrop(224)`
167
+ - `ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2)`
168
+ - `RandomRotation(±15°)`
169
+ - ToTensor + ImageNet normalization (`mean=[0.485, 0.456, 0.406]`, `std=[0.229, 0.224, 0.225]`)
170
+
171
+ Validation-time transforms: `Resize(256) → CenterCrop(224) → ToTensor → Normalize`.
172
+
173
+ **Symmetry-aware horizontal flip:** with p=0.5 the crop is horizontally flipped and the label is swapped according to the canonical symmetry of the 8-class scheme:
174
+
175
+ ```
176
+ left ↔ right
177
+ front-left ↔ front-right
178
+ back-left ↔ back-right
179
+ front, back unchanged
180
+ ```
181
+
182
+ This effectively doubles training data without breaking label semantics.
183
+
184
+ **Class balancing:** a `WeightedRandomSampler` with weights inversely proportional to per-class frequency ensures all 8 classes are sampled at equal rates during training.
185
+
186
+ #### Training Hyperparameters
187
+
188
+ - **Training regime:** fp16 mixed precision when running on CUDA (via `torch.cuda.amp`); fp32 on CPU.
189
+ - **Optimizer:** AdamW, `lr=1e-3`, `weight_decay=0.01` (head parameters only — backbone is frozen).
190
+ - **Loss:** `CrossEntropyLoss(label_smoothing=0.1)`.
191
+ - **LR schedule:** `CosineAnnealingLR(T_max=epochs)`.
192
+ - **Default epochs / batch size:** 30 / 32.
193
+ - **Backbone:** frozen DINOv2 (`small` = ViT-S/14, 384-dim; `base` = ViT-B/14, 768-dim; `large` = ViT-L/14, 1024-dim).
194
+ - **Head:** `LayerNorm → Linear(feat_dim, 256) → GELU → Dropout(0.3) → Linear(256, 128) → GELU → Dropout(0.3) → Linear(128, 8)`.
195
+
196
+ Only the MLP head is trained — for the `small` variant this is well under 1M trainable parameters.
197
+
198
+ #### Speeds, Sizes, Times
199
+
200
+ - **Checkpoint size:** ~88 MB for the `small` variant (`best_pose_model.pth`), ~350 MB for `base`.
201
+ - **Inference (CPU):** ~15–20 ms/image (`small`), ~30–40 ms/image (`base`).
202
+ - **Inference (GPU):** roughly 5–10× faster than CPU.
203
+
204
+ [More Information Needed — wall-clock training time, throughput per epoch]
205
+
206
+ ## Evaluation
207
+
208
+ ### Testing Data, Factors & Metrics
209
+
210
+ #### Testing Data
211
+
212
+ When training from a single `--data_dir`, the script performs an 80/20 random split into train/val. When `--train_csv` and `--val_csv` are supplied, those are used directly.
213
+
214
+ [More Information Needed — held-out test set details, if any beyond the val split]
215
+
216
+ #### Factors
217
+
218
+ The natural disaggregations of interest are:
219
+
220
+ - **Pose class** (8 categories) — adjacent-class confusion is the dominant error mode.
221
+ - **Species** (zebra vs giraffe) — coverage and accuracy may differ.
222
+ - **Site / session** (e.g., Mpala vs OPC sessions) — proxies for altitude, lighting, and habitat.
223
+
224
+ [More Information Needed — disaggregated numbers]
225
+
226
+ #### Metrics
227
+
228
+ - **Top-1 accuracy** (overall and per-class).
229
+ - **8×8 confusion matrix** (printed by [train_pose_classifier.py](train_pose_classifier.py) at the end of training).
230
+
231
+ ### Results
232
+
233
+ Target performance reported in the user guide:
234
+
235
+ - Overall validation accuracy: **>85%**
236
+ - Critical front/back classes: **>90%**
237
+
238
+ [More Information Needed — actual measured numbers for the released checkpoint, ideally as a confusion matrix figure]
239
+
240
+ #### Summary
241
+
242
+ The `small` DINOv2 backbone with the MLP head described above is the released configuration and offers a favorable accuracy/latency trade-off for the drone-navigation use case. The `base` and `large` variants are supported by the same training script for users with more compute and labeled data.
243
+
244
+ ## Model Examination
245
+
246
+ [More Information Needed — saliency/feature-attribution analysis, if any]
247
+
248
+ ## Environmental Impact
249
+
250
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://doi.org/10.48550/arXiv.1910.09700).
251
+
252
+ - **Hardware Type:** [More Information Needed — GPU model used for training]
253
+ - **Hours used:** [More Information Needed]
254
+ - **Cloud Provider:** Ohio Supercomputer Center (OSC)
255
+ - **Compute Region:** Ohio, USA
256
+ - **Carbon Emitted:** [More Information Needed]
257
+
258
+ ## Technical Specifications
259
+
260
+ ### Model Architecture and Objective
261
+
262
+ ```
263
+ Input Image (224×224, RGB, ImageNet-normalized)
264
+
265
+
266
+ DINOv2 ViT (frozen)
267
+ - small : ViT-S/14 → 384-d feature vector
268
+ - base : ViT-B/14 → 768-d feature vector
269
+ - large : ViT-L/14 → 1024-d feature vector
270
+
271
+
272
+ MLP head (trainable)
273
+ LayerNorm(feat_dim)
274
+ Linear(feat_dim → 256) + GELU + Dropout(0.3)
275
+ Linear(256 → 128) + GELU + Dropout(0.3)
276
+ Linear(128 → 8)
277
+
278
+
279
+ Logits over {front, front-left, front-right, left, right,
280
+ back-left, back-right, back}
281
+ ```
282
+
283
+ Training objective: cross-entropy with label smoothing (0.1), optimized only over the MLP head parameters.
284
+
285
+ ### Compute Infrastructure
286
+
287
+ #### Hardware
288
+
289
+ - **Training:** a single CUDA-capable GPU is sufficient for the `small` variant; mixed precision is enabled automatically. Larger DINOv2 variants benefit from more GPU memory.
290
+ - **Inference:** runs on CPU or a single GPU. CPU is viable for low-throughput on-board use; GPU is recommended for batched offline processing.
291
+
292
+ #### Software
293
+
294
+ - Python 3.x
295
+ - PyTorch (with `torch.hub` access to `facebookresearch/dinov2`)
296
+ - torchvision
297
+ - pandas, numpy, Pillow, tqdm
298
+
299
+ ## Citation
300
+
301
+ [More Information Needed]
302
+
303
+ <!--
304
+ If you use our model in your work, please cite the model and any associated paper.
305
+
306
+ **Model**
307
+ ```
308
+ @software{<ref_code>,
309
+ author = {<author1 and author2>},
310
+ doi = {<doi once generated>},
311
+ title = {DINOv2 8-Class Animal Pose Classifier},
312
+ version = {<version#>},
313
+ year = {<year>},
314
+ url = {https://huggingface.co/imageomics/<model_name>}
315
+ }
316
+ ```
317
+ -->
318
+
319
+ Underlying backbone:
320
+
321
+ ```
322
+ @article{oquab2023dinov2,
323
+ title = {DINOv2: Learning Robust Visual Features without Supervision},
324
+ author = {Oquab, Maxime and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and others},
325
+ journal = {arXiv preprint arXiv:2304.07193},
326
+ year = {2023},
327
+ url = {https://arxiv.org/abs/2304.07193}
328
+ }
329
+ ```
330
+
331
+ ## Acknowledgements
332
+
333
+ This work was supported by the [Imageomics Institute](https://imageomics.org), which is funded by the US National Science Foundation's Harnessing the Data Revolution (HDR) program under [Award #2118240](https://www.nsf.gov/awardsearch/showAward?AWD_ID=2118240) (Imageomics: A New Frontier of Biological Information Powered by Knowledge-Guided Machine Learning). Any opinions, findings and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation.
334
+
335
+ Compute was provided by the [Ohio Supercomputer Center](https://www.osc.edu/). The backbone model is DINOv2 by Meta AI Research.
336
+
337
+ ## Glossary
338
+
339
+ - **Pose / viewpoint:** the orientation of the animal relative to the camera, discretized here into 8 bins of ~45° each.
340
+ - **Frozen backbone:** the DINOv2 weights are fixed during training; gradients flow only through the MLP head.
341
+ - **Symmetry-aware flip:** horizontal-flip augmentation paired with a label swap (`left↔right`, `front-left↔front-right`, `back-left↔back-right`) so that flipped images carry geometrically correct labels.
342
+
343
+ ## More Information
344
+
345
+ See [POSE_CLASSIFIER_GUIDE.md](POSE_CLASSIFIER_GUIDE.md) for visual references of each pose class, training tips, and integration notes for the navigation pipeline.
346
+
347
+ ## Model Card Authors
348
+
349
+ Jenna Kline
350
+
351
+ ## Model Card Contact
352
+
353
+ Elizabeth Campolongo, campolongo.4@osu.edu
checkpoints/best_pose_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2956a457eabb17260e7da687898042f2376aea178d2c22a13a16f1c12c48d21d
3
+ size 88828281
pose_classifier.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # navigation_scripts/pose_classifier.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms as T
6
+ from pathlib import Path
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+
11
+ # Must match train_pose_classifier.py
12
+ POSE_CLASSES = ['front', 'front-left', 'front-right', 'left', 'right', 'back-left', 'back-right', 'back']
13
+ NUM_CLASSES = len(POSE_CLASSES)
14
+
15
+ DINO_MODELS = {
16
+ 'small': ('dinov2_vits14', 384),
17
+ 'base': ('dinov2_vitb14', 768),
18
+ 'large': ('dinov2_vitl14', 1024),
19
+ }
20
+
21
+
22
+ class _PoseClassifierModel(nn.Module):
23
+ """DINOv2 + MLP head for 8-class pose classification (mirrors train_pose_classifier.PoseClassifier)."""
24
+
25
+ def __init__(self, model_size='small', dropout=0.3):
26
+ super().__init__()
27
+ model_name, feat_dim = DINO_MODELS[model_size]
28
+
29
+ self.backbone = torch.hub.load('facebookresearch/dinov2', model_name)
30
+ for param in self.backbone.parameters():
31
+ param.requires_grad = False
32
+ self.backbone.eval()
33
+
34
+ self.head = nn.Sequential(
35
+ nn.LayerNorm(feat_dim),
36
+ nn.Linear(feat_dim, 256),
37
+ nn.GELU(),
38
+ nn.Dropout(dropout),
39
+ nn.Linear(256, 128),
40
+ nn.GELU(),
41
+ nn.Dropout(dropout),
42
+ nn.Linear(128, NUM_CLASSES),
43
+ )
44
+
45
+ def forward(self, x):
46
+ with torch.no_grad():
47
+ features = self.backbone(x)
48
+ return self.head(features)
49
+
50
+
51
+ class ViewPointClassifier:
52
+ """
53
+ Predicts one of 8 canonical zebra viewpoints:
54
+ front, front-left, front-right, left, right, back-left, back-right, back
55
+
56
+ Uses a DINOv2-small backbone (frozen) with a trained MLP head.
57
+
58
+ __call__(crops) → list[str]
59
+ Each crop is a PIL.Image (RGB). Returns the predicted pose label.
60
+ """
61
+ LABELS = POSE_CLASSES
62
+
63
+ def _to_pil(self, img):
64
+ """Accept PIL.Image | np.ndarray | torch.Tensor -> PIL.Image (RGB)."""
65
+ if isinstance(img, Image.Image):
66
+ return img.convert("RGB")
67
+
68
+ if isinstance(img, np.ndarray):
69
+ if img.ndim == 3 and img.shape[2] == 3:
70
+ img = img[..., ::-1] # BGR → RGB
71
+ return Image.fromarray(img)
72
+
73
+ if torch.is_tensor(img):
74
+ return T.ToPILImage()(img.cpu())
75
+
76
+ raise TypeError(f"Unsupported crop type {type(img)}")
77
+
78
+ def __init__(
79
+ self,
80
+ weight_path="checkpoints/best_pose_model.pth",
81
+ model_size: str = "small",
82
+ device: str = "cpu",
83
+ ):
84
+ self.device = torch.device(device)
85
+
86
+ # Build the same architecture used in training
87
+ self.model = _PoseClassifierModel(model_size=model_size)
88
+
89
+ # Load checkpoint (saved by train_pose_classifier.py)
90
+ ckpt = torch.load(weight_path, map_location=self.device)
91
+ self.model.load_state_dict(ckpt['model_state_dict'])
92
+ self.model.eval().to(self.device)
93
+
94
+ # Match the validation transforms from training
95
+ self.tf = T.Compose(
96
+ [
97
+ T.Resize(256),
98
+ T.CenterCrop(224),
99
+ T.ToTensor(),
100
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
101
+ ]
102
+ )
103
+
104
+ @torch.inference_mode()
105
+ def __call__(self, crops):
106
+ """
107
+ Parameters
108
+ ----------
109
+ crops : list[PIL.Image]
110
+ One crop per detection.
111
+
112
+ Returns
113
+ -------
114
+ list[str]
115
+ Predicted pose label for each crop, e.g. 'front', 'back-left'.
116
+ """
117
+ if not crops:
118
+ return []
119
+ pil_crops = [self._to_pil(c) for c in crops]
120
+ batch = torch.stack([self.tf(c) for c in pil_crops]).to(self.device)
121
+ logits = self.model(batch) # shape [N, 8]
122
+ preds = torch.argmax(logits, dim=-1).cpu() # single-label
123
+ return [self.LABELS[i] for i in preds]
124
+
125
+ # ───────── quick sanity check ─────────
126
+ if __name__ == "__main__":
127
+ from PIL import Image
128
+ import random
129
+
130
+ img_dir = Path("some/test/crops") # directory of zebra chip .jpgs
131
+ samples = [Image.open(p) for p in random.sample(list(img_dir.glob("*.jpg")), 4)]
132
+
133
+ clf = ViewPointClassifier(device="cpu")
134
+ print(clf(samples))
train_pose_classifier.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 8-Class Pose Classifier Training
4
+ ================================
5
+ Train a classifier for animal pose relative to camera.
6
+
7
+ Classes:
8
+ front, front-left, front-right, left, right, back-left, back-right, back
9
+
10
+ Usage:
11
+ python train_pose_classifier.py --data_dir ./pose_labels --epochs 30
12
+ python train_pose_classifier.py --train_csv train.csv --val_csv val.csv --epochs 30
13
+ """
14
+
15
+ import argparse
16
+ import os
17
+ from pathlib import Path
18
+ import numpy as np
19
+ from PIL import Image
20
+ from tqdm import tqdm
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
26
+ from torchvision import transforms
27
+ import pandas as pd
28
+
29
+ # ============================================================
30
+ # Configuration
31
+ # ============================================================
32
+
33
+ POSE_CLASSES = ['front', 'front-left', 'front-right', 'left', 'right', 'back-left', 'back-right', 'back']
34
+ CLASS_TO_IDX = {c: i for i, c in enumerate(POSE_CLASSES)}
35
+ IDX_TO_CLASS = {i: c for c, i in CLASS_TO_IDX.items()}
36
+ NUM_CLASSES = len(POSE_CLASSES)
37
+
38
+ # Horizontal flip swaps these pairs
39
+ FLIP_PAIRS = {
40
+ 'front-left': 'front-right',
41
+ 'front-right': 'front-left',
42
+ 'left': 'right',
43
+ 'right': 'left',
44
+ 'back-left': 'back-right',
45
+ 'back-right': 'back-left',
46
+ 'front': 'front',
47
+ 'back': 'back',
48
+ }
49
+
50
+ # DINOv2 model sizes
51
+ DINO_MODELS = {
52
+ 'small': ('dinov2_vits14', 384),
53
+ 'base': ('dinov2_vitb14', 768),
54
+ 'large': ('dinov2_vitl14', 1024),
55
+ }
56
+
57
+
58
+ # ============================================================
59
+ # Dataset
60
+ # ============================================================
61
+
62
+ class PoseDataset(Dataset):
63
+ """Dataset that supports both folder structure and CSV"""
64
+
65
+ def __init__(self, data_source, transform=None, augment_flip=True):
66
+ """
67
+ Args:
68
+ data_source: Either a directory path (folder structure) or CSV path
69
+ transform: Image transforms
70
+ augment_flip: Whether to apply horizontal flip with label swap
71
+ """
72
+ self.transform = transform
73
+ self.augment_flip = augment_flip
74
+ self.samples = []
75
+
76
+ data_path = Path(data_source)
77
+
78
+ if data_path.is_dir():
79
+ # Load from folder structure
80
+ for cls in POSE_CLASSES:
81
+ cls_dir = data_path / cls
82
+ if cls_dir.exists():
83
+ for img_path in cls_dir.glob('*'):
84
+ if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
85
+ self.samples.append((str(img_path), cls))
86
+ else:
87
+ # Load from CSV
88
+ df = pd.read_csv(data_path)
89
+ img_col = 'image_path' if 'image_path' in df.columns else df.columns[0]
90
+ label_col = 'pose' if 'pose' in df.columns else df.columns[1]
91
+
92
+ for _, row in df.iterrows():
93
+ if row[label_col] in POSE_CLASSES:
94
+ self.samples.append((row[img_col], row[label_col]))
95
+
96
+ print(f"Loaded {len(self.samples)} samples")
97
+ self._print_distribution()
98
+
99
+ def _print_distribution(self):
100
+ from collections import Counter
101
+ counts = Counter(s[1] for s in self.samples)
102
+ print("Class distribution:")
103
+ for cls in POSE_CLASSES:
104
+ print(f" {cls}: {counts.get(cls, 0)}")
105
+
106
+ def __len__(self):
107
+ return len(self.samples)
108
+
109
+ def __getitem__(self, idx):
110
+ img_path, label = self.samples[idx]
111
+ image = Image.open(img_path).convert('RGB')
112
+
113
+ # Horizontal flip augmentation with label swap
114
+ do_flip = self.augment_flip and torch.rand(1) < 0.5
115
+ if do_flip:
116
+ image = transforms.functional.hflip(image)
117
+ label = FLIP_PAIRS[label]
118
+
119
+ if self.transform:
120
+ image = self.transform(image)
121
+
122
+ return image, CLASS_TO_IDX[label]
123
+
124
+ def get_sample_weights(self):
125
+ """Weights for balanced sampling"""
126
+ from collections import Counter
127
+ counts = Counter(s[1] for s in self.samples)
128
+ weights = [1.0 / counts[s[1]] for s in self.samples]
129
+ return torch.DoubleTensor(weights)
130
+
131
+
132
+ # ============================================================
133
+ # Model
134
+ # ============================================================
135
+
136
+ class PoseClassifier(nn.Module):
137
+ """DINOv2 + MLP head for 8-class pose classification"""
138
+
139
+ def __init__(self, model_size='small', dropout=0.3):
140
+ super().__init__()
141
+
142
+ model_name, feat_dim = DINO_MODELS[model_size]
143
+
144
+ # Load frozen DINOv2 backbone
145
+ self.backbone = torch.hub.load('facebookresearch/dinov2', model_name)
146
+ for param in self.backbone.parameters():
147
+ param.requires_grad = False
148
+ self.backbone.eval()
149
+
150
+ # Trainable MLP head
151
+ self.head = nn.Sequential(
152
+ nn.LayerNorm(feat_dim),
153
+ nn.Linear(feat_dim, 256),
154
+ nn.GELU(),
155
+ nn.Dropout(dropout),
156
+ nn.Linear(256, 128),
157
+ nn.GELU(),
158
+ nn.Dropout(dropout),
159
+ nn.Linear(128, NUM_CLASSES)
160
+ )
161
+
162
+ def forward(self, x):
163
+ with torch.no_grad():
164
+ features = self.backbone(x)
165
+ return self.head(features)
166
+
167
+ def predict_proba(self, x):
168
+ logits = self.forward(x)
169
+ return F.softmax(logits, dim=-1)
170
+
171
+
172
+ # ============================================================
173
+ # Training
174
+ # ============================================================
175
+
176
+ def get_transforms(train=True):
177
+ normalize = transforms.Normalize(
178
+ mean=[0.485, 0.456, 0.406],
179
+ std=[0.229, 0.224, 0.225]
180
+ )
181
+
182
+ if train:
183
+ return transforms.Compose([
184
+ transforms.Resize(256),
185
+ transforms.RandomCrop(224),
186
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
187
+ transforms.RandomRotation(15),
188
+ transforms.ToTensor(),
189
+ normalize,
190
+ ])
191
+ else:
192
+ return transforms.Compose([
193
+ transforms.Resize(256),
194
+ transforms.CenterCrop(224),
195
+ transforms.ToTensor(),
196
+ normalize,
197
+ ])
198
+
199
+
200
+ def train_epoch(model, dataloader, optimizer, criterion, device, scaler=None):
201
+ model.train()
202
+ model.backbone.eval() # Keep backbone frozen
203
+
204
+ total_loss = 0
205
+ correct = 0
206
+ total = 0
207
+
208
+ pbar = tqdm(dataloader, desc='Training')
209
+ for images, labels in pbar:
210
+ images, labels = images.to(device), labels.to(device)
211
+
212
+ optimizer.zero_grad()
213
+
214
+ if scaler:
215
+ with torch.cuda.amp.autocast():
216
+ outputs = model(images)
217
+ loss = criterion(outputs, labels)
218
+ scaler.scale(loss).backward()
219
+ scaler.step(optimizer)
220
+ scaler.update()
221
+ else:
222
+ outputs = model(images)
223
+ loss = criterion(outputs, labels)
224
+ loss.backward()
225
+ optimizer.step()
226
+
227
+ total_loss += loss.item()
228
+ _, predicted = outputs.max(1)
229
+ total += labels.size(0)
230
+ correct += predicted.eq(labels).sum().item()
231
+
232
+ pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*correct/total:.1f}%'})
233
+
234
+ return total_loss / len(dataloader), correct / total
235
+
236
+
237
+ @torch.no_grad()
238
+ def evaluate(model, dataloader, criterion, device):
239
+ model.eval()
240
+
241
+ total_loss = 0
242
+ correct = 0
243
+ total = 0
244
+ all_preds, all_labels = [], []
245
+
246
+ for images, labels in tqdm(dataloader, desc='Evaluating'):
247
+ images, labels = images.to(device), labels.to(device)
248
+
249
+ outputs = model(images)
250
+ loss = criterion(outputs, labels)
251
+
252
+ total_loss += loss.item()
253
+ _, predicted = outputs.max(1)
254
+ total += labels.size(0)
255
+ correct += predicted.eq(labels).sum().item()
256
+
257
+ all_preds.extend(predicted.cpu().numpy())
258
+ all_labels.extend(labels.cpu().numpy())
259
+
260
+ return total_loss / len(dataloader), correct / total, all_preds, all_labels
261
+
262
+
263
+ def print_confusion_matrix(preds, labels):
264
+ """Print confusion matrix"""
265
+ from collections import defaultdict
266
+
267
+ matrix = defaultdict(lambda: defaultdict(int))
268
+ for p, l in zip(preds, labels):
269
+ matrix[IDX_TO_CLASS[l]][IDX_TO_CLASS[p]] += 1
270
+
271
+ print("\nConfusion Matrix (rows=true, cols=pred):")
272
+
273
+ # Header
274
+ header = f"{'':>12}" + "".join(f"{c[:6]:>8}" for c in POSE_CLASSES)
275
+ print(header)
276
+
277
+ for true_class in POSE_CLASSES:
278
+ row = f"{true_class:>12}"
279
+ for pred_class in POSE_CLASSES:
280
+ count = matrix[true_class][pred_class]
281
+ row += f"{count:>8}"
282
+ print(row)
283
+
284
+ # Per-class accuracy
285
+ print("\nPer-class accuracy:")
286
+ for cls in POSE_CLASSES:
287
+ correct = matrix[cls][cls]
288
+ total = sum(matrix[cls].values())
289
+ acc = correct / total * 100 if total > 0 else 0
290
+ print(f" {cls:>12}: {acc:5.1f}% ({correct}/{total})")
291
+
292
+
293
+ def export_onnx(model, output_path, device='cpu'):
294
+ """Export to ONNX"""
295
+ model.eval()
296
+ model.to(device)
297
+
298
+ dummy = torch.randn(1, 3, 224, 224).to(device)
299
+
300
+ torch.onnx.export(
301
+ model, dummy, output_path,
302
+ export_params=True,
303
+ opset_version=14,
304
+ input_names=['image'],
305
+ output_names=['logits'],
306
+ dynamic_axes={'image': {0: 'batch'}, 'logits': {0: 'batch'}}
307
+ )
308
+ print(f"Exported to {output_path}")
309
+
310
+
311
+ def main():
312
+ parser = argparse.ArgumentParser()
313
+ parser.add_argument('--data_dir', type=str, help='Directory with class folders')
314
+ parser.add_argument('--train_csv', type=str, help='Training CSV')
315
+ parser.add_argument('--val_csv', type=str, help='Validation CSV')
316
+ parser.add_argument('--model_size', type=str, default='small', choices=['small', 'base', 'large'])
317
+ parser.add_argument('--epochs', type=int, default=30)
318
+ parser.add_argument('--batch_size', type=int, default=32)
319
+ parser.add_argument('--lr', type=float, default=1e-3)
320
+ parser.add_argument('--output_dir', type=str, default='./checkpoints')
321
+ parser.add_argument('--export_onnx', action='store_true')
322
+ args = parser.parse_args()
323
+
324
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
325
+ print(f"Device: {device}")
326
+
327
+ os.makedirs(args.output_dir, exist_ok=True)
328
+
329
+ # Load data
330
+ train_transform = get_transforms(train=True)
331
+ val_transform = get_transforms(train=False)
332
+
333
+ if args.train_csv:
334
+ train_dataset = PoseDataset(args.train_csv, train_transform, augment_flip=True)
335
+ val_dataset = PoseDataset(args.val_csv, val_transform, augment_flip=False) if args.val_csv else None
336
+ elif args.data_dir:
337
+ full_dataset = PoseDataset(args.data_dir, train_transform, augment_flip=True)
338
+ # Split 80/20
339
+ n_val = int(0.2 * len(full_dataset))
340
+ n_train = len(full_dataset) - n_val
341
+ train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [n_train, n_val])
342
+ # Wrap val with no augmentation
343
+ val_dataset.dataset.augment_flip = False
344
+ val_dataset.dataset.transform = val_transform
345
+ else:
346
+ print("Provide --data_dir or --train_csv")
347
+ return
348
+
349
+ # Weighted sampler for class balance
350
+ if hasattr(train_dataset, 'get_sample_weights'):
351
+ weights = train_dataset.get_sample_weights()
352
+ sampler = WeightedRandomSampler(weights, len(weights))
353
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4)
354
+ else:
355
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
356
+
357
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) if val_dataset else None
358
+
359
+ # Model
360
+ print(f"\nLoading DINOv2-{args.model_size}...")
361
+ model = PoseClassifier(model_size=args.model_size).to(device)
362
+
363
+ trainable = sum(p.numel() for p in model.head.parameters())
364
+ print(f"Trainable parameters: {trainable:,}")
365
+
366
+ # Training
367
+ criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
368
+ optimizer = torch.optim.AdamW(model.head.parameters(), lr=args.lr, weight_decay=0.01)
369
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
370
+ scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
371
+
372
+ best_acc = 0
373
+
374
+ for epoch in range(args.epochs):
375
+ print(f"\nEpoch {epoch+1}/{args.epochs}")
376
+
377
+ train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, scaler)
378
+
379
+ if val_loader:
380
+ val_loss, val_acc, preds, labels = evaluate(model, val_loader, criterion, device)
381
+ print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.1f}%")
382
+ print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc*100:.1f}%")
383
+
384
+ if val_acc > best_acc:
385
+ best_acc = val_acc
386
+ torch.save({
387
+ 'epoch': epoch,
388
+ 'model_state_dict': model.state_dict(),
389
+ 'head_state_dict': model.head.state_dict(),
390
+ 'val_acc': val_acc,
391
+ 'classes': POSE_CLASSES,
392
+ }, f'{args.output_dir}/best_pose_model.pth')
393
+ print(f" → Saved (acc: {val_acc*100:.1f}%)")
394
+ else:
395
+ print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.1f}%")
396
+
397
+ scheduler.step()
398
+
399
+ # Final evaluation
400
+ if val_loader:
401
+ print("\n" + "="*60)
402
+ print("Final Evaluation")
403
+ print("="*60)
404
+
405
+ ckpt = torch.load(f'{args.output_dir}/best_pose_model.pth')
406
+ model.load_state_dict(ckpt['model_state_dict'])
407
+
408
+ _, acc, preds, labels = evaluate(model, val_loader, criterion, device)
409
+ print(f"Best Accuracy: {acc*100:.1f}%")
410
+ print_confusion_matrix(preds, labels)
411
+
412
+ # Export
413
+ if args.export_onnx:
414
+ export_onnx(model, f'{args.output_dir}/pose_classifier.onnx')
415
+
416
+ print("\nDone!")
417
+
418
+
419
+ if __name__ == '__main__':
420
+ main()