Spaces:
Running
Running
Commit ·
4bd136e
1
Parent(s): bf06606
Deploy SignMotionGPT Demo with LFS
Browse files- .gitattributes +4 -0
- INFERENCE_AND_VIS.md +253 -0
- README.md +116 -12
- app.py +166 -0
- collators.py +75 -0
- config.py +87 -0
- data.py +169 -0
- data/motion_llm_dataset.json +3 -0
- data/smplx_models/SMPLX_NEUTRAL.npz +3 -0
- data/vqvae_model.pt +3 -0
- data/vqvae_stats.pt +3 -0
- generate.py +194 -0
- inference.py +244 -0
- mGPT/__init__.py +0 -0
- mGPT/archs/__init__.py +0 -0
- mGPT/archs/mgpt_vq.py +189 -0
- mGPT/archs/tools/__init__.py +0 -0
- mGPT/archs/tools/quantize_cnn.py +410 -0
- mGPT/archs/tools/resnet.py +81 -0
- metrics.py +731 -0
- model.py +152 -0
- requirements.txt +32 -0
- setup_env.sh +70 -0
- templates.py +133 -0
- test_dataset_eval.py +534 -0
- test_overfit.py +1562 -0
- train.py +744 -0
- train_mgpt_vqvae.py +438 -0
- train_pipeline.py +264 -0
- train_vqvae.py +421 -0
- visualize.py +681 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/motion_llm_dataset.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
/content/SignMotionGPT/data/vqvae_model.pt filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
/content/SignMotionGPT/data/smplx_models/SMPLX_NEUTRAL.npz filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
/content/SignMotionGPT/data/vqvae_stats.pt filter=lfs diff=lfs merge=lfs -text
|
INFERENCE_AND_VIS.md
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference & Visualization Quick Reference
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
After training your 3-stage SignMotionGPT model, use these scripts to generate and visualize motions.
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## 1. Inference (Generate Motion Tokens)
|
| 9 |
+
|
| 10 |
+
### Basic Usage
|
| 11 |
+
```bash
|
| 12 |
+
# Generate from Stage 3 model (recommended)
|
| 13 |
+
python inference.py --prompt "walking forward"
|
| 14 |
+
|
| 15 |
+
# Try different stages
|
| 16 |
+
python inference.py --prompt "dancing" --stage 1 # Motion-only LM
|
| 17 |
+
python inference.py --prompt "dancing" --stage 2 # Multi-task
|
| 18 |
+
python inference.py --prompt "dancing" --stage 3 # T2M SFT (best quality)
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
### Save Output
|
| 22 |
+
```bash
|
| 23 |
+
python inference.py --prompt "jumping" --output my_motion.txt
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
### With Participant ID
|
| 27 |
+
```bash
|
| 28 |
+
python inference.py --prompt "yoga pose" --pid P40
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### Expected Output
|
| 32 |
+
```
|
| 33 |
+
============================================================
|
| 34 |
+
Motion Generation Inference - Stage 3
|
| 35 |
+
============================================================
|
| 36 |
+
Prompt: 'walking forward'
|
| 37 |
+
Device: cuda
|
| 38 |
+
|
| 39 |
+
Loading Stage 3 model from: /kaggle/working/SignMotionGPT/stage3_t2m_sft
|
| 40 |
+
✅ Stage 3 model loaded successfully
|
| 41 |
+
|
| 42 |
+
Generating motion for: 'walking forward'
|
| 43 |
+
|
| 44 |
+
============================================================
|
| 45 |
+
Generated Motion:
|
| 46 |
+
============================================================
|
| 47 |
+
<MOT_BEGIN><motion_224><motion_39><motion_76>...<MOT_END>
|
| 48 |
+
============================================================
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
## 2. Visualization (Motion Tokens → 3D Animation)
|
| 54 |
+
|
| 55 |
+
### Prerequisites
|
| 56 |
+
|
| 57 |
+
#### Option A: Use Google Drive (Colab/Kaggle)
|
| 58 |
+
Edit `setup_env.sh` and add your Google Drive file IDs:
|
| 59 |
+
```bash
|
| 60 |
+
VQVAE_MODEL_ID="1AbCdEfGhIj" # VQ-VAE checkpoint (.pt)
|
| 61 |
+
VQVAE_STATS_ID="2KlMnOpQrSt" # Normalization stats (.pt)
|
| 62 |
+
SMPLX_MODELS_ID="3UvWxYzAbCd" # SMPL-X models (.zip)
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Then run:
|
| 66 |
+
```bash
|
| 67 |
+
bash setup_env.sh
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
#### Option B: Manual Setup (Local)
|
| 71 |
+
```bash
|
| 72 |
+
export VQVAE_CHECKPOINT=/path/to/vqvae_model.pt
|
| 73 |
+
export VQVAE_STATS_PATH=/path/to/vqvae_stats.pt
|
| 74 |
+
export SMPLX_MODEL_DIR=/path/to/smplx_models
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Basic Usage
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
# Visualize token string
|
| 81 |
+
python visualize.py --tokens "<MOT_BEGIN><motion_177><motion_135>...<MOT_END>"
|
| 82 |
+
|
| 83 |
+
# Visualize from file
|
| 84 |
+
python visualize.py --input my_motion.txt
|
| 85 |
+
|
| 86 |
+
# Generate + visualize in one command
|
| 87 |
+
python visualize.py --prompt "walking" --stage 3
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Custom Output
|
| 91 |
+
```bash
|
| 92 |
+
python visualize.py \
|
| 93 |
+
--input motion_tokens.txt \
|
| 94 |
+
--output walk_animation.html \
|
| 95 |
+
--title "Walking Forward" \
|
| 96 |
+
--fps 30
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
### With Custom Paths
|
| 100 |
+
```bash
|
| 101 |
+
python visualize.py \
|
| 102 |
+
--tokens "<MOT_BEGIN>..." \
|
| 103 |
+
--vqvae-ckpt /custom/vqvae.pt \
|
| 104 |
+
--stats /custom/stats.pt \
|
| 105 |
+
--smplx-dir /custom/smplx_models \
|
| 106 |
+
--output animation.html
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
### Expected Output
|
| 110 |
+
```
|
| 111 |
+
============================================================
|
| 112 |
+
Motion Visualization Pipeline
|
| 113 |
+
============================================================
|
| 114 |
+
|
| 115 |
+
[1/5] Parsing tokens...
|
| 116 |
+
Parsed 15 tokens
|
| 117 |
+
|
| 118 |
+
[2/5] Loading VQ-VAE...
|
| 119 |
+
✅ VQ-VAE loaded (codebook size: 512)
|
| 120 |
+
|
| 121 |
+
[3/5] Loading normalization stats...
|
| 122 |
+
✅ Stats loaded (mean shape: (182,))
|
| 123 |
+
|
| 124 |
+
[4/5] Loading SMPL-X model...
|
| 125 |
+
✅ SMPL-X loaded
|
| 126 |
+
|
| 127 |
+
[5/5] Decoding and rendering...
|
| 128 |
+
Decoding tokens to SMPL-X parameters...
|
| 129 |
+
Decoded params shape: (16, 182)
|
| 130 |
+
Converting parameters to vertices...
|
| 131 |
+
Vertices shape: (16, 10475, 3), Faces: (20908, 3)
|
| 132 |
+
Creating animation...
|
| 133 |
+
✅ Animation saved to: motion_animation.html
|
| 134 |
+
|
| 135 |
+
============================================================
|
| 136 |
+
✅ Visualization complete!
|
| 137 |
+
============================================================
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
---
|
| 141 |
+
|
| 142 |
+
## 3. Complete Workflow Example
|
| 143 |
+
|
| 144 |
+
### A. Train (already done)
|
| 145 |
+
```bash
|
| 146 |
+
python train_pipeline.py
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
### B. Generate Motion Tokens
|
| 150 |
+
```bash
|
| 151 |
+
python inference.py --prompt "college" --stage 3 --output college_motion.txt
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
### C. Visualize
|
| 155 |
+
```bash
|
| 156 |
+
python visualize.py --input college_motion.txt --output college_animation.html
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
### D. View Animation
|
| 160 |
+
Open `college_animation.html` in a browser. You'll see an interactive 3D SMPL-X character performing the motion. Use mouse to rotate/zoom, and click Play/Pause buttons.
|
| 161 |
+
|
| 162 |
+
---
|
| 163 |
+
|
| 164 |
+
## 4. Troubleshooting
|
| 165 |
+
|
| 166 |
+
### Inference Issues
|
| 167 |
+
|
| 168 |
+
**"Checkpoint not found"**
|
| 169 |
+
- Ensure you've trained all stages first: `python train_pipeline.py`
|
| 170 |
+
- Check that `OUT_S1`, `OUT_S2`, `OUT_S3` directories exist in `WORK_DIR`
|
| 171 |
+
|
| 172 |
+
**"Dataset not found"**
|
| 173 |
+
- Inference needs the dataset to build vocabulary
|
| 174 |
+
- Set `DATA_JSON_PATH` in `config.py` or via environment variable
|
| 175 |
+
|
| 176 |
+
### Visualization Issues
|
| 177 |
+
|
| 178 |
+
**"VQ-VAE checkpoint not found"**
|
| 179 |
+
- Download VQ-VAE model or set `VQVAE_CHECKPOINT` path
|
| 180 |
+
- The VQ-VAE is separate from LLM training (used to decode tokens to SMPL-X params)
|
| 181 |
+
|
| 182 |
+
**"SMPL-X models not found"**
|
| 183 |
+
- Download SMPL-X models from https://smpl-x.is.tue.mpg.de/
|
| 184 |
+
- Extract to a directory and set `SMPLX_MODEL_DIR`
|
| 185 |
+
|
| 186 |
+
**"No tokens to visualize"**
|
| 187 |
+
- Check token format: should contain `<motion_ID>` tags or space-separated numbers
|
| 188 |
+
- Example valid formats:
|
| 189 |
+
- `<MOT_BEGIN><motion_177><motion_135><MOT_END>`
|
| 190 |
+
- `177 135 152 200 46 142`
|
| 191 |
+
|
| 192 |
+
**"Shape mismatch" or "Decoding errors"**
|
| 193 |
+
- Ensure VQ-VAE checkpoint matches the codebook size used in LLM training
|
| 194 |
+
- Check `CODEBOOK_SIZE`, `CODE_DIM`, `SMPL_DIM` in `visualize.py` match training
|
| 195 |
+
|
| 196 |
+
---
|
| 197 |
+
|
| 198 |
+
## 5. Configuration
|
| 199 |
+
|
| 200 |
+
### Key Environment Variables
|
| 201 |
+
|
| 202 |
+
| Variable | Purpose | Default |
|
| 203 |
+
|----------|---------|---------|
|
| 204 |
+
| `VQVAE_CHECKPOINT` | VQ-VAE model path | `./data/vqvae_model.pt` |
|
| 205 |
+
| `VQVAE_STATS_PATH` | Normalization stats | `./data/vqvae_stats.pt` |
|
| 206 |
+
| `SMPLX_MODEL_DIR` | SMPL-X models directory | `./data/smplx_models` |
|
| 207 |
+
| `VIS_OUTPUT_DIR` | Output directory for animations | `WORK_DIR` |
|
| 208 |
+
|
| 209 |
+
### VQ-VAE Architecture (must match training)
|
| 210 |
+
In `visualize.py`:
|
| 211 |
+
```python
|
| 212 |
+
SMPL_DIM = 182 # SMPL-X parameter dimension
|
| 213 |
+
CODEBOOK_SIZE = 512 # Motion vocabulary size
|
| 214 |
+
CODE_DIM = 512 # Latent code dimension
|
| 215 |
+
VQ_ARGS = dict(
|
| 216 |
+
width=512,
|
| 217 |
+
depth=3,
|
| 218 |
+
down_t=2,
|
| 219 |
+
stride_t=2,
|
| 220 |
+
...
|
| 221 |
+
)
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
---
|
| 225 |
+
|
| 226 |
+
## 6. Tips
|
| 227 |
+
|
| 228 |
+
### Inference
|
| 229 |
+
- **Stage 3** generally produces best quality for text-to-motion
|
| 230 |
+
- **Stage 2** can handle M2T and denoising (but inference.py only does T2M)
|
| 231 |
+
- **Stage 1** generates motion without text conditioning (still needs prompt for length)
|
| 232 |
+
- Use `--no-per-prompt-vocab` to allow novel combinations (less constrained)
|
| 233 |
+
|
| 234 |
+
### Visualization
|
| 235 |
+
- **FPS 20-30** works well for most motions
|
| 236 |
+
- Longer sequences may take a few seconds to render
|
| 237 |
+
- The HTML file is self-contained and can be shared
|
| 238 |
+
- 3D mesh has ~10K vertices; animations can be large for long sequences
|
| 239 |
+
|
| 240 |
+
### Performance
|
| 241 |
+
- Inference: ~1-2 seconds per generation (depends on length)
|
| 242 |
+
- Visualization: ~3-10 seconds (depends on sequence length and batch size)
|
| 243 |
+
- Both run on GPU if available, fall back to CPU otherwise
|
| 244 |
+
|
| 245 |
+
---
|
| 246 |
+
|
| 247 |
+
## 7. Next Steps
|
| 248 |
+
|
| 249 |
+
- **Batch Inference**: Loop over multiple prompts and save outputs
|
| 250 |
+
- **Evaluate Quality**: Compare generated tokens to ground truth using edit distance
|
| 251 |
+
- **Fine-tune Generation**: Adjust `GEN_TEMPERATURE`, `GEN_TOP_P` in `config.py`
|
| 252 |
+
- **Export to Other Formats**: Extend `visualize.py` to export BVH, FBX, or USD
|
| 253 |
+
|
README.md
CHANGED
|
@@ -1,12 +1,116 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
### 1) Configure setup script (one time)
|
| 3 |
+
|
| 4 |
+
Run the setup:
|
| 5 |
+
|
| 6 |
+
```bash
|
| 7 |
+
bash setup_env.sh
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
After setup, defaults are:
|
| 11 |
+
- `WORK_DIR` = current directory
|
| 12 |
+
- `DATA_JSON_PATH` = `./data/motion_llm_dataset.json`
|
| 13 |
+
|
| 14 |
+
You can override via environment variables if needed:
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
export WORK_DIR=/path/to/workdir
|
| 18 |
+
export DATA_JSON_PATH=/path/to/motion_llm_dataset.json
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Overview
|
| 22 |
+
|
| 23 |
+
This repository implements a robust 2-stage training pipeline for motion generation, replicating the high-performance "overfit" test setup:
|
| 24 |
+
- **Stage 1**: Motion-only Language Model (MLM) - Pre-training on motion token sequences to learn the "language of motion".
|
| 25 |
+
- **Stage 2**: Text-to-Motion Fine-Tuning (T2M) - Supervised fine-tuning to align text prompts with motion sequences.
|
| 26 |
+
|
| 27 |
+
Key features:
|
| 28 |
+
- **Integrated Evaluation**: Automatically computes FID, Diversity, and Multimodality (MIM) metrics.
|
| 29 |
+
- **Side-by-Side Visualization**: Generates HTML comparisons of Ground Truth vs Generated motions.
|
| 30 |
+
- **Test Set Evaluation**: Can optionally run evaluation on a held-out test set (SMPL-X data).
|
| 31 |
+
- **Hugging Face Integration**: Automatic checkpointing and resuming from the Hub.
|
| 32 |
+
|
| 33 |
+
## Installation
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
# Clone the repository
|
| 37 |
+
git clone https://github.com/rajvizala/SignMotionGPT.git
|
| 38 |
+
cd SignMotionGPT
|
| 39 |
+
|
| 40 |
+
# Setup Everything
|
| 41 |
+
bash setup_env.sh
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Dataset Format
|
| 45 |
+
|
| 46 |
+
Your dataset should be a JSON file with the following structure:
|
| 47 |
+
|
| 48 |
+
```json
|
| 49 |
+
[
|
| 50 |
+
{
|
| 51 |
+
"text_query": "a person walks forward",
|
| 52 |
+
"motion_tokens": "42 18 91 ...",
|
| 53 |
+
"participant_id": "P001" // Optional
|
| 54 |
+
},
|
| 55 |
+
...
|
| 56 |
+
]
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Quick Start
|
| 60 |
+
|
| 61 |
+
### 1. Configure Training
|
| 62 |
+
|
| 63 |
+
Edit `config.py` to set your paths and hyperparameters. Key settings include:
|
| 64 |
+
- `DATA_JSON_PATH`: Path to your dataset.
|
| 65 |
+
- `MODEL_NAME`: Base model (e.g., "Qwen/Qwen3-0.6B").
|
| 66 |
+
- `PIPELINE_OUTPUT_DIR`: Directory for checkpoints and results.
|
| 67 |
+
- `HF_TOKEN`: Your Hugging Face token (or set via env var).
|
| 68 |
+
|
| 69 |
+
### 2. Run Full Pipeline
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
python train_pipeline.py
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
This script orchestrates the entire process:
|
| 76 |
+
1. **Data Loading & Cleaning**: Deduplicates samples and builds vocabulary.
|
| 77 |
+
2. **Stage 1 Training**: Motion Language Modeling (Pre-training).
|
| 78 |
+
3. **Stage 2 Training**: Text-to-Motion Fine-Tuning.
|
| 79 |
+
4. **Evaluation**: Runs inference on specific words, computes metrics (FID, Diversity, MIM), and generates visualizations.
|
| 80 |
+
5. **Test Set Evaluation**: (Optional) Runs evaluation on held-out test data if configured.
|
| 81 |
+
|
| 82 |
+
### 3. Environment Variables
|
| 83 |
+
|
| 84 |
+
You can control many aspects via environment variables without editing code:
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
# Training Config
|
| 88 |
+
export PIPELINE_S1_EPOCHS=20
|
| 89 |
+
export PIPELINE_S2_EPOCHS=20
|
| 90 |
+
export PIPELINE_S1_BATCH=8
|
| 91 |
+
export PIPELINE_S2_BATCH=8
|
| 92 |
+
|
| 93 |
+
# Hugging Face
|
| 94 |
+
export HUGGINGFACE_HUB_TOKEN="your_token"
|
| 95 |
+
export HF_UPLOAD_INTERVAL_EPOCHS=2
|
| 96 |
+
|
| 97 |
+
# Evaluation
|
| 98 |
+
export EVALUATION_WORDS="passport,send,library"
|
| 99 |
+
export TEST_EVAL_SAMPLE_LIMIT=100
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
## Held-out Test Dataset Evaluation
|
| 103 |
+
|
| 104 |
+
The pipeline includes integration with `test_dataset_eval.py` to measure performance on an unseen SMPL-X test dataset.
|
| 105 |
+
|
| 106 |
+
To enable this, ensure `TEST_EVAL_DOWNLOAD_DIR` or `TEST_EVAL_EXTRACT_DIR` are configured in `config.py` or via env vars. The pipeline will attempt to run this after training if data is available.
|
| 107 |
+
|
| 108 |
+
## Visualization
|
| 109 |
+
|
| 110 |
+
The pipeline automatically generates side-by-side HTML visualizations in the output directory (`html_visualizations` folder). You can open these in any browser to compare Ground Truth motions with the model's generations.
|
| 111 |
+
|
| 112 |
+
To manually visualize tokens:
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
python visualize.py --tokens "<MOT_BEGIN><motion_177>...<MOT_END>" --output my_anim.html
|
| 116 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import warnings
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
# Add root to path to allow imports from project root when running from demo-code/
|
| 9 |
+
# or when running from root
|
| 10 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 11 |
+
parent_dir = os.path.dirname(current_dir)
|
| 12 |
+
sys.path.append(current_dir)
|
| 13 |
+
sys.path.append(parent_dir)
|
| 14 |
+
|
| 15 |
+
# Import project modules
|
| 16 |
+
try:
|
| 17 |
+
from inference import load_trained_model, inference as run_inference_cmd
|
| 18 |
+
from visualize import visualize
|
| 19 |
+
from model import setup_model_and_tokenizer, get_motion_token_info
|
| 20 |
+
from generate import generate_t2m
|
| 21 |
+
from data import compute_length_stats, build_prompt_vocab, check_has_participant_id, load_dataset
|
| 22 |
+
import config
|
| 23 |
+
except ImportError as e:
|
| 24 |
+
print(f"Error importing project modules: {e}")
|
| 25 |
+
print("Make sure you are running this from the project root or have the project structure intact.")
|
| 26 |
+
|
| 27 |
+
# Constants
|
| 28 |
+
HF_REPO_ID = "rdz-falcon/SignMotionGPT"
|
| 29 |
+
EPOCH_SUBFOLDER = "stage2/epoch-030"
|
| 30 |
+
|
| 31 |
+
def load_model_from_hf(repo_id, subfolder, token=None):
|
| 32 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 33 |
+
print(f"Loading model from HF: {repo_id}/{subfolder}")
|
| 34 |
+
try:
|
| 35 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder, token=token, trust_remote_code=True)
|
| 36 |
+
model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=subfolder, token=token, trust_remote_code=True)
|
| 37 |
+
return model, tokenizer
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"Error loading model: {e}")
|
| 40 |
+
return None, None
|
| 41 |
+
|
| 42 |
+
# Global model cache
|
| 43 |
+
MODEL = None
|
| 44 |
+
TOKENIZER = None
|
| 45 |
+
MOTION_TOKEN_IDS = None
|
| 46 |
+
MOT_BEGIN_ID = None
|
| 47 |
+
MOT_END_ID = None
|
| 48 |
+
CODEBOOK_SIZE = 512
|
| 49 |
+
|
| 50 |
+
def init_model():
|
| 51 |
+
global MODEL, TOKENIZER, MOTION_TOKEN_IDS, MOT_BEGIN_ID, MOT_END_ID
|
| 52 |
+
if MODEL is not None:
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
|
| 56 |
+
|
| 57 |
+
# Load model/tokenizer
|
| 58 |
+
MODEL, TOKENIZER = load_model_from_hf(HF_REPO_ID, EPOCH_SUBFOLDER, token)
|
| 59 |
+
|
| 60 |
+
if MODEL is None:
|
| 61 |
+
raise RuntimeError(f"Failed to load model from {HF_REPO_ID}/{EPOCH_SUBFOLDER}")
|
| 62 |
+
|
| 63 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 64 |
+
MODEL.to(device)
|
| 65 |
+
MODEL.eval()
|
| 66 |
+
|
| 67 |
+
# Setup token info
|
| 68 |
+
motion_token_ids = []
|
| 69 |
+
for i in range(CODEBOOK_SIZE):
|
| 70 |
+
t = f"<motion_{i}>"
|
| 71 |
+
if t in TOKENIZER.get_vocab():
|
| 72 |
+
motion_token_ids.append(TOKENIZER.convert_tokens_to_ids(t))
|
| 73 |
+
|
| 74 |
+
MOTION_TOKEN_IDS = motion_token_ids
|
| 75 |
+
MOT_BEGIN_ID = TOKENIZER.convert_tokens_to_ids("<MOT_BEGIN>") if "<MOT_BEGIN>" in TOKENIZER.get_vocab() else None
|
| 76 |
+
MOT_END_ID = TOKENIZER.convert_tokens_to_ids("<MOT_END>") if "<MOT_END>" in TOKENIZER.get_vocab() else None
|
| 77 |
+
|
| 78 |
+
print("Model initialized.")
|
| 79 |
+
|
| 80 |
+
def generate_motion_app(text_prompt):
|
| 81 |
+
if not text_prompt:
|
| 82 |
+
return None, "Please enter a prompt."
|
| 83 |
+
|
| 84 |
+
if MODEL is None:
|
| 85 |
+
try:
|
| 86 |
+
init_model()
|
| 87 |
+
except Exception as e:
|
| 88 |
+
return None, f"Model Initialization Failed: {e}"
|
| 89 |
+
|
| 90 |
+
device = MODEL.device
|
| 91 |
+
print(f"Generating for: {text_prompt}")
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
generated_tokens = generate_t2m(
|
| 95 |
+
model=MODEL,
|
| 96 |
+
tokenizer=TOKENIZER,
|
| 97 |
+
prompt_text=text_prompt,
|
| 98 |
+
mot_begin_id=MOT_BEGIN_ID,
|
| 99 |
+
mot_end_id=MOT_END_ID,
|
| 100 |
+
motion_token_ids=MOTION_TOKEN_IDS,
|
| 101 |
+
length_stats_by_text={}, # Fallback to global_median_len
|
| 102 |
+
global_median_len=100, # Reasonable default
|
| 103 |
+
prompt_vocab=None,
|
| 104 |
+
has_pid=False,
|
| 105 |
+
per_prompt_vocab=False # Allow all tokens
|
| 106 |
+
)
|
| 107 |
+
except Exception as e:
|
| 108 |
+
return None, f"Generation Error: {e}"
|
| 109 |
+
|
| 110 |
+
# Visualization
|
| 111 |
+
try:
|
| 112 |
+
# Ensure paths for VQ-VAE and SMPL-X
|
| 113 |
+
# In HF Spaces, we assume these are in the repo (e.g., ./data)
|
| 114 |
+
data_dir = os.environ.get("DATA_DIR", "data")
|
| 115 |
+
vqvae_ckpt = os.path.join(data_dir, "vqvae_model.pt")
|
| 116 |
+
stats_path = os.path.join(data_dir, "vqvae_stats.pt")
|
| 117 |
+
smplx_dir = os.path.join(data_dir, "smplx_models")
|
| 118 |
+
|
| 119 |
+
# Check existence
|
| 120 |
+
missing = []
|
| 121 |
+
if not os.path.exists(vqvae_ckpt): missing.append(vqvae_ckpt)
|
| 122 |
+
if not os.path.exists(stats_path): missing.append(stats_path)
|
| 123 |
+
if not os.path.exists(smplx_dir): missing.append(smplx_dir)
|
| 124 |
+
|
| 125 |
+
if missing:
|
| 126 |
+
return None, f"Missing visualization files in {data_dir}: {missing}. Please ensure they are uploaded to the Space."
|
| 127 |
+
|
| 128 |
+
# Output to a temporary file
|
| 129 |
+
# Gradio needs a file path or HTML string. visualize returns a Figure.
|
| 130 |
+
output_html = "temp_viz.html"
|
| 131 |
+
|
| 132 |
+
fig = visualize(
|
| 133 |
+
tokens=generated_tokens,
|
| 134 |
+
vqvae_ckpt=vqvae_ckpt,
|
| 135 |
+
stats_path=stats_path,
|
| 136 |
+
smplx_dir=smplx_dir,
|
| 137 |
+
output_html=output_html,
|
| 138 |
+
title=f"Motion: {text_prompt}",
|
| 139 |
+
fps=20
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
if fig is None:
|
| 143 |
+
return None, "Visualization failed (no frames produced)."
|
| 144 |
+
|
| 145 |
+
return fig, f"Success! Generated tokens length: {len(generated_tokens.split())}"
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
return None, f"Visualization Error: {e}"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# Gradio UI
|
| 152 |
+
with gr.Interface(
|
| 153 |
+
fn=generate_motion_app,
|
| 154 |
+
inputs=gr.Textbox(label="Enter Motion Prompt", placeholder="e.g. walking forward"),
|
| 155 |
+
outputs=[
|
| 156 |
+
gr.Plot(label="Motion Visualization"),
|
| 157 |
+
gr.Textbox(label="Status/Output")
|
| 158 |
+
],
|
| 159 |
+
title="SignMotionGPT Demo",
|
| 160 |
+
description="Generate Sign Language/Motion Avatars from Text. Using model checkpoint: epoch 30."
|
| 161 |
+
) as demo:
|
| 162 |
+
pass
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
demo.launch()
|
| 166 |
+
|
collators.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data collators with label masking for training
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AssistantSpanCollator:
|
| 8 |
+
"""
|
| 9 |
+
Collator that masks labels to only train on assistant responses.
|
| 10 |
+
|
| 11 |
+
For where=="mot": labels only inside <MOT_BEGIN>...<MOT_END> in assistant
|
| 12 |
+
For where=="text": labels entire assistant span (for M2T tasks)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, tokenizer, max_length):
|
| 16 |
+
self.tok = tokenizer
|
| 17 |
+
self.max_len = max_length
|
| 18 |
+
|
| 19 |
+
# Get special token IDs
|
| 20 |
+
self.im_start = self.tok.convert_tokens_to_ids("<|im_start|>")
|
| 21 |
+
self.im_end = self.tok.convert_tokens_to_ids("<|im_end|>")
|
| 22 |
+
self.mot_beg = self.tok.convert_tokens_to_ids("<MOT_BEGIN>")
|
| 23 |
+
self.mot_end = self.tok.convert_tokens_to_ids("<MOT_END>")
|
| 24 |
+
|
| 25 |
+
def __call__(self, examples):
|
| 26 |
+
texts = [e["text"] for e in examples]
|
| 27 |
+
wheres = [e["where"] for e in examples]
|
| 28 |
+
|
| 29 |
+
# Tokenize
|
| 30 |
+
enc = self.tok(
|
| 31 |
+
texts,
|
| 32 |
+
return_tensors="pt",
|
| 33 |
+
padding=True,
|
| 34 |
+
truncation=True,
|
| 35 |
+
max_length=self.max_len
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
input_ids = enc["input_ids"]
|
| 39 |
+
labels = input_ids.clone().fill_(-100)
|
| 40 |
+
|
| 41 |
+
# Apply label masking per example
|
| 42 |
+
for i, w in enumerate(wheres):
|
| 43 |
+
seq = input_ids[i]
|
| 44 |
+
|
| 45 |
+
# Find last <|im_start|> (start of assistant)
|
| 46 |
+
starts = (seq == self.im_start).nonzero(as_tuple=True)[0]
|
| 47 |
+
if starts.numel() == 0:
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
a_start = int(starts[-1].item())
|
| 51 |
+
|
| 52 |
+
# Find corresponding <|im_end|>
|
| 53 |
+
sub = seq[a_start+1:]
|
| 54 |
+
ends = (sub == self.im_end).nonzero(as_tuple=True)[0]
|
| 55 |
+
a_end = (a_start + 1 + int(ends[0].item())) if ends.numel() > 0 else (seq.size(0) - 1)
|
| 56 |
+
|
| 57 |
+
if w == "text":
|
| 58 |
+
# Label entire assistant span
|
| 59 |
+
labels[i, a_start+1:a_end] = seq[a_start+1:a_end]
|
| 60 |
+
else:
|
| 61 |
+
# Label only motion tokens between <MOT_BEGIN> and <MOT_END>
|
| 62 |
+
asst = seq[a_start+1:a_end]
|
| 63 |
+
bpos = (asst == self.mot_beg).nonzero(as_tuple=True)[0]
|
| 64 |
+
epos = (asst == self.mot_end).nonzero(as_tuple=True)[0]
|
| 65 |
+
|
| 66 |
+
if bpos.numel() > 0 and epos.numel() > 0 and epos[0] >= bpos[0]:
|
| 67 |
+
b = a_start + 1 + int(bpos[0].item())
|
| 68 |
+
e = a_start + 1 + int(epos[0].item())
|
| 69 |
+
labels[i, b:e+1] = seq[b:e+1]
|
| 70 |
+
|
| 71 |
+
return {
|
| 72 |
+
"input_ids": input_ids,
|
| 73 |
+
"attention_mask": enc["attention_mask"],
|
| 74 |
+
"labels": labels
|
| 75 |
+
}
|
config.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration file for Motion LLM training
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
# Random seed
|
| 8 |
+
SEED = 42
|
| 9 |
+
|
| 10 |
+
# Paths
|
| 11 |
+
# WORK_DIR defaults to current working directory if not explicitly set
|
| 12 |
+
WORK_DIR = os.environ.get("WORK_DIR", os.getcwd())
|
| 13 |
+
DATA_DIR = os.environ.get("DATA_DIR", os.path.join(WORK_DIR, "data"))
|
| 14 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
# Single-file JSON dataset path (can be overridden via env)
|
| 17 |
+
DATA_JSON_PATH = os.environ.get(
|
| 18 |
+
"DATA_JSON_PATH",
|
| 19 |
+
os.path.join(DATA_DIR, "motion_llm_dataset.json"),
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Directory Configuration
|
| 23 |
+
# PIPELINE_OUTPUT_DIR matches test_overfit's default "./motion_gpt_full_model"
|
| 24 |
+
PIPELINE_OUTPUT_DIR = os.environ.get("PIPELINE_OUTPUT_DIR", "./motion_gpt_full_model")
|
| 25 |
+
METRICS_JSON_PATH = os.path.join(PIPELINE_OUTPUT_DIR, "metrics.json")
|
| 26 |
+
CHECKPOINTS_DIR = os.path.join(PIPELINE_OUTPUT_DIR, "checkpoints")
|
| 27 |
+
|
| 28 |
+
# Model configuration
|
| 29 |
+
MODEL_NAME = "Qwen/Qwen3-0.6B" # Matches test_overfit.py
|
| 30 |
+
MAX_SEQ_LEN = 512 # Kept from previous config, though test_overfit uses 256 in datasets
|
| 31 |
+
DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16
|
| 32 |
+
|
| 33 |
+
# Evaluation Words (matches test_overfit.py)
|
| 34 |
+
EVALUATION_WORDS = ["passport", "send", "library", "push"]
|
| 35 |
+
|
| 36 |
+
# Training Hyperparameters (matches test_overfit.py)
|
| 37 |
+
# Stage 1
|
| 38 |
+
S1_EPOCHS = 20
|
| 39 |
+
S1_LR = 5e-5
|
| 40 |
+
S1_BATCH_SIZE = 8
|
| 41 |
+
|
| 42 |
+
# Stage 2
|
| 43 |
+
S2_EPOCHS = 20
|
| 44 |
+
S2_LR = 2e-5
|
| 45 |
+
S2_BATCH_SIZE = 8
|
| 46 |
+
|
| 47 |
+
# Inference Hyperparameters (matches test_overfit.py)
|
| 48 |
+
INFERENCE_REPETITION_PENALTY = 1.2
|
| 49 |
+
INFERENCE_TEMPERATURE = 0.7
|
| 50 |
+
INFERENCE_TOP_K = 50
|
| 51 |
+
|
| 52 |
+
# Special Tokens (matches test_overfit.py)
|
| 53 |
+
M_START = "<M_START>"
|
| 54 |
+
M_END = "<M_END>"
|
| 55 |
+
PAD_TOKEN = "<PAD>"
|
| 56 |
+
|
| 57 |
+
# Hugging Face Hub Configuration
|
| 58 |
+
HF_USE_HUB = True
|
| 59 |
+
HF_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("hf_auth_token")
|
| 60 |
+
HF_USER = os.environ.get("HF_USER", "rdz-falcon") # Derived from test_overfit.py repo ids
|
| 61 |
+
HF_STAGE1_REPO_ID = "rdz-falcon/SignMotionGPTfit-archive"
|
| 62 |
+
HF_STAGE2_REPO_ID = "rdz-falcon/SignMotionGPTfit-archive"
|
| 63 |
+
HF_PRIVATE_REPO = os.environ.get("HF_PRIVATE", "true").lower() != "false"
|
| 64 |
+
FORCE_STAGE2_FROM_STAGE1_RAW = os.environ.get("FORCE_STAGE2_FROM_STAGE1", "false")
|
| 65 |
+
FORCE_STAGE2_FROM_STAGE1 = str(FORCE_STAGE2_FROM_STAGE1_RAW).strip().lower() not in ("0", "false", "no", "off")
|
| 66 |
+
HF_STAGE2_SAVE_SUBDIR = os.environ.get("HF_STAGE2_SAVE_SUBDIR", "stage2_v2")
|
| 67 |
+
CHECKPOINT_UPLOAD_INTERVAL_EPOCHS = int(os.environ.get("HF_UPLOAD_INTERVAL_EPOCHS", "2"))
|
| 68 |
+
HF_DISABLE_PROGRESS = os.environ.get("HF_DISABLE_PROGRESS", "true").lower() != "false"
|
| 69 |
+
|
| 70 |
+
# Evaluation controls
|
| 71 |
+
RUN_EVALS_ONLY = False
|
| 72 |
+
EVAL_SAMPLE_LIMIT = 100
|
| 73 |
+
|
| 74 |
+
# Test Eval Config (from test_dataset_eval.py defaults)
|
| 75 |
+
TEST_EVAL_OUTPUT_DIR = os.environ.get("TEST_EVAL_OUTPUT_DIR", PIPELINE_OUTPUT_DIR)
|
| 76 |
+
TEST_EVAL_DOWNLOAD_DIR = os.environ.get(
|
| 77 |
+
"TEST_EVAL_DOWNLOAD_DIR", os.path.join(WORK_DIR, "test_data", "downloads")
|
| 78 |
+
)
|
| 79 |
+
TEST_EVAL_EXTRACT_DIR = os.environ.get(
|
| 80 |
+
"TEST_EVAL_EXTRACT_DIR", os.path.join(WORK_DIR, "test_data", "extracted")
|
| 81 |
+
)
|
| 82 |
+
TEST_EVAL_SAMPLE_LIMIT = int(os.environ.get("TEST_EVAL_SAMPLE_LIMIT", "300"))
|
| 83 |
+
TEST_EVAL_MAX_ZIPS = int(os.environ.get("TEST_EVAL_MAX_ZIPS", "500"))
|
| 84 |
+
TEST_EVAL_HF_REPO = os.environ.get("TEST_EVAL_HF_REPO", "rdz-falcon/SignMotionGPTfit-archive")
|
| 85 |
+
TEST_EVAL_HF_SUBFOLDER = os.environ.get(
|
| 86 |
+
"TEST_EVAL_HF_SUBFOLDER", f"{HF_STAGE2_SAVE_SUBDIR}/latest"
|
| 87 |
+
)
|
data.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset loading and vocabulary building utilities
|
| 3 |
+
"""
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
from typing import List, Dict, Tuple, Any
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
from config import M_START, M_END, PAD_TOKEN
|
| 13 |
+
|
| 14 |
+
# ======================================================================================
|
| 15 |
+
# Logic from test_overfit.py
|
| 16 |
+
# ======================================================================================
|
| 17 |
+
|
| 18 |
+
def read_json_data(json_path: str) -> List[Dict[str, Any]]:
|
| 19 |
+
"""Loads the dataset from the specified JSON file."""
|
| 20 |
+
if not os.path.exists(json_path):
|
| 21 |
+
raise FileNotFoundError(f"Dataset not found at: {json_path}")
|
| 22 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 23 |
+
return json.load(f)
|
| 24 |
+
|
| 25 |
+
def deduplicate_and_prepare_data(entries: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
| 26 |
+
"""
|
| 27 |
+
Cleans the entire dataset by ensuring each (word, participant_id) pair is unique.
|
| 28 |
+
If a conflict is found (same pair, different motion), it keeps only the first one encountered.
|
| 29 |
+
Then, it prepares the full list of motion tokens from the cleaned data.
|
| 30 |
+
"""
|
| 31 |
+
print("\n---> Cleaning dataset by removing ambiguous (word, participant_id) pairs...")
|
| 32 |
+
|
| 33 |
+
unique_samples = {}
|
| 34 |
+
conflicts_found = 0
|
| 35 |
+
|
| 36 |
+
for entry in entries:
|
| 37 |
+
word = entry.get("word", "").lower()
|
| 38 |
+
pid = entry.get("participant_id", "")
|
| 39 |
+
key = (word, pid)
|
| 40 |
+
|
| 41 |
+
if key not in unique_samples:
|
| 42 |
+
unique_samples[key] = entry
|
| 43 |
+
else:
|
| 44 |
+
# A sample for this key already exists. We only care if it's a conflict.
|
| 45 |
+
existing_tokens = unique_samples[key].get("motion_tokens")
|
| 46 |
+
current_tokens = entry.get("motion_tokens")
|
| 47 |
+
if existing_tokens != current_tokens:
|
| 48 |
+
conflicts_found += 1
|
| 49 |
+
# We do nothing, effectively discarding this new conflicting sample.
|
| 50 |
+
|
| 51 |
+
cleaned_data = list(unique_samples.values())
|
| 52 |
+
|
| 53 |
+
print(f"Original samples: {len(entries)}")
|
| 54 |
+
print(f"Cleaned samples (unique (word, pid) pairs): {len(cleaned_data)}")
|
| 55 |
+
print(f"Removed {len(entries) - len(cleaned_data)} total samples. ({conflicts_found} were direct conflicts).")
|
| 56 |
+
|
| 57 |
+
print("\n---> Extracting motion tokens from the full cleaned dataset...")
|
| 58 |
+
all_motion_tokens = set()
|
| 59 |
+
for entry in cleaned_data:
|
| 60 |
+
motion_tokens = entry.get("motion_tokens", "").strip().split()
|
| 61 |
+
for token in motion_tokens:
|
| 62 |
+
all_motion_tokens.add(f"<M{token}>")
|
| 63 |
+
|
| 64 |
+
unique_tokens = sorted(list(all_motion_tokens))
|
| 65 |
+
print(f"Found {len(unique_tokens)} unique motion tokens in the entire dataset.")
|
| 66 |
+
|
| 67 |
+
return cleaned_data, unique_tokens
|
| 68 |
+
|
| 69 |
+
class MotionDataset(Dataset):
|
| 70 |
+
"""Dataset for Stage 1: Contains only motion token sequences."""
|
| 71 |
+
def __init__(self, data: List[Dict[str, Any]], tokenizer: AutoTokenizer, max_length: int = 256):
|
| 72 |
+
self.tokenizer = tokenizer
|
| 73 |
+
self.max_length = max_length
|
| 74 |
+
self.sequences = []
|
| 75 |
+
|
| 76 |
+
for item in data:
|
| 77 |
+
tokens_str = item.get("motion_tokens", "")
|
| 78 |
+
wrapped_tokens = " ".join([f"<M{t}>" for t in tokens_str.split()])
|
| 79 |
+
full_sequence = f"{M_START} {wrapped_tokens} {M_END}"
|
| 80 |
+
self.sequences.append(full_sequence)
|
| 81 |
+
|
| 82 |
+
def __len__(self):
|
| 83 |
+
return len(self.sequences)
|
| 84 |
+
|
| 85 |
+
def __getitem__(self, idx):
|
| 86 |
+
return self.tokenizer(
|
| 87 |
+
self.sequences[idx],
|
| 88 |
+
truncation=True,
|
| 89 |
+
max_length=self.max_length,
|
| 90 |
+
padding="max_length",
|
| 91 |
+
return_tensors="pt"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
class TextMotionDataset(Dataset):
|
| 95 |
+
"""Dataset for Stage 2: Contains (prompt, motion_sequence) pairs."""
|
| 96 |
+
def __init__(self, data: List[Dict[str, Any]], tokenizer: AutoTokenizer, max_length: int = 256):
|
| 97 |
+
self.tokenizer = tokenizer
|
| 98 |
+
self.max_length = max_length
|
| 99 |
+
self.items = []
|
| 100 |
+
|
| 101 |
+
for item in data:
|
| 102 |
+
prompt = f"Instruction: Generate motion for word '{item['word']}' with variant '{item['participant_id']}'.\nMotion: "
|
| 103 |
+
|
| 104 |
+
tokens_str = item.get("motion_tokens", "")
|
| 105 |
+
wrapped_tokens = " ".join([f"<M{t}>" for t in tokens_str.split()])
|
| 106 |
+
target_sequence = f"{M_START} {wrapped_tokens} {M_END}"
|
| 107 |
+
|
| 108 |
+
full_text = prompt + target_sequence
|
| 109 |
+
|
| 110 |
+
tokenized = self.tokenizer(
|
| 111 |
+
full_text,
|
| 112 |
+
truncation=True,
|
| 113 |
+
max_length=self.max_length,
|
| 114 |
+
padding="max_length",
|
| 115 |
+
return_tensors="pt"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
prompt_tokenized = self.tokenizer(prompt, return_tensors="pt")
|
| 119 |
+
prompt_len = prompt_tokenized.input_ids.shape[1]
|
| 120 |
+
|
| 121 |
+
labels = tokenized['input_ids'].clone()
|
| 122 |
+
labels[0, :prompt_len] = -100
|
| 123 |
+
|
| 124 |
+
self.items.append({
|
| 125 |
+
"input_ids": tokenized['input_ids'].squeeze(0),
|
| 126 |
+
"attention_mask": tokenized['attention_mask'].squeeze(0),
|
| 127 |
+
"labels": labels.squeeze(0)
|
| 128 |
+
})
|
| 129 |
+
|
| 130 |
+
def __len__(self):
|
| 131 |
+
return len(self.items)
|
| 132 |
+
|
| 133 |
+
def __getitem__(self, idx):
|
| 134 |
+
return self.items[idx]
|
| 135 |
+
|
| 136 |
+
# ======================================================================================
|
| 137 |
+
# Legacy utilities (kept for compatibility if needed, but mostly superseded)
|
| 138 |
+
# ======================================================================================
|
| 139 |
+
|
| 140 |
+
def build_motion_vocab(dataset):
|
| 141 |
+
"""
|
| 142 |
+
Build motion vocabulary by finding max token ID
|
| 143 |
+
Returns: (codebook_size, max_token_id)
|
| 144 |
+
"""
|
| 145 |
+
def max_token_in_example(ex):
|
| 146 |
+
return max(int(x) for x in ex["motion_tokens"].split())
|
| 147 |
+
|
| 148 |
+
global_max_id = 0
|
| 149 |
+
for ex in dataset:
|
| 150 |
+
global_max_id = max(global_max_id, max_token_in_example(ex))
|
| 151 |
+
|
| 152 |
+
codebook_size = global_max_id + 1
|
| 153 |
+
return codebook_size, global_max_id
|
| 154 |
+
|
| 155 |
+
def motion_specials_to_ids(s: str) -> List[int]:
|
| 156 |
+
"""Extract motion IDs from special tokens"""
|
| 157 |
+
toks = s.strip().split()
|
| 158 |
+
ids = []
|
| 159 |
+
for t in toks:
|
| 160 |
+
if t.startswith("<motion_") or (t.startswith("<M") and t.endswith(">") and t[2:-1].isdigit()):
|
| 161 |
+
# Handle both <motion_ID> and <MID> formats
|
| 162 |
+
try:
|
| 163 |
+
if t.startswith("<motion_"):
|
| 164 |
+
ids.append(int(t[8:-1]))
|
| 165 |
+
else:
|
| 166 |
+
ids.append(int(t[2:-1]))
|
| 167 |
+
except:
|
| 168 |
+
pass
|
| 169 |
+
return ids
|
data/motion_llm_dataset.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba9a0521241d7c72d0759c739ea323eee47e04cf41a5a7b756b9e083b40bc4e1
|
| 3 |
+
size 16798494
|
data/smplx_models/SMPLX_NEUTRAL.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:376021446ddc86e99acacd795182bbef903e61d33b76b9d8b359c2b0865bd992
|
| 3 |
+
size 108752058
|
data/vqvae_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fadbf3fb4ded1c6fe7752e7e381b627a46fa37787d051d969b73d97f81b278fb
|
| 3 |
+
size 231392924
|
data/vqvae_stats.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa86de891dd702ca71f0006cfbf68839c5eba35fb728891ab9f1890949dca943
|
| 3 |
+
size 2876
|
generate.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generation and inference utilities with constrained decoding
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import LogitsProcessor, LogitsProcessorList
|
| 6 |
+
from typing import Dict
|
| 7 |
+
from config import (
|
| 8 |
+
SYSTEM_MSG, GEN_MAX_NEW_TOKENS, GEN_TEMPERATURE,
|
| 9 |
+
GEN_TOP_P, GEN_TOP_K, GEN_NO_REPEAT_NGRAM_SIZE,
|
| 10 |
+
GEN_REPETITION_PENALTY, GEN_END_LOGIT_SLOPE
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LengthAwareMotionLogitsProcessor(LogitsProcessor):
|
| 15 |
+
"""
|
| 16 |
+
Constrained decoding processor that:
|
| 17 |
+
1. Enforces motion token vocabulary
|
| 18 |
+
2. Controls sequence length (min/soft_target/max)
|
| 19 |
+
3. Biases toward ending at soft_target length
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, prompt_len, mot_begin_id, mot_end_id, motion_ids,
|
| 23 |
+
hard_min, soft_target, hard_max, end_logit_slope=0.25):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.prompt_len = int(prompt_len)
|
| 26 |
+
self.mot_begin_id = int(mot_begin_id)
|
| 27 |
+
self.mot_end_id = int(mot_end_id)
|
| 28 |
+
self.motion_ids = torch.tensor(sorted(set(int(x) for x in motion_ids)))
|
| 29 |
+
self.motion_plus_end = torch.tensor(
|
| 30 |
+
sorted(set(list(self.motion_ids.tolist()) + [self.mot_end_id]))
|
| 31 |
+
)
|
| 32 |
+
self.hard_min = int(hard_min)
|
| 33 |
+
self.soft_target = int(soft_target)
|
| 34 |
+
self.hard_max = int(hard_max)
|
| 35 |
+
self.end_logit_slope = float(end_logit_slope)
|
| 36 |
+
|
| 37 |
+
def __call__(self, input_ids, scores):
|
| 38 |
+
device = scores.device
|
| 39 |
+
bs = scores.size(0)
|
| 40 |
+
mask = torch.full_like(scores, float("-inf"))
|
| 41 |
+
|
| 42 |
+
for b in range(bs):
|
| 43 |
+
gen = input_ids[b, self.prompt_len:]
|
| 44 |
+
|
| 45 |
+
# No tokens generated yet - must start with MOT_BEGIN
|
| 46 |
+
if gen.numel() == 0:
|
| 47 |
+
allowed = torch.tensor([self.mot_begin_id], device=device)
|
| 48 |
+
mask[b].index_fill_(0, allowed, 0.0)
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
# Find MOT_BEGIN position
|
| 52 |
+
begin_pos = (gen == self.mot_begin_id).nonzero(as_tuple=True)[0]
|
| 53 |
+
if begin_pos.numel() == 0:
|
| 54 |
+
allowed = torch.tensor([self.mot_begin_id], device=device)
|
| 55 |
+
mask[b].index_fill_(0, allowed, 0.0)
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
# Already generated MOT_END - force EOS
|
| 59 |
+
if (gen == self.mot_end_id).any():
|
| 60 |
+
allowed = torch.tensor([self.mot_end_id], device=device)
|
| 61 |
+
mask[b].index_fill_(0, allowed, 0.0)
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
# Count motion tokens after MOT_BEGIN
|
| 65 |
+
after_begin = gen[begin_pos[0].item() + 1:]
|
| 66 |
+
cur_len = after_begin.numel()
|
| 67 |
+
|
| 68 |
+
# Before minimum length - only allow motion tokens
|
| 69 |
+
if cur_len < self.hard_min:
|
| 70 |
+
allowed = self.motion_ids.to(device)
|
| 71 |
+
mask[b].index_fill_(0, allowed, 0.0)
|
| 72 |
+
|
| 73 |
+
# After maximum length - force end
|
| 74 |
+
elif cur_len >= self.hard_max:
|
| 75 |
+
allowed = torch.tensor([self.mot_end_id], device=device)
|
| 76 |
+
mask[b].index_fill_(0, allowed, 0.0)
|
| 77 |
+
|
| 78 |
+
# Between min and max - allow motion tokens or end
|
| 79 |
+
else:
|
| 80 |
+
allowed = self.motion_plus_end.to(device)
|
| 81 |
+
mask[b].index_fill_(0, allowed, 0.0)
|
| 82 |
+
|
| 83 |
+
# Bias toward ending at soft_target
|
| 84 |
+
distance = max(0, cur_len - self.soft_target)
|
| 85 |
+
bias = self.end_logit_slope * float(distance)
|
| 86 |
+
scores[b, self.mot_end_id] = scores[b, self.mot_end_id] + bias
|
| 87 |
+
|
| 88 |
+
return scores + mask
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_len_controls(prompt_text: str, length_stats_by_text: Dict, global_median_len: int):
|
| 92 |
+
"""
|
| 93 |
+
Get length controls (min/soft_target/max) for a given prompt
|
| 94 |
+
"""
|
| 95 |
+
s = length_stats_by_text.get(prompt_text)
|
| 96 |
+
if s is None:
|
| 97 |
+
med = global_median_len
|
| 98 |
+
else:
|
| 99 |
+
med = s["median"]
|
| 100 |
+
|
| 101 |
+
hard_min = max(1, int(0.6 * med))
|
| 102 |
+
soft_tgt = med
|
| 103 |
+
hard_max = max(hard_min + 4, int(1.4 * med))
|
| 104 |
+
|
| 105 |
+
return hard_min, soft_tgt, hard_max
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def generate_t2m(
|
| 109 |
+
model,
|
| 110 |
+
tokenizer,
|
| 111 |
+
prompt_text: str,
|
| 112 |
+
mot_begin_id: int,
|
| 113 |
+
mot_end_id: int,
|
| 114 |
+
motion_token_ids: list,
|
| 115 |
+
length_stats_by_text: Dict,
|
| 116 |
+
global_median_len: int,
|
| 117 |
+
prompt_vocab: Dict = None,
|
| 118 |
+
pid: str = None,
|
| 119 |
+
has_pid: bool = False,
|
| 120 |
+
max_new_tokens: int = None,
|
| 121 |
+
per_prompt_vocab: bool = True
|
| 122 |
+
):
|
| 123 |
+
"""
|
| 124 |
+
Generate motion sequence from text prompt with constrained decoding
|
| 125 |
+
"""
|
| 126 |
+
model.eval()
|
| 127 |
+
device = next(model.parameters()).device
|
| 128 |
+
|
| 129 |
+
if max_new_tokens is None:
|
| 130 |
+
max_new_tokens = GEN_MAX_NEW_TOKENS
|
| 131 |
+
|
| 132 |
+
# Build prompt
|
| 133 |
+
pid_tok = ""
|
| 134 |
+
if has_pid and pid is not None:
|
| 135 |
+
pid_tok = f"<PID_{pid}>"
|
| 136 |
+
|
| 137 |
+
user_text = f"<T2M>{pid_tok}\n\n" + prompt_text
|
| 138 |
+
prompt = (
|
| 139 |
+
"<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
|
| 140 |
+
+ "<|im_start|>user\n" + user_text + "\n<|im_end|>\n"
|
| 141 |
+
+ "<|im_start|>assistant\n"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Tokenize
|
| 145 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 146 |
+
prompt_len = inputs["input_ids"].size(1)
|
| 147 |
+
|
| 148 |
+
# Get length controls
|
| 149 |
+
hard_min, soft_tgt, hard_max = get_len_controls(
|
| 150 |
+
prompt_text, length_stats_by_text, global_median_len
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Get allowed motion tokens
|
| 154 |
+
if per_prompt_vocab and prompt_vocab:
|
| 155 |
+
allowed_motion_ids = prompt_vocab.get(prompt_text, motion_token_ids)
|
| 156 |
+
else:
|
| 157 |
+
allowed_motion_ids = motion_token_ids
|
| 158 |
+
|
| 159 |
+
# Setup constrained decoding
|
| 160 |
+
processors = LogitsProcessorList([
|
| 161 |
+
LengthAwareMotionLogitsProcessor(
|
| 162 |
+
prompt_len=prompt_len,
|
| 163 |
+
mot_begin_id=mot_begin_id,
|
| 164 |
+
mot_end_id=mot_end_id,
|
| 165 |
+
motion_ids=allowed_motion_ids,
|
| 166 |
+
hard_min=hard_min,
|
| 167 |
+
soft_target=soft_tgt,
|
| 168 |
+
hard_max=hard_max,
|
| 169 |
+
end_logit_slope=GEN_END_LOGIT_SLOPE,
|
| 170 |
+
)
|
| 171 |
+
])
|
| 172 |
+
|
| 173 |
+
# Generate
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
out = model.generate(
|
| 176 |
+
input_ids=inputs["input_ids"],
|
| 177 |
+
attention_mask=inputs.get("attention_mask"),
|
| 178 |
+
max_new_tokens=min(max_new_tokens, hard_max + 4),
|
| 179 |
+
do_sample=True,
|
| 180 |
+
temperature=GEN_TEMPERATURE,
|
| 181 |
+
top_p=GEN_TOP_P,
|
| 182 |
+
top_k=GEN_TOP_K,
|
| 183 |
+
no_repeat_ngram_size=GEN_NO_REPEAT_NGRAM_SIZE,
|
| 184 |
+
repetition_penalty=GEN_REPETITION_PENALTY,
|
| 185 |
+
logits_processor=processors,
|
| 186 |
+
eos_token_id=mot_end_id,
|
| 187 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Decode
|
| 191 |
+
decoded = tokenizer.decode(out[0], skip_special_tokens=False)
|
| 192 |
+
reply = decoded.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0]
|
| 193 |
+
|
| 194 |
+
return reply
|
inference.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference script for generating motion tokens from text prompts.
|
| 3 |
+
Run after training to generate motion sequences from any text description.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python inference.py --prompt "walking forward" --stage 3
|
| 7 |
+
python inference.py --prompt "dancing" --stage 2 --output motion_output.txt
|
| 8 |
+
"""
|
| 9 |
+
import os
|
| 10 |
+
import argparse
|
| 11 |
+
import torch
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
from config import (
|
| 15 |
+
OUT_S1, OUT_S2, OUT_S3, MAX_SEQ_LEN, DATA_JSON_PATH,
|
| 16 |
+
WORK_DIR
|
| 17 |
+
)
|
| 18 |
+
from data import (
|
| 19 |
+
load_dataset, compute_length_stats, build_prompt_vocab,
|
| 20 |
+
check_has_participant_id
|
| 21 |
+
)
|
| 22 |
+
from model import setup_model_and_tokenizer, get_motion_token_info
|
| 23 |
+
from generate import generate_t2m
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_trained_model(stage: int, device: torch.device):
|
| 27 |
+
"""
|
| 28 |
+
Load a trained model from a specific stage checkpoint.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
stage: Stage number (1, 2, or 3)
|
| 32 |
+
device: Device to load model on
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
model, tokenizer, motion_token_ids, mot_begin_id, mot_end_id
|
| 36 |
+
"""
|
| 37 |
+
stage_dirs = {1: OUT_S1, 2: OUT_S2, 3: OUT_S3}
|
| 38 |
+
stage_dir = stage_dirs.get(stage)
|
| 39 |
+
|
| 40 |
+
if not stage_dir or not os.path.exists(stage_dir):
|
| 41 |
+
raise FileNotFoundError(
|
| 42 |
+
f"Stage {stage} checkpoint not found at {stage_dir}. "
|
| 43 |
+
f"Train stage {stage} first."
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
print(f"\nLoading Stage {stage} model from: {stage_dir}")
|
| 47 |
+
|
| 48 |
+
# Load dataset to build vocab (needed for model setup)
|
| 49 |
+
if not os.path.exists(DATA_JSON_PATH):
|
| 50 |
+
raise FileNotFoundError(f"Dataset not found: {DATA_JSON_PATH}")
|
| 51 |
+
|
| 52 |
+
raw_ds = load_dataset(DATA_JSON_PATH)
|
| 53 |
+
|
| 54 |
+
# Build motion vocab
|
| 55 |
+
def max_token_in_example(ex):
|
| 56 |
+
return max(int(x) for x in ex["motion_tokens"].split())
|
| 57 |
+
|
| 58 |
+
global_max_id = max(max_token_in_example(ex) for ex in raw_ds)
|
| 59 |
+
codebook_size = global_max_id + 1
|
| 60 |
+
|
| 61 |
+
# Check for participant IDs
|
| 62 |
+
has_pid = check_has_participant_id(raw_ds)
|
| 63 |
+
unique_pids = None
|
| 64 |
+
if has_pid:
|
| 65 |
+
unique_pids = sorted({str(ex["participant_id"]) for ex in raw_ds})
|
| 66 |
+
|
| 67 |
+
# Setup model and tokenizer with same config as training
|
| 68 |
+
model, tokenizer, _ = setup_model_and_tokenizer(codebook_size, unique_pids)
|
| 69 |
+
|
| 70 |
+
# Load trained weights from checkpoint
|
| 71 |
+
# Try different checkpoint naming patterns
|
| 72 |
+
possible_ckpts = [
|
| 73 |
+
os.path.join(stage_dir, "pytorch_model.bin"),
|
| 74 |
+
os.path.join(stage_dir, "model.safetensors"),
|
| 75 |
+
os.path.join(stage_dir, "adapter_model.bin"),
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
loaded = False
|
| 79 |
+
for ckpt_path in possible_ckpts:
|
| 80 |
+
if os.path.exists(ckpt_path):
|
| 81 |
+
print(f"Loading checkpoint: {ckpt_path}")
|
| 82 |
+
# Unsloth/PEFT models save adapters separately
|
| 83 |
+
# The model will auto-load from the directory
|
| 84 |
+
loaded = True
|
| 85 |
+
break
|
| 86 |
+
|
| 87 |
+
if not loaded:
|
| 88 |
+
print(f"⚠️ No explicit checkpoint file found, using model directory: {stage_dir}")
|
| 89 |
+
|
| 90 |
+
# Move model to device
|
| 91 |
+
model.to(device)
|
| 92 |
+
model.eval()
|
| 93 |
+
|
| 94 |
+
# Get motion token info
|
| 95 |
+
motion_token_ids, mot_begin_id, mot_end_id = get_motion_token_info(
|
| 96 |
+
tokenizer, codebook_size
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
print(f"✅ Stage {stage} model loaded successfully")
|
| 100 |
+
print(f" Vocabulary size: {len(tokenizer)}")
|
| 101 |
+
print(f" Motion tokens: {len(motion_token_ids)}")
|
| 102 |
+
|
| 103 |
+
return model, tokenizer, motion_token_ids, mot_begin_id, mot_end_id, raw_ds
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def inference(
|
| 107 |
+
prompt: str,
|
| 108 |
+
stage: int = 3,
|
| 109 |
+
pid: str = None,
|
| 110 |
+
output_file: str = None,
|
| 111 |
+
per_prompt_vocab: bool = True,
|
| 112 |
+
device: torch.device = None
|
| 113 |
+
):
|
| 114 |
+
"""
|
| 115 |
+
Generate motion tokens from a text prompt.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
prompt: Text description of desired motion
|
| 119 |
+
stage: Which training stage model to use (1, 2, or 3)
|
| 120 |
+
pid: Optional participant ID for personalization
|
| 121 |
+
output_file: Optional file to save output tokens
|
| 122 |
+
per_prompt_vocab: Whether to use per-prompt vocabulary constraints
|
| 123 |
+
device: Device to run inference on
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Generated motion token string
|
| 127 |
+
"""
|
| 128 |
+
if device is None:
|
| 129 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 130 |
+
|
| 131 |
+
print("="*60)
|
| 132 |
+
print(f"Motion Generation Inference - Stage {stage}")
|
| 133 |
+
print("="*60)
|
| 134 |
+
print(f"Prompt: '{prompt}'")
|
| 135 |
+
print(f"Device: {device}")
|
| 136 |
+
|
| 137 |
+
# Load model and dataset
|
| 138 |
+
model, tokenizer, motion_token_ids, mot_begin_id, mot_end_id, raw_ds = load_trained_model(stage, device)
|
| 139 |
+
|
| 140 |
+
# Compute length stats and prompt vocab
|
| 141 |
+
print("\nComputing dataset statistics...")
|
| 142 |
+
length_stats_by_text, global_median_len = compute_length_stats(raw_ds)
|
| 143 |
+
prompt_vocab = build_prompt_vocab(raw_ds)
|
| 144 |
+
has_pid = check_has_participant_id(raw_ds)
|
| 145 |
+
|
| 146 |
+
# Generate motion tokens
|
| 147 |
+
print(f"\nGenerating motion for: '{prompt}'")
|
| 148 |
+
print(f"Per-prompt vocabulary: {per_prompt_vocab}")
|
| 149 |
+
|
| 150 |
+
generated = generate_t2m(
|
| 151 |
+
model=model,
|
| 152 |
+
tokenizer=tokenizer,
|
| 153 |
+
prompt_text=prompt,
|
| 154 |
+
mot_begin_id=mot_begin_id,
|
| 155 |
+
mot_end_id=mot_end_id,
|
| 156 |
+
motion_token_ids=motion_token_ids,
|
| 157 |
+
length_stats_by_text=length_stats_by_text,
|
| 158 |
+
global_median_len=global_median_len,
|
| 159 |
+
prompt_vocab=prompt_vocab,
|
| 160 |
+
has_pid=has_pid,
|
| 161 |
+
per_prompt_vocab=per_prompt_vocab,
|
| 162 |
+
pid=pid
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
print("\n" + "="*60)
|
| 166 |
+
print("Generated Motion:")
|
| 167 |
+
print("="*60)
|
| 168 |
+
print(generated)
|
| 169 |
+
print("="*60)
|
| 170 |
+
|
| 171 |
+
# Optionally save to file
|
| 172 |
+
if output_file:
|
| 173 |
+
output_path = Path(output_file)
|
| 174 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 175 |
+
with open(output_path, 'w') as f:
|
| 176 |
+
f.write(generated)
|
| 177 |
+
print(f"\n✅ Output saved to: {output_file}")
|
| 178 |
+
|
| 179 |
+
return generated
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def main():
|
| 183 |
+
parser = argparse.ArgumentParser(
|
| 184 |
+
description="Generate motion tokens from text prompts using trained SignMotionGPT model"
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--prompt",
|
| 188 |
+
type=str,
|
| 189 |
+
required=True,
|
| 190 |
+
help="Text description of the desired motion (e.g., 'walking forward', 'dancing')"
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--stage",
|
| 194 |
+
type=int,
|
| 195 |
+
default=3,
|
| 196 |
+
choices=[1, 2, 3],
|
| 197 |
+
help="Which training stage model to use (1=motion-only, 2=multi-task, 3=T2M SFT, default=3)"
|
| 198 |
+
)
|
| 199 |
+
parser.add_argument(
|
| 200 |
+
"--pid",
|
| 201 |
+
type=str,
|
| 202 |
+
default=None,
|
| 203 |
+
help="Optional participant ID for personalized generation (e.g., 'P40')"
|
| 204 |
+
)
|
| 205 |
+
parser.add_argument(
|
| 206 |
+
"--output",
|
| 207 |
+
type=str,
|
| 208 |
+
default=None,
|
| 209 |
+
help="Optional output file to save generated tokens"
|
| 210 |
+
)
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--no-per-prompt-vocab",
|
| 213 |
+
action="store_true",
|
| 214 |
+
help="Disable per-prompt vocabulary constraints (allows all motion tokens)"
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--device",
|
| 218 |
+
type=str,
|
| 219 |
+
default=None,
|
| 220 |
+
choices=["cpu", "cuda", "cuda:0", "cuda:1"],
|
| 221 |
+
help="Device to run inference on (default: auto-detect)"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
args = parser.parse_args()
|
| 225 |
+
|
| 226 |
+
# Setup device
|
| 227 |
+
if args.device:
|
| 228 |
+
device = torch.device(args.device)
|
| 229 |
+
else:
|
| 230 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 231 |
+
|
| 232 |
+
# Run inference
|
| 233 |
+
inference(
|
| 234 |
+
prompt=args.prompt,
|
| 235 |
+
stage=args.stage,
|
| 236 |
+
pid=args.pid,
|
| 237 |
+
output_file=args.output,
|
| 238 |
+
per_prompt_vocab=not args.no_per_prompt_vocab,
|
| 239 |
+
device=device
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
if __name__ == "__main__":
|
| 244 |
+
main()
|
mGPT/__init__.py
ADDED
|
File without changes
|
mGPT/archs/__init__.py
ADDED
|
File without changes
|
mGPT/archs/mgpt_vq.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from typing import List, Optional, Union
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
from torch.distributions.distribution import Distribution
|
| 7 |
+
from .tools.resnet import Resnet1D
|
| 8 |
+
from .tools.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VQVae(nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(self,
|
| 15 |
+
nfeats: int,
|
| 16 |
+
quantizer: str = "ema_reset",
|
| 17 |
+
code_num=512,
|
| 18 |
+
code_dim=512,
|
| 19 |
+
output_emb_width=512,
|
| 20 |
+
down_t=3,
|
| 21 |
+
stride_t=2,
|
| 22 |
+
width=512,
|
| 23 |
+
depth=3,
|
| 24 |
+
dilation_growth_rate=3,
|
| 25 |
+
norm=None,
|
| 26 |
+
activation: str = "relu",
|
| 27 |
+
**kwargs) -> None:
|
| 28 |
+
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
self.code_dim = code_dim
|
| 32 |
+
|
| 33 |
+
self.encoder = Encoder(nfeats,
|
| 34 |
+
output_emb_width,
|
| 35 |
+
down_t,
|
| 36 |
+
stride_t,
|
| 37 |
+
width,
|
| 38 |
+
depth,
|
| 39 |
+
dilation_growth_rate,
|
| 40 |
+
activation=activation,
|
| 41 |
+
norm=norm)
|
| 42 |
+
|
| 43 |
+
self.decoder = Decoder(nfeats,
|
| 44 |
+
output_emb_width,
|
| 45 |
+
down_t,
|
| 46 |
+
stride_t,
|
| 47 |
+
width,
|
| 48 |
+
depth,
|
| 49 |
+
dilation_growth_rate,
|
| 50 |
+
activation=activation,
|
| 51 |
+
norm=norm)
|
| 52 |
+
|
| 53 |
+
if quantizer == "ema_reset":
|
| 54 |
+
self.quantizer = QuantizeEMAReset(code_num, code_dim, mu=0.99)
|
| 55 |
+
elif quantizer == "orig":
|
| 56 |
+
self.quantizer = Quantizer(code_num, code_dim, beta=1.0)
|
| 57 |
+
elif quantizer == "ema":
|
| 58 |
+
self.quantizer = QuantizeEMA(code_num, code_dim, mu=0.99)
|
| 59 |
+
elif quantizer == "reset":
|
| 60 |
+
self.quantizer = QuantizeReset(code_num, code_dim)
|
| 61 |
+
|
| 62 |
+
def preprocess(self, x):
|
| 63 |
+
# (bs, T, Jx3) -> (bs, Jx3, T)
|
| 64 |
+
x = x.permute(0, 2, 1)
|
| 65 |
+
return x
|
| 66 |
+
|
| 67 |
+
def postprocess(self, x):
|
| 68 |
+
# (bs, Jx3, T) -> (bs, T, Jx3)
|
| 69 |
+
x = x.permute(0, 2, 1)
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
def forward(self, features: Tensor):
|
| 73 |
+
# Preprocess
|
| 74 |
+
x_in = self.preprocess(features)
|
| 75 |
+
|
| 76 |
+
# Encode
|
| 77 |
+
x_encoder = self.encoder(x_in)
|
| 78 |
+
|
| 79 |
+
# quantization
|
| 80 |
+
x_quantized, loss, perplexity = self.quantizer(x_encoder)
|
| 81 |
+
|
| 82 |
+
# decoder
|
| 83 |
+
x_decoder = self.decoder(x_quantized)
|
| 84 |
+
x_out = self.postprocess(x_decoder)
|
| 85 |
+
|
| 86 |
+
return x_out, loss, perplexity
|
| 87 |
+
|
| 88 |
+
def encode(
|
| 89 |
+
self,
|
| 90 |
+
features: Tensor,
|
| 91 |
+
) -> Union[Tensor, Distribution]:
|
| 92 |
+
|
| 93 |
+
N, T, _ = features.shape
|
| 94 |
+
x_in = self.preprocess(features)
|
| 95 |
+
x_encoder = self.encoder(x_in)
|
| 96 |
+
x_encoder = self.postprocess(x_encoder)
|
| 97 |
+
x_encoder = x_encoder.contiguous().view(-1,
|
| 98 |
+
x_encoder.shape[-1]) # (NT, C)
|
| 99 |
+
code_idx = self.quantizer.quantize(x_encoder)
|
| 100 |
+
code_idx = code_idx.view(N, -1)
|
| 101 |
+
|
| 102 |
+
# latent, dist
|
| 103 |
+
return code_idx, None
|
| 104 |
+
|
| 105 |
+
def decode(self, z: Tensor):
|
| 106 |
+
|
| 107 |
+
x_d = self.quantizer.dequantize(z)
|
| 108 |
+
x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous()
|
| 109 |
+
|
| 110 |
+
# decoder
|
| 111 |
+
x_decoder = self.decoder(x_d)
|
| 112 |
+
x_out = self.postprocess(x_decoder)
|
| 113 |
+
return x_out
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class Encoder(nn.Module):
|
| 117 |
+
|
| 118 |
+
def __init__(self,
|
| 119 |
+
input_emb_width=3,
|
| 120 |
+
output_emb_width=512,
|
| 121 |
+
down_t=3,
|
| 122 |
+
stride_t=2,
|
| 123 |
+
width=512,
|
| 124 |
+
depth=3,
|
| 125 |
+
dilation_growth_rate=3,
|
| 126 |
+
activation='relu',
|
| 127 |
+
norm=None):
|
| 128 |
+
super().__init__()
|
| 129 |
+
|
| 130 |
+
blocks = []
|
| 131 |
+
filter_t, pad_t = stride_t * 2, stride_t // 2
|
| 132 |
+
blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1))
|
| 133 |
+
blocks.append(nn.ReLU())
|
| 134 |
+
|
| 135 |
+
for i in range(down_t):
|
| 136 |
+
input_dim = width
|
| 137 |
+
block = nn.Sequential(
|
| 138 |
+
nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t),
|
| 139 |
+
Resnet1D(width,
|
| 140 |
+
depth,
|
| 141 |
+
dilation_growth_rate,
|
| 142 |
+
activation=activation,
|
| 143 |
+
norm=norm),
|
| 144 |
+
)
|
| 145 |
+
blocks.append(block)
|
| 146 |
+
blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1))
|
| 147 |
+
self.model = nn.Sequential(*blocks)
|
| 148 |
+
|
| 149 |
+
def forward(self, x):
|
| 150 |
+
return self.model(x)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class Decoder(nn.Module):
|
| 154 |
+
|
| 155 |
+
def __init__(self,
|
| 156 |
+
input_emb_width=3,
|
| 157 |
+
output_emb_width=512,
|
| 158 |
+
down_t=3,
|
| 159 |
+
stride_t=2,
|
| 160 |
+
width=512,
|
| 161 |
+
depth=3,
|
| 162 |
+
dilation_growth_rate=3,
|
| 163 |
+
activation='relu',
|
| 164 |
+
norm=None):
|
| 165 |
+
super().__init__()
|
| 166 |
+
blocks = []
|
| 167 |
+
|
| 168 |
+
filter_t, pad_t = stride_t * 2, stride_t // 2
|
| 169 |
+
blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1))
|
| 170 |
+
blocks.append(nn.ReLU())
|
| 171 |
+
for i in range(down_t):
|
| 172 |
+
out_dim = width
|
| 173 |
+
block = nn.Sequential(
|
| 174 |
+
Resnet1D(width,
|
| 175 |
+
depth,
|
| 176 |
+
dilation_growth_rate,
|
| 177 |
+
reverse_dilation=True,
|
| 178 |
+
activation=activation,
|
| 179 |
+
norm=norm), nn.Upsample(scale_factor=2,
|
| 180 |
+
mode='nearest'),
|
| 181 |
+
nn.Conv1d(width, out_dim, 3, 1, 1))
|
| 182 |
+
blocks.append(block)
|
| 183 |
+
blocks.append(nn.Conv1d(width, width, 3, 1, 1))
|
| 184 |
+
blocks.append(nn.ReLU())
|
| 185 |
+
blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1))
|
| 186 |
+
self.model = nn.Sequential(*blocks)
|
| 187 |
+
|
| 188 |
+
def forward(self, x):
|
| 189 |
+
return self.model(x)
|
mGPT/archs/tools/__init__.py
ADDED
|
File without changes
|
mGPT/archs/tools/quantize_cnn.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
class QuantizeEMAReset(nn.Module):
|
| 8 |
+
def __init__(self, nb_code, code_dim, mu):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.nb_code = nb_code
|
| 11 |
+
self.code_dim = code_dim
|
| 12 |
+
self.mu = mu
|
| 13 |
+
self.reset_codebook()
|
| 14 |
+
|
| 15 |
+
def reset_codebook(self):
|
| 16 |
+
self.init = False
|
| 17 |
+
self.code_sum = None
|
| 18 |
+
self.code_count = None
|
| 19 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).to(device))
|
| 21 |
+
|
| 22 |
+
def _tile(self, x):
|
| 23 |
+
nb_code_x, code_dim = x.shape
|
| 24 |
+
if nb_code_x < self.nb_code:
|
| 25 |
+
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
|
| 26 |
+
std = 0.01 / np.sqrt(code_dim)
|
| 27 |
+
out = x.repeat(n_repeats, 1)
|
| 28 |
+
out = out + torch.randn_like(out) * std
|
| 29 |
+
else :
|
| 30 |
+
out = x
|
| 31 |
+
return out
|
| 32 |
+
|
| 33 |
+
def init_codebook(self, x):
|
| 34 |
+
out = self._tile(x)
|
| 35 |
+
self.codebook = out[:self.nb_code]
|
| 36 |
+
self.code_sum = self.codebook.clone()
|
| 37 |
+
self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
|
| 38 |
+
self.init = True
|
| 39 |
+
|
| 40 |
+
@torch.no_grad()
|
| 41 |
+
def compute_perplexity(self, code_idx) :
|
| 42 |
+
# Calculate new centres
|
| 43 |
+
code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
|
| 44 |
+
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
|
| 45 |
+
|
| 46 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
| 47 |
+
prob = code_count / torch.sum(code_count)
|
| 48 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
| 49 |
+
return perplexity
|
| 50 |
+
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def update_codebook(self, x, code_idx):
|
| 53 |
+
|
| 54 |
+
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
|
| 55 |
+
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
|
| 56 |
+
|
| 57 |
+
code_sum = torch.matmul(code_onehot, x) # nb_code, w
|
| 58 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
| 59 |
+
|
| 60 |
+
out = self._tile(x)
|
| 61 |
+
code_rand = out[:self.nb_code]
|
| 62 |
+
|
| 63 |
+
# Update centres
|
| 64 |
+
self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
|
| 65 |
+
self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
|
| 66 |
+
|
| 67 |
+
usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
|
| 68 |
+
code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
|
| 69 |
+
|
| 70 |
+
self.codebook = usage * code_update + (1 - usage) * code_rand
|
| 71 |
+
prob = code_count / torch.sum(code_count)
|
| 72 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
return perplexity
|
| 76 |
+
|
| 77 |
+
def preprocess(self, x):
|
| 78 |
+
# NCT -> NTC -> [NT, C]
|
| 79 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 80 |
+
x = x.view(-1, x.shape[-1])
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def quantize(self, x):
|
| 84 |
+
# Calculate latent code x_l
|
| 85 |
+
k_w = self.codebook.t()
|
| 86 |
+
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
|
| 87 |
+
keepdim=True) # (N * L, b)
|
| 88 |
+
_, code_idx = torch.min(distance, dim=-1)
|
| 89 |
+
return code_idx
|
| 90 |
+
|
| 91 |
+
def dequantize(self, code_idx):
|
| 92 |
+
x = F.embedding(code_idx, self.codebook)
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
N, width, T = x.shape
|
| 98 |
+
|
| 99 |
+
# Preprocess
|
| 100 |
+
x = self.preprocess(x)
|
| 101 |
+
|
| 102 |
+
# Init codebook if not inited
|
| 103 |
+
if self.training and not self.init:
|
| 104 |
+
self.init_codebook(x)
|
| 105 |
+
|
| 106 |
+
# quantize and dequantize through bottleneck
|
| 107 |
+
code_idx = self.quantize(x)
|
| 108 |
+
x_d = self.dequantize(code_idx)
|
| 109 |
+
|
| 110 |
+
# Update embeddings
|
| 111 |
+
if self.training:
|
| 112 |
+
perplexity = self.update_codebook(x, code_idx)
|
| 113 |
+
else :
|
| 114 |
+
perplexity = self.compute_perplexity(code_idx)
|
| 115 |
+
|
| 116 |
+
# Loss
|
| 117 |
+
commit_loss = F.mse_loss(x, x_d.detach())
|
| 118 |
+
|
| 119 |
+
# Passthrough
|
| 120 |
+
x_d = x + (x_d - x).detach()
|
| 121 |
+
|
| 122 |
+
# Postprocess
|
| 123 |
+
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
|
| 124 |
+
|
| 125 |
+
return x_d, commit_loss, perplexity
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class Quantizer(nn.Module):
|
| 130 |
+
def __init__(self, n_e, e_dim, beta):
|
| 131 |
+
super(Quantizer, self).__init__()
|
| 132 |
+
|
| 133 |
+
self.e_dim = e_dim
|
| 134 |
+
self.n_e = n_e
|
| 135 |
+
self.beta = beta
|
| 136 |
+
|
| 137 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
| 138 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
| 139 |
+
|
| 140 |
+
def forward(self, z):
|
| 141 |
+
|
| 142 |
+
N, width, T = z.shape
|
| 143 |
+
z = self.preprocess(z)
|
| 144 |
+
assert z.shape[-1] == self.e_dim
|
| 145 |
+
z_flattened = z.contiguous().view(-1, self.e_dim)
|
| 146 |
+
|
| 147 |
+
# B x V
|
| 148 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
| 149 |
+
# B x 1
|
| 150 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
| 151 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
| 152 |
+
|
| 153 |
+
# compute loss for embedding
|
| 154 |
+
loss = torch.mean((z_q - z.detach())**2) + self.beta * torch.mean((z_q.detach() - z)**2)
|
| 155 |
+
|
| 156 |
+
# preserve gradients
|
| 157 |
+
z_q = z + (z_q - z).detach()
|
| 158 |
+
z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
|
| 159 |
+
|
| 160 |
+
min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype)
|
| 161 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
| 162 |
+
perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10)))
|
| 163 |
+
return z_q, loss, perplexity
|
| 164 |
+
|
| 165 |
+
def quantize(self, z):
|
| 166 |
+
|
| 167 |
+
assert z.shape[-1] == self.e_dim
|
| 168 |
+
|
| 169 |
+
# B x V
|
| 170 |
+
d = torch.sum(z ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * torch.matmul(z, self.embedding.weight.t())
|
| 171 |
+
# B x 1
|
| 172 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
| 173 |
+
return min_encoding_indices
|
| 174 |
+
|
| 175 |
+
def dequantize(self, indices):
|
| 176 |
+
|
| 177 |
+
index_flattened = indices.view(-1)
|
| 178 |
+
z_q = self.embedding(index_flattened)
|
| 179 |
+
z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous()
|
| 180 |
+
return z_q
|
| 181 |
+
|
| 182 |
+
def preprocess(self, x):
|
| 183 |
+
# NCT -> NTC -> [NT, C]
|
| 184 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 185 |
+
x = x.view(-1, x.shape[-1])
|
| 186 |
+
return x
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class QuantizeReset(nn.Module):
|
| 191 |
+
def __init__(self, nb_code, code_dim):
|
| 192 |
+
super().__init__()
|
| 193 |
+
self.nb_code = nb_code
|
| 194 |
+
self.code_dim = code_dim
|
| 195 |
+
self.reset_codebook()
|
| 196 |
+
self.codebook = nn.Parameter(torch.randn(nb_code, code_dim))
|
| 197 |
+
|
| 198 |
+
def reset_codebook(self):
|
| 199 |
+
self.init = False
|
| 200 |
+
self.code_count = None
|
| 201 |
+
|
| 202 |
+
def _tile(self, x):
|
| 203 |
+
nb_code_x, code_dim = x.shape
|
| 204 |
+
if nb_code_x < self.nb_code:
|
| 205 |
+
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
|
| 206 |
+
std = 0.01 / np.sqrt(code_dim)
|
| 207 |
+
out = x.repeat(n_repeats, 1)
|
| 208 |
+
out = out + torch.randn_like(out) * std
|
| 209 |
+
else :
|
| 210 |
+
out = x
|
| 211 |
+
return out
|
| 212 |
+
|
| 213 |
+
def init_codebook(self, x):
|
| 214 |
+
out = self._tile(x)
|
| 215 |
+
self.codebook = nn.Parameter(out[:self.nb_code])
|
| 216 |
+
self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
|
| 217 |
+
self.init = True
|
| 218 |
+
|
| 219 |
+
@torch.no_grad()
|
| 220 |
+
def compute_perplexity(self, code_idx) :
|
| 221 |
+
# Calculate new centres
|
| 222 |
+
code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
|
| 223 |
+
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
|
| 224 |
+
|
| 225 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
| 226 |
+
prob = code_count / torch.sum(code_count)
|
| 227 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
| 228 |
+
return perplexity
|
| 229 |
+
|
| 230 |
+
def update_codebook(self, x, code_idx):
|
| 231 |
+
|
| 232 |
+
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
|
| 233 |
+
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
|
| 234 |
+
|
| 235 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
| 236 |
+
|
| 237 |
+
out = self._tile(x)
|
| 238 |
+
code_rand = out[:self.nb_code]
|
| 239 |
+
|
| 240 |
+
# Update centres
|
| 241 |
+
self.code_count = code_count # nb_code
|
| 242 |
+
usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
|
| 243 |
+
|
| 244 |
+
self.codebook.data = usage * self.codebook.data + (1 - usage) * code_rand
|
| 245 |
+
prob = code_count / torch.sum(code_count)
|
| 246 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
return perplexity
|
| 250 |
+
|
| 251 |
+
def preprocess(self, x):
|
| 252 |
+
# NCT -> NTC -> [NT, C]
|
| 253 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 254 |
+
x = x.view(-1, x.shape[-1])
|
| 255 |
+
return x
|
| 256 |
+
|
| 257 |
+
def quantize(self, x):
|
| 258 |
+
# Calculate latent code x_l
|
| 259 |
+
k_w = self.codebook.t()
|
| 260 |
+
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
|
| 261 |
+
keepdim=True) # (N * L, b)
|
| 262 |
+
_, code_idx = torch.min(distance, dim=-1)
|
| 263 |
+
return code_idx
|
| 264 |
+
|
| 265 |
+
def dequantize(self, code_idx):
|
| 266 |
+
x = F.embedding(code_idx, self.codebook)
|
| 267 |
+
return x
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def forward(self, x):
|
| 271 |
+
N, width, T = x.shape
|
| 272 |
+
# Preprocess
|
| 273 |
+
x = self.preprocess(x)
|
| 274 |
+
# Init codebook if not inited
|
| 275 |
+
if self.training and not self.init:
|
| 276 |
+
self.init_codebook(x)
|
| 277 |
+
# quantize and dequantize through bottleneck
|
| 278 |
+
code_idx = self.quantize(x)
|
| 279 |
+
x_d = self.dequantize(code_idx)
|
| 280 |
+
# Update embeddings
|
| 281 |
+
if self.training:
|
| 282 |
+
perplexity = self.update_codebook(x, code_idx)
|
| 283 |
+
else :
|
| 284 |
+
perplexity = self.compute_perplexity(code_idx)
|
| 285 |
+
|
| 286 |
+
# Loss
|
| 287 |
+
commit_loss = F.mse_loss(x, x_d.detach())
|
| 288 |
+
|
| 289 |
+
# Passthrough
|
| 290 |
+
x_d = x + (x_d - x).detach()
|
| 291 |
+
|
| 292 |
+
# Postprocess
|
| 293 |
+
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
|
| 294 |
+
|
| 295 |
+
return x_d, commit_loss, perplexity
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class QuantizeEMA(nn.Module):
|
| 299 |
+
def __init__(self, nb_code, code_dim, mu):
|
| 300 |
+
super().__init__()
|
| 301 |
+
self.nb_code = nb_code
|
| 302 |
+
self.code_dim = code_dim
|
| 303 |
+
self.mu = mu
|
| 304 |
+
self.reset_codebook()
|
| 305 |
+
|
| 306 |
+
def reset_codebook(self):
|
| 307 |
+
self.init = False
|
| 308 |
+
self.code_sum = None
|
| 309 |
+
self.code_count = None
|
| 310 |
+
self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda())
|
| 311 |
+
|
| 312 |
+
def _tile(self, x):
|
| 313 |
+
nb_code_x, code_dim = x.shape
|
| 314 |
+
if nb_code_x < self.nb_code:
|
| 315 |
+
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
|
| 316 |
+
std = 0.01 / np.sqrt(code_dim)
|
| 317 |
+
out = x.repeat(n_repeats, 1)
|
| 318 |
+
out = out + torch.randn_like(out) * std
|
| 319 |
+
else :
|
| 320 |
+
out = x
|
| 321 |
+
return out
|
| 322 |
+
|
| 323 |
+
def init_codebook(self, x):
|
| 324 |
+
out = self._tile(x)
|
| 325 |
+
self.codebook = out[:self.nb_code]
|
| 326 |
+
self.code_sum = self.codebook.clone()
|
| 327 |
+
self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
|
| 328 |
+
self.init = True
|
| 329 |
+
|
| 330 |
+
@torch.no_grad()
|
| 331 |
+
def compute_perplexity(self, code_idx) :
|
| 332 |
+
# Calculate new centres
|
| 333 |
+
code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
|
| 334 |
+
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
|
| 335 |
+
|
| 336 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
| 337 |
+
prob = code_count / torch.sum(code_count)
|
| 338 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
| 339 |
+
return perplexity
|
| 340 |
+
|
| 341 |
+
@torch.no_grad()
|
| 342 |
+
def update_codebook(self, x, code_idx):
|
| 343 |
+
|
| 344 |
+
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
|
| 345 |
+
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
|
| 346 |
+
|
| 347 |
+
code_sum = torch.matmul(code_onehot, x) # nb_code, w
|
| 348 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
| 349 |
+
|
| 350 |
+
# Update centres
|
| 351 |
+
self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
|
| 352 |
+
self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
|
| 353 |
+
|
| 354 |
+
code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
|
| 355 |
+
|
| 356 |
+
self.codebook = code_update
|
| 357 |
+
prob = code_count / torch.sum(code_count)
|
| 358 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
| 359 |
+
|
| 360 |
+
return perplexity
|
| 361 |
+
|
| 362 |
+
def preprocess(self, x):
|
| 363 |
+
# NCT -> NTC -> [NT, C]
|
| 364 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 365 |
+
x = x.view(-1, x.shape[-1])
|
| 366 |
+
return x
|
| 367 |
+
|
| 368 |
+
def quantize(self, x):
|
| 369 |
+
# Calculate latent code x_l
|
| 370 |
+
k_w = self.codebook.t()
|
| 371 |
+
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
|
| 372 |
+
keepdim=True) # (N * L, b)
|
| 373 |
+
_, code_idx = torch.min(distance, dim=-1)
|
| 374 |
+
return code_idx
|
| 375 |
+
|
| 376 |
+
def dequantize(self, code_idx):
|
| 377 |
+
x = F.embedding(code_idx, self.codebook)
|
| 378 |
+
return x
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def forward(self, x):
|
| 382 |
+
N, width, T = x.shape
|
| 383 |
+
|
| 384 |
+
# Preprocess
|
| 385 |
+
x = self.preprocess(x)
|
| 386 |
+
|
| 387 |
+
# Init codebook if not inited
|
| 388 |
+
if self.training and not self.init:
|
| 389 |
+
self.init_codebook(x)
|
| 390 |
+
|
| 391 |
+
# quantize and dequantize through bottleneck
|
| 392 |
+
code_idx = self.quantize(x)
|
| 393 |
+
x_d = self.dequantize(code_idx)
|
| 394 |
+
|
| 395 |
+
# Update embeddings
|
| 396 |
+
if self.training:
|
| 397 |
+
perplexity = self.update_codebook(x, code_idx)
|
| 398 |
+
else :
|
| 399 |
+
perplexity = self.compute_perplexity(code_idx)
|
| 400 |
+
|
| 401 |
+
# Loss
|
| 402 |
+
commit_loss = F.mse_loss(x, x_d.detach())
|
| 403 |
+
|
| 404 |
+
# Passthrough
|
| 405 |
+
x_d = x + (x_d - x).detach()
|
| 406 |
+
|
| 407 |
+
# Postprocess
|
| 408 |
+
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
|
| 409 |
+
|
| 410 |
+
return x_d, commit_loss, perplexity
|
mGPT/archs/tools/resnet.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class nonlinearity(nn.Module):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
def forward(self, x):
|
| 10 |
+
# swish
|
| 11 |
+
return x * torch.sigmoid(x)
|
| 12 |
+
|
| 13 |
+
class ResConv1DBlock(nn.Module):
|
| 14 |
+
def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None):
|
| 15 |
+
super().__init__()
|
| 16 |
+
padding = dilation
|
| 17 |
+
self.norm = norm
|
| 18 |
+
if norm == "LN":
|
| 19 |
+
self.norm1 = nn.LayerNorm(n_in)
|
| 20 |
+
self.norm2 = nn.LayerNorm(n_in)
|
| 21 |
+
elif norm == "GN":
|
| 22 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
|
| 23 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
|
| 24 |
+
elif norm == "BN":
|
| 25 |
+
self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
|
| 26 |
+
self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
|
| 27 |
+
|
| 28 |
+
else:
|
| 29 |
+
self.norm1 = nn.Identity()
|
| 30 |
+
self.norm2 = nn.Identity()
|
| 31 |
+
|
| 32 |
+
if activation == "relu":
|
| 33 |
+
self.activation1 = nn.ReLU()
|
| 34 |
+
self.activation2 = nn.ReLU()
|
| 35 |
+
|
| 36 |
+
elif activation == "silu":
|
| 37 |
+
self.activation1 = nonlinearity()
|
| 38 |
+
self.activation2 = nonlinearity()
|
| 39 |
+
|
| 40 |
+
elif activation == "gelu":
|
| 41 |
+
self.activation1 = nn.GELU()
|
| 42 |
+
self.activation2 = nn.GELU()
|
| 43 |
+
|
| 44 |
+
self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation)
|
| 45 |
+
self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0,)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
x_orig = x
|
| 50 |
+
if self.norm == "LN":
|
| 51 |
+
x = self.norm1(x.transpose(-2, -1))
|
| 52 |
+
x = self.activation1(x.transpose(-2, -1))
|
| 53 |
+
else:
|
| 54 |
+
x = self.norm1(x)
|
| 55 |
+
x = self.activation1(x)
|
| 56 |
+
|
| 57 |
+
x = self.conv1(x)
|
| 58 |
+
|
| 59 |
+
if self.norm == "LN":
|
| 60 |
+
x = self.norm2(x.transpose(-2, -1))
|
| 61 |
+
x = self.activation2(x.transpose(-2, -1))
|
| 62 |
+
else:
|
| 63 |
+
x = self.norm2(x)
|
| 64 |
+
x = self.activation2(x)
|
| 65 |
+
|
| 66 |
+
x = self.conv2(x)
|
| 67 |
+
x = x + x_orig
|
| 68 |
+
return x
|
| 69 |
+
|
| 70 |
+
class Resnet1D(nn.Module):
|
| 71 |
+
def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None):
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) for depth in range(n_depth)]
|
| 75 |
+
if reverse_dilation:
|
| 76 |
+
blocks = blocks[::-1]
|
| 77 |
+
|
| 78 |
+
self.model = nn.Sequential(*blocks)
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
return self.model(x)
|
metrics.py
ADDED
|
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation metrics for motion generation
|
| 3 |
+
"""
|
| 4 |
+
import random
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import json
|
| 8 |
+
import numpy as np
|
| 9 |
+
import scipy.linalg
|
| 10 |
+
import torch
|
| 11 |
+
from typing import List, Tuple, Dict, Optional, Any
|
| 12 |
+
from rapidfuzz.distance import Levenshtein
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
from data import motion_specials_to_ids
|
| 15 |
+
from config import (
|
| 16 |
+
SEED, PIPELINE_OUTPUT_DIR, M_START, M_END,
|
| 17 |
+
INFERENCE_TEMPERATURE, INFERENCE_TOP_K, INFERENCE_REPETITION_PENALTY
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
random.seed(SEED)
|
| 21 |
+
|
| 22 |
+
# ======================================================================================
|
| 23 |
+
# Logic from test_overfit.py (Metrics & Visualization)
|
| 24 |
+
# ======================================================================================
|
| 25 |
+
|
| 26 |
+
def calculate_activation_statistics_np(activations: np.ndarray):
|
| 27 |
+
"""
|
| 28 |
+
Params:
|
| 29 |
+
-- activations: num_samples x dim_feat (numpy)
|
| 30 |
+
Returns:
|
| 31 |
+
-- mu: dim_feat
|
| 32 |
+
-- sigma: dim_feat x dim_feat
|
| 33 |
+
"""
|
| 34 |
+
mu = np.mean(activations, axis=0)
|
| 35 |
+
cov = np.cov(activations, rowvar=False)
|
| 36 |
+
return mu, cov
|
| 37 |
+
|
| 38 |
+
def calculate_frechet_distance_np(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
| 39 |
+
"""Numpy implementation of the Frechet Distance."""
|
| 40 |
+
mu1 = np.atleast_1d(mu1)
|
| 41 |
+
mu2 = np.atleast_1d(mu2)
|
| 42 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 43 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 44 |
+
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
|
| 45 |
+
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
|
| 46 |
+
diff = mu1 - mu2
|
| 47 |
+
covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 48 |
+
if not np.isfinite(covmean).all():
|
| 49 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 50 |
+
covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 51 |
+
if np.iscomplexobj(covmean):
|
| 52 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 53 |
+
m = np.max(np.abs(covmean.imag))
|
| 54 |
+
raise ValueError(f"Imaginary component {m}")
|
| 55 |
+
covmean = covmean.real
|
| 56 |
+
tr_covmean = np.trace(covmean)
|
| 57 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
| 58 |
+
|
| 59 |
+
def calculate_diversity_np(activation: np.ndarray, diversity_times: int = 200) -> float:
|
| 60 |
+
"""Mean pairwise L2 distance across random pairs."""
|
| 61 |
+
assert len(activation.shape) == 2
|
| 62 |
+
if activation.shape[0] < 2:
|
| 63 |
+
return 0.0
|
| 64 |
+
num_samples = activation.shape[0]
|
| 65 |
+
effective_times = min(diversity_times, max(1, num_samples - 1))
|
| 66 |
+
first_indices = np.random.choice(num_samples, effective_times, replace=False)
|
| 67 |
+
second_indices = np.random.choice(num_samples, effective_times, replace=False)
|
| 68 |
+
diffs = activation[first_indices] - activation[second_indices]
|
| 69 |
+
dist = np.linalg.norm(diffs, axis=1)
|
| 70 |
+
return float(dist.mean())
|
| 71 |
+
|
| 72 |
+
def calculate_multimodality_np(activation: np.ndarray, multimodality_times: int = 20) -> float:
|
| 73 |
+
"""
|
| 74 |
+
activation: [num_labels, num_per_label, D]
|
| 75 |
+
Returns mean pairwise within-label diversity (higher = more multimodal).
|
| 76 |
+
"""
|
| 77 |
+
assert len(activation.shape) == 3
|
| 78 |
+
num_labels, num_per_label, _ = activation.shape
|
| 79 |
+
if num_per_label < 2:
|
| 80 |
+
return float("nan")
|
| 81 |
+
effective_times = min(multimodality_times, max(1, num_per_label - 1))
|
| 82 |
+
first_dices = np.random.choice(num_per_label, effective_times, replace=False)
|
| 83 |
+
second_dices = np.random.choice(num_per_label, effective_times, replace=False)
|
| 84 |
+
diffs = activation[:, first_dices] - activation[:, second_dices]
|
| 85 |
+
dist = np.linalg.norm(diffs, axis=2)
|
| 86 |
+
return float(dist.mean())
|
| 87 |
+
|
| 88 |
+
# --------------------------------------------------------------------------------------
|
| 89 |
+
# Token sequence → activation (bag-of-motion-tokens) helpers
|
| 90 |
+
# --------------------------------------------------------------------------------------
|
| 91 |
+
def _extract_motion_tokens_from_sequence(seq: str) -> list[str]:
|
| 92 |
+
# Expect tokens like <M123>, within M_START/M_END fences; keep only <M...>
|
| 93 |
+
return [tok for tok in seq.split() if tok.startswith("<M") and tok.endswith(">")]
|
| 94 |
+
|
| 95 |
+
def _extract_ids_from_sequence(seq: str) -> list[int]:
|
| 96 |
+
return [int(t[2:-1]) for t in _extract_motion_tokens_from_sequence(seq) if t[2:-1].isdigit()]
|
| 97 |
+
|
| 98 |
+
def _build_token_index(tokens_vocab: list[str]) -> Dict[str, int]:
|
| 99 |
+
return {tok: idx for idx, tok in enumerate(tokens_vocab)}
|
| 100 |
+
|
| 101 |
+
def _sequence_to_activation(seq: str, token_to_index: Dict[str, int]) -> np.ndarray:
|
| 102 |
+
vec = np.zeros((len(token_to_index),), dtype=np.float32)
|
| 103 |
+
for tok in _extract_motion_tokens_from_sequence(seq):
|
| 104 |
+
idx = token_to_index.get(tok)
|
| 105 |
+
if idx is not None:
|
| 106 |
+
vec[idx] += 1.0
|
| 107 |
+
# Normalize to unit length to reduce length bias
|
| 108 |
+
norm = np.linalg.norm(vec)
|
| 109 |
+
if norm > 0:
|
| 110 |
+
vec = vec / norm
|
| 111 |
+
return vec
|
| 112 |
+
|
| 113 |
+
def generate_motion(model, tokenizer, prompt, device):
|
| 114 |
+
"""Generates a motion sequence from a prompt using sampling."""
|
| 115 |
+
model.eval()
|
| 116 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 117 |
+
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
output = model.generate(
|
| 120 |
+
**inputs,
|
| 121 |
+
max_new_tokens=100,
|
| 122 |
+
do_sample=True,
|
| 123 |
+
temperature=INFERENCE_TEMPERATURE,
|
| 124 |
+
top_k=INFERENCE_TOP_K,
|
| 125 |
+
repetition_penalty=INFERENCE_REPETITION_PENALTY,
|
| 126 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 127 |
+
eos_token_id=tokenizer.convert_tokens_to_ids(M_END),
|
| 128 |
+
early_stopping=True
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
decoded = tokenizer.decode(output[0], skip_special_tokens=False)
|
| 132 |
+
if "Motion: " in decoded:
|
| 133 |
+
motion_part = decoded.split("Motion: ")[-1]
|
| 134 |
+
else:
|
| 135 |
+
motion_part = decoded
|
| 136 |
+
return motion_part.strip()
|
| 137 |
+
|
| 138 |
+
def _collect_eval_pairs(model, tokenizer, data, device) -> list[Tuple[str, str, str]]:
|
| 139 |
+
"""
|
| 140 |
+
Returns list of (word, participant_id, gt_sequence, generated_sequence) for each sample in data.
|
| 141 |
+
"""
|
| 142 |
+
results = []
|
| 143 |
+
for sample in data:
|
| 144 |
+
gt_tokens_str = sample.get("motion_tokens", "")
|
| 145 |
+
gt_wrapped = " ".join([f"<M{t}>" for t in gt_tokens_str.split()])
|
| 146 |
+
gt_sequence = f"{M_START} {gt_wrapped} {M_END}"
|
| 147 |
+
prompt = f"Instruction: Generate motion for word '{sample['word']}' with variant '{sample['participant_id']}'.\nMotion: "
|
| 148 |
+
generated_sequence = generate_motion(model, tokenizer, prompt, device)
|
| 149 |
+
pid = str(sample.get("participant_id", ""))
|
| 150 |
+
results.append((sample["word"], pid, gt_sequence, generated_sequence))
|
| 151 |
+
return results
|
| 152 |
+
|
| 153 |
+
def _activations_from_pairs(pairs: list[Tuple[str, str, str]], vocab_tokens: list[str]):
|
| 154 |
+
"""
|
| 155 |
+
Build numpy activations and labels arrays from sequences.
|
| 156 |
+
Returns:
|
| 157 |
+
gt_acts: (N, D)
|
| 158 |
+
gen_acts: (N, D)
|
| 159 |
+
labels: list[str] length N (word labels)
|
| 160 |
+
"""
|
| 161 |
+
token_to_index = _build_token_index(vocab_tokens)
|
| 162 |
+
gt_vecs = []
|
| 163 |
+
gen_vecs = []
|
| 164 |
+
labels = []
|
| 165 |
+
for pair in pairs:
|
| 166 |
+
# Support both legacy 3-tuple (word, gt, gen) and new 4-tuple (word, pid, gt, gen)
|
| 167 |
+
if len(pair) == 4:
|
| 168 |
+
word, _pid, gt_seq, gen_seq = pair
|
| 169 |
+
else:
|
| 170 |
+
word, gt_seq, gen_seq = pair
|
| 171 |
+
gt_vecs.append(_sequence_to_activation(gt_seq, token_to_index))
|
| 172 |
+
gen_vecs.append(_sequence_to_activation(gen_seq, token_to_index))
|
| 173 |
+
labels.append(word)
|
| 174 |
+
return np.stack(gt_vecs, axis=0), np.stack(gen_vecs, axis=0), labels
|
| 175 |
+
|
| 176 |
+
def _to_label_tensor3(acts: np.ndarray, labels: list[str]) -> np.ndarray:
|
| 177 |
+
"""
|
| 178 |
+
Convert N x D activations with string labels to [L, K, D] by truncating each label
|
| 179 |
+
to the minimum count across labels.
|
| 180 |
+
"""
|
| 181 |
+
label_to_indices: Dict[str, list[int]] = {}
|
| 182 |
+
for i, lbl in enumerate(labels):
|
| 183 |
+
label_to_indices.setdefault(lbl, []).append(i)
|
| 184 |
+
per_label_counts = [len(idxs) for idxs in label_to_indices.values()]
|
| 185 |
+
if len(per_label_counts) == 0:
|
| 186 |
+
raise ValueError("No labels found for multimodality computation.")
|
| 187 |
+
min_count = max(2, min(per_label_counts))
|
| 188 |
+
label_names = sorted(label_to_indices.keys())
|
| 189 |
+
stacked = []
|
| 190 |
+
for lbl in label_names:
|
| 191 |
+
idxs = label_to_indices[lbl][:min_count]
|
| 192 |
+
stacked.append(acts[idxs])
|
| 193 |
+
return np.stack(stacked, axis=0) # [L, K, D]
|
| 194 |
+
|
| 195 |
+
def evaluate_metrics_motiongpt_style(model, tokenizer, eval_data, all_motion_tokens, device):
|
| 196 |
+
"""
|
| 197 |
+
Computes:
|
| 198 |
+
- Diversity: GT vs GEN (pair)
|
| 199 |
+
- Multimodality (MIM): GT vs GEN (pair)
|
| 200 |
+
- FID: between GT and GEN
|
| 201 |
+
"""
|
| 202 |
+
print("\n" + "="*80)
|
| 203 |
+
print(" METRICS EVALUATION (FID, Diversity, Multimodality)")
|
| 204 |
+
print("="*80)
|
| 205 |
+
pairs = _collect_eval_pairs(model, tokenizer, eval_data, device)
|
| 206 |
+
gt_acts, gen_acts, labels = _activations_from_pairs(pairs, all_motion_tokens)
|
| 207 |
+
# Diversity
|
| 208 |
+
diversity_times = min(200, max(4, gt_acts.shape[0] - 1))
|
| 209 |
+
diversity_gt = calculate_diversity_np(gt_acts, diversity_times=diversity_times)
|
| 210 |
+
diversity_gen = calculate_diversity_np(gen_acts, diversity_times=diversity_times)
|
| 211 |
+
# Multimodality (MIM)
|
| 212 |
+
try:
|
| 213 |
+
gt_lbl_tensor = _to_label_tensor3(gt_acts, labels)
|
| 214 |
+
gen_lbl_tensor = _to_label_tensor3(gen_acts, labels)
|
| 215 |
+
multimodality_times = min(20, max(3, gt_lbl_tensor.shape[1] - 1))
|
| 216 |
+
mim_gt = calculate_multimodality_np(gt_lbl_tensor, multimodality_times=multimodality_times)
|
| 217 |
+
mim_gen = calculate_multimodality_np(gen_lbl_tensor, multimodality_times=multimodality_times)
|
| 218 |
+
except Exception as exc:
|
| 219 |
+
print(f"⚠️ Multimodality could not be computed reliably: {exc}")
|
| 220 |
+
mim_gt = float("nan")
|
| 221 |
+
mim_gen = float("nan")
|
| 222 |
+
# FID
|
| 223 |
+
mu_gen, cov_gen = calculate_activation_statistics_np(gen_acts)
|
| 224 |
+
mu_gt, cov_gt = calculate_activation_statistics_np(gt_acts)
|
| 225 |
+
fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen)
|
| 226 |
+
print(f"Diversity: GT = {diversity_gt:.4f} | GEN = {diversity_gen:.4f}")
|
| 227 |
+
print(f"Multimodality (MIM): GT = {mim_gt:.4f} | GEN = {mim_gen:.4f}")
|
| 228 |
+
print(f"FID (GT vs GEN): {fid:.4f}")
|
| 229 |
+
return {
|
| 230 |
+
"diversity_gt": diversity_gt,
|
| 231 |
+
"diversity_gen": diversity_gen,
|
| 232 |
+
"mim_gt": mim_gt,
|
| 233 |
+
"mim_gen": mim_gen,
|
| 234 |
+
"fid": fid,
|
| 235 |
+
"pairs": pairs, # for visualization usage
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
def _encode_params_to_feature(params: np.ndarray, vq_model, mean, std, device) -> np.ndarray:
|
| 239 |
+
"""
|
| 240 |
+
Convert SMPL-X parameter sequence (T, D) into a single clip feature using
|
| 241 |
+
the VQ-VAE encoder output BEFORE quantization. Average-pool over time to get (D_embed,).
|
| 242 |
+
"""
|
| 243 |
+
if params.size == 0:
|
| 244 |
+
return np.zeros((getattr(vq_model.vqvae, "output_emb_width", 512),), dtype=np.float32)
|
| 245 |
+
x = torch.from_numpy(params.astype(np.float32)).to(device) # [T, D]
|
| 246 |
+
x = x.unsqueeze(0) # [1, T, D]
|
| 247 |
+
with torch.no_grad():
|
| 248 |
+
# Normalize / preprocess
|
| 249 |
+
x_pre = None
|
| 250 |
+
if hasattr(vq_model.vqvae, "preprocess"):
|
| 251 |
+
try:
|
| 252 |
+
x_pre = vq_model.vqvae.preprocess(x) # expected to return tensor ready for encoder
|
| 253 |
+
except Exception:
|
| 254 |
+
x_pre = None
|
| 255 |
+
if x_pre is None:
|
| 256 |
+
# Manual normalization with provided mean/std
|
| 257 |
+
if mean is not None and std is not None:
|
| 258 |
+
mean_t = torch.from_numpy(np.array(mean, dtype=np.float32)).to(device).view(1, 1, -1)
|
| 259 |
+
std_t = torch.from_numpy(np.array(std, dtype=np.float32)).to(device).view(1, 1, -1)
|
| 260 |
+
x_norm = (x - mean_t) / (std_t + 1e-8)
|
| 261 |
+
else:
|
| 262 |
+
x_norm = x
|
| 263 |
+
# Some encoders expect [N, D, T]
|
| 264 |
+
x_pre = x_norm.transpose(1, 2).contiguous() # [1, D, T]
|
| 265 |
+
# Encode to get pre-quant latent
|
| 266 |
+
z_e = vq_model.vqvae.encoder(x_pre)
|
| 267 |
+
# z_e could be [N, D_embed, T_q] or [N, T_q, D_embed]
|
| 268 |
+
if z_e.dim() == 3:
|
| 269 |
+
embed_dim_known = getattr(vq_model.vqvae, "output_emb_width", None)
|
| 270 |
+
if embed_dim_known is not None:
|
| 271 |
+
if z_e.shape[1] == embed_dim_known:
|
| 272 |
+
time_axis = 2 # [N, D_embed, T_q]
|
| 273 |
+
elif z_e.shape[2] == embed_dim_known:
|
| 274 |
+
time_axis = 1 # [N, T_q, D_embed]
|
| 275 |
+
else:
|
| 276 |
+
time_axis = 2 if z_e.shape[2] < z_e.shape[1] else 1
|
| 277 |
+
else:
|
| 278 |
+
time_axis = 2 if z_e.shape[2] < z_e.shape[1] else 1
|
| 279 |
+
feat = z_e.mean(dim=time_axis).squeeze(0)
|
| 280 |
+
elif z_e.dim() == 2:
|
| 281 |
+
feat = z_e.squeeze(0)
|
| 282 |
+
else:
|
| 283 |
+
feat = z_e.view(1, -1).mean(dim=0)
|
| 284 |
+
feat_np = feat.detach().cpu().numpy().astype(np.float32)
|
| 285 |
+
# L2 normalize
|
| 286 |
+
norm = np.linalg.norm(feat_np)
|
| 287 |
+
if norm > 0:
|
| 288 |
+
feat_np = feat_np / norm
|
| 289 |
+
return feat_np
|
| 290 |
+
|
| 291 |
+
def evaluate_metrics_encoder_style(
|
| 292 |
+
model,
|
| 293 |
+
tokenizer,
|
| 294 |
+
eval_data,
|
| 295 |
+
device,
|
| 296 |
+
vqvae_ckpt: Optional[str] = None,
|
| 297 |
+
stats_path: Optional[str] = None,
|
| 298 |
+
sample_limit: int = 100,
|
| 299 |
+
):
|
| 300 |
+
"""
|
| 301 |
+
Computes FID, Diversity, and MIM using VQ-VAE encoder pre-quantization features.
|
| 302 |
+
"""
|
| 303 |
+
print("\n" + "="*80)
|
| 304 |
+
print(" METRICS EVALUATION (VQ-VAE Encoder Features)")
|
| 305 |
+
print("="*80)
|
| 306 |
+
# Lazy import to reuse your visualization utilities and stats
|
| 307 |
+
try:
|
| 308 |
+
from visualize import load_vqvae, load_stats, VQVAE_CHECKPOINT as DEFAULT_VQ, STATS_PATH as DEFAULT_STATS
|
| 309 |
+
vq_ckpt = vqvae_ckpt or os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ)
|
| 310 |
+
stats_p = stats_path or os.getenv("VQVAE_STATS_PATH", DEFAULT_STATS)
|
| 311 |
+
vq_model = load_vqvae(vq_ckpt, device=device)
|
| 312 |
+
mean, std = load_stats(stats_p)
|
| 313 |
+
from visualize import decode_tokens_to_params
|
| 314 |
+
except Exception as exc:
|
| 315 |
+
print(f"⚠️ Could not set up VQ-VAE encoder metrics: {exc}")
|
| 316 |
+
return {}
|
| 317 |
+
# Collect GT/GEN token sequences for pairs (limit to speed-up)
|
| 318 |
+
pairs = _collect_eval_pairs(model, tokenizer, eval_data[:sample_limit], device)
|
| 319 |
+
# Build features
|
| 320 |
+
gt_feats = []
|
| 321 |
+
gen_feats = []
|
| 322 |
+
labels = []
|
| 323 |
+
for pair in pairs:
|
| 324 |
+
if len(pair) == 4:
|
| 325 |
+
word, _pid, gt_seq, gen_seq = pair
|
| 326 |
+
else:
|
| 327 |
+
word, gt_seq, gen_seq = pair
|
| 328 |
+
# Decode to SMPL-X
|
| 329 |
+
tokens_gt = _extract_ids_from_sequence(gt_seq)
|
| 330 |
+
tokens_gen = _extract_ids_from_sequence(gen_seq)
|
| 331 |
+
try:
|
| 332 |
+
params_gt = decode_tokens_to_params(tokens_gt, vq_model, mean, std, device=device) # (T, D) denorm
|
| 333 |
+
except Exception:
|
| 334 |
+
params_gt = np.zeros((0, 182), dtype=np.float32)
|
| 335 |
+
try:
|
| 336 |
+
params_gen = decode_tokens_to_params(tokens_gen, vq_model, mean, std, device=device) # (T, D) denorm
|
| 337 |
+
except Exception:
|
| 338 |
+
params_gen = np.zeros((0, 182), dtype=np.float32)
|
| 339 |
+
# Encode (pre-quant) -> pooled feature
|
| 340 |
+
feat_gt = _encode_params_to_feature(params_gt, vq_model, mean, std, device)
|
| 341 |
+
feat_gen = _encode_params_to_feature(params_gen, vq_model, mean, std, device)
|
| 342 |
+
gt_feats.append(feat_gt)
|
| 343 |
+
gen_feats.append(feat_gen)
|
| 344 |
+
labels.append(word)
|
| 345 |
+
gt_feats = np.stack(gt_feats, axis=0)
|
| 346 |
+
gen_feats = np.stack(gen_feats, axis=0)
|
| 347 |
+
# Diversity
|
| 348 |
+
diversity_times = min(200, max(4, gt_feats.shape[0] - 1))
|
| 349 |
+
diversity_gt = calculate_diversity_np(gt_feats, diversity_times=diversity_times)
|
| 350 |
+
diversity_gen = calculate_diversity_np(gen_feats, diversity_times=diversity_times)
|
| 351 |
+
# Multimodality (MIM)
|
| 352 |
+
try:
|
| 353 |
+
gt_lbl_tensor = _to_label_tensor3(gt_feats, labels)
|
| 354 |
+
gen_lbl_tensor = _to_label_tensor3(gen_feats, labels)
|
| 355 |
+
multimodality_times = min(20, max(3, gt_lbl_tensor.shape[1] - 1))
|
| 356 |
+
mim_gt = calculate_multimodality_np(gt_lbl_tensor, multimodality_times=multimodality_times)
|
| 357 |
+
mim_gen = calculate_multimodality_np(gen_lbl_tensor, multimodality_times=multimodality_times)
|
| 358 |
+
except Exception as exc:
|
| 359 |
+
print(f"⚠️ Multimodality could not be computed reliably: {exc}")
|
| 360 |
+
mim_gt = float("nan")
|
| 361 |
+
mim_gen = float("nan")
|
| 362 |
+
# FID (on encoder features)
|
| 363 |
+
mu_gen, cov_gen = calculate_activation_statistics_np(gen_feats)
|
| 364 |
+
mu_gt, cov_gt = calculate_activation_statistics_np(gt_feats)
|
| 365 |
+
fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen)
|
| 366 |
+
print(f"Diversity (encoder feats): GT = {diversity_gt:.4f} | GEN = {diversity_gen:.4f}")
|
| 367 |
+
print(f"Multimodality (MIM, encoder): GT = {mim_gt:.4f} | GEN = {mim_gen:.4f}")
|
| 368 |
+
print(f"FID (encoder feats, GT vs GEN): {fid:.4f}")
|
| 369 |
+
return {
|
| 370 |
+
"diversity_gt": diversity_gt,
|
| 371 |
+
"diversity_gen": diversity_gen,
|
| 372 |
+
"mim_gt": mim_gt,
|
| 373 |
+
"mim_gen": mim_gen,
|
| 374 |
+
"fid": fid,
|
| 375 |
+
"pairs": pairs,
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
def save_side_by_side_visualizations(pairs: list[Tuple[str, str, str]], output_dir: str, limit: int = 4):
|
| 379 |
+
"""
|
| 380 |
+
Generate side-by-side 3D animations for GT vs GEN.
|
| 381 |
+
"""
|
| 382 |
+
try:
|
| 383 |
+
from visualize import (
|
| 384 |
+
load_vqvae, load_stats, load_smplx_model,
|
| 385 |
+
decode_tokens_to_params, params_to_vertices,
|
| 386 |
+
VQVAE_CHECKPOINT as DEFAULT_VQ, STATS_PATH as DEFAULT_STATS, SMPLX_MODEL_DIR as DEFAULT_SMPLX
|
| 387 |
+
)
|
| 388 |
+
import plotly.graph_objects as go
|
| 389 |
+
from plotly.subplots import make_subplots
|
| 390 |
+
except Exception as exc:
|
| 391 |
+
print(f"⚠️ Visualization skipped (missing dependencies): {exc}")
|
| 392 |
+
return
|
| 393 |
+
|
| 394 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 395 |
+
vqvae_ckpt = os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ)
|
| 396 |
+
stats_path = os.getenv("VQVAE_STATS_PATH", DEFAULT_STATS)
|
| 397 |
+
smplx_dir = os.getenv("SMPLX_MODEL_DIR", DEFAULT_SMPLX)
|
| 398 |
+
|
| 399 |
+
print("Loading VQ-VAE, stats, SMPL-X ...")
|
| 400 |
+
vq_model = load_vqvae(vqvae_ckpt)
|
| 401 |
+
mean, std = load_stats(stats_path)
|
| 402 |
+
smplx_model = load_smplx_model(smplx_dir)
|
| 403 |
+
|
| 404 |
+
def animate_side_by_side(verts_left, faces, verts_right, fps=20, titles=("Ground Truth", "LLM Generated"), output_html=None):
|
| 405 |
+
T = min(verts_left.shape[0], verts_right.shape[0])
|
| 406 |
+
verts_left, verts_right = verts_left[:T], verts_right[:T]
|
| 407 |
+
i, j, k = faces.T.tolist()
|
| 408 |
+
fig = make_subplots(
|
| 409 |
+
rows=1, cols=2,
|
| 410 |
+
specs=[[{'type': 'scene'}, {'type': 'scene'}]],
|
| 411 |
+
horizontal_spacing=0.05,
|
| 412 |
+
subplot_titles=list(titles)
|
| 413 |
+
)
|
| 414 |
+
left_mesh = go.Mesh3d(x=verts_left[0,:,0], y=verts_left[0,:,1], z=verts_left[0,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False)
|
| 415 |
+
right_mesh = go.Mesh3d(x=verts_right[0,:,0], y=verts_right[0,:,1], z=verts_right[0,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False)
|
| 416 |
+
fig.add_trace(left_mesh, row=1, col=1)
|
| 417 |
+
fig.add_trace(right_mesh, row=1, col=2)
|
| 418 |
+
frames = []
|
| 419 |
+
for t in range(T):
|
| 420 |
+
frames.append(go.Frame(
|
| 421 |
+
name=str(t),
|
| 422 |
+
data=[
|
| 423 |
+
go.Mesh3d(x=verts_left[t,:,0], y=verts_left[t,:,1], z=verts_left[t,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False,scene="scene"),
|
| 424 |
+
go.Mesh3d(x=verts_right[t,:,0], y=verts_right[t,:,1], z=verts_right[t,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False,scene="scene2")
|
| 425 |
+
]
|
| 426 |
+
))
|
| 427 |
+
fig.frames = frames
|
| 428 |
+
fig.update_layout(
|
| 429 |
+
showlegend=False,
|
| 430 |
+
margin=dict(l=10, r=10, t=50, b=10),
|
| 431 |
+
scene=dict(aspectmode='data',xaxis=dict(visible=False),yaxis=dict(visible=False),zaxis=dict(visible=False),
|
| 432 |
+
camera=dict(eye=dict(x=0,y=-2,z=0.7))),
|
| 433 |
+
scene2=dict(aspectmode='data',xaxis=dict(visible=False),yaxis=dict(visible=False),zaxis=dict(visible=False),
|
| 434 |
+
camera=dict(eye=dict(x=0,y=-2,z=0.7))),
|
| 435 |
+
updatemenus=[dict(
|
| 436 |
+
type="buttons", x=0.5, xanchor="center", y=1.15, yanchor="top",
|
| 437 |
+
buttons=[
|
| 438 |
+
dict(label="Play", method="animate", args=[None, {"frame": {"duration": max(1,1000//fps), "redraw": True}, "fromcurrent": True}]),
|
| 439 |
+
dict(label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}}])
|
| 440 |
+
]
|
| 441 |
+
)]
|
| 442 |
+
)
|
| 443 |
+
if output_html:
|
| 444 |
+
fig.write_html(output_html)
|
| 445 |
+
print(f"✅ Saved: {output_html}")
|
| 446 |
+
return fig
|
| 447 |
+
|
| 448 |
+
# Determine which words to include (up to `limit` distinct words)
|
| 449 |
+
allowed_words = None
|
| 450 |
+
if isinstance(limit, int) and limit > 0:
|
| 451 |
+
ordered_unique_words = []
|
| 452 |
+
for pair in pairs:
|
| 453 |
+
word = pair[0]
|
| 454 |
+
if word not in ordered_unique_words:
|
| 455 |
+
ordered_unique_words.append(word)
|
| 456 |
+
if len(ordered_unique_words) >= limit:
|
| 457 |
+
break
|
| 458 |
+
allowed_words = set(ordered_unique_words)
|
| 459 |
+
|
| 460 |
+
for pair in pairs:
|
| 461 |
+
try:
|
| 462 |
+
if len(pair) == 4:
|
| 463 |
+
word, pid, gt_seq, gen_seq = pair
|
| 464 |
+
else:
|
| 465 |
+
word, gt_seq, gen_seq = pair
|
| 466 |
+
pid = "unknown"
|
| 467 |
+
if allowed_words is not None and word not in allowed_words:
|
| 468 |
+
continue
|
| 469 |
+
tokens_gt = _extract_ids_from_sequence(gt_seq)
|
| 470 |
+
tokens_gen = _extract_ids_from_sequence(gen_seq)
|
| 471 |
+
params_gt = decode_tokens_to_params(tokens_gt, vq_model, mean, std)
|
| 472 |
+
params_gen = decode_tokens_to_params(tokens_gen, vq_model, mean, std)
|
| 473 |
+
verts_gt, faces = params_to_vertices(params_gt, smplx_model)
|
| 474 |
+
verts_gen, _ = params_to_vertices(params_gen, smplx_model)
|
| 475 |
+
out_dir = os.path.join(output_dir)
|
| 476 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 477 |
+
# Sanitize for filesystem safety
|
| 478 |
+
safe_word = re.sub(r'[^A-Za-z0-9_-]+', '_', str(word))
|
| 479 |
+
safe_pid = re.sub(r'[^A-Za-z0-9_-]+', '_', str(pid))
|
| 480 |
+
output_html = os.path.join(out_dir, f"word_{safe_word}_{safe_pid}_side_by_side.html")
|
| 481 |
+
animate_side_by_side(
|
| 482 |
+
verts_left=verts_gt,
|
| 483 |
+
faces=faces,
|
| 484 |
+
verts_right=verts_gen,
|
| 485 |
+
fps=20,
|
| 486 |
+
titles=("Ground Truth", "LLM Generated"),
|
| 487 |
+
output_html=output_html
|
| 488 |
+
)
|
| 489 |
+
except Exception as exc:
|
| 490 |
+
print(f"⚠️ Error creating visualization for word '{pair[0]}': {exc}")
|
| 491 |
+
|
| 492 |
+
def run_inference_on_all_samples(model, tokenizer, data, device):
|
| 493 |
+
"""
|
| 494 |
+
Runs inference on ALL available samples for the trained words and compares
|
| 495 |
+
each one to its specific ground truth.
|
| 496 |
+
"""
|
| 497 |
+
print("\n" + "="*80)
|
| 498 |
+
print(" INFERENCE AND EVALUATION (ALL SAMPLES)")
|
| 499 |
+
print(" Goal: Test the model's performance on every variant.")
|
| 500 |
+
print("="*80)
|
| 501 |
+
|
| 502 |
+
def compare_sequences(gt: str, gen: str):
|
| 503 |
+
"""Provides a simple visual diff of two sequences without external libraries."""
|
| 504 |
+
gt_tokens = gt.split()
|
| 505 |
+
gen_tokens = gen.split()
|
| 506 |
+
|
| 507 |
+
print("\nDetailed Comparison (✅ = Match, ❌ = Mismatch/Missing/Added):")
|
| 508 |
+
|
| 509 |
+
gt_str = " GT: "
|
| 510 |
+
gen_str = " GEN: "
|
| 511 |
+
diff_str = " "
|
| 512 |
+
|
| 513 |
+
max_len = max(len(gt_tokens), len(gen_tokens))
|
| 514 |
+
|
| 515 |
+
for i in range(max_len):
|
| 516 |
+
gt_tok = gt_tokens[i] if i < len(gt_tokens) else "___"
|
| 517 |
+
gen_tok = gen_tokens[i] if i < len(gen_tokens) else "___"
|
| 518 |
+
|
| 519 |
+
max_tok_len = max(len(gt_tok), len(gen_tok))
|
| 520 |
+
gt_tok_padded = gt_tok.ljust(max_tok_len)
|
| 521 |
+
gen_tok_padded = gen_tok.ljust(max_tok_len)
|
| 522 |
+
|
| 523 |
+
gt_str += gt_tok_padded + " "
|
| 524 |
+
gen_str += gen_tok_padded + " "
|
| 525 |
+
|
| 526 |
+
if gt_tok == gen_tok:
|
| 527 |
+
diff_str += "✅".ljust(max_tok_len) + " "
|
| 528 |
+
else:
|
| 529 |
+
diff_str += "❌".ljust(max_tok_len) + " "
|
| 530 |
+
|
| 531 |
+
print(gt_str)
|
| 532 |
+
print(gen_str)
|
| 533 |
+
print(diff_str)
|
| 534 |
+
|
| 535 |
+
data_by_word = {}
|
| 536 |
+
for item in data:
|
| 537 |
+
word = item['word']
|
| 538 |
+
if word not in data_by_word:
|
| 539 |
+
data_by_word[word] = []
|
| 540 |
+
data_by_word[word].append(item)
|
| 541 |
+
|
| 542 |
+
for word, samples in data_by_word.items():
|
| 543 |
+
print(f"\n\n{'='*25} TESTING WORD: '{word}' {'='*25}")
|
| 544 |
+
num_correct = 0
|
| 545 |
+
|
| 546 |
+
for i, sample in enumerate(samples):
|
| 547 |
+
print(f"\n--- Testing Variant {i+1}/{len(samples)}: '{sample['participant_id']}' ---")
|
| 548 |
+
|
| 549 |
+
gt_tokens_str = sample.get("motion_tokens", "")
|
| 550 |
+
gt_wrapped = " ".join([f"<M{t}>" for t in gt_tokens_str.split()])
|
| 551 |
+
gt_sequence = f"{M_START} {gt_wrapped} {M_END}"
|
| 552 |
+
print(f"Ground Truth:\n{gt_sequence}")
|
| 553 |
+
|
| 554 |
+
prompt = f"Instruction: Generate motion for word '{sample['word']}' with variant '{sample['participant_id']}'.\nMotion: "
|
| 555 |
+
generated_sequence = generate_motion(model, tokenizer, prompt, device)
|
| 556 |
+
print(f"\nLLM Generated:\n{generated_sequence}")
|
| 557 |
+
|
| 558 |
+
compare_sequences(gt_sequence, generated_sequence)
|
| 559 |
+
|
| 560 |
+
if gt_sequence.strip() == generated_sequence.strip():
|
| 561 |
+
num_correct += 1
|
| 562 |
+
|
| 563 |
+
print("-" * 80)
|
| 564 |
+
|
| 565 |
+
accuracy = (num_correct / len(samples)) * 100
|
| 566 |
+
print(f"\nSUMMARY FOR '{word}': {num_correct}/{len(samples)} correct ({accuracy:.1f}%)")
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
# ======================================================================================
|
| 570 |
+
# Existing Utilities (Compatibility)
|
| 571 |
+
# ======================================================================================
|
| 572 |
+
def seq_edit_distance(a_ids: List[int], b_ids: List[int]) -> int:
|
| 573 |
+
"""Token-level Levenshtein distance"""
|
| 574 |
+
return Levenshtein.distance(a_ids, b_ids)
|
| 575 |
+
|
| 576 |
+
def best_ref_distance(pred_ids: List[int], refs: List[List[int]]) -> int:
|
| 577 |
+
"""Find minimum edit distance to any reference"""
|
| 578 |
+
if not refs:
|
| 579 |
+
return len(pred_ids)
|
| 580 |
+
return min(seq_edit_distance(pred_ids, r) for r in refs)
|
| 581 |
+
|
| 582 |
+
def build_text_to_refs(dataset):
|
| 583 |
+
"""
|
| 584 |
+
Build mapping from text prompts to list of reference motion sequences
|
| 585 |
+
"""
|
| 586 |
+
text_to_refs = defaultdict(list)
|
| 587 |
+
for ex in dataset:
|
| 588 |
+
text_to_refs[ex["text_query"]].append(
|
| 589 |
+
[int(x) for x in ex["motion_tokens"].split()]
|
| 590 |
+
)
|
| 591 |
+
return text_to_refs
|
| 592 |
+
|
| 593 |
+
def _concat(ids_list: List[List[int]]) -> List[int]:
|
| 594 |
+
out = []
|
| 595 |
+
for s in ids_list:
|
| 596 |
+
out.extend(s)
|
| 597 |
+
return out
|
| 598 |
+
|
| 599 |
+
def _distinct_n(ids_list: List[List[int]], n: int) -> float:
|
| 600 |
+
if n <= 0:
|
| 601 |
+
return 0.0
|
| 602 |
+
total = 0
|
| 603 |
+
uniq = set()
|
| 604 |
+
for seq in ids_list:
|
| 605 |
+
if len(seq) < n:
|
| 606 |
+
continue
|
| 607 |
+
total += (len(seq) - n + 1)
|
| 608 |
+
for i in range(len(seq) - n + 1):
|
| 609 |
+
uniq.add(tuple(seq[i:i+n]))
|
| 610 |
+
if total == 0:
|
| 611 |
+
return 0.0
|
| 612 |
+
return len(uniq) / float(total)
|
| 613 |
+
|
| 614 |
+
def token_fid_diag(gens: List[List[int]], refs: List[List[int]], codebook_size: int) -> float:
|
| 615 |
+
"""
|
| 616 |
+
Diagonal-covariance Fréchet distance between histograms of token usage.
|
| 617 |
+
This is a lightweight proxy for FID using token distributions.
|
| 618 |
+
"""
|
| 619 |
+
if len(gens) == 0 or len(refs) == 0:
|
| 620 |
+
return float("nan")
|
| 621 |
+
|
| 622 |
+
def feats(batch: List[List[int]]) -> np.ndarray:
|
| 623 |
+
mats = []
|
| 624 |
+
for seq in batch:
|
| 625 |
+
hist = np.bincount([x for x in seq if 0 <= x < codebook_size], minlength=codebook_size).astype(np.float64)
|
| 626 |
+
s = hist.sum()
|
| 627 |
+
if s > 0:
|
| 628 |
+
hist /= s
|
| 629 |
+
mats.append(hist)
|
| 630 |
+
return np.stack(mats, axis=0)
|
| 631 |
+
|
| 632 |
+
G = feats(gens)
|
| 633 |
+
R = feats(refs)
|
| 634 |
+
mu_g = G.mean(axis=0)
|
| 635 |
+
mu_r = R.mean(axis=0)
|
| 636 |
+
var_g = G.var(axis=0)
|
| 637 |
+
var_r = R.var(axis=0)
|
| 638 |
+
mean_term = np.sum((mu_g - mu_r) ** 2)
|
| 639 |
+
# Diagonal covariance approximation
|
| 640 |
+
cov_term = np.sum(var_g + var_r - 2.0 * np.sqrt(np.clip(var_g * var_r, 0.0, None)))
|
| 641 |
+
return float(mean_term + cov_term)
|
| 642 |
+
|
| 643 |
+
def compute_token_metrics(
|
| 644 |
+
gen_by_text: Dict[str, List[int]],
|
| 645 |
+
text_to_refs: Dict[str, List[List[int]]],
|
| 646 |
+
codebook_size: int,
|
| 647 |
+
) -> Dict[str, float]:
|
| 648 |
+
"""
|
| 649 |
+
Compute token-level metrics:
|
| 650 |
+
- FID_diag: Fréchet distance between token histograms (diag cov)
|
| 651 |
+
- MIM: average min edit distance to references
|
| 652 |
+
- Diversity: distinct-1 and distinct-2
|
| 653 |
+
"""
|
| 654 |
+
gens = list(gen_by_text.values())
|
| 655 |
+
refs_all = _concat([v for v in text_to_refs.values()])
|
| 656 |
+
# refs_all is concatenated list of ids; split sequences are needed
|
| 657 |
+
ref_seqs = [r for refs in text_to_refs.values() for r in refs]
|
| 658 |
+
|
| 659 |
+
fid_diag = token_fid_diag(gens, ref_seqs, codebook_size)
|
| 660 |
+
|
| 661 |
+
# MIM: average best edit distance per prompt (only over prompts we generated)
|
| 662 |
+
mim_dists = []
|
| 663 |
+
for text, gen_ids in gen_by_text.items():
|
| 664 |
+
refs = text_to_refs.get(text, [])
|
| 665 |
+
mim_dists.append(best_ref_distance(gen_ids, refs))
|
| 666 |
+
mim = float(sum(mim_dists) / len(mim_dists)) if mim_dists else float("nan")
|
| 667 |
+
|
| 668 |
+
div1 = _distinct_n(gens, 1)
|
| 669 |
+
div2 = _distinct_n(gens, 2)
|
| 670 |
+
|
| 671 |
+
return {
|
| 672 |
+
"FID_diag": fid_diag,
|
| 673 |
+
"MIM": mim,
|
| 674 |
+
"distinct_1": div1,
|
| 675 |
+
"distinct_2": div2,
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
+
def eval_t2m_set(
|
| 679 |
+
model,
|
| 680 |
+
tokenizer,
|
| 681 |
+
sample_pairs: List[Tuple[str, List[List[int]]]],
|
| 682 |
+
mot_begin_id: int,
|
| 683 |
+
mot_end_id: int,
|
| 684 |
+
motion_token_ids: list,
|
| 685 |
+
length_stats_by_text: dict,
|
| 686 |
+
global_median_len: int,
|
| 687 |
+
prompt_vocab: dict = None,
|
| 688 |
+
has_pid: bool = False,
|
| 689 |
+
per_prompt_vocab: bool = True,
|
| 690 |
+
n_eval: int = 100
|
| 691 |
+
):
|
| 692 |
+
"""
|
| 693 |
+
Evaluate text-to-motion generation on a set of samples
|
| 694 |
+
Returns a compact dict with avg_edit_dist & median_len; kept for pipeline compatibility.
|
| 695 |
+
"""
|
| 696 |
+
random.shuffle(sample_pairs)
|
| 697 |
+
subset = sample_pairs[:min(n_eval, len(sample_pairs))]
|
| 698 |
+
|
| 699 |
+
dists = []
|
| 700 |
+
lens = []
|
| 701 |
+
|
| 702 |
+
for text, ref_list in subset:
|
| 703 |
+
gen = generate_t2m(
|
| 704 |
+
model=model,
|
| 705 |
+
tokenizer=tokenizer,
|
| 706 |
+
prompt_text=text,
|
| 707 |
+
mot_begin_id=mot_begin_id,
|
| 708 |
+
mot_end_id=mot_end_id,
|
| 709 |
+
motion_token_ids=motion_token_ids,
|
| 710 |
+
length_stats_by_text=length_stats_by_text,
|
| 711 |
+
global_median_len=global_median_len,
|
| 712 |
+
prompt_vocab=prompt_vocab,
|
| 713 |
+
pid=None,
|
| 714 |
+
has_pid=has_pid,
|
| 715 |
+
per_prompt_vocab=per_prompt_vocab
|
| 716 |
+
)
|
| 717 |
+
span = gen.split("<MOT_BEGIN>")[-1]
|
| 718 |
+
span = span.split("<MOT_END>")[0]
|
| 719 |
+
pred_ids = motion_specials_to_ids(span)
|
| 720 |
+
d = best_ref_distance(pred_ids, ref_list)
|
| 721 |
+
dists.append(d)
|
| 722 |
+
lens.append(len(pred_ids))
|
| 723 |
+
|
| 724 |
+
if dists:
|
| 725 |
+
avg_dist = sum(dists) / len(dists)
|
| 726 |
+
median_len = sorted(lens)[len(lens)//2] if lens else 0
|
| 727 |
+
print(f"Eval T2M: avg_edit_dist={avg_dist:.2f}, median_len={median_len}, n={len(dists)}")
|
| 728 |
+
return {"avg_edit_dist": avg_dist, "median_len": median_len, "n_samples": len(dists)}
|
| 729 |
+
else:
|
| 730 |
+
print("Eval T2M: no samples")
|
| 731 |
+
return {}
|
model.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model and tokenizer initialization
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
from typing import List, Set, Tuple
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 7 |
+
from unsloth import FastLanguageModel
|
| 8 |
+
from config import (
|
| 9 |
+
MODEL_NAME, MAX_SEQ_LEN, DTYPE,
|
| 10 |
+
LORA_R, LORA_ALPHA, LORA_DROPOUT,
|
| 11 |
+
LORA_TARGET_MODULES, LORA_MODULES_TO_SAVE,
|
| 12 |
+
PAD_TOKEN, M_START, M_END
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# ======================================================================================
|
| 16 |
+
# Logic from test_overfit.py (Standard Transformers)
|
| 17 |
+
# ======================================================================================
|
| 18 |
+
|
| 19 |
+
def setup_model_and_tokenizer_raw(model_name: str, motion_tokens: List[str]) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
| 20 |
+
"""Loads the model and tokenizer, adding special and motion tokens (Standard Transformers)."""
|
| 21 |
+
print(f"\n---> Loading base model and tokenizer: {model_name}")
|
| 22 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 23 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
| 24 |
+
|
| 25 |
+
# Add special tokens (matches test_overfit.py)
|
| 26 |
+
tokenizer.add_special_tokens({"pad_token": PAD_TOKEN, "additional_special_tokens": [M_START, M_END]})
|
| 27 |
+
|
| 28 |
+
print(f"Adding {len(motion_tokens)} motion tokens to the tokenizer.")
|
| 29 |
+
tokenizer.add_tokens(motion_tokens, special_tokens=True)
|
| 30 |
+
|
| 31 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 32 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 33 |
+
|
| 34 |
+
return model, tokenizer
|
| 35 |
+
|
| 36 |
+
def ensure_tokenizer_has_motion_tokens(tokenizer: AutoTokenizer, motion_tokens: List[str]) -> int:
|
| 37 |
+
"""
|
| 38 |
+
Adds any missing motion tokens to the tokenizer. Returns number of tokens added.
|
| 39 |
+
"""
|
| 40 |
+
tokenizer.add_special_tokens({"pad_token": PAD_TOKEN, "additional_special_tokens": [M_START, M_END]})
|
| 41 |
+
added = tokenizer.add_tokens(motion_tokens, special_tokens=True)
|
| 42 |
+
return added
|
| 43 |
+
|
| 44 |
+
# ======================================================================================
|
| 45 |
+
# Existing Logic (Unsloth / LoRA)
|
| 46 |
+
# ======================================================================================
|
| 47 |
+
|
| 48 |
+
def build_special_tokens(codebook_size: int, unique_pids: List[str] = None) -> List[str]:
|
| 49 |
+
"""
|
| 50 |
+
Build all special tokens for motion vocabulary
|
| 51 |
+
"""
|
| 52 |
+
# Motion tokens
|
| 53 |
+
motion_tokens = [f"<motion_{i}>" for i in range(codebook_size)]
|
| 54 |
+
|
| 55 |
+
# Boundary tokens
|
| 56 |
+
boundary_tokens = ["<MOT_BEGIN>", "<MOT_END>"]
|
| 57 |
+
|
| 58 |
+
# Task tokens
|
| 59 |
+
task_tokens = ["<T2M>", "<M2T>", "<DENOISE>", "<MOTION_MASK>"]
|
| 60 |
+
|
| 61 |
+
# Participant ID tokens
|
| 62 |
+
pid_tokens = []
|
| 63 |
+
if unique_pids:
|
| 64 |
+
pid_tokens = ["<PID_NULL>"] + [f"<PID_{pid}>" for pid in unique_pids]
|
| 65 |
+
|
| 66 |
+
return boundary_tokens + motion_tokens + task_tokens + pid_tokens
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def setup_model_and_tokenizer(codebook_size: int, unique_pids: List[str] = None):
|
| 70 |
+
"""
|
| 71 |
+
Initialize model and tokenizer with custom tokens (Unsloth LoRA)
|
| 72 |
+
Returns: (model, tokenizer, new_token_ids)
|
| 73 |
+
"""
|
| 74 |
+
# Build special tokens
|
| 75 |
+
additional_special_tokens = build_special_tokens(codebook_size, unique_pids)
|
| 76 |
+
|
| 77 |
+
# Load base model
|
| 78 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 79 |
+
model_name=MODEL_NAME,
|
| 80 |
+
max_seq_length=MAX_SEQ_LEN,
|
| 81 |
+
dtype=DTYPE,
|
| 82 |
+
load_in_4bit=False,
|
| 83 |
+
trust_remote_code=True,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Configure tokenizer
|
| 87 |
+
tokenizer.padding_side = "right"
|
| 88 |
+
|
| 89 |
+
# Add special tokens
|
| 90 |
+
existing = set(tokenizer.special_tokens_map_extended.get("additional_special_tokens", []))
|
| 91 |
+
to_add = [t for t in additional_special_tokens if t not in existing]
|
| 92 |
+
|
| 93 |
+
if to_add:
|
| 94 |
+
tokenizer.add_special_tokens({"additional_special_tokens": to_add})
|
| 95 |
+
|
| 96 |
+
if tokenizer.pad_token is None:
|
| 97 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 98 |
+
|
| 99 |
+
# Resize embeddings
|
| 100 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 101 |
+
|
| 102 |
+
# Apply LoRA
|
| 103 |
+
model = FastLanguageModel.get_peft_model(
|
| 104 |
+
model,
|
| 105 |
+
r=LORA_R,
|
| 106 |
+
lora_alpha=LORA_ALPHA,
|
| 107 |
+
lora_dropout=LORA_DROPOUT,
|
| 108 |
+
bias="none",
|
| 109 |
+
target_modules=LORA_TARGET_MODULES,
|
| 110 |
+
modules_to_save=LORA_MODULES_TO_SAVE,
|
| 111 |
+
use_gradient_checkpointing="unsloth",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Get new token IDs for gradient masking
|
| 115 |
+
new_token_ids = set(tokenizer.convert_tokens_to_ids(additional_special_tokens))
|
| 116 |
+
|
| 117 |
+
# Apply gradient mask to prevent base vocab drift
|
| 118 |
+
apply_gradient_mask(model, new_token_ids)
|
| 119 |
+
|
| 120 |
+
return model, tokenizer, new_token_ids
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def apply_gradient_mask(model, new_token_ids: Set[int]):
|
| 124 |
+
"""
|
| 125 |
+
Apply gradient mask so only new token embeddings are updated
|
| 126 |
+
"""
|
| 127 |
+
def mask_rows_hook(param, rows: set):
|
| 128 |
+
mask = torch.zeros(param.size(0), device=param.device, dtype=param.dtype)
|
| 129 |
+
idxs = sorted(list(rows))
|
| 130 |
+
if len(idxs) > 0:
|
| 131 |
+
mask[idxs] = 1.0
|
| 132 |
+
param.register_hook(lambda g: g * mask.unsqueeze(1))
|
| 133 |
+
|
| 134 |
+
with torch.no_grad():
|
| 135 |
+
emb = model.get_input_embeddings().weight
|
| 136 |
+
head = model.get_output_embeddings().weight
|
| 137 |
+
|
| 138 |
+
mask_rows_hook(emb, new_token_ids)
|
| 139 |
+
mask_rows_hook(head, new_token_ids)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_motion_token_info(tokenizer, codebook_size: int):
|
| 143 |
+
"""
|
| 144 |
+
Get motion token IDs and boundary token IDs
|
| 145 |
+
Returns: (motion_token_ids, mot_begin_id, mot_end_id)
|
| 146 |
+
"""
|
| 147 |
+
motion_token_strs = [f"<motion_{i}>" for i in range(codebook_size)]
|
| 148 |
+
motion_token_ids = tokenizer.convert_tokens_to_ids(motion_token_strs)
|
| 149 |
+
mot_begin_id = tokenizer.convert_tokens_to_ids("<MOT_BEGIN>")
|
| 150 |
+
mot_end_id = tokenizer.convert_tokens_to_ids("<MOT_END>")
|
| 151 |
+
|
| 152 |
+
return motion_token_ids, mot_begin_id, mot_end_id
|
requirements.txt
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
| 3 |
+
transformers
|
| 4 |
+
accelerate
|
| 5 |
+
numpy
|
| 6 |
+
scipy
|
| 7 |
+
rapidfuzz
|
| 8 |
+
huggingface_hub
|
| 9 |
+
plotly
|
| 10 |
+
smplx
|
| 11 |
+
# Core dependencies
|
| 12 |
+
torch>=2.0.0
|
| 13 |
+
transformers>=4.40.0
|
| 14 |
+
datasets>=2.14.0
|
| 15 |
+
accelerate>=0.20.0
|
| 16 |
+
|
| 17 |
+
# Unsloth for efficient training
|
| 18 |
+
unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
|
| 19 |
+
|
| 20 |
+
# Training utilities
|
| 21 |
+
bitsandbytes>=0.41.0
|
| 22 |
+
peft>=0.4.0
|
| 23 |
+
trl>=0.4.7
|
| 24 |
+
|
| 25 |
+
# Evaluation
|
| 26 |
+
rapidfuzz>=3.0.0
|
| 27 |
+
|
| 28 |
+
# Utilities
|
| 29 |
+
numpy>=1.24.0
|
| 30 |
+
tqdm>=4.65.0
|
| 31 |
+
huggingface_hub>=0.22.0
|
| 32 |
+
gdown>=5.2.0
|
setup_env.sh
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
|
| 5 |
+
# Usage:
|
| 6 |
+
# bash setup_env.sh
|
| 7 |
+
#
|
| 8 |
+
# - Installs Python dependencies from requirements.txt
|
| 9 |
+
# - Downloads a public Google Drive dataset file into ./data/motion_llm_dataset.json
|
| 10 |
+
# - Exports env vars for this session (optional) and prints instructions
|
| 11 |
+
|
| 12 |
+
THIS_DIR="$(pwd)"
|
| 13 |
+
DATA_DIR="$THIS_DIR/data"
|
| 14 |
+
mkdir -p "$DATA_DIR"
|
| 15 |
+
|
| 16 |
+
# --- Explicit placeholders (replace these later) ---
|
| 17 |
+
# Training dataset
|
| 18 |
+
GDRIVE_ID="11711RgTmzauXpYVFoqLF8DZXiZlZovfn"
|
| 19 |
+
|
| 20 |
+
# Visualization assets (optional - only needed for visualize.py)
|
| 21 |
+
VQVAE_MODEL_ID="1JEMKVZWFG4Ue7k3Nm7q1o7-uBVsVricY"
|
| 22 |
+
VQVAE_STATS_ID="1WTwP5DdBl4c-X5Kj7jXtlEHofOX2BifZ"
|
| 23 |
+
SMPLX_MODELS_ID="1tZEfqw9zHgOaBEw5X_oazAEnesRtE9ky"
|
| 24 |
+
|
| 25 |
+
# Hugging Face token
|
| 26 |
+
HF_TOKEN_IN=""
|
| 27 |
+
# ---------------------------------------------------
|
| 28 |
+
|
| 29 |
+
echo "Installing Python dependencies..."
|
| 30 |
+
python -m pip install --upgrade pip
|
| 31 |
+
pip install -r requirements.txt
|
| 32 |
+
|
| 33 |
+
if [[ -n "$GDRIVE_ID" ]] && [[ "$GDRIVE_ID" != "YOUR_GOOGLE_DRIVE_FILE_ID_HERE" ]]; then
|
| 34 |
+
echo "Downloading training dataset from Google Drive (file id: $GDRIVE_ID)..."
|
| 35 |
+
gdown --id "$GDRIVE_ID" -O "$DATA_DIR/motion_llm_dataset.json"
|
| 36 |
+
else
|
| 37 |
+
echo "No training dataset Google Drive ID provided. Skipping dataset download."
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
# Download visualization assets if IDs are provided
|
| 41 |
+
if [[ -n "$VQVAE_MODEL_ID" ]] && [[ "$VQVAE_MODEL_ID" != "YOUR_VQVAE_CHECKPOINT_GDRIVE_ID_HERE" ]]; then
|
| 42 |
+
echo "Downloading VQ-VAE model from Google Drive (file id: $VQVAE_MODEL_ID)..."
|
| 43 |
+
gdown --id "$VQVAE_MODEL_ID" -O "$DATA_DIR/vqvae_model.pt"
|
| 44 |
+
fi
|
| 45 |
+
|
| 46 |
+
if [[ -n "$VQVAE_STATS_ID" ]] && [[ "$VQVAE_STATS_ID" != "YOUR_VQVAE_STATS_GDRIVE_ID_HERE" ]]; then
|
| 47 |
+
echo "Downloading VQ-VAE stats from Google Drive (file id: $VQVAE_STATS_ID)..."
|
| 48 |
+
gdown --id "$VQVAE_STATS_ID" -O "$DATA_DIR/vqvae_stats.pt"
|
| 49 |
+
fi
|
| 50 |
+
|
| 51 |
+
if [[ -n "$SMPLX_MODELS_ID" ]] && [[ "$SMPLX_MODELS_ID" != "YOUR_SMPLX_MODELS_GDRIVE_ID_HERE" ]]; then
|
| 52 |
+
echo "Downloading SMPL-X neutral model (.npz) from Google Drive (file id: $SMPLX_MODELS_ID)..."
|
| 53 |
+
mkdir -p "$DATA_DIR/smplx_models"
|
| 54 |
+
gdown --id "$SMPLX_MODELS_ID" -O "$DATA_DIR/smplx_models/SMPLX_NEUTRAL.npz"
|
| 55 |
+
echo "Saved SMPLX_NEUTRAL.npz to $DATA_DIR/smplx_models"
|
| 56 |
+
fi
|
| 57 |
+
|
| 58 |
+
if [[ -n "$HF_TOKEN_IN" ]]; then
|
| 59 |
+
echo "Exporting HUGGINGFACE_HUB_TOKEN for this shell session..."
|
| 60 |
+
export HUGGINGFACE_HUB_TOKEN="$HF_TOKEN_IN"
|
| 61 |
+
fi
|
| 62 |
+
|
| 63 |
+
echo
|
| 64 |
+
echo "Environment setup complete."
|
| 65 |
+
echo "- WORK_DIR defaults to: $THIS_DIR"
|
| 66 |
+
echo "- DATA_JSON_PATH defaults to: $DATA_DIR/motion_llm_dataset.json"
|
| 67 |
+
echo "- To persist HF token, set an environment variable before running:"
|
| 68 |
+
echo " export HUGGINGFACE_HUB_TOKEN=hf_..."
|
| 69 |
+
echo
|
| 70 |
+
echo "You can now run your training scripts."
|
templates.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prompt templates and mapping functions for different training stages
|
| 3 |
+
"""
|
| 4 |
+
import random
|
| 5 |
+
from data import ids_to_motion_specials
|
| 6 |
+
from config import SYSTEM_MSG, SEED
|
| 7 |
+
|
| 8 |
+
random.seed(SEED)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def pid_token_from_example(ex, has_pid: bool):
|
| 12 |
+
"""Get participant ID token from example"""
|
| 13 |
+
if not has_pid:
|
| 14 |
+
return ""
|
| 15 |
+
|
| 16 |
+
pid = ex.get("participant_id", None)
|
| 17 |
+
if pid is not None:
|
| 18 |
+
return f"<PID_{pid}>"
|
| 19 |
+
return "<PID_NULL>"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def map_stage1(ex, has_pid: bool):
|
| 23 |
+
"""
|
| 24 |
+
Stage 1: Word + optional PID conditioning to learn motion language.
|
| 25 |
+
The user explicitly provides the word (+PID); assistant outputs motion span.
|
| 26 |
+
"""
|
| 27 |
+
mot = ids_to_motion_specials(ex["motion_tokens"])
|
| 28 |
+
assistant = f"<MOT_BEGIN> {mot} <MOT_END>"
|
| 29 |
+
pid_tok = pid_token_from_example(ex, has_pid)
|
| 30 |
+
word = ex.get("word", ex.get("text_query", ""))
|
| 31 |
+
|
| 32 |
+
# Word + PID conditioning (no natural language chatter to keep it compact)
|
| 33 |
+
user = f"<T2M>{pid_tok}\nword: {word}"
|
| 34 |
+
text = (
|
| 35 |
+
"<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
|
| 36 |
+
+ "<|im_start|>user\n" + user + "\n<|im_end|>\n"
|
| 37 |
+
+ "<|im_start|>assistant\n" + assistant + "\n<|im_end|>\n"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
return {"text": text, "where": "mot"}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def map_stage2(ex, has_pid: bool):
|
| 44 |
+
"""
|
| 45 |
+
Stage 2: Multi-task (T2M/M2T/DENOISE)
|
| 46 |
+
Randomly choose between text-to-motion, motion-to-text, or denoising
|
| 47 |
+
"""
|
| 48 |
+
t = ex["text_query"]
|
| 49 |
+
mot = ids_to_motion_specials(ex["motion_tokens"])
|
| 50 |
+
pid_tok = pid_token_from_example(ex, has_pid)
|
| 51 |
+
|
| 52 |
+
# Sample task type
|
| 53 |
+
task = random.choices(["t2m", "m2t", "denoise"], weights=[0.5, 0.3, 0.2], k=1)[0]
|
| 54 |
+
|
| 55 |
+
if task == "t2m":
|
| 56 |
+
# Text to motion
|
| 57 |
+
assistant = f"<MOT_BEGIN> {mot} <MOT_END>"
|
| 58 |
+
text = (
|
| 59 |
+
"<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
|
| 60 |
+
+ "<|im_start|>user\n" + f"<T2M>{pid_tok}\n\n" + t + "\n<|im_end|>\n"
|
| 61 |
+
+ "<|im_start|>assistant\n" + assistant + "\n<|im_end|>\n"
|
| 62 |
+
)
|
| 63 |
+
where = "mot"
|
| 64 |
+
|
| 65 |
+
elif task == "m2t":
|
| 66 |
+
# Motion to text
|
| 67 |
+
user = f"<M2T>{pid_tok}\n\n<MOT_BEGIN> {mot} <MOT_END>"
|
| 68 |
+
text = (
|
| 69 |
+
"<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
|
| 70 |
+
+ "<|im_start|>user\n" + user + "\n<|im_end|>\n"
|
| 71 |
+
+ "<|im_start|>assistant\n" + t + "\n<|im_end|>\n"
|
| 72 |
+
)
|
| 73 |
+
where = "text"
|
| 74 |
+
|
| 75 |
+
else:
|
| 76 |
+
# Denoising
|
| 77 |
+
toks = mot.split()
|
| 78 |
+
noisy = []
|
| 79 |
+
for tok in toks:
|
| 80 |
+
if random.random() < 0.15:
|
| 81 |
+
noisy.append("<MOTION_MASK>")
|
| 82 |
+
else:
|
| 83 |
+
noisy.append(tok)
|
| 84 |
+
|
| 85 |
+
user = f"<DENOISE>{pid_tok}\n\n<MOT_BEGIN> {' '.join(noisy)} <MOT_END>"
|
| 86 |
+
assistant = f"<MOT_BEGIN> {mot} <MOT_END>"
|
| 87 |
+
text = (
|
| 88 |
+
"<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
|
| 89 |
+
+ "<|im_start|>user\n" + user + "\n<|im_end|>\n"
|
| 90 |
+
+ "<|im_start|>assistant\n" + assistant + "\n<|im_end|>\n"
|
| 91 |
+
)
|
| 92 |
+
where = "mot"
|
| 93 |
+
|
| 94 |
+
return {"text": text, "where": where, "text_query": t}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def map_stage3(ex, has_pid: bool):
|
| 98 |
+
"""
|
| 99 |
+
Stage 3 (Instruct): Word-only request, no participant ID.
|
| 100 |
+
The system prompt directs: "Output motion tokens for the given word".
|
| 101 |
+
"""
|
| 102 |
+
t = ex["text_query"]
|
| 103 |
+
mot = ids_to_motion_specials(ex["motion_tokens"])
|
| 104 |
+
assistant = f"<MOT_BEGIN> {mot} <MOT_END>"
|
| 105 |
+
|
| 106 |
+
# Instruct-style, no PID
|
| 107 |
+
user = f"<T2M>\nword: {t}"
|
| 108 |
+
text = (
|
| 109 |
+
"<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
|
| 110 |
+
+ "<|im_start|>user\n" + user + "\n<|im_end|>\n"
|
| 111 |
+
+ "<|im_start|>assistant\n" + assistant + "\n<|im_end|>\n"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return {
|
| 115 |
+
"text": text,
|
| 116 |
+
"where": "mot",
|
| 117 |
+
"text_query": t,
|
| 118 |
+
"motion_tokens": ex["motion_tokens"]
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def create_mapper(stage: int, has_pid: bool):
|
| 123 |
+
"""
|
| 124 |
+
Create a mapper function for a specific stage
|
| 125 |
+
"""
|
| 126 |
+
if stage == 1:
|
| 127 |
+
return lambda ex: map_stage1(ex, has_pid)
|
| 128 |
+
elif stage == 2:
|
| 129 |
+
return lambda ex: map_stage2(ex, has_pid)
|
| 130 |
+
elif stage == 3:
|
| 131 |
+
return lambda ex: map_stage3(ex, has_pid)
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError(f"Unknown stage: {stage}")
|
test_dataset_eval.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluate the SignMotionGPT model on a held-out SMPL-X test dataset.
|
| 3 |
+
|
| 4 |
+
The script can download Google Drive archives or consume an already extracted
|
| 5 |
+
directory of `video_data.pkl` files. Each sequence is converted into encoder
|
| 6 |
+
features via the project VQ-VAE utilities and compared against motions generated
|
| 7 |
+
by the LLM to compute FID/Diversity/Multimodality metrics.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import pickle
|
| 16 |
+
import random
|
| 17 |
+
import sys
|
| 18 |
+
import zipfile
|
| 19 |
+
from typing import Dict, List, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 24 |
+
|
| 25 |
+
from config import (
|
| 26 |
+
TEST_EVAL_DOWNLOAD_DIR,
|
| 27 |
+
TEST_EVAL_EXTRACT_DIR,
|
| 28 |
+
TEST_EVAL_HF_REPO,
|
| 29 |
+
TEST_EVAL_HF_SUBFOLDER,
|
| 30 |
+
TEST_EVAL_MAX_ZIPS,
|
| 31 |
+
TEST_EVAL_OUTPUT_DIR,
|
| 32 |
+
TEST_EVAL_SAMPLE_LIMIT,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
M_START = "<M_START>"
|
| 36 |
+
M_END = "<M_END>"
|
| 37 |
+
PAD_TOKEN = "<PAD>"
|
| 38 |
+
|
| 39 |
+
INFERENCE_REPETITION_PENALTY = 1.2
|
| 40 |
+
INFERENCE_TEMPERATURE = 0.7
|
| 41 |
+
INFERENCE_TOP_K = 50
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# -----------------------------------------------------------------------------
|
| 45 |
+
# Download / extraction helpers
|
| 46 |
+
# -----------------------------------------------------------------------------
|
| 47 |
+
def try_import_gdown() -> bool:
|
| 48 |
+
try:
|
| 49 |
+
import gdown # noqa: F401
|
| 50 |
+
|
| 51 |
+
return True
|
| 52 |
+
except Exception:
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def download_drive_folder(folder_url_or_id: str, dest_dir: str) -> None:
|
| 57 |
+
os.makedirs(dest_dir, exist_ok=True)
|
| 58 |
+
if not try_import_gdown():
|
| 59 |
+
raise RuntimeError("gdown is required for Drive downloads. Install with `pip install gdown`.")
|
| 60 |
+
import gdown
|
| 61 |
+
|
| 62 |
+
if "drive.google.com" in folder_url_or_id:
|
| 63 |
+
url = folder_url_or_id
|
| 64 |
+
else:
|
| 65 |
+
url = f"https://drive.google.com/drive/folders/{folder_url_or_id}"
|
| 66 |
+
print(f"Downloading Drive folder to {dest_dir} ...")
|
| 67 |
+
gdown.download_folder(url=url, output=dest_dir, quiet=False, use_cookies=False)
|
| 68 |
+
print("Download complete.")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def list_zip_files(download_dir: str) -> List[str]:
|
| 72 |
+
matches: List[str] = []
|
| 73 |
+
for root, _dirs, files in os.walk(download_dir):
|
| 74 |
+
for name in files:
|
| 75 |
+
if name.lower().endswith(".zip"):
|
| 76 |
+
matches.append(os.path.join(root, name))
|
| 77 |
+
return sorted(matches)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def extract_zip_files(zip_paths: List[str], extract_dir: str, limit: Optional[int]) -> List[str]:
|
| 81 |
+
os.makedirs(extract_dir, exist_ok=True)
|
| 82 |
+
extracted_roots: List[str] = []
|
| 83 |
+
for idx, zp in enumerate(zip_paths):
|
| 84 |
+
if limit is not None and idx >= limit:
|
| 85 |
+
break
|
| 86 |
+
try:
|
| 87 |
+
with zipfile.ZipFile(zp, "r") as archive:
|
| 88 |
+
subdir = os.path.splitext(os.path.basename(zp))[0]
|
| 89 |
+
target = os.path.join(extract_dir, subdir)
|
| 90 |
+
os.makedirs(target, exist_ok=True)
|
| 91 |
+
archive.extractall(target)
|
| 92 |
+
extracted_roots.append(target)
|
| 93 |
+
except Exception as exc:
|
| 94 |
+
print(f"⚠️ Failed to extract {zp}: {exc}")
|
| 95 |
+
print(f"Extracted {len(extracted_roots)} archives.")
|
| 96 |
+
return extracted_roots
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def find_video_pkl_paths(extracted_root: str) -> List[str]:
|
| 100 |
+
matches: List[str] = []
|
| 101 |
+
for root, _dirs, files in os.walk(extracted_root):
|
| 102 |
+
for name in files:
|
| 103 |
+
if name == "video_data.pkl":
|
| 104 |
+
matches.append(os.path.join(root, name))
|
| 105 |
+
return matches
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def parse_word_from_path(path: str) -> str:
|
| 109 |
+
base = os.path.basename(os.path.dirname(path))
|
| 110 |
+
if "-" in base:
|
| 111 |
+
word = base.split("-", 1)[1]
|
| 112 |
+
else:
|
| 113 |
+
word = base
|
| 114 |
+
return word.strip().lower()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# -----------------------------------------------------------------------------
|
| 118 |
+
# SMPL-X helpers
|
| 119 |
+
# -----------------------------------------------------------------------------
|
| 120 |
+
def try_to_array(value) -> Optional[np.ndarray]:
|
| 121 |
+
if isinstance(value, np.ndarray):
|
| 122 |
+
return value
|
| 123 |
+
try:
|
| 124 |
+
return np.asarray(value)
|
| 125 |
+
except Exception:
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def load_smplx_params_from_pkl(pkl_path: str) -> Optional[np.ndarray]:
|
| 130 |
+
try:
|
| 131 |
+
with open(pkl_path, "rb") as handle:
|
| 132 |
+
payload = pickle.load(handle)
|
| 133 |
+
except Exception as exc:
|
| 134 |
+
print(f"⚠️ Could not read {pkl_path}: {exc}")
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
if not isinstance(payload, (list, tuple)) or len(payload) == 0:
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
def get_vec(frame: dict, key: str, expected: int, allow_trim: bool = True) -> np.ndarray:
|
| 141 |
+
val = frame.get(key)
|
| 142 |
+
arr = try_to_array(val)
|
| 143 |
+
if arr is None:
|
| 144 |
+
return np.zeros((expected,), dtype=np.float32)
|
| 145 |
+
arr = np.array(arr, dtype=np.float32).reshape(-1)
|
| 146 |
+
if arr.size == expected:
|
| 147 |
+
return arr
|
| 148 |
+
if allow_trim and arr.size > expected:
|
| 149 |
+
if key == "body_pose" and arr.size == 66 and expected == 63:
|
| 150 |
+
return arr[3:3 + 63]
|
| 151 |
+
return arr[:expected]
|
| 152 |
+
if arr.size < expected:
|
| 153 |
+
out = np.zeros((expected,), dtype=np.float32)
|
| 154 |
+
out[: arr.size] = arr
|
| 155 |
+
return out
|
| 156 |
+
return arr[:expected]
|
| 157 |
+
|
| 158 |
+
sequences: List[np.ndarray] = []
|
| 159 |
+
for frame in payload:
|
| 160 |
+
if not isinstance(frame, dict):
|
| 161 |
+
continue
|
| 162 |
+
vec = np.concatenate(
|
| 163 |
+
[
|
| 164 |
+
get_vec(frame, "shape", 10),
|
| 165 |
+
get_vec(frame, "body_pose", 63),
|
| 166 |
+
get_vec(frame, "lhand_pose", 45),
|
| 167 |
+
get_vec(frame, "rhand_pose", 45),
|
| 168 |
+
get_vec(frame, "cam_trans", 3),
|
| 169 |
+
get_vec(frame, "expression", 10),
|
| 170 |
+
get_vec(frame, "jaw_pose", 3),
|
| 171 |
+
np.zeros((3,), dtype=np.float32), # eye pose placeholder
|
| 172 |
+
],
|
| 173 |
+
axis=0,
|
| 174 |
+
)
|
| 175 |
+
sequences.append(vec)
|
| 176 |
+
if not sequences:
|
| 177 |
+
return None
|
| 178 |
+
return np.stack(sequences, axis=0).astype(np.float32)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def import_visualize_helpers():
|
| 182 |
+
try:
|
| 183 |
+
from visualize import (
|
| 184 |
+
load_vqvae,
|
| 185 |
+
load_stats,
|
| 186 |
+
decode_tokens_to_params,
|
| 187 |
+
VQVAE_CHECKPOINT as DEFAULT_VQ,
|
| 188 |
+
STATS_PATH as DEFAULT_STATS,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
return load_vqvae, load_stats, decode_tokens_to_params, DEFAULT_VQ, DEFAULT_STATS
|
| 192 |
+
except Exception as exc:
|
| 193 |
+
raise RuntimeError(f"Failed to import visualize helpers: {exc}") from exc
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _encode_params_to_feature(
|
| 197 |
+
params: np.ndarray,
|
| 198 |
+
vq_model,
|
| 199 |
+
mean,
|
| 200 |
+
std,
|
| 201 |
+
device: torch.device,
|
| 202 |
+
) -> Optional[np.ndarray]:
|
| 203 |
+
if params is None or params.size == 0:
|
| 204 |
+
return None
|
| 205 |
+
clip = torch.from_numpy(params.astype(np.float32)).unsqueeze(0).to(device)
|
| 206 |
+
with torch.no_grad():
|
| 207 |
+
x_pre = None
|
| 208 |
+
if hasattr(vq_model.vqvae, "preprocess"):
|
| 209 |
+
try:
|
| 210 |
+
x_pre = vq_model.vqvae.preprocess(clip)
|
| 211 |
+
except Exception:
|
| 212 |
+
x_pre = None
|
| 213 |
+
if x_pre is None:
|
| 214 |
+
if mean is not None and std is not None:
|
| 215 |
+
mean_t = torch.from_numpy(np.array(mean, dtype=np.float32)).to(device).view(1, 1, -1)
|
| 216 |
+
std_t = torch.from_numpy(np.array(std, dtype=np.float32)).to(device).view(1, 1, -1)
|
| 217 |
+
clip = (clip - mean_t) / (std_t + 1e-8)
|
| 218 |
+
x_pre = clip.transpose(1, 2).contiguous()
|
| 219 |
+
latent = vq_model.vqvae.encoder(x_pre)
|
| 220 |
+
if latent.dim() == 3:
|
| 221 |
+
embed_dim = getattr(vq_model.vqvae, "output_emb_width", None)
|
| 222 |
+
if embed_dim is not None:
|
| 223 |
+
if latent.shape[1] == embed_dim:
|
| 224 |
+
axis = 2
|
| 225 |
+
elif latent.shape[2] == embed_dim:
|
| 226 |
+
axis = 1
|
| 227 |
+
else:
|
| 228 |
+
axis = 2 if latent.shape[2] < latent.shape[1] else 1
|
| 229 |
+
else:
|
| 230 |
+
axis = 2 if latent.shape[2] < latent.shape[1] else 1
|
| 231 |
+
feat = latent.mean(dim=axis).squeeze(0)
|
| 232 |
+
elif latent.dim() == 2:
|
| 233 |
+
feat = latent.squeeze(0)
|
| 234 |
+
else:
|
| 235 |
+
feat = latent.view(1, -1).mean(dim=0)
|
| 236 |
+
vec = feat.detach().cpu().numpy().astype(np.float32)
|
| 237 |
+
norm = np.linalg.norm(vec)
|
| 238 |
+
if norm > 0:
|
| 239 |
+
vec = vec / norm
|
| 240 |
+
return vec
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# -----------------------------------------------------------------------------
|
| 244 |
+
# Metrics helpers
|
| 245 |
+
# -----------------------------------------------------------------------------
|
| 246 |
+
def calculate_activation_statistics_np(activations: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 247 |
+
mu = np.mean(activations, axis=0)
|
| 248 |
+
cov = np.cov(activations, rowvar=False)
|
| 249 |
+
return mu, cov
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def calculate_frechet_distance_np(mu1, sigma1, mu2, sigma2, eps=1e-6) -> float:
|
| 253 |
+
from scipy.linalg import sqrtm
|
| 254 |
+
|
| 255 |
+
mu1 = np.atleast_1d(mu1)
|
| 256 |
+
mu2 = np.atleast_1d(mu2)
|
| 257 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 258 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 259 |
+
assert mu1.shape == mu2.shape, "Mean vectors must match"
|
| 260 |
+
assert sigma1.shape == sigma2.shape, "Covariance matrices must match"
|
| 261 |
+
diff = mu1 - mu2
|
| 262 |
+
covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
|
| 263 |
+
if not np.isfinite(covmean).all():
|
| 264 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 265 |
+
covmean = sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 266 |
+
if np.iscomplexobj(covmean):
|
| 267 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 268 |
+
raise ValueError("Covmean contains large imaginary components")
|
| 269 |
+
covmean = covmean.real
|
| 270 |
+
return float(diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean))
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def calculate_diversity_np(activation: np.ndarray, diversity_times: int = 200) -> float:
|
| 274 |
+
assert activation.ndim == 2
|
| 275 |
+
n = activation.shape[0]
|
| 276 |
+
if n < 2:
|
| 277 |
+
return float("nan")
|
| 278 |
+
times = min(diversity_times, max(1, n - 1))
|
| 279 |
+
idx1 = np.random.choice(n, times, replace=False)
|
| 280 |
+
idx2 = np.random.choice(n, times, replace=False)
|
| 281 |
+
diffs = activation[idx1] - activation[idx2]
|
| 282 |
+
return float(np.linalg.norm(diffs, axis=1).mean())
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _to_label_tensor3(acts: np.ndarray, labels: List[str]) -> np.ndarray:
|
| 286 |
+
label_to_indices: Dict[str, List[int]] = {}
|
| 287 |
+
for idx, lbl in enumerate(labels):
|
| 288 |
+
label_to_indices.setdefault(lbl, []).append(idx)
|
| 289 |
+
counts = [len(v) for v in label_to_indices.values()]
|
| 290 |
+
if not counts:
|
| 291 |
+
raise ValueError("No labels available for multimodality computation.")
|
| 292 |
+
min_count = max(2, min(counts))
|
| 293 |
+
stacked = []
|
| 294 |
+
for lbl in sorted(label_to_indices.keys()):
|
| 295 |
+
stacked.append(acts[label_to_indices[lbl][:min_count]])
|
| 296 |
+
return np.stack(stacked, axis=0)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def calculate_multimodality_np(activation: np.ndarray, multimodality_times: int = 20) -> float:
|
| 300 |
+
assert activation.ndim == 3
|
| 301 |
+
_, per_label, _ = activation.shape
|
| 302 |
+
if per_label < 2:
|
| 303 |
+
return float("nan")
|
| 304 |
+
times = min(multimodality_times, max(1, per_label - 1))
|
| 305 |
+
first = np.random.choice(per_label, times, replace=False)
|
| 306 |
+
second = np.random.choice(per_label, times, replace=False)
|
| 307 |
+
diffs = activation[:, first] - activation[:, second]
|
| 308 |
+
return float(np.linalg.norm(diffs, axis=2).mean())
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# -----------------------------------------------------------------------------
|
| 312 |
+
# Generation helpers
|
| 313 |
+
# -----------------------------------------------------------------------------
|
| 314 |
+
def extract_ids_from_sequence(seq: str) -> List[int]:
|
| 315 |
+
content = seq
|
| 316 |
+
if M_START in seq and M_END in seq:
|
| 317 |
+
content = seq.split(M_START, 1)[-1].split(M_END, 1)[0]
|
| 318 |
+
ids: List[int] = []
|
| 319 |
+
for tok in content.split():
|
| 320 |
+
if tok.startswith("<M") and tok.endswith(">"):
|
| 321 |
+
payload = tok[2:-1]
|
| 322 |
+
if payload.isdigit():
|
| 323 |
+
ids.append(int(payload))
|
| 324 |
+
return ids
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def generate_motion_text(model, tokenizer, word: str, device: torch.device) -> str:
|
| 328 |
+
model.eval()
|
| 329 |
+
prompt = f"Instruction: Generate motion for word '{word}' with variant 'unknown'.\nMotion: "
|
| 330 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 331 |
+
with torch.no_grad():
|
| 332 |
+
output = model.generate(
|
| 333 |
+
**inputs,
|
| 334 |
+
max_new_tokens=100,
|
| 335 |
+
do_sample=True,
|
| 336 |
+
temperature=INFERENCE_TEMPERATURE,
|
| 337 |
+
top_k=INFERENCE_TOP_K,
|
| 338 |
+
repetition_penalty=INFERENCE_REPETITION_PENALTY,
|
| 339 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 340 |
+
eos_token_id=tokenizer.convert_tokens_to_ids(M_END),
|
| 341 |
+
)
|
| 342 |
+
decoded = tokenizer.decode(output[0], skip_special_tokens=False)
|
| 343 |
+
if "Motion: " in decoded:
|
| 344 |
+
return decoded.split("Motion: ", 1)[-1].strip()
|
| 345 |
+
return decoded.strip()
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# -----------------------------------------------------------------------------
|
| 349 |
+
# Core evaluation
|
| 350 |
+
# -----------------------------------------------------------------------------
|
| 351 |
+
def parse_args() -> argparse.Namespace:
|
| 352 |
+
parser = argparse.ArgumentParser(
|
| 353 |
+
"Evaluate the trained Stage 2 model on an unseen SMPL-X test dataset."
|
| 354 |
+
)
|
| 355 |
+
group = parser.add_mutually_exclusive_group(required=True)
|
| 356 |
+
group.add_argument("--drive-url", type=str, help="Google Drive folder URL to download archives from.")
|
| 357 |
+
group.add_argument("--drive-id", type=str, help="Google Drive folder ID to download archives from.")
|
| 358 |
+
group.add_argument(
|
| 359 |
+
"--local-extracted-dir",
|
| 360 |
+
type=str,
|
| 361 |
+
help="Use an existing directory that already contains extracted `video_data.pkl` files.",
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
parser.add_argument("--max-zips", type=int, default=TEST_EVAL_MAX_ZIPS, help="Maximum number of zip files to extract.")
|
| 365 |
+
parser.add_argument("--download-dir", type=str, default=TEST_EVAL_DOWNLOAD_DIR, help="Directory to store downloaded zips.")
|
| 366 |
+
parser.add_argument("--extract-dir", type=str, default=TEST_EVAL_EXTRACT_DIR, help="Directory to extract archives into.")
|
| 367 |
+
|
| 368 |
+
parser.add_argument("--hf-repo-id", type=str, default=TEST_EVAL_HF_REPO, help="Hugging Face repo containing the Stage 2 checkpoint.")
|
| 369 |
+
parser.add_argument(
|
| 370 |
+
"--hf-subfolder",
|
| 371 |
+
type=str,
|
| 372 |
+
default=TEST_EVAL_HF_SUBFOLDER,
|
| 373 |
+
help="Subfolder inside the repo that hosts the Stage 2 model (e.g., `stage2_v2/epoch-020`).",
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
parser.add_argument("--vqvae-ckpt", type=str, default=None, help="Optional override for VQ-VAE checkpoint path.")
|
| 377 |
+
parser.add_argument("--stats-path", type=str, default=None, help="Optional override for VQ-VAE stats file.")
|
| 378 |
+
|
| 379 |
+
parser.add_argument("--output-dir", type=str, default=TEST_EVAL_OUTPUT_DIR, help="Directory to write metrics JSON.")
|
| 380 |
+
parser.add_argument("--sample-limit", type=int, default=TEST_EVAL_SAMPLE_LIMIT, help="Maximum number of samples to evaluate.")
|
| 381 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
| 382 |
+
return parser.parse_args()
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def run_evaluation(args: argparse.Namespace) -> Dict[str, object]:
|
| 386 |
+
random.seed(args.seed)
|
| 387 |
+
np.random.seed(args.seed)
|
| 388 |
+
torch.manual_seed(args.seed)
|
| 389 |
+
|
| 390 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 391 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 392 |
+
metrics_path = os.path.join(args.output_dir, "metrics_test.json")
|
| 393 |
+
|
| 394 |
+
print(f"Loading Stage 2 model from HF: {args.hf_repo_id} (subfolder='{args.hf_subfolder}')")
|
| 395 |
+
tokenizer = AutoTokenizer.from_pretrained(args.hf_repo_id, subfolder=args.hf_subfolder, trust_remote_code=True)
|
| 396 |
+
model = AutoModelForCausalLM.from_pretrained(args.hf_repo_id, subfolder=args.hf_subfolder, trust_remote_code=True)
|
| 397 |
+
if tokenizer.pad_token is None:
|
| 398 |
+
tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
|
| 399 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 400 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 401 |
+
model.to(device)
|
| 402 |
+
|
| 403 |
+
load_vqvae, load_stats, decode_tokens_to_params, DEFAULT_VQ, DEFAULT_STATS = import_visualize_helpers()
|
| 404 |
+
vq_ckpt = args.vqvae_ckpt if args.vqvae_ckpt else os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ)
|
| 405 |
+
stats_path = args.stats_path if args.stats_path else os.getenv("VQVAE_STATS_PATH", DEFAULT_STATS)
|
| 406 |
+
print(f"Loading VQ-VAE from: {vq_ckpt}")
|
| 407 |
+
vq_model = load_vqvae(vq_ckpt, device=device)
|
| 408 |
+
print(f"Loading stats from: {stats_path}")
|
| 409 |
+
mean, std = load_stats(stats_path)
|
| 410 |
+
|
| 411 |
+
extracted_dirs: List[str] = []
|
| 412 |
+
if args.local_extracted_dir:
|
| 413 |
+
if not os.path.isdir(args.local_extracted_dir):
|
| 414 |
+
raise FileNotFoundError(f"Local extracted dir not found: {args.local_extracted_dir}")
|
| 415 |
+
extracted_dirs = [args.local_extracted_dir]
|
| 416 |
+
else:
|
| 417 |
+
folder_ref = args.drive_url if args.drive_url else args.drive_id
|
| 418 |
+
download_drive_folder(folder_ref, args.download_dir)
|
| 419 |
+
zips = list_zip_files(args.download_dir)
|
| 420 |
+
if not zips:
|
| 421 |
+
raise RuntimeError("No zip files found after download.")
|
| 422 |
+
extracted_dirs = extract_zip_files(zips, args.extract_dir, limit=args.max_zips)
|
| 423 |
+
|
| 424 |
+
samples: List[Tuple[str, str]] = []
|
| 425 |
+
for root in extracted_dirs:
|
| 426 |
+
for pkl_path in find_video_pkl_paths(root):
|
| 427 |
+
samples.append((parse_word_from_path(pkl_path), pkl_path))
|
| 428 |
+
if not samples:
|
| 429 |
+
raise RuntimeError("No `video_data.pkl` files discovered in the extracted directories.")
|
| 430 |
+
|
| 431 |
+
random.shuffle(samples)
|
| 432 |
+
samples = samples[: args.sample_limit]
|
| 433 |
+
print(f"Found {len(samples)} samples to evaluate.")
|
| 434 |
+
|
| 435 |
+
gt_features: List[np.ndarray] = []
|
| 436 |
+
gen_features: List[np.ndarray] = []
|
| 437 |
+
labels: List[str] = []
|
| 438 |
+
|
| 439 |
+
for idx, (word, pkl_path) in enumerate(samples, 1):
|
| 440 |
+
params_gt = load_smplx_params_from_pkl(pkl_path)
|
| 441 |
+
if params_gt is None or params_gt.ndim != 2:
|
| 442 |
+
print(f"Skipping {pkl_path}: invalid SMPL-X payload.")
|
| 443 |
+
continue
|
| 444 |
+
try:
|
| 445 |
+
feat_gt = _encode_params_to_feature(params_gt, vq_model, mean, std, device)
|
| 446 |
+
except Exception as exc:
|
| 447 |
+
print(f"Skipping {pkl_path}: encoder failed ({exc}).")
|
| 448 |
+
continue
|
| 449 |
+
if feat_gt is None:
|
| 450 |
+
print(f"Skipping {pkl_path}: empty GT feature.")
|
| 451 |
+
continue
|
| 452 |
+
|
| 453 |
+
gen_text = generate_motion_text(model, tokenizer, word, device)
|
| 454 |
+
token_ids = extract_ids_from_sequence(gen_text)
|
| 455 |
+
if not token_ids:
|
| 456 |
+
print(f"Skipping GEN for '{word}': no motion tokens produced.")
|
| 457 |
+
continue
|
| 458 |
+
try:
|
| 459 |
+
params_gen = decode_tokens_to_params(token_ids, vq_model, mean, std, device=device)
|
| 460 |
+
except Exception as exc:
|
| 461 |
+
print(f"Skipping GEN for '{word}': decode failed ({exc}).")
|
| 462 |
+
continue
|
| 463 |
+
feat_gen = _encode_params_to_feature(params_gen, vq_model, mean, std, device)
|
| 464 |
+
if feat_gen is None:
|
| 465 |
+
print(f"Skipping GEN for '{word}': empty GEN feature.")
|
| 466 |
+
continue
|
| 467 |
+
|
| 468 |
+
gt_features.append(feat_gt)
|
| 469 |
+
gen_features.append(feat_gen)
|
| 470 |
+
labels.append(word)
|
| 471 |
+
if idx % 25 == 0:
|
| 472 |
+
print(f"Processed {idx} samples...")
|
| 473 |
+
|
| 474 |
+
if len(gt_features) < 5 or len(gen_features) < 5:
|
| 475 |
+
print("⚠️ Not enough samples to compute stable metrics; results may be noisy.")
|
| 476 |
+
|
| 477 |
+
gt_feats = np.stack(gt_features, axis=0)
|
| 478 |
+
gen_feats = np.stack(gen_features, axis=0)
|
| 479 |
+
|
| 480 |
+
diversity_gt = calculate_diversity_np(gt_feats, diversity_times=min(200, max(4, gt_feats.shape[0] - 1)))
|
| 481 |
+
diversity_gen = calculate_diversity_np(gen_feats, diversity_times=min(200, max(4, gen_feats.shape[0] - 1)))
|
| 482 |
+
|
| 483 |
+
try:
|
| 484 |
+
gt_lbl_tensor = _to_label_tensor3(gt_feats, labels)
|
| 485 |
+
gen_lbl_tensor = _to_label_tensor3(gen_feats, labels)
|
| 486 |
+
mim_gt = calculate_multimodality_np(
|
| 487 |
+
gt_lbl_tensor, multimodality_times=min(20, max(3, gt_lbl_tensor.shape[1] - 1))
|
| 488 |
+
)
|
| 489 |
+
mim_gen = calculate_multimodality_np(
|
| 490 |
+
gen_lbl_tensor, multimodality_times=min(20, max(3, gen_lbl_tensor.shape[1] - 1))
|
| 491 |
+
)
|
| 492 |
+
except Exception as exc:
|
| 493 |
+
print(f"⚠️ Multimodality could not be computed reliably: {exc}")
|
| 494 |
+
mim_gt = float("nan")
|
| 495 |
+
mim_gen = float("nan")
|
| 496 |
+
|
| 497 |
+
mu_gen, cov_gen = calculate_activation_statistics_np(gen_feats)
|
| 498 |
+
mu_gt, cov_gt = calculate_activation_statistics_np(gt_feats)
|
| 499 |
+
fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen)
|
| 500 |
+
|
| 501 |
+
metrics_payload = {
|
| 502 |
+
"source": "test_raw_smplx_encoder_features",
|
| 503 |
+
"counts": {
|
| 504 |
+
"samples_total": len(samples),
|
| 505 |
+
"samples_used": int(gt_feats.shape[0]),
|
| 506 |
+
},
|
| 507 |
+
"fid": fid,
|
| 508 |
+
"diversity": {
|
| 509 |
+
"ground_truth": diversity_gt,
|
| 510 |
+
"model": diversity_gen,
|
| 511 |
+
},
|
| 512 |
+
"multimodality": {
|
| 513 |
+
"ground_truth": mim_gt,
|
| 514 |
+
"model": mim_gen,
|
| 515 |
+
},
|
| 516 |
+
}
|
| 517 |
+
with open(metrics_path, "w", encoding="utf-8") as handle:
|
| 518 |
+
json.dump(metrics_payload, handle, ensure_ascii=False, indent=2)
|
| 519 |
+
print(f"\n✅ Saved test metrics to {metrics_path}")
|
| 520 |
+
return metrics_payload
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def main() -> None:
|
| 524 |
+
args = parse_args()
|
| 525 |
+
try:
|
| 526 |
+
run_evaluation(args)
|
| 527 |
+
except Exception as exc:
|
| 528 |
+
print(f"Evaluation failed: {exc}")
|
| 529 |
+
sys.exit(1)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
if __name__ == "__main__":
|
| 533 |
+
main()
|
| 534 |
+
|
test_overfit.py
ADDED
|
@@ -0,0 +1,1562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
from typing import Dict, List, Tuple, Any, Optional
|
| 7 |
+
import shutil
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch.utils.data import Dataset, DataLoader
|
| 13 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 14 |
+
from torch.optim import AdamW
|
| 15 |
+
from huggingface_hub import HfApi, upload_folder, hf_hub_download
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import scipy.linalg
|
| 19 |
+
# ======================================================================================
|
| 20 |
+
# 0. Configuration
|
| 21 |
+
# ======================================================================================
|
| 22 |
+
# --- Paths and Words ---
|
| 23 |
+
DATASET_PATH = "/content/SignMotionGPT/data/motion_llm_dataset.json"
|
| 24 |
+
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
| 25 |
+
# We will train on the full dataset, but use these words for our final evaluation
|
| 26 |
+
EVALUATION_WORDS = ["passport", "send", "library", "push"]
|
| 27 |
+
OUTPUT_DIR = "./motion_gpt_full_model"
|
| 28 |
+
|
| 29 |
+
# --- Evaluation controls ---
|
| 30 |
+
# If True: after training, only compute metrics (FID, Diversity, MIM) and save to JSON.
|
| 31 |
+
# Skip per-sample inference logs and HTML visualizations.
|
| 32 |
+
# If False: run the existing flow and also compute these 3 metrics.
|
| 33 |
+
RUN_EVALS_ONLY = False
|
| 34 |
+
EVAL_SAMPLE_LIMIT = 100
|
| 35 |
+
METRICS_JSON_PATH = ""
|
| 36 |
+
|
| 37 |
+
# --- Training Hyperparameters ---
|
| 38 |
+
# NOTE: Training on the full dataset will take longer.
|
| 39 |
+
# These epochs are a starting point.
|
| 40 |
+
S1_EPOCHS = 20
|
| 41 |
+
S1_LR = 5e-5
|
| 42 |
+
S1_BATCH_SIZE = 8 # Kept small for Colab VRAM
|
| 43 |
+
|
| 44 |
+
S2_EPOCHS = 20
|
| 45 |
+
S2_LR = 2e-5
|
| 46 |
+
S2_BATCH_SIZE = 8
|
| 47 |
+
|
| 48 |
+
# --- Inference Hyperparameters ---
|
| 49 |
+
INFERENCE_REPETITION_PENALTY = 1.2
|
| 50 |
+
INFERENCE_TEMPERATURE = 0.7
|
| 51 |
+
INFERENCE_TOP_K = 50
|
| 52 |
+
|
| 53 |
+
# --- Special Tokens ---
|
| 54 |
+
M_START = "<M_START>"
|
| 55 |
+
M_END = "<M_END>"
|
| 56 |
+
PAD_TOKEN = "<PAD>"
|
| 57 |
+
|
| 58 |
+
# --- Hugging Face Hub Configuration ---
|
| 59 |
+
# Provide HUGGINGFACE_HUB_TOKEN or hf_auth_token in environment for private repos.
|
| 60 |
+
HF_USE_HUB = True
|
| 61 |
+
hf_auth_token = os.getenv("hf_auth_token")
|
| 62 |
+
if hf_auth_token is None:
|
| 63 |
+
raise ValueError("hf_auth_token environment variable is not set")
|
| 64 |
+
HF_STAGE1_REPO_ID = "rdz-falcon/SignMotionGPTfit-archive"
|
| 65 |
+
HF_STAGE2_REPO_ID = "rdz-falcon/SignMotionGPTfit-archive"
|
| 66 |
+
HF_PRIVATE_REPO = os.environ.get("HF_PRIVATE", "true").lower() != "false"
|
| 67 |
+
FORCE_STAGE2_FROM_STAGE1_RAW = os.environ.get("FORCE_STAGE2_FROM_STAGE1", "false")
|
| 68 |
+
FORCE_STAGE2_FROM_STAGE1 = str(FORCE_STAGE2_FROM_STAGE1_RAW).strip().lower() not in ("0", "false", "no", "off")
|
| 69 |
+
# Save Stage 2 checkpoints to a new subfolder so old stage2 checkpoints remain intact
|
| 70 |
+
HF_STAGE2_SAVE_SUBDIR = os.environ.get("HF_STAGE2_SAVE_SUBDIR", "stage2_v2")
|
| 71 |
+
|
| 72 |
+
# --- Local Checkpoint Root ---
|
| 73 |
+
CHECKPOINTS_DIR = ""
|
| 74 |
+
|
| 75 |
+
# --- Upload frequency and progress control ---
|
| 76 |
+
# Push to Hugging Face only every N epochs (still save locally every epoch)
|
| 77 |
+
CHECKPOINT_UPLOAD_INTERVAL_EPOCHS = int(os.environ.get("HF_UPLOAD_INTERVAL_EPOCHS", "2"))
|
| 78 |
+
# Disable HF Hub progress bars to reduce noisy logs (set HF_DISABLE_PROGRESS=false to re-enable)
|
| 79 |
+
HF_DISABLE_PROGRESS = os.environ.get("HF_DISABLE_PROGRESS", "true").lower() != "false"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _refresh_runtime_paths() -> None:
|
| 83 |
+
"""Refresh derived paths when OUTPUT_DIR changes."""
|
| 84 |
+
global METRICS_JSON_PATH, CHECKPOINTS_DIR
|
| 85 |
+
METRICS_JSON_PATH = os.path.join(OUTPUT_DIR, "metrics.json")
|
| 86 |
+
CHECKPOINTS_DIR = os.path.join(OUTPUT_DIR, "checkpoints")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _apply_progress_setting() -> None:
|
| 90 |
+
"""Apply huggingface_hub progress bar preference."""
|
| 91 |
+
if HF_DISABLE_PROGRESS:
|
| 92 |
+
try:
|
| 93 |
+
# Also respected by huggingface_hub internal progress usage
|
| 94 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
| 95 |
+
from huggingface_hub.utils import disable_progress_bars # type: ignore
|
| 96 |
+
|
| 97 |
+
disable_progress_bars()
|
| 98 |
+
except Exception:
|
| 99 |
+
pass
|
| 100 |
+
else:
|
| 101 |
+
os.environ.pop("HF_HUB_DISABLE_PROGRESS_BARS", None)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def apply_config_overrides(overrides: Optional[Dict[str, Any]] = None) -> None:
|
| 105 |
+
"""
|
| 106 |
+
Allow external callers to override module-level configuration prior to running main().
|
| 107 |
+
"""
|
| 108 |
+
global hf_auth_token, HF_DISABLE_PROGRESS, OUTPUT_DIR
|
| 109 |
+
if not overrides:
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
updated_paths = False
|
| 113 |
+
progress_flag_updated = False
|
| 114 |
+
for key, value in overrides.items():
|
| 115 |
+
if key == "hf_auth_token":
|
| 116 |
+
hf_auth_token = value
|
| 117 |
+
continue
|
| 118 |
+
if key not in globals():
|
| 119 |
+
print(f"[config] Unknown override ignored: {key}")
|
| 120 |
+
continue
|
| 121 |
+
globals()[key] = value
|
| 122 |
+
if key == "OUTPUT_DIR":
|
| 123 |
+
updated_paths = True
|
| 124 |
+
if key == "HF_DISABLE_PROGRESS":
|
| 125 |
+
progress_flag_updated = True
|
| 126 |
+
if updated_paths:
|
| 127 |
+
_refresh_runtime_paths()
|
| 128 |
+
if progress_flag_updated:
|
| 129 |
+
_apply_progress_setting()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
_refresh_runtime_paths()
|
| 133 |
+
_apply_progress_setting()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# ======================================================================================
|
| 137 |
+
# 1. Data Loading and Preparation (NEW & IMPROVED)
|
| 138 |
+
# ======================================================================================
|
| 139 |
+
def read_json_data(json_path: str) -> List[Dict[str, Any]]:
|
| 140 |
+
"""Loads the dataset from the specified JSON file."""
|
| 141 |
+
if not os.path.exists(json_path):
|
| 142 |
+
raise FileNotFoundError(f"Dataset not found at: {json_path}")
|
| 143 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 144 |
+
return json.load(f)
|
| 145 |
+
|
| 146 |
+
def deduplicate_and_prepare_data(entries: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
| 147 |
+
"""
|
| 148 |
+
Cleans the entire dataset by ensuring each (word, participant_id) pair is unique.
|
| 149 |
+
If a conflict is found (same pair, different motion), it keeps only the first one encountered.
|
| 150 |
+
Then, it prepares the full list of motion tokens from the cleaned data.
|
| 151 |
+
"""
|
| 152 |
+
print("\n---> Cleaning dataset by removing ambiguous (word, participant_id) pairs...")
|
| 153 |
+
|
| 154 |
+
unique_samples = {}
|
| 155 |
+
conflicts_found = 0
|
| 156 |
+
|
| 157 |
+
for entry in entries:
|
| 158 |
+
word = entry.get("word", "").lower()
|
| 159 |
+
pid = entry.get("participant_id", "")
|
| 160 |
+
key = (word, pid)
|
| 161 |
+
|
| 162 |
+
if key not in unique_samples:
|
| 163 |
+
unique_samples[key] = entry
|
| 164 |
+
else:
|
| 165 |
+
# A sample for this key already exists. We only care if it's a conflict.
|
| 166 |
+
existing_tokens = unique_samples[key].get("motion_tokens")
|
| 167 |
+
current_tokens = entry.get("motion_tokens")
|
| 168 |
+
if existing_tokens != current_tokens:
|
| 169 |
+
conflicts_found += 1
|
| 170 |
+
# We do nothing, effectively discarding this new conflicting sample.
|
| 171 |
+
|
| 172 |
+
cleaned_data = list(unique_samples.values())
|
| 173 |
+
|
| 174 |
+
print(f"Original samples: {len(entries)}")
|
| 175 |
+
print(f"Cleaned samples (unique (word, pid) pairs): {len(cleaned_data)}")
|
| 176 |
+
print(f"Removed {len(entries) - len(cleaned_data)} total samples. ({conflicts_found} were direct conflicts).")
|
| 177 |
+
|
| 178 |
+
print("\n---> Extracting motion tokens from the full cleaned dataset...")
|
| 179 |
+
all_motion_tokens = set()
|
| 180 |
+
for entry in cleaned_data:
|
| 181 |
+
motion_tokens = entry.get("motion_tokens", "").strip().split()
|
| 182 |
+
for token in motion_tokens:
|
| 183 |
+
all_motion_tokens.add(f"<M{token}>")
|
| 184 |
+
|
| 185 |
+
unique_tokens = sorted(list(all_motion_tokens))
|
| 186 |
+
print(f"Found {len(unique_tokens)} unique motion tokens in the entire dataset.")
|
| 187 |
+
|
| 188 |
+
return cleaned_data, unique_tokens
|
| 189 |
+
|
| 190 |
+
# ======================================================================================
|
| 191 |
+
# 2. Model and Tokenizer Setup
|
| 192 |
+
# ======================================================================================
|
| 193 |
+
def setup_model_and_tokenizer(model_name: str, motion_tokens: List[str]) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
| 194 |
+
"""Loads the model and tokenizer, adding special and motion tokens."""
|
| 195 |
+
print(f"\n---> Loading base model and tokenizer: {model_name}")
|
| 196 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 197 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
| 198 |
+
|
| 199 |
+
tokenizer.add_special_tokens({"pad_token": PAD_TOKEN, "additional_special_tokens": [M_START, M_END]})
|
| 200 |
+
|
| 201 |
+
print(f"Adding {len(motion_tokens)} motion tokens to the tokenizer.")
|
| 202 |
+
tokenizer.add_tokens(motion_tokens, special_tokens=True)
|
| 203 |
+
|
| 204 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 205 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 206 |
+
|
| 207 |
+
return model, tokenizer
|
| 208 |
+
|
| 209 |
+
# ======================================================================================
|
| 210 |
+
# 2b. Hugging Face Hub Utilities and Checkpointing
|
| 211 |
+
# ======================================================================================
|
| 212 |
+
def _format_seconds(seconds: float) -> str:
|
| 213 |
+
"""Formats seconds into H:MM:SS or M:SS."""
|
| 214 |
+
seconds = int(max(0, seconds))
|
| 215 |
+
h = seconds // 3600
|
| 216 |
+
m = (seconds % 3600) // 60
|
| 217 |
+
s = seconds % 60
|
| 218 |
+
if h > 0:
|
| 219 |
+
return f"{h:d}:{m:02d}:{s:02d}"
|
| 220 |
+
return f"{m:d}:{s:02d}"
|
| 221 |
+
|
| 222 |
+
def _ensure_dir(path: str) -> None:
|
| 223 |
+
os.makedirs(path, exist_ok=True)
|
| 224 |
+
|
| 225 |
+
def _resolve_and_ensure_repo(repo_id: str) -> Optional[str]:
|
| 226 |
+
"""
|
| 227 |
+
Ensures the HF repo exists. Returns the fully-qualified repo_id (namespace/repo)
|
| 228 |
+
when token is available; otherwise returns the input repo_id.
|
| 229 |
+
"""
|
| 230 |
+
if not HF_USE_HUB:
|
| 231 |
+
return None
|
| 232 |
+
if hf_auth_token is None:
|
| 233 |
+
print("⚠️ HF token not found. Set HUGGINGFACE_HUB_TOKEN or hf_auth_token to enable Hub sync.")
|
| 234 |
+
return None
|
| 235 |
+
api = HfApi()
|
| 236 |
+
try:
|
| 237 |
+
who = api.whoami(token=hf_auth_token)
|
| 238 |
+
namespace = who.get("name") or (who.get("orgs", [None])[0] if isinstance(who.get("orgs"), list) else None)
|
| 239 |
+
except Exception as exc:
|
| 240 |
+
print(f"⚠️ Unable to resolve HF namespace: {exc}")
|
| 241 |
+
namespace = None
|
| 242 |
+
if "/" not in repo_id and namespace:
|
| 243 |
+
full_repo_id = f"{namespace}/{repo_id}"
|
| 244 |
+
else:
|
| 245 |
+
full_repo_id = repo_id
|
| 246 |
+
try:
|
| 247 |
+
api.create_repo(
|
| 248 |
+
repo_id=full_repo_id,
|
| 249 |
+
token=hf_auth_token,
|
| 250 |
+
repo_type="model",
|
| 251 |
+
private=HF_PRIVATE_REPO,
|
| 252 |
+
exist_ok=True,
|
| 253 |
+
)
|
| 254 |
+
except Exception as exc:
|
| 255 |
+
print(f"⚠️ create_repo failed (may already exist): {exc}")
|
| 256 |
+
return full_repo_id
|
| 257 |
+
|
| 258 |
+
def _repo_has_stage_latest(repo_id: str, stage: str) -> bool:
|
| 259 |
+
"""Checks if a stage/latest checkpoint exists in the HF repo."""
|
| 260 |
+
if not HF_USE_HUB or hf_auth_token is None:
|
| 261 |
+
return False
|
| 262 |
+
api = HfApi()
|
| 263 |
+
try:
|
| 264 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=hf_auth_token)
|
| 265 |
+
return any(path.startswith(f"{stage}/latest/") and path.endswith("config.json") for path in files)
|
| 266 |
+
except Exception as exc:
|
| 267 |
+
print(f"⚠️ Could not list files for {repo_id}: {exc}")
|
| 268 |
+
return False
|
| 269 |
+
|
| 270 |
+
def _repo_list_epoch_numbers(repo_id: str, stage: str) -> List[int]:
|
| 271 |
+
"""
|
| 272 |
+
Returns sorted list of epoch numbers available under {stage}/epoch-XXX/ by scanning files.
|
| 273 |
+
Works even if 'latest' does not exist.
|
| 274 |
+
"""
|
| 275 |
+
if not HF_USE_HUB or hf_auth_token is None:
|
| 276 |
+
return []
|
| 277 |
+
api = HfApi()
|
| 278 |
+
try:
|
| 279 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=hf_auth_token)
|
| 280 |
+
except Exception as exc:
|
| 281 |
+
print(f"⚠️ Could not list files for {repo_id}: {exc}")
|
| 282 |
+
return []
|
| 283 |
+
epoch_numbers: List[int] = []
|
| 284 |
+
pattern = re.compile(rf"^{re.escape(stage)}/epoch-(\d+)/config\.json$")
|
| 285 |
+
for path in files:
|
| 286 |
+
m = pattern.match(path)
|
| 287 |
+
if m:
|
| 288 |
+
try:
|
| 289 |
+
epoch_numbers.append(int(m.group(1)))
|
| 290 |
+
except ValueError:
|
| 291 |
+
pass
|
| 292 |
+
return sorted(set(epoch_numbers))
|
| 293 |
+
|
| 294 |
+
def _repo_get_latest_epoch_subfolder(repo_id: str, stage: str) -> Optional[str]:
|
| 295 |
+
"""
|
| 296 |
+
Returns subfolder path like '{stage}/epoch-XXX' for the highest available epoch, or None.
|
| 297 |
+
"""
|
| 298 |
+
epochs = _repo_list_epoch_numbers(repo_id, stage)
|
| 299 |
+
if not epochs:
|
| 300 |
+
return None
|
| 301 |
+
latest = max(epochs)
|
| 302 |
+
return f"{stage}/epoch-{latest:03d}"
|
| 303 |
+
|
| 304 |
+
def _load_model_and_tokenizer_from_hf_subfolder(repo_id: str, subfolder: str) -> Optional[Tuple[AutoModelForCausalLM, AutoTokenizer]]:
|
| 305 |
+
"""
|
| 306 |
+
Loads model and tokenizer from HF under a specific subfolder (e.g., 'stage1/epoch-020').
|
| 307 |
+
"""
|
| 308 |
+
if not HF_USE_HUB or hf_auth_token is None:
|
| 309 |
+
return None
|
| 310 |
+
print(f"\n---> Loading checkpoint from Hugging Face: {repo_id} (subfolder='{subfolder}')")
|
| 311 |
+
try:
|
| 312 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder, trust_remote_code=True)
|
| 313 |
+
model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=subfolder, trust_remote_code=True)
|
| 314 |
+
except Exception as exc:
|
| 315 |
+
print(f"⚠️ Failed to load model/tokenizer from subfolder '{subfolder}': {exc}")
|
| 316 |
+
return None
|
| 317 |
+
if tokenizer.pad_token is None:
|
| 318 |
+
tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
|
| 319 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 320 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 321 |
+
return model, tokenizer
|
| 322 |
+
|
| 323 |
+
def _download_training_state_from_subfolder(repo_id: str, subfolder: str) -> Optional[Dict[str, Any]]:
|
| 324 |
+
"""
|
| 325 |
+
Downloads training_state.json from a specific subfolder (e.g., 'stage1/epoch-020').
|
| 326 |
+
"""
|
| 327 |
+
if not HF_USE_HUB or hf_auth_token is None:
|
| 328 |
+
return None
|
| 329 |
+
try:
|
| 330 |
+
state_path = hf_hub_download(
|
| 331 |
+
repo_id=repo_id,
|
| 332 |
+
filename=f"{subfolder}/training_state.json",
|
| 333 |
+
repo_type="model",
|
| 334 |
+
token=hf_auth_token,
|
| 335 |
+
)
|
| 336 |
+
with open(state_path, "r", encoding="utf-8") as f:
|
| 337 |
+
return json.load(f)
|
| 338 |
+
except Exception:
|
| 339 |
+
return None
|
| 340 |
+
|
| 341 |
+
def _download_training_state(repo_id: str, stage: str) -> Optional[Dict[str, Any]]:
|
| 342 |
+
"""Downloads training_state.json from HF if present."""
|
| 343 |
+
if not HF_USE_HUB or hf_auth_token is None:
|
| 344 |
+
return None
|
| 345 |
+
try:
|
| 346 |
+
state_path = hf_hub_download(
|
| 347 |
+
repo_id=repo_id,
|
| 348 |
+
filename=f"{stage}/latest/training_state.json",
|
| 349 |
+
repo_type="model",
|
| 350 |
+
token=hf_auth_token,
|
| 351 |
+
)
|
| 352 |
+
with open(state_path, "r", encoding="utf-8") as f:
|
| 353 |
+
return json.load(f)
|
| 354 |
+
except Exception:
|
| 355 |
+
return None
|
| 356 |
+
|
| 357 |
+
def _download_optimizer_state(repo_id: str, stage: str) -> Optional[str]:
|
| 358 |
+
"""Downloads optimizer.pt for resuming optimizer state."""
|
| 359 |
+
if not HF_USE_HUB or hf_auth_token is None:
|
| 360 |
+
return None
|
| 361 |
+
try:
|
| 362 |
+
opt_path = hf_hub_download(
|
| 363 |
+
repo_id=repo_id,
|
| 364 |
+
filename=f"{stage}/latest/optimizer.pt",
|
| 365 |
+
repo_type="model",
|
| 366 |
+
token=hf_auth_token,
|
| 367 |
+
)
|
| 368 |
+
return opt_path
|
| 369 |
+
except Exception:
|
| 370 |
+
return None
|
| 371 |
+
|
| 372 |
+
def _load_model_and_tokenizer_from_hf(repo_id: str, stage: str) -> Optional[Tuple[AutoModelForCausalLM, AutoTokenizer]]:
|
| 373 |
+
"""
|
| 374 |
+
Loads model and tokenizer from HF under subfolder {stage}/latest if available.
|
| 375 |
+
"""
|
| 376 |
+
if not _repo_has_stage_latest(repo_id, stage):
|
| 377 |
+
return None
|
| 378 |
+
print(f"\n---> Loading checkpoint from Hugging Face: {repo_id} (subfolder='{stage}/latest')")
|
| 379 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=f"{stage}/latest", trust_remote_code=True)
|
| 380 |
+
model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=f"{stage}/latest", trust_remote_code=True)
|
| 381 |
+
if tokenizer.pad_token is None:
|
| 382 |
+
tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
|
| 383 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 384 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 385 |
+
return model, tokenizer
|
| 386 |
+
|
| 387 |
+
def _ensure_tokenizer_has_motion_tokens(tokenizer: AutoTokenizer, motion_tokens: List[str]) -> int:
|
| 388 |
+
"""
|
| 389 |
+
Adds any missing motion tokens to the tokenizer. Returns number of tokens added.
|
| 390 |
+
"""
|
| 391 |
+
tokenizer.add_special_tokens({"pad_token": PAD_TOKEN, "additional_special_tokens": [M_START, M_END]})
|
| 392 |
+
added = tokenizer.add_tokens(motion_tokens, special_tokens=True)
|
| 393 |
+
return added
|
| 394 |
+
|
| 395 |
+
def _save_and_push_checkpoint(
|
| 396 |
+
stage: str,
|
| 397 |
+
epoch_index_zero_based: int,
|
| 398 |
+
model: AutoModelForCausalLM,
|
| 399 |
+
tokenizer: AutoTokenizer,
|
| 400 |
+
optimizer: AdamW,
|
| 401 |
+
avg_loss: float,
|
| 402 |
+
dataloader_len: int,
|
| 403 |
+
batch_size: int,
|
| 404 |
+
total_epochs: int,
|
| 405 |
+
repo_id: Optional[str],
|
| 406 |
+
) -> None:
|
| 407 |
+
"""
|
| 408 |
+
Saves checkpoint locally (per-epoch and latest) and pushes to HF under:
|
| 409 |
+
- {stage}/epoch-XXX
|
| 410 |
+
- {stage}/latest
|
| 411 |
+
Also saves optimizer state and training_state.json to preserve resume info.
|
| 412 |
+
"""
|
| 413 |
+
epoch_number = epoch_index_zero_based + 1
|
| 414 |
+
stage_dir = os.path.join(CHECKPOINTS_DIR, stage)
|
| 415 |
+
epoch_dir_name = f"epoch-{epoch_number:03d}"
|
| 416 |
+
epoch_dir = os.path.join(stage_dir, epoch_dir_name)
|
| 417 |
+
latest_dir = os.path.join(stage_dir, "latest")
|
| 418 |
+
_ensure_dir(epoch_dir)
|
| 419 |
+
_ensure_dir(stage_dir)
|
| 420 |
+
|
| 421 |
+
# Save model + tokenizer
|
| 422 |
+
model.save_pretrained(epoch_dir)
|
| 423 |
+
tokenizer.save_pretrained(epoch_dir)
|
| 424 |
+
|
| 425 |
+
# Save optimizer state
|
| 426 |
+
torch.save(optimizer.state_dict(), os.path.join(epoch_dir, "optimizer.pt"))
|
| 427 |
+
|
| 428 |
+
# Save training state
|
| 429 |
+
training_state = {
|
| 430 |
+
"stage": stage,
|
| 431 |
+
"epoch_completed": epoch_number,
|
| 432 |
+
"total_epochs_for_stage": total_epochs,
|
| 433 |
+
"global_step": epoch_number * dataloader_len,
|
| 434 |
+
"avg_loss": float(avg_loss),
|
| 435 |
+
"batch_size": batch_size,
|
| 436 |
+
"saved_at": datetime.utcnow().isoformat() + "Z",
|
| 437 |
+
}
|
| 438 |
+
with open(os.path.join(epoch_dir, "training_state.json"), "w", encoding="utf-8") as f:
|
| 439 |
+
json.dump(training_state, f, ensure_ascii=False, indent=2)
|
| 440 |
+
|
| 441 |
+
# Update "latest"
|
| 442 |
+
if os.path.exists(latest_dir):
|
| 443 |
+
shutil.rmtree(latest_dir)
|
| 444 |
+
shutil.copytree(epoch_dir, latest_dir)
|
| 445 |
+
|
| 446 |
+
# Push to Hugging Face
|
| 447 |
+
if HF_USE_HUB and repo_id and hf_auth_token:
|
| 448 |
+
try:
|
| 449 |
+
upload_folder(
|
| 450 |
+
repo_id=repo_id,
|
| 451 |
+
folder_path=epoch_dir,
|
| 452 |
+
path_in_repo=f"{stage}/{epoch_dir_name}",
|
| 453 |
+
repo_type="model",
|
| 454 |
+
token=hf_auth_token,
|
| 455 |
+
commit_message=f"{stage}: save {epoch_dir_name}",
|
| 456 |
+
)
|
| 457 |
+
upload_folder(
|
| 458 |
+
repo_id=repo_id,
|
| 459 |
+
folder_path=latest_dir,
|
| 460 |
+
path_in_repo=f"{stage}/latest",
|
| 461 |
+
repo_type="model",
|
| 462 |
+
token=hf_auth_token,
|
| 463 |
+
commit_message=f"{stage}: update latest -> {epoch_dir_name}",
|
| 464 |
+
)
|
| 465 |
+
print(f"☁️ Pushed checkpoint to HF: {repo_id} ({stage}/{epoch_dir_name} and {stage}/latest)")
|
| 466 |
+
except Exception as exc:
|
| 467 |
+
print(f"⚠️ Failed to push checkpoint to HF: {exc}")
|
| 468 |
+
else:
|
| 469 |
+
print("ℹ️ Skipped HF push (Hub disabled or token/repo missing).")
|
| 470 |
+
|
| 471 |
+
# ======================================================================================
|
| 472 |
+
# 3. Training Stage 1: Motion Language Modeling
|
| 473 |
+
# ======================================================================================
|
| 474 |
+
class MotionDataset(Dataset):
|
| 475 |
+
"""Dataset for Stage 1: Contains only motion token sequences."""
|
| 476 |
+
def __init__(self, data: List[Dict[str, Any]], tokenizer: AutoTokenizer, max_length: int = 256):
|
| 477 |
+
self.tokenizer = tokenizer
|
| 478 |
+
self.max_length = max_length
|
| 479 |
+
self.sequences = []
|
| 480 |
+
|
| 481 |
+
for item in data:
|
| 482 |
+
tokens_str = item.get("motion_tokens", "")
|
| 483 |
+
wrapped_tokens = " ".join([f"<M{t}>" for t in tokens_str.split()])
|
| 484 |
+
full_sequence = f"{M_START} {wrapped_tokens} {M_END}"
|
| 485 |
+
self.sequences.append(full_sequence)
|
| 486 |
+
|
| 487 |
+
def __len__(self):
|
| 488 |
+
return len(self.sequences)
|
| 489 |
+
|
| 490 |
+
def __getitem__(self, idx):
|
| 491 |
+
return self.tokenizer(
|
| 492 |
+
self.sequences[idx],
|
| 493 |
+
truncation=True,
|
| 494 |
+
max_length=self.max_length,
|
| 495 |
+
padding="max_length",
|
| 496 |
+
return_tensors="pt"
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
def train_stage1(
|
| 500 |
+
model,
|
| 501 |
+
tokenizer,
|
| 502 |
+
data,
|
| 503 |
+
device,
|
| 504 |
+
start_epoch: int = 0,
|
| 505 |
+
hf_repo_id: Optional[str] = None,
|
| 506 |
+
):
|
| 507 |
+
"""Trains the model on motion sequences only to learn the 'language of motion'.
|
| 508 |
+
Resumes from Hugging Face if available (model/tokenizer/optimizer)."""
|
| 509 |
+
print("\n" + "="*80)
|
| 510 |
+
print(" STAGE 1: MOTION LANGUAGE MODELING (PRE-TRAINING)")
|
| 511 |
+
print(f" Training on {len(data)} samples.")
|
| 512 |
+
print("="*80)
|
| 513 |
+
|
| 514 |
+
dataset = MotionDataset(data, tokenizer)
|
| 515 |
+
dataloader = DataLoader(dataset, batch_size=S1_BATCH_SIZE, shuffle=True)
|
| 516 |
+
|
| 517 |
+
optimizer = AdamW(model.parameters(), lr=S1_LR)
|
| 518 |
+
model.to(device)
|
| 519 |
+
model.train()
|
| 520 |
+
|
| 521 |
+
# Try to resume optimizer if we resumed from HF
|
| 522 |
+
if hf_repo_id and start_epoch > 0 and HF_USE_HUB and hf_auth_token:
|
| 523 |
+
opt_path = _download_optimizer_state(hf_repo_id, "stage1")
|
| 524 |
+
if opt_path is not None:
|
| 525 |
+
try:
|
| 526 |
+
optimizer.load_state_dict(torch.load(opt_path, map_location=device))
|
| 527 |
+
print("↩️ Resumed optimizer state for Stage 1 from HF.")
|
| 528 |
+
except Exception as exc:
|
| 529 |
+
print(f"⚠️ Failed to load optimizer state for Stage 1: {exc}")
|
| 530 |
+
|
| 531 |
+
for epoch in range(start_epoch, S1_EPOCHS):
|
| 532 |
+
total_loss = 0
|
| 533 |
+
total_batches = len(dataloader)
|
| 534 |
+
epoch_start_time = time.time()
|
| 535 |
+
step_interval = max(1, total_batches // 50) # ~2% progress updates
|
| 536 |
+
for i, batch in enumerate(dataloader, 1):
|
| 537 |
+
optimizer.zero_grad()
|
| 538 |
+
|
| 539 |
+
input_ids = batch['input_ids'].squeeze(1).to(device)
|
| 540 |
+
attention_mask = batch['attention_mask'].squeeze(1).to(device)
|
| 541 |
+
|
| 542 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
| 543 |
+
|
| 544 |
+
loss = outputs.loss
|
| 545 |
+
loss.backward()
|
| 546 |
+
optimizer.step()
|
| 547 |
+
total_loss += loss.item()
|
| 548 |
+
|
| 549 |
+
# Progress with ETA
|
| 550 |
+
if i == 1 or (i % step_interval == 0) or (i == total_batches):
|
| 551 |
+
elapsed = time.time() - epoch_start_time
|
| 552 |
+
est_total = (elapsed / i) * total_batches
|
| 553 |
+
eta = est_total - elapsed
|
| 554 |
+
pct = (i / total_batches) * 100.0
|
| 555 |
+
print(
|
| 556 |
+
f"\r[Stage 1] Epoch {epoch+1}/{S1_EPOCHS} - "
|
| 557 |
+
f"{i}/{total_batches} ({pct:.1f}%) - ETA {_format_seconds(eta)}",
|
| 558 |
+
end="",
|
| 559 |
+
flush=True,
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Finish the progress line
|
| 563 |
+
print()
|
| 564 |
+
avg_loss = total_loss / len(dataloader)
|
| 565 |
+
print(f"--- End of Epoch {epoch+1}/{S1_EPOCHS}, Average Loss: {avg_loss:.4f} ---")
|
| 566 |
+
# Save checkpoint locally every epoch; push to HF only at interval or final epoch
|
| 567 |
+
push_this_epoch = ((epoch + 1) % CHECKPOINT_UPLOAD_INTERVAL_EPOCHS == 0) or ((epoch + 1) == S1_EPOCHS)
|
| 568 |
+
repo_for_epoch = hf_repo_id if push_this_epoch else None
|
| 569 |
+
_save_and_push_checkpoint(
|
| 570 |
+
stage="stage1",
|
| 571 |
+
epoch_index_zero_based=epoch,
|
| 572 |
+
model=model,
|
| 573 |
+
tokenizer=tokenizer,
|
| 574 |
+
optimizer=optimizer,
|
| 575 |
+
avg_loss=avg_loss,
|
| 576 |
+
dataloader_len=len(dataloader),
|
| 577 |
+
batch_size=S1_BATCH_SIZE,
|
| 578 |
+
total_epochs=S1_EPOCHS,
|
| 579 |
+
repo_id=repo_for_epoch,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
print("\n✅ Stage 1 Training Complete.")
|
| 583 |
+
return model
|
| 584 |
+
|
| 585 |
+
# ======================================================================================
|
| 586 |
+
# 4. Training Stage 2: Text-to-Motion Fine-Tuning
|
| 587 |
+
# ======================================================================================
|
| 588 |
+
class TextMotionDataset(Dataset):
|
| 589 |
+
"""Dataset for Stage 2: Contains (prompt, motion_sequence) pairs."""
|
| 590 |
+
def __init__(self, data: List[Dict[str, Any]], tokenizer: AutoTokenizer, max_length: int = 256):
|
| 591 |
+
self.tokenizer = tokenizer
|
| 592 |
+
self.max_length = max_length
|
| 593 |
+
self.items = []
|
| 594 |
+
|
| 595 |
+
for item in data:
|
| 596 |
+
prompt = f"Instruction: Generate motion for word '{item['word']}' with variant '{item['participant_id']}'.\nMotion: "
|
| 597 |
+
|
| 598 |
+
tokens_str = item.get("motion_tokens", "")
|
| 599 |
+
wrapped_tokens = " ".join([f"<M{t}>" for t in tokens_str.split()])
|
| 600 |
+
target_sequence = f"{M_START} {wrapped_tokens} {M_END}"
|
| 601 |
+
|
| 602 |
+
full_text = prompt + target_sequence
|
| 603 |
+
|
| 604 |
+
tokenized = self.tokenizer(
|
| 605 |
+
full_text,
|
| 606 |
+
truncation=True,
|
| 607 |
+
max_length=self.max_length,
|
| 608 |
+
padding="max_length",
|
| 609 |
+
return_tensors="pt"
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
prompt_tokenized = self.tokenizer(prompt, return_tensors="pt")
|
| 613 |
+
prompt_len = prompt_tokenized.input_ids.shape[1]
|
| 614 |
+
|
| 615 |
+
labels = tokenized['input_ids'].clone()
|
| 616 |
+
labels[0, :prompt_len] = -100
|
| 617 |
+
|
| 618 |
+
self.items.append({
|
| 619 |
+
"input_ids": tokenized['input_ids'].squeeze(0),
|
| 620 |
+
"attention_mask": tokenized['attention_mask'].squeeze(0),
|
| 621 |
+
"labels": labels.squeeze(0)
|
| 622 |
+
})
|
| 623 |
+
|
| 624 |
+
def __len__(self):
|
| 625 |
+
return len(self.items)
|
| 626 |
+
|
| 627 |
+
def __getitem__(self, idx):
|
| 628 |
+
return self.items[idx]
|
| 629 |
+
|
| 630 |
+
def train_stage2(
|
| 631 |
+
model,
|
| 632 |
+
tokenizer,
|
| 633 |
+
data,
|
| 634 |
+
device,
|
| 635 |
+
start_epoch: int = 0,
|
| 636 |
+
hf_repo_id: Optional[str] = None,
|
| 637 |
+
hf_stage_subdir: str = "stage2",
|
| 638 |
+
):
|
| 639 |
+
"""Fine-tunes the motion-aware model to connect text prompts to motions.
|
| 640 |
+
Resumes from Hugging Face if available (model/tokenizer/optimizer)."""
|
| 641 |
+
print("\n" + "="*80)
|
| 642 |
+
print(" STAGE 2: TEXT-TO-MOTION FINE-TUNING")
|
| 643 |
+
print(f" Training on {len(data)} samples.")
|
| 644 |
+
print("="*80)
|
| 645 |
+
|
| 646 |
+
dataset = TextMotionDataset(data, tokenizer)
|
| 647 |
+
dataloader = DataLoader(dataset, batch_size=S2_BATCH_SIZE, shuffle=True)
|
| 648 |
+
|
| 649 |
+
optimizer = AdamW(model.parameters(), lr=S2_LR)
|
| 650 |
+
model.to(device)
|
| 651 |
+
model.train()
|
| 652 |
+
|
| 653 |
+
# Try to resume optimizer if we resumed from HF
|
| 654 |
+
if hf_repo_id and start_epoch > 0 and HF_USE_HUB and hf_auth_token:
|
| 655 |
+
opt_path = _download_optimizer_state(hf_repo_id, hf_stage_subdir)
|
| 656 |
+
if opt_path is not None:
|
| 657 |
+
try:
|
| 658 |
+
optimizer.load_state_dict(torch.load(opt_path, map_location=device))
|
| 659 |
+
print("↩️ Resumed optimizer state for Stage 2 from HF.")
|
| 660 |
+
except Exception as exc:
|
| 661 |
+
print(f"⚠️ Failed to load optimizer state for Stage 2: {exc}")
|
| 662 |
+
|
| 663 |
+
for epoch in range(start_epoch, S2_EPOCHS):
|
| 664 |
+
total_loss = 0
|
| 665 |
+
total_batches = len(dataloader)
|
| 666 |
+
epoch_start_time = time.time()
|
| 667 |
+
step_interval = max(1, total_batches // 50) # ~2% progress updates
|
| 668 |
+
for i, batch in enumerate(dataloader, 1):
|
| 669 |
+
optimizer.zero_grad()
|
| 670 |
+
|
| 671 |
+
input_ids = batch['input_ids'].to(device)
|
| 672 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 673 |
+
labels = batch['labels'].to(device)
|
| 674 |
+
|
| 675 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 676 |
+
|
| 677 |
+
loss = outputs.loss
|
| 678 |
+
loss.backward()
|
| 679 |
+
optimizer.step()
|
| 680 |
+
total_loss += loss.item()
|
| 681 |
+
|
| 682 |
+
# Progress with ETA
|
| 683 |
+
if i == 1 or (i % step_interval == 0) or (i == total_batches):
|
| 684 |
+
elapsed = time.time() - epoch_start_time
|
| 685 |
+
est_total = (elapsed / i) * total_batches
|
| 686 |
+
eta = est_total - elapsed
|
| 687 |
+
pct = (i / total_batches) * 100.0
|
| 688 |
+
print(
|
| 689 |
+
f"\r[Stage 2] Epoch {epoch+1}/{S2_EPOCHS} - "
|
| 690 |
+
f"{i}/{total_batches} ({pct:.1f}%) - ETA {_format_seconds(eta)}",
|
| 691 |
+
end="",
|
| 692 |
+
flush=True,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
# Finish the progress line
|
| 696 |
+
print()
|
| 697 |
+
avg_loss = total_loss / len(dataloader)
|
| 698 |
+
print(f"--- End of Epoch {epoch+1}/{S2_EPOCHS}, Average Loss: {avg_loss:.4f} ---")
|
| 699 |
+
# Save checkpoint locally every epoch; push to HF only at interval or final epoch
|
| 700 |
+
push_this_epoch = ((epoch + 1) % CHECKPOINT_UPLOAD_INTERVAL_EPOCHS == 0) or ((epoch + 1) == S2_EPOCHS)
|
| 701 |
+
repo_for_epoch = hf_repo_id if push_this_epoch else None
|
| 702 |
+
_save_and_push_checkpoint(
|
| 703 |
+
stage=hf_stage_subdir,
|
| 704 |
+
epoch_index_zero_based=epoch,
|
| 705 |
+
model=model,
|
| 706 |
+
tokenizer=tokenizer,
|
| 707 |
+
optimizer=optimizer,
|
| 708 |
+
avg_loss=avg_loss,
|
| 709 |
+
dataloader_len=len(dataloader),
|
| 710 |
+
batch_size=S2_BATCH_SIZE,
|
| 711 |
+
total_epochs=S2_EPOCHS,
|
| 712 |
+
repo_id=repo_for_epoch,
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
print("\n✅ Stage 2 Training Complete.")
|
| 716 |
+
if not os.path.exists(OUTPUT_DIR):
|
| 717 |
+
os.makedirs(OUTPUT_DIR)
|
| 718 |
+
model.save_pretrained(OUTPUT_DIR)
|
| 719 |
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
| 720 |
+
print(f"Model saved to {OUTPUT_DIR}")
|
| 721 |
+
return model
|
| 722 |
+
|
| 723 |
+
# ======================================================================================
|
| 724 |
+
# 5. Inference and Comparison
|
| 725 |
+
# ======================================================================================
|
| 726 |
+
def generate_motion(model, tokenizer, prompt, device):
|
| 727 |
+
"""Generates a motion sequence from a prompt using sampling."""
|
| 728 |
+
model.eval()
|
| 729 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 730 |
+
|
| 731 |
+
with torch.no_grad():
|
| 732 |
+
output = model.generate(
|
| 733 |
+
**inputs,
|
| 734 |
+
max_new_tokens=100,
|
| 735 |
+
do_sample=True,
|
| 736 |
+
temperature=INFERENCE_TEMPERATURE,
|
| 737 |
+
top_k=INFERENCE_TOP_K,
|
| 738 |
+
repetition_penalty=INFERENCE_REPETITION_PENALTY,
|
| 739 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 740 |
+
eos_token_id=tokenizer.convert_tokens_to_ids(M_END),
|
| 741 |
+
early_stopping=True
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
decoded = tokenizer.decode(output[0], skip_special_tokens=False)
|
| 745 |
+
motion_part = decoded.split("Motion: ")[-1]
|
| 746 |
+
return motion_part.strip()
|
| 747 |
+
|
| 748 |
+
def compare_sequences(gt: str, gen: str):
|
| 749 |
+
"""Provides a simple visual diff of two sequences without external libraries."""
|
| 750 |
+
gt_tokens = gt.split()
|
| 751 |
+
gen_tokens = gen.split()
|
| 752 |
+
|
| 753 |
+
print("\nDetailed Comparison (✅ = Match, ❌ = Mismatch/Missing/Added):")
|
| 754 |
+
|
| 755 |
+
gt_str = " GT: "
|
| 756 |
+
gen_str = " GEN: "
|
| 757 |
+
diff_str = " "
|
| 758 |
+
|
| 759 |
+
max_len = max(len(gt_tokens), len(gen_tokens))
|
| 760 |
+
|
| 761 |
+
for i in range(max_len):
|
| 762 |
+
gt_tok = gt_tokens[i] if i < len(gt_tokens) else "___"
|
| 763 |
+
gen_tok = gen_tokens[i] if i < len(gen_tokens) else "___"
|
| 764 |
+
|
| 765 |
+
max_tok_len = max(len(gt_tok), len(gen_tok))
|
| 766 |
+
gt_tok_padded = gt_tok.ljust(max_tok_len)
|
| 767 |
+
gen_tok_padded = gen_tok.ljust(max_tok_len)
|
| 768 |
+
|
| 769 |
+
gt_str += gt_tok_padded + " "
|
| 770 |
+
gen_str += gen_tok_padded + " "
|
| 771 |
+
|
| 772 |
+
if gt_tok == gen_tok:
|
| 773 |
+
diff_str += "✅".ljust(max_tok_len) + " "
|
| 774 |
+
else:
|
| 775 |
+
diff_str += "❌".ljust(max_tok_len) + " "
|
| 776 |
+
|
| 777 |
+
print(gt_str)
|
| 778 |
+
print(gen_str)
|
| 779 |
+
print(diff_str)
|
| 780 |
+
|
| 781 |
+
def run_inference_on_all_samples(model, tokenizer, data, device):
|
| 782 |
+
"""
|
| 783 |
+
Runs inference on ALL available samples for the trained words and compares
|
| 784 |
+
each one to its specific ground truth.
|
| 785 |
+
"""
|
| 786 |
+
print("\n" + "="*80)
|
| 787 |
+
print(" INFERENCE AND EVALUATION (ALL SAMPLES)")
|
| 788 |
+
print(" Goal: Test the model's performance on every variant.")
|
| 789 |
+
print("="*80)
|
| 790 |
+
|
| 791 |
+
data_by_word = {}
|
| 792 |
+
for item in data:
|
| 793 |
+
word = item['word']
|
| 794 |
+
if word not in data_by_word:
|
| 795 |
+
data_by_word[word] = []
|
| 796 |
+
data_by_word[word].append(item)
|
| 797 |
+
|
| 798 |
+
for word, samples in data_by_word.items():
|
| 799 |
+
print(f"\n\n{'='*25} TESTING WORD: '{word}' {'='*25}")
|
| 800 |
+
num_correct = 0
|
| 801 |
+
|
| 802 |
+
for i, sample in enumerate(samples):
|
| 803 |
+
print(f"\n--- Testing Variant {i+1}/{len(samples)}: '{sample['participant_id']}' ---")
|
| 804 |
+
|
| 805 |
+
gt_tokens_str = sample.get("motion_tokens", "")
|
| 806 |
+
gt_wrapped = " ".join([f"<M{t}>" for t in gt_tokens_str.split()])
|
| 807 |
+
gt_sequence = f"{M_START} {gt_wrapped} {M_END}"
|
| 808 |
+
print(f"Ground Truth:\n{gt_sequence}")
|
| 809 |
+
|
| 810 |
+
prompt = f"Instruction: Generate motion for word '{sample['word']}' with variant '{sample['participant_id']}'.\nMotion: "
|
| 811 |
+
generated_sequence = generate_motion(model, tokenizer, prompt, device)
|
| 812 |
+
print(f"\nLLM Generated:\n{generated_sequence}")
|
| 813 |
+
|
| 814 |
+
compare_sequences(gt_sequence, generated_sequence)
|
| 815 |
+
|
| 816 |
+
if gt_sequence.strip() == generated_sequence.strip():
|
| 817 |
+
num_correct += 1
|
| 818 |
+
|
| 819 |
+
print("-" * 80)
|
| 820 |
+
|
| 821 |
+
accuracy = (num_correct / len(samples)) * 100
|
| 822 |
+
print(f"\nSUMMARY FOR '{word}': {num_correct}/{len(samples)} correct ({accuracy:.1f}%)")
|
| 823 |
+
|
| 824 |
+
# ======================================================================================
|
| 825 |
+
# 5b. Metrics: FID, Diversity, Multimodality (MIM) using MotionGPT-style utils
|
| 826 |
+
# ======================================================================================
|
| 827 |
+
def calculate_activation_statistics_np(activations: np.ndarray):
|
| 828 |
+
"""
|
| 829 |
+
Params:
|
| 830 |
+
-- activations: num_samples x dim_feat (numpy)
|
| 831 |
+
Returns:
|
| 832 |
+
-- mu: dim_feat
|
| 833 |
+
-- sigma: dim_feat x dim_feat
|
| 834 |
+
"""
|
| 835 |
+
mu = np.mean(activations, axis=0)
|
| 836 |
+
cov = np.cov(activations, rowvar=False)
|
| 837 |
+
return mu, cov
|
| 838 |
+
|
| 839 |
+
def calculate_frechet_distance_np(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
| 840 |
+
"""Numpy implementation of the Frechet Distance."""
|
| 841 |
+
mu1 = np.atleast_1d(mu1)
|
| 842 |
+
mu2 = np.atleast_1d(mu2)
|
| 843 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 844 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 845 |
+
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
|
| 846 |
+
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
|
| 847 |
+
diff = mu1 - mu2
|
| 848 |
+
covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 849 |
+
if not np.isfinite(covmean).all():
|
| 850 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 851 |
+
covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 852 |
+
if np.iscomplexobj(covmean):
|
| 853 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 854 |
+
m = np.max(np.abs(covmean.imag))
|
| 855 |
+
raise ValueError(f"Imaginary component {m}")
|
| 856 |
+
covmean = covmean.real
|
| 857 |
+
tr_covmean = np.trace(covmean)
|
| 858 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
| 859 |
+
|
| 860 |
+
def calculate_diversity_np(activation: np.ndarray, diversity_times: int = 200) -> float:
|
| 861 |
+
"""Mean pairwise L2 distance across random pairs."""
|
| 862 |
+
assert len(activation.shape) == 2
|
| 863 |
+
assert activation.shape[0] > max(2, diversity_times)
|
| 864 |
+
num_samples = activation.shape[0]
|
| 865 |
+
first_indices = np.random.choice(num_samples, diversity_times, replace=False)
|
| 866 |
+
second_indices = np.random.choice(num_samples, diversity_times, replace=False)
|
| 867 |
+
diffs = activation[first_indices] - activation[second_indices]
|
| 868 |
+
dist = np.linalg.norm(diffs, axis=1)
|
| 869 |
+
return float(dist.mean())
|
| 870 |
+
|
| 871 |
+
def calculate_multimodality_np(activation: np.ndarray, multimodality_times: int = 20) -> float:
|
| 872 |
+
"""
|
| 873 |
+
activation: [num_labels, num_per_label, D]
|
| 874 |
+
Returns mean pairwise within-label diversity (higher = more multimodal).
|
| 875 |
+
"""
|
| 876 |
+
assert len(activation.shape) == 3
|
| 877 |
+
num_labels, num_per_label, _ = activation.shape
|
| 878 |
+
assert num_per_label > multimodality_times
|
| 879 |
+
first_dices = np.random.choice(num_per_label, multimodality_times, replace=False)
|
| 880 |
+
second_dices = np.random.choice(num_per_label, multimodality_times, replace=False)
|
| 881 |
+
diffs = activation[:, first_dices] - activation[:, second_dices]
|
| 882 |
+
dist = np.linalg.norm(diffs, axis=2)
|
| 883 |
+
return float(dist.mean())
|
| 884 |
+
|
| 885 |
+
# --------------------------------------------------------------------------------------
|
| 886 |
+
# Token sequence → activation (bag-of-motion-tokens) helpers
|
| 887 |
+
# --------------------------------------------------------------------------------------
|
| 888 |
+
def _extract_motion_tokens_from_sequence(seq: str) -> list[str]:
|
| 889 |
+
# Expect tokens like <M123>, within M_START/M_END fences; keep only <M...>
|
| 890 |
+
return [tok for tok in seq.split() if tok.startswith("<M") and tok.endswith(">")]
|
| 891 |
+
|
| 892 |
+
def _build_token_index(tokens_vocab: list[str]) -> Dict[str, int]:
|
| 893 |
+
return {tok: idx for idx, tok in enumerate(tokens_vocab)}
|
| 894 |
+
|
| 895 |
+
def _sequence_to_activation(seq: str, token_to_index: Dict[str, int]) -> np.ndarray:
|
| 896 |
+
vec = np.zeros((len(token_to_index),), dtype=np.float32)
|
| 897 |
+
for tok in _extract_motion_tokens_from_sequence(seq):
|
| 898 |
+
idx = token_to_index.get(tok)
|
| 899 |
+
if idx is not None:
|
| 900 |
+
vec[idx] += 1.0
|
| 901 |
+
# Normalize to unit length to reduce length bias
|
| 902 |
+
norm = np.linalg.norm(vec)
|
| 903 |
+
if norm > 0:
|
| 904 |
+
vec = vec / norm
|
| 905 |
+
return vec
|
| 906 |
+
|
| 907 |
+
def _collect_eval_pairs(model, tokenizer, data, device) -> list[Tuple[str, str, str]]:
|
| 908 |
+
"""
|
| 909 |
+
Returns list of (word, participant_id, gt_sequence, generated_sequence) for each sample in data.
|
| 910 |
+
"""
|
| 911 |
+
results = []
|
| 912 |
+
for sample in data:
|
| 913 |
+
gt_tokens_str = sample.get("motion_tokens", "")
|
| 914 |
+
gt_wrapped = " ".join([f"<M{t}>" for t in gt_tokens_str.split()])
|
| 915 |
+
gt_sequence = f"{M_START} {gt_wrapped} {M_END}"
|
| 916 |
+
prompt = f"Instruction: Generate motion for word '{sample['word']}' with variant '{sample['participant_id']}'.\nMotion: "
|
| 917 |
+
generated_sequence = generate_motion(model, tokenizer, prompt, device)
|
| 918 |
+
pid = str(sample.get("participant_id", ""))
|
| 919 |
+
results.append((sample["word"], pid, gt_sequence, generated_sequence))
|
| 920 |
+
return results
|
| 921 |
+
|
| 922 |
+
def _activations_from_pairs(pairs: list[Tuple[str, str, str]], vocab_tokens: list[str]):
|
| 923 |
+
"""
|
| 924 |
+
Build numpy activations and labels arrays from sequences.
|
| 925 |
+
Returns:
|
| 926 |
+
gt_acts: (N, D)
|
| 927 |
+
gen_acts: (N, D)
|
| 928 |
+
labels: list[str] length N (word labels)
|
| 929 |
+
"""
|
| 930 |
+
token_to_index = _build_token_index(vocab_tokens)
|
| 931 |
+
gt_vecs = []
|
| 932 |
+
gen_vecs = []
|
| 933 |
+
labels = []
|
| 934 |
+
for pair in pairs:
|
| 935 |
+
# Support both legacy 3-tuple (word, gt, gen) and new 4-tuple (word, pid, gt, gen)
|
| 936 |
+
if len(pair) == 4:
|
| 937 |
+
word, _pid, gt_seq, gen_seq = pair
|
| 938 |
+
else:
|
| 939 |
+
word, gt_seq, gen_seq = pair
|
| 940 |
+
gt_vecs.append(_sequence_to_activation(gt_seq, token_to_index))
|
| 941 |
+
gen_vecs.append(_sequence_to_activation(gen_seq, token_to_index))
|
| 942 |
+
labels.append(word)
|
| 943 |
+
return np.stack(gt_vecs, axis=0), np.stack(gen_vecs, axis=0), labels
|
| 944 |
+
|
| 945 |
+
def _to_label_tensor3(acts: np.ndarray, labels: list[str]) -> np.ndarray:
|
| 946 |
+
"""
|
| 947 |
+
Convert N x D activations with string labels to [L, K, D] by truncating each label
|
| 948 |
+
to the minimum count across labels.
|
| 949 |
+
"""
|
| 950 |
+
label_to_indices: Dict[str, list[int]] = {}
|
| 951 |
+
for i, lbl in enumerate(labels):
|
| 952 |
+
label_to_indices.setdefault(lbl, []).append(i)
|
| 953 |
+
per_label_counts = [len(idxs) for idxs in label_to_indices.values()]
|
| 954 |
+
if len(per_label_counts) == 0:
|
| 955 |
+
raise ValueError("No labels found for multimodality computation.")
|
| 956 |
+
min_count = max(2, min(per_label_counts))
|
| 957 |
+
label_names = sorted(label_to_indices.keys())
|
| 958 |
+
stacked = []
|
| 959 |
+
for lbl in label_names:
|
| 960 |
+
idxs = label_to_indices[lbl][:min_count]
|
| 961 |
+
stacked.append(acts[idxs])
|
| 962 |
+
return np.stack(stacked, axis=0) # [L, K, D]
|
| 963 |
+
|
| 964 |
+
def evaluate_metrics_motiongpt_style(model, tokenizer, eval_data, all_motion_tokens, device):
|
| 965 |
+
"""
|
| 966 |
+
Computes:
|
| 967 |
+
- Diversity: GT vs GEN (pair)
|
| 968 |
+
- Multimodality (MIM): GT vs GEN (pair)
|
| 969 |
+
- FID: between GT and GEN
|
| 970 |
+
"""
|
| 971 |
+
print("\n" + "="*80)
|
| 972 |
+
print(" METRICS EVALUATION (FID, Diversity, Multimodality)")
|
| 973 |
+
print("="*80)
|
| 974 |
+
pairs = _collect_eval_pairs(model, tokenizer, eval_data, device)
|
| 975 |
+
gt_acts, gen_acts, labels = _activations_from_pairs(pairs, all_motion_tokens)
|
| 976 |
+
# Diversity
|
| 977 |
+
diversity_times = min(200, max(4, gt_acts.shape[0] - 1))
|
| 978 |
+
diversity_gt = calculate_diversity_np(gt_acts, diversity_times=diversity_times)
|
| 979 |
+
diversity_gen = calculate_diversity_np(gen_acts, diversity_times=diversity_times)
|
| 980 |
+
# Multimodality (MIM)
|
| 981 |
+
try:
|
| 982 |
+
gt_lbl_tensor = _to_label_tensor3(gt_acts, labels)
|
| 983 |
+
gen_lbl_tensor = _to_label_tensor3(gen_acts, labels)
|
| 984 |
+
multimodality_times = min(20, max(3, gt_lbl_tensor.shape[1] - 1))
|
| 985 |
+
mim_gt = calculate_multimodality_np(gt_lbl_tensor, multimodality_times=multimodality_times)
|
| 986 |
+
mim_gen = calculate_multimodality_np(gen_lbl_tensor, multimodality_times=multimodality_times)
|
| 987 |
+
except Exception as exc:
|
| 988 |
+
print(f"⚠️ Multimodality could not be computed reliably: {exc}")
|
| 989 |
+
mim_gt = float("nan")
|
| 990 |
+
mim_gen = float("nan")
|
| 991 |
+
# FID
|
| 992 |
+
mu_gen, cov_gen = calculate_activation_statistics_np(gen_acts)
|
| 993 |
+
mu_gt, cov_gt = calculate_activation_statistics_np(gt_acts)
|
| 994 |
+
fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen)
|
| 995 |
+
print(f"Diversity: GT = {diversity_gt:.4f} | GEN = {diversity_gen:.4f}")
|
| 996 |
+
print(f"Multimodality (MIM): GT = {mim_gt:.4f} | GEN = {mim_gen:.4f}")
|
| 997 |
+
print(f"FID (GT vs GEN): {fid:.4f}")
|
| 998 |
+
return {
|
| 999 |
+
"diversity_gt": diversity_gt,
|
| 1000 |
+
"diversity_gen": diversity_gen,
|
| 1001 |
+
"mim_gt": mim_gt,
|
| 1002 |
+
"mim_gen": mim_gen,
|
| 1003 |
+
"fid": fid,
|
| 1004 |
+
"pairs": pairs, # for visualization usage
|
| 1005 |
+
}
|
| 1006 |
+
|
| 1007 |
+
# ======================================================================================
|
| 1008 |
+
# 5b-ALT. Metrics using VQ-VAE codebook embeddings (near-standard activations)
|
| 1009 |
+
# ======================================================================================
|
| 1010 |
+
def _sequence_to_codebook_feature(seq: str, vq_model) -> np.ndarray:
|
| 1011 |
+
"""
|
| 1012 |
+
Build a single clip feature by mean-pooling VQ-VAE codebook embeddings
|
| 1013 |
+
corresponding to the token ids in the sequence. L2-normalized.
|
| 1014 |
+
"""
|
| 1015 |
+
token_ids = _extract_ids_from_sequence(seq)
|
| 1016 |
+
# Resolve code dimension and codebook availability
|
| 1017 |
+
quantizer = getattr(vq_model.vqvae, "quantizer", None)
|
| 1018 |
+
if quantizer is None:
|
| 1019 |
+
raise RuntimeError("VQ-VAE quantizer missing; cannot extract codebook embeddings.")
|
| 1020 |
+
# Try dequantize -> mean over time (preferred)
|
| 1021 |
+
feat_vec = None
|
| 1022 |
+
if hasattr(quantizer, "dequantize") and token_ids:
|
| 1023 |
+
try:
|
| 1024 |
+
idx = torch.tensor(token_ids, dtype=torch.long, device=next(vq_model.parameters()).device).unsqueeze(0)
|
| 1025 |
+
with torch.no_grad():
|
| 1026 |
+
dq = quantizer.dequantize(idx)
|
| 1027 |
+
if dq is not None:
|
| 1028 |
+
# Expect shape [N, code_dim, T]; average over T
|
| 1029 |
+
if dq.ndim == 3:
|
| 1030 |
+
if dq.shape[0] == 1:
|
| 1031 |
+
x = dq.squeeze(0) # [code_dim, T] or [T, code_dim]
|
| 1032 |
+
else:
|
| 1033 |
+
x = dq.mean(dim=0)
|
| 1034 |
+
if x.shape[0] < x.shape[1]:
|
| 1035 |
+
# [code_dim, T]
|
| 1036 |
+
feat = x.mean(dim=1)
|
| 1037 |
+
else:
|
| 1038 |
+
# [T, code_dim]
|
| 1039 |
+
feat = x.mean(dim=0)
|
| 1040 |
+
feat_vec = feat.detach().cpu().numpy().astype(np.float32)
|
| 1041 |
+
except Exception:
|
| 1042 |
+
feat_vec = None
|
| 1043 |
+
# Fallback: direct codebook lookup -> mean over token ids
|
| 1044 |
+
if feat_vec is None:
|
| 1045 |
+
codebook = getattr(quantizer, "codebook", None)
|
| 1046 |
+
if codebook is None:
|
| 1047 |
+
raise RuntimeError("Quantizer has neither dequantize() nor codebook.")
|
| 1048 |
+
code_np = codebook.detach().cpu().numpy() # [K, D]
|
| 1049 |
+
if not token_ids:
|
| 1050 |
+
feat_vec = np.zeros((code_np.shape[1],), dtype=np.float32)
|
| 1051 |
+
else:
|
| 1052 |
+
ids = np.asarray(token_ids, dtype=np.int64)
|
| 1053 |
+
ids = np.clip(ids, 0, code_np.shape[0] - 1)
|
| 1054 |
+
feat_vec = code_np[ids].mean(axis=0).astype(np.float32)
|
| 1055 |
+
# L2-normalize to reduce length/scale bias
|
| 1056 |
+
norm = np.linalg.norm(feat_vec)
|
| 1057 |
+
if norm > 0:
|
| 1058 |
+
feat_vec = feat_vec / norm
|
| 1059 |
+
return feat_vec
|
| 1060 |
+
|
| 1061 |
+
|
| 1062 |
+
def _activations_from_pairs_codebook(pairs: list[Tuple[str, str, str]], vq_model):
|
| 1063 |
+
"""
|
| 1064 |
+
Produce codebook-embedding features for GT and GEN sequences and their labels.
|
| 1065 |
+
Returns:
|
| 1066 |
+
gt_feats: (N, D)
|
| 1067 |
+
gen_feats: (N, D)
|
| 1068 |
+
labels: list[str] of length N (word labels)
|
| 1069 |
+
"""
|
| 1070 |
+
gt_feats = []
|
| 1071 |
+
gen_feats = []
|
| 1072 |
+
labels = []
|
| 1073 |
+
for pair in pairs:
|
| 1074 |
+
if len(pair) == 4:
|
| 1075 |
+
word, _pid, gt_seq, gen_seq = pair
|
| 1076 |
+
else:
|
| 1077 |
+
word, gt_seq, gen_seq = pair
|
| 1078 |
+
gt_feats.append(_sequence_to_codebook_feature(gt_seq, vq_model))
|
| 1079 |
+
gen_feats.append(_sequence_to_codebook_feature(gen_seq, vq_model))
|
| 1080 |
+
labels.append(word)
|
| 1081 |
+
return np.stack(gt_feats, axis=0), np.stack(gen_feats, axis=0), labels
|
| 1082 |
+
|
| 1083 |
+
|
| 1084 |
+
def evaluate_metrics_codebook_style(model, tokenizer, eval_data, device, vqvae_ckpt: Optional[str] = None):
|
| 1085 |
+
"""
|
| 1086 |
+
Computes FID, Diversity, and MIM using features derived from the VQ-VAE codebook:
|
| 1087 |
+
- Feature per clip = mean-pooled codebook embeddings over token sequence, L2-normalized
|
| 1088 |
+
- Diversity/MIM computed exactly as in MotionGPT-style helpers but on these features
|
| 1089 |
+
- FID computed via full covariance Fréchet distance on these features
|
| 1090 |
+
Returns a dict mirroring evaluate_metrics_motiongpt_style.
|
| 1091 |
+
"""
|
| 1092 |
+
print("\n" + "="*80)
|
| 1093 |
+
print(" METRICS EVALUATION (Codebook-Embedding Features)")
|
| 1094 |
+
print("="*80)
|
| 1095 |
+
# Lazy import to avoid hard dependency at module import time
|
| 1096 |
+
try:
|
| 1097 |
+
from visualize import load_vqvae, VQVAE_CHECKPOINT as DEFAULT_VQ
|
| 1098 |
+
vq_ckpt = vqvae_ckpt or os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ)
|
| 1099 |
+
vq_model = load_vqvae(vq_ckpt, device=device)
|
| 1100 |
+
except Exception as exc:
|
| 1101 |
+
print(f"⚠️ Could not load VQ-VAE for codebook metrics: {exc}")
|
| 1102 |
+
return {}
|
| 1103 |
+
# Collect pairs and build features
|
| 1104 |
+
pairs = _collect_eval_pairs(model, tokenizer, eval_data, device)
|
| 1105 |
+
gt_feats, gen_feats, labels = _activations_from_pairs_codebook(pairs, vq_model)
|
| 1106 |
+
# Diversity
|
| 1107 |
+
diversity_times = min(200, max(4, gt_feats.shape[0] - 1))
|
| 1108 |
+
diversity_gt = calculate_diversity_np(gt_feats, diversity_times=diversity_times)
|
| 1109 |
+
diversity_gen = calculate_diversity_np(gen_feats, diversity_times=diversity_times)
|
| 1110 |
+
# Multimodality (MIM)
|
| 1111 |
+
try:
|
| 1112 |
+
gt_lbl_tensor = _to_label_tensor3(gt_feats, labels)
|
| 1113 |
+
gen_lbl_tensor = _to_label_tensor3(gen_feats, labels)
|
| 1114 |
+
multimodality_times = min(20, max(3, gt_lbl_tensor.shape[1] - 1))
|
| 1115 |
+
mim_gt = calculate_multimodality_np(gt_lbl_tensor, multimodality_times=multimodality_times)
|
| 1116 |
+
mim_gen = calculate_multimodality_np(gen_lbl_tensor, multimodality_times=multimodality_times)
|
| 1117 |
+
except Exception as exc:
|
| 1118 |
+
print(f"⚠️ Multimodality could not be computed reliably: {exc}")
|
| 1119 |
+
mim_gt = float("nan")
|
| 1120 |
+
mim_gen = float("nan")
|
| 1121 |
+
# FID (on codebook features)
|
| 1122 |
+
mu_gen, cov_gen = calculate_activation_statistics_np(gen_feats)
|
| 1123 |
+
mu_gt, cov_gt = calculate_activation_statistics_np(gt_feats)
|
| 1124 |
+
fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen)
|
| 1125 |
+
print(f"Diversity (codebook feats): GT = {diversity_gt:.4f} | GEN = {diversity_gen:.4f}")
|
| 1126 |
+
print(f"Multimodality (MIM, codebook): GT = {mim_gt:.4f} | GEN = {mim_gen:.4f}")
|
| 1127 |
+
print(f"FID (codebook feats, GT vs GEN): {fid:.4f}")
|
| 1128 |
+
return {
|
| 1129 |
+
"diversity_gt": diversity_gt,
|
| 1130 |
+
"diversity_gen": diversity_gen,
|
| 1131 |
+
"mim_gt": mim_gt,
|
| 1132 |
+
"mim_gen": mim_gen,
|
| 1133 |
+
"fid": fid,
|
| 1134 |
+
"pairs": pairs,
|
| 1135 |
+
}
|
| 1136 |
+
|
| 1137 |
+
# ======================================================================================
|
| 1138 |
+
# 5b-ALT2. Metrics using VQ-VAE encoder pre-quantization features (as described)
|
| 1139 |
+
# ======================================================================================
|
| 1140 |
+
def _encode_params_to_feature(params: np.ndarray, vq_model, mean, std, device) -> np.ndarray:
|
| 1141 |
+
"""
|
| 1142 |
+
Convert SMPL-X parameter sequence (T, D) into a single clip feature using
|
| 1143 |
+
the VQ-VAE encoder output BEFORE quantization. Average-pool over time to get (D_embed,).
|
| 1144 |
+
- Attempts to use vq_model.vqvae.preprocess; otherwise applies manual normalization with mean/std.
|
| 1145 |
+
- Handles encoder outputs shaped as [N, D, T] or [N, T, D_embed].
|
| 1146 |
+
"""
|
| 1147 |
+
if params.size == 0:
|
| 1148 |
+
return np.zeros((getattr(vq_model.vqvae, "output_emb_width", 512),), dtype=np.float32)
|
| 1149 |
+
x = torch.from_numpy(params.astype(np.float32)).to(device) # [T, D]
|
| 1150 |
+
x = x.unsqueeze(0) # [1, T, D]
|
| 1151 |
+
with torch.no_grad():
|
| 1152 |
+
# Normalize / preprocess
|
| 1153 |
+
x_pre = None
|
| 1154 |
+
if hasattr(vq_model.vqvae, "preprocess"):
|
| 1155 |
+
try:
|
| 1156 |
+
x_pre = vq_model.vqvae.preprocess(x) # expected to return tensor ready for encoder
|
| 1157 |
+
except Exception:
|
| 1158 |
+
x_pre = None
|
| 1159 |
+
if x_pre is None:
|
| 1160 |
+
# Manual normalization with provided mean/std
|
| 1161 |
+
if mean is not None and std is not None:
|
| 1162 |
+
mean_t = torch.from_numpy(np.array(mean, dtype=np.float32)).to(device).view(1, 1, -1)
|
| 1163 |
+
std_t = torch.from_numpy(np.array(std, dtype=np.float32)).to(device).view(1, 1, -1)
|
| 1164 |
+
x_norm = (x - mean_t) / (std_t + 1e-8)
|
| 1165 |
+
else:
|
| 1166 |
+
x_norm = x
|
| 1167 |
+
# Some encoders expect [N, D, T]
|
| 1168 |
+
x_pre = x_norm.transpose(1, 2).contiguous() # [1, D, T]
|
| 1169 |
+
# Encode to get pre-quant latent
|
| 1170 |
+
z_e = vq_model.vqvae.encoder(x_pre)
|
| 1171 |
+
# z_e could be [N, D_embed, T_q] or [N, T_q, D_embed]
|
| 1172 |
+
if z_e.dim() == 3:
|
| 1173 |
+
# Determine which axis is time by comparing to known embed dim when available,
|
| 1174 |
+
# otherwise assume time is the smaller dimension (varies per clip).
|
| 1175 |
+
embed_dim_known = getattr(vq_model.vqvae, "output_emb_width", None)
|
| 1176 |
+
if embed_dim_known is not None:
|
| 1177 |
+
if z_e.shape[1] == embed_dim_known:
|
| 1178 |
+
time_axis = 2 # [N, D_embed, T_q]
|
| 1179 |
+
elif z_e.shape[2] == embed_dim_known:
|
| 1180 |
+
time_axis = 1 # [N, T_q, D_embed]
|
| 1181 |
+
else:
|
| 1182 |
+
time_axis = 2 if z_e.shape[2] < z_e.shape[1] else 1
|
| 1183 |
+
else:
|
| 1184 |
+
time_axis = 2 if z_e.shape[2] < z_e.shape[1] else 1
|
| 1185 |
+
feat = z_e.mean(dim=time_axis).squeeze(0)
|
| 1186 |
+
elif z_e.dim() == 2:
|
| 1187 |
+
feat = z_e.squeeze(0)
|
| 1188 |
+
else:
|
| 1189 |
+
# Fallback: flatten then reduce
|
| 1190 |
+
feat = z_e.view(1, -1).mean(dim=0)
|
| 1191 |
+
feat_np = feat.detach().cpu().numpy().astype(np.float32)
|
| 1192 |
+
# L2 normalize
|
| 1193 |
+
norm = np.linalg.norm(feat_np)
|
| 1194 |
+
if norm > 0:
|
| 1195 |
+
feat_np = feat_np / norm
|
| 1196 |
+
return feat_np
|
| 1197 |
+
|
| 1198 |
+
|
| 1199 |
+
def evaluate_metrics_encoder_style(
|
| 1200 |
+
model,
|
| 1201 |
+
tokenizer,
|
| 1202 |
+
eval_data,
|
| 1203 |
+
device,
|
| 1204 |
+
vqvae_ckpt: Optional[str] = None,
|
| 1205 |
+
stats_path: Optional[str] = None,
|
| 1206 |
+
sample_limit: int = 100,
|
| 1207 |
+
):
|
| 1208 |
+
"""
|
| 1209 |
+
Computes FID, Diversity, and MIM using VQ-VAE encoder pre-quantization features:
|
| 1210 |
+
- For each sample, decode tokens -> SMPL-X params, then run through VQ-VAE encoder,
|
| 1211 |
+
average-pool across time, L2-normalize to get a clip feature.
|
| 1212 |
+
- Diversity/MIM identical formulations but on these encoder features.
|
| 1213 |
+
- FID via full covariance Fréchet distance on these encoder features.
|
| 1214 |
+
Evaluates on up to 'sample_limit' samples for speed.
|
| 1215 |
+
"""
|
| 1216 |
+
print("\n" + "="*80)
|
| 1217 |
+
print(" METRICS EVALUATION (VQ-VAE Encoder Features)")
|
| 1218 |
+
print("="*80)
|
| 1219 |
+
# Lazy import to reuse your visualization utilities and stats
|
| 1220 |
+
try:
|
| 1221 |
+
from visualize import load_vqvae, load_stats, VQVAE_CHECKPOINT as DEFAULT_VQ, STATS_PATH as DEFAULT_STATS
|
| 1222 |
+
vq_ckpt = vqvae_ckpt or os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ)
|
| 1223 |
+
stats_p = stats_path or os.getenv("VQVAE_STATS_PATH", DEFAULT_STATS)
|
| 1224 |
+
vq_model = load_vqvae(vq_ckpt, device=device)
|
| 1225 |
+
mean, std = load_stats(stats_p)
|
| 1226 |
+
from visualize import decode_tokens_to_params
|
| 1227 |
+
except Exception as exc:
|
| 1228 |
+
print(f"⚠️ Could not set up VQ-VAE encoder metrics: {exc}")
|
| 1229 |
+
return {}
|
| 1230 |
+
# Collect GT/GEN token sequences for pairs (limit to speed-up)
|
| 1231 |
+
pairs = _collect_eval_pairs(model, tokenizer, eval_data[:sample_limit], device)
|
| 1232 |
+
# Build features
|
| 1233 |
+
gt_feats = []
|
| 1234 |
+
gen_feats = []
|
| 1235 |
+
labels = []
|
| 1236 |
+
for pair in pairs:
|
| 1237 |
+
if len(pair) == 4:
|
| 1238 |
+
word, _pid, gt_seq, gen_seq = pair
|
| 1239 |
+
else:
|
| 1240 |
+
word, gt_seq, gen_seq = pair
|
| 1241 |
+
# Decode to SMPL-X
|
| 1242 |
+
tokens_gt = _extract_ids_from_sequence(gt_seq)
|
| 1243 |
+
tokens_gen = _extract_ids_from_sequence(gen_seq)
|
| 1244 |
+
try:
|
| 1245 |
+
params_gt = decode_tokens_to_params(tokens_gt, vq_model, mean, std, device=device) # (T, D) denorm
|
| 1246 |
+
except Exception:
|
| 1247 |
+
params_gt = np.zeros((0, 182), dtype=np.float32)
|
| 1248 |
+
try:
|
| 1249 |
+
params_gen = decode_tokens_to_params(tokens_gen, vq_model, mean, std, device=device) # (T, D) denorm
|
| 1250 |
+
except Exception:
|
| 1251 |
+
params_gen = np.zeros((0, 182), dtype=np.float32)
|
| 1252 |
+
# Encode (pre-quant) -> pooled feature
|
| 1253 |
+
feat_gt = _encode_params_to_feature(params_gt, vq_model, mean, std, device)
|
| 1254 |
+
feat_gen = _encode_params_to_feature(params_gen, vq_model, mean, std, device)
|
| 1255 |
+
gt_feats.append(feat_gt)
|
| 1256 |
+
gen_feats.append(feat_gen)
|
| 1257 |
+
labels.append(word)
|
| 1258 |
+
gt_feats = np.stack(gt_feats, axis=0)
|
| 1259 |
+
gen_feats = np.stack(gen_feats, axis=0)
|
| 1260 |
+
# Diversity
|
| 1261 |
+
diversity_times = min(200, max(4, gt_feats.shape[0] - 1))
|
| 1262 |
+
diversity_gt = calculate_diversity_np(gt_feats, diversity_times=diversity_times)
|
| 1263 |
+
diversity_gen = calculate_diversity_np(gen_feats, diversity_times=diversity_times)
|
| 1264 |
+
# Multimodality (MIM)
|
| 1265 |
+
try:
|
| 1266 |
+
gt_lbl_tensor = _to_label_tensor3(gt_feats, labels)
|
| 1267 |
+
gen_lbl_tensor = _to_label_tensor3(gen_feats, labels)
|
| 1268 |
+
multimodality_times = min(20, max(3, gt_lbl_tensor.shape[1] - 1))
|
| 1269 |
+
mim_gt = calculate_multimodality_np(gt_lbl_tensor, multimodality_times=multimodality_times)
|
| 1270 |
+
mim_gen = calculate_multimodality_np(gen_lbl_tensor, multimodality_times=multimodality_times)
|
| 1271 |
+
except Exception as exc:
|
| 1272 |
+
print(f"⚠️ Multimodality could not be computed reliably: {exc}")
|
| 1273 |
+
mim_gt = float("nan")
|
| 1274 |
+
mim_gen = float("nan")
|
| 1275 |
+
# FID (on encoder features)
|
| 1276 |
+
mu_gen, cov_gen = calculate_activation_statistics_np(gen_feats)
|
| 1277 |
+
mu_gt, cov_gt = calculate_activation_statistics_np(gt_feats)
|
| 1278 |
+
fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen)
|
| 1279 |
+
print(f"Diversity (encoder feats): GT = {diversity_gt:.4f} | GEN = {diversity_gen:.4f}")
|
| 1280 |
+
print(f"Multimodality (MIM, encoder): GT = {mim_gt:.4f} | GEN = {mim_gen:.4f}")
|
| 1281 |
+
print(f"FID (encoder feats, GT vs GEN): {fid:.4f}")
|
| 1282 |
+
return {
|
| 1283 |
+
"diversity_gt": diversity_gt,
|
| 1284 |
+
"diversity_gen": diversity_gen,
|
| 1285 |
+
"mim_gt": mim_gt,
|
| 1286 |
+
"mim_gen": mim_gen,
|
| 1287 |
+
"fid": fid,
|
| 1288 |
+
"pairs": pairs,
|
| 1289 |
+
}
|
| 1290 |
+
|
| 1291 |
+
# ======================================================================================
|
| 1292 |
+
# 5c. Side-by-side visualization (4 samples)
|
| 1293 |
+
# ======================================================================================
|
| 1294 |
+
def _extract_ids_from_sequence(seq: str) -> list[int]:
|
| 1295 |
+
return [int(t[2:-1]) for t in _extract_motion_tokens_from_sequence(seq) if t[2:-1].isdigit()]
|
| 1296 |
+
|
| 1297 |
+
def save_side_by_side_visualizations(pairs: list[Tuple[str, str, str]], output_dir: str, limit: int = 4):
|
| 1298 |
+
"""
|
| 1299 |
+
Generate side-by-side 3D animations for GT vs GEN, saving one HTML per sample
|
| 1300 |
+
using filename scheme: word_PID_side_by_side.html.
|
| 1301 |
+
- Processes ALL samples for up to `limit` distinct words (if provided).
|
| 1302 |
+
- Requires visualize.py utilities and plotly.
|
| 1303 |
+
"""
|
| 1304 |
+
try:
|
| 1305 |
+
from visualize import (
|
| 1306 |
+
load_vqvae, load_stats, load_smplx_model,
|
| 1307 |
+
decode_tokens_to_params, params_to_vertices,
|
| 1308 |
+
VQVAE_CHECKPOINT as DEFAULT_VQ, STATS_PATH as DEFAULT_STATS, SMPLX_MODEL_DIR as DEFAULT_SMPLX
|
| 1309 |
+
)
|
| 1310 |
+
import plotly.graph_objects as go
|
| 1311 |
+
from plotly.subplots import make_subplots
|
| 1312 |
+
except Exception as exc:
|
| 1313 |
+
print(f"⚠️ Visualization skipped (missing dependencies): {exc}")
|
| 1314 |
+
return
|
| 1315 |
+
|
| 1316 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1317 |
+
vqvae_ckpt = os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ)
|
| 1318 |
+
stats_path = os.getenv("VQVAE_STATS_PATH", DEFAULT_STATS)
|
| 1319 |
+
smplx_dir = os.getenv("SMPLX_MODEL_DIR", DEFAULT_SMPLX)
|
| 1320 |
+
|
| 1321 |
+
print("Loading VQ-VAE, stats, SMPL-X ...")
|
| 1322 |
+
vq_model = load_vqvae(vqvae_ckpt)
|
| 1323 |
+
mean, std = load_stats(stats_path)
|
| 1324 |
+
smplx_model = load_smplx_model(smplx_dir)
|
| 1325 |
+
|
| 1326 |
+
def animate_side_by_side(verts_left, faces, verts_right, fps=20, titles=("Ground Truth", "LLM Generated"), output_html=None):
|
| 1327 |
+
T = min(verts_left.shape[0], verts_right.shape[0])
|
| 1328 |
+
verts_left, verts_right = verts_left[:T], verts_right[:T]
|
| 1329 |
+
i, j, k = faces.T.tolist()
|
| 1330 |
+
fig = make_subplots(
|
| 1331 |
+
rows=1, cols=2,
|
| 1332 |
+
specs=[[{'type': 'scene'}, {'type': 'scene'}]],
|
| 1333 |
+
horizontal_spacing=0.05,
|
| 1334 |
+
subplot_titles=list(titles)
|
| 1335 |
+
)
|
| 1336 |
+
left_mesh = go.Mesh3d(x=verts_left[0,:,0], y=verts_left[0,:,1], z=verts_left[0,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False)
|
| 1337 |
+
right_mesh = go.Mesh3d(x=verts_right[0,:,0], y=verts_right[0,:,1], z=verts_right[0,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False)
|
| 1338 |
+
fig.add_trace(left_mesh, row=1, col=1)
|
| 1339 |
+
fig.add_trace(right_mesh, row=1, col=2)
|
| 1340 |
+
frames = []
|
| 1341 |
+
for t in range(T):
|
| 1342 |
+
frames.append(go.Frame(
|
| 1343 |
+
name=str(t),
|
| 1344 |
+
data=[
|
| 1345 |
+
go.Mesh3d(x=verts_left[t,:,0], y=verts_left[t,:,1], z=verts_left[t,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False,scene="scene"),
|
| 1346 |
+
go.Mesh3d(x=verts_right[t,:,0], y=verts_right[t,:,1], z=verts_right[t,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False,scene="scene2")
|
| 1347 |
+
]
|
| 1348 |
+
))
|
| 1349 |
+
fig.frames = frames
|
| 1350 |
+
fig.update_layout(
|
| 1351 |
+
showlegend=False,
|
| 1352 |
+
margin=dict(l=10, r=10, t=50, b=10),
|
| 1353 |
+
scene=dict(aspectmode='data',xaxis=dict(visible=False),yaxis=dict(visible=False),zaxis=dict(visible=False),
|
| 1354 |
+
camera=dict(eye=dict(x=0,y=-2,z=0.7))),
|
| 1355 |
+
scene2=dict(aspectmode='data',xaxis=dict(visible=False),yaxis=dict(visible=False),zaxis=dict(visible=False),
|
| 1356 |
+
camera=dict(eye=dict(x=0,y=-2,z=0.7))),
|
| 1357 |
+
updatemenus=[dict(
|
| 1358 |
+
type="buttons", x=0.5, xanchor="center", y=1.15, yanchor="top",
|
| 1359 |
+
buttons=[
|
| 1360 |
+
dict(label="Play", method="animate", args=[None, {"frame": {"duration": max(1,1000//fps), "redraw": True}, "fromcurrent": True}]),
|
| 1361 |
+
dict(label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}}])
|
| 1362 |
+
]
|
| 1363 |
+
)]
|
| 1364 |
+
)
|
| 1365 |
+
if output_html:
|
| 1366 |
+
fig.write_html(output_html)
|
| 1367 |
+
print(f"✅ Saved: {output_html}")
|
| 1368 |
+
return fig
|
| 1369 |
+
|
| 1370 |
+
# Determine which words to include (up to `limit` distinct words)
|
| 1371 |
+
allowed_words = None
|
| 1372 |
+
if isinstance(limit, int) and limit > 0:
|
| 1373 |
+
ordered_unique_words = []
|
| 1374 |
+
for pair in pairs:
|
| 1375 |
+
word = pair[0]
|
| 1376 |
+
if word not in ordered_unique_words:
|
| 1377 |
+
ordered_unique_words.append(word)
|
| 1378 |
+
if len(ordered_unique_words) >= limit:
|
| 1379 |
+
break
|
| 1380 |
+
allowed_words = set(ordered_unique_words)
|
| 1381 |
+
|
| 1382 |
+
for pair in pairs:
|
| 1383 |
+
try:
|
| 1384 |
+
if len(pair) == 4:
|
| 1385 |
+
word, pid, gt_seq, gen_seq = pair
|
| 1386 |
+
else:
|
| 1387 |
+
word, gt_seq, gen_seq = pair
|
| 1388 |
+
pid = "unknown"
|
| 1389 |
+
if allowed_words is not None and word not in allowed_words:
|
| 1390 |
+
continue
|
| 1391 |
+
tokens_gt = _extract_ids_from_sequence(gt_seq)
|
| 1392 |
+
tokens_gen = _extract_ids_from_sequence(gen_seq)
|
| 1393 |
+
params_gt = decode_tokens_to_params(tokens_gt, vq_model, mean, std)
|
| 1394 |
+
params_gen = decode_tokens_to_params(tokens_gen, vq_model, mean, std)
|
| 1395 |
+
verts_gt, faces = params_to_vertices(params_gt, smplx_model)
|
| 1396 |
+
verts_gen, _ = params_to_vertices(params_gen, smplx_model)
|
| 1397 |
+
out_dir = os.path.join(output_dir)
|
| 1398 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 1399 |
+
# Sanitize for filesystem safety
|
| 1400 |
+
safe_word = re.sub(r'[^A-Za-z0-9_-]+', '_', str(word))
|
| 1401 |
+
safe_pid = re.sub(r'[^A-Za-z0-9_-]+', '_', str(pid))
|
| 1402 |
+
output_html = os.path.join(out_dir, f"word_{safe_word}_{safe_pid}_side_by_side.html")
|
| 1403 |
+
animate_side_by_side(
|
| 1404 |
+
verts_left=verts_gt,
|
| 1405 |
+
faces=faces,
|
| 1406 |
+
verts_right=verts_gen,
|
| 1407 |
+
fps=20,
|
| 1408 |
+
titles=("Ground Truth", "LLM Generated"),
|
| 1409 |
+
output_html=output_html
|
| 1410 |
+
)
|
| 1411 |
+
except Exception as exc:
|
| 1412 |
+
print(f"⚠️ Error creating visualization for word '{pair[0]}': {exc}")
|
| 1413 |
+
|
| 1414 |
+
# ======================================================================================
|
| 1415 |
+
# 6. Main Execution Block (UPDATED)
|
| 1416 |
+
# ======================================================================================
|
| 1417 |
+
def main(config_overrides: Optional[Dict[str, Any]] = None):
|
| 1418 |
+
"""Main function to run the entire pipeline."""
|
| 1419 |
+
apply_config_overrides(config_overrides)
|
| 1420 |
+
if config_overrides:
|
| 1421 |
+
printable = {k: v for k, v in config_overrides.items() if "token" not in k.lower()}
|
| 1422 |
+
if printable:
|
| 1423 |
+
print("\nApplied config overrides:")
|
| 1424 |
+
for key, value in printable.items():
|
| 1425 |
+
print(f" - {key} = {value}")
|
| 1426 |
+
random.seed(42)
|
| 1427 |
+
torch.manual_seed(42)
|
| 1428 |
+
|
| 1429 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1430 |
+
print(f"Using device: {device}")
|
| 1431 |
+
|
| 1432 |
+
# 1. Load ALL data
|
| 1433 |
+
all_entries = read_json_data(DATASET_PATH)
|
| 1434 |
+
|
| 1435 |
+
# 2. Clean the ENTIRE dataset and get all tokens
|
| 1436 |
+
cleaned_data, all_motion_tokens = deduplicate_and_prepare_data(all_entries)
|
| 1437 |
+
|
| 1438 |
+
# 3. Stage 1: Initialize or resume from HF, then train
|
| 1439 |
+
resolved_stage1_repo = _resolve_and_ensure_repo(HF_STAGE1_REPO_ID) if HF_USE_HUB else None
|
| 1440 |
+
start_epoch_s1 = 0
|
| 1441 |
+
stage1_loaded = None
|
| 1442 |
+
if resolved_stage1_repo:
|
| 1443 |
+
if _repo_has_stage_latest(resolved_stage1_repo, "stage1"):
|
| 1444 |
+
stage1_loaded = _load_model_and_tokenizer_from_hf(resolved_stage1_repo, "stage1")
|
| 1445 |
+
state_s1 = _download_training_state(resolved_stage1_repo, "stage1")
|
| 1446 |
+
if state_s1 and isinstance(state_s1.get("epoch_completed"), int):
|
| 1447 |
+
start_epoch_s1 = state_s1["epoch_completed"]
|
| 1448 |
+
else:
|
| 1449 |
+
# Fallback: no 'latest' folder; select highest epoch-XXX
|
| 1450 |
+
latest_s1_sub = _repo_get_latest_epoch_subfolder(resolved_stage1_repo, "stage1")
|
| 1451 |
+
if latest_s1_sub:
|
| 1452 |
+
stage1_loaded = _load_model_and_tokenizer_from_hf_subfolder(resolved_stage1_repo, latest_s1_sub)
|
| 1453 |
+
state_s1 = _download_training_state_from_subfolder(resolved_stage1_repo, latest_s1_sub)
|
| 1454 |
+
if state_s1 and isinstance(state_s1.get("epoch_completed"), int):
|
| 1455 |
+
start_epoch_s1 = state_s1["epoch_completed"]
|
| 1456 |
+
|
| 1457 |
+
if stage1_loaded:
|
| 1458 |
+
base_model, tokenizer = stage1_loaded
|
| 1459 |
+
# Ensure tokenizer contains all motion tokens (add missing if dataset expanded)
|
| 1460 |
+
added = _ensure_tokenizer_has_motion_tokens(tokenizer, all_motion_tokens)
|
| 1461 |
+
if added > 0:
|
| 1462 |
+
base_model.resize_token_embeddings(len(tokenizer))
|
| 1463 |
+
else:
|
| 1464 |
+
base_model, tokenizer = setup_model_and_tokenizer(MODEL_NAME, all_motion_tokens)
|
| 1465 |
+
|
| 1466 |
+
print(f"\nStarting Stage 1 training on {len(cleaned_data)} samples (resume from epoch {start_epoch_s1}).")
|
| 1467 |
+
motion_model = train_stage1(
|
| 1468 |
+
base_model,
|
| 1469 |
+
tokenizer,
|
| 1470 |
+
cleaned_data,
|
| 1471 |
+
device,
|
| 1472 |
+
start_epoch=start_epoch_s1,
|
| 1473 |
+
hf_repo_id=resolved_stage1_repo,
|
| 1474 |
+
)
|
| 1475 |
+
|
| 1476 |
+
# 4. Stage 2: Initialize or resume from HF, then train
|
| 1477 |
+
resolved_stage2_repo = _resolve_and_ensure_repo(HF_STAGE2_REPO_ID) if HF_USE_HUB else None
|
| 1478 |
+
start_epoch_s2 = 0
|
| 1479 |
+
stage2_loaded = None
|
| 1480 |
+
print(f"\nStage 2 resume policy: FORCE_STAGE2_FROM_STAGE1={FORCE_STAGE2_FROM_STAGE1}, save_subdir='{HF_STAGE2_SAVE_SUBDIR}'")
|
| 1481 |
+
# For this run we want Stage 2 to start from Stage 1 epoch-20 even if an old stage2 exists.
|
| 1482 |
+
# Only resume Stage 2 if explicitly allowed and if there is a checkpoint under the save subdir.
|
| 1483 |
+
if not FORCE_STAGE2_FROM_STAGE1 and resolved_stage2_repo:
|
| 1484 |
+
# Prefer loading from the configured Stage 2 save subdir (e.g., 'stage2_v2')
|
| 1485 |
+
if _repo_has_stage_latest(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR):
|
| 1486 |
+
stage2_loaded = _load_model_and_tokenizer_from_hf(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR)
|
| 1487 |
+
state_s2 = _download_training_state(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR)
|
| 1488 |
+
if state_s2 and isinstance(state_s2.get("epoch_completed"), int):
|
| 1489 |
+
start_epoch_s2 = state_s2["epoch_completed"]
|
| 1490 |
+
print(f"Resuming Stage 2 from HF subfolder: {HF_STAGE2_SAVE_SUBDIR}/latest (epoch_completed={start_epoch_s2})")
|
| 1491 |
+
else:
|
| 1492 |
+
latest_s2_sub = _repo_get_latest_epoch_subfolder(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR)
|
| 1493 |
+
if latest_s2_sub:
|
| 1494 |
+
stage2_loaded = _load_model_and_tokenizer_from_hf_subfolder(resolved_stage2_repo, latest_s2_sub)
|
| 1495 |
+
state_s2 = _download_training_state_from_subfolder(resolved_stage2_repo, latest_s2_sub)
|
| 1496 |
+
if state_s2 and isinstance(state_s2.get("epoch_completed"), int):
|
| 1497 |
+
start_epoch_s2 = state_s2["epoch_completed"]
|
| 1498 |
+
print(f"Resuming Stage 2 from HF subfolder: {latest_s2_sub} (epoch_completed={start_epoch_s2})")
|
| 1499 |
+
|
| 1500 |
+
if stage2_loaded:
|
| 1501 |
+
stage2_model, tokenizer = stage2_loaded
|
| 1502 |
+
added2 = _ensure_tokenizer_has_motion_tokens(tokenizer, all_motion_tokens)
|
| 1503 |
+
if added2 > 0:
|
| 1504 |
+
stage2_model.resize_token_embeddings(len(tokenizer))
|
| 1505 |
+
else:
|
| 1506 |
+
stage2_model = motion_model # Start Stage 2 from Stage 1 model
|
| 1507 |
+
|
| 1508 |
+
print(f"\nStarting Stage 2 training on {len(cleaned_data)} samples (resume from epoch {start_epoch_s2}).")
|
| 1509 |
+
final_model = train_stage2(
|
| 1510 |
+
stage2_model,
|
| 1511 |
+
tokenizer,
|
| 1512 |
+
cleaned_data,
|
| 1513 |
+
device,
|
| 1514 |
+
start_epoch=start_epoch_s2,
|
| 1515 |
+
hf_repo_id=resolved_stage2_repo,
|
| 1516 |
+
hf_stage_subdir=HF_STAGE2_SAVE_SUBDIR,
|
| 1517 |
+
)
|
| 1518 |
+
|
| 1519 |
+
# 5. Filter the cleaned data to get a smaller set for evaluation
|
| 1520 |
+
# This keeps the evaluation focused on our benchmark words and the logs readable
|
| 1521 |
+
print("\n--- Filtering data for evaluation on specific words ---")
|
| 1522 |
+
evaluation_data = [item for item in cleaned_data if item['word'].lower() in EVALUATION_WORDS]
|
| 1523 |
+
print(f"Found {len(evaluation_data)} samples for evaluation words: {EVALUATION_WORDS}")
|
| 1524 |
+
|
| 1525 |
+
# 6. Metrics-only mode or full flow
|
| 1526 |
+
if RUN_EVALS_ONLY:
|
| 1527 |
+
# Compute the 3 metrics using VQ-VAE encoder features and save to JSON
|
| 1528 |
+
metrics_enc = evaluate_metrics_encoder_style(
|
| 1529 |
+
final_model, tokenizer, evaluation_data, device, sample_limit=EVAL_SAMPLE_LIMIT
|
| 1530 |
+
)
|
| 1531 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 1532 |
+
metrics_payload = {
|
| 1533 |
+
"source": "vqvae_encoder",
|
| 1534 |
+
"fid": metrics_enc.get("fid"),
|
| 1535 |
+
"diversity": {
|
| 1536 |
+
"ground_truth": metrics_enc.get("diversity_gt"),
|
| 1537 |
+
"model": metrics_enc.get("diversity_gen"),
|
| 1538 |
+
},
|
| 1539 |
+
"multimodality": {
|
| 1540 |
+
"ground_truth": metrics_enc.get("mim_gt"),
|
| 1541 |
+
"model": metrics_enc.get("mim_gen"),
|
| 1542 |
+
},
|
| 1543 |
+
"num_pairs": len(metrics_enc.get("pairs", [])),
|
| 1544 |
+
}
|
| 1545 |
+
with open(METRICS_JSON_PATH, "w", encoding="utf-8") as f:
|
| 1546 |
+
json.dump(metrics_payload, f, ensure_ascii=False, indent=2)
|
| 1547 |
+
print(f"\n✅ Saved metrics to {METRICS_JSON_PATH}")
|
| 1548 |
+
return
|
| 1549 |
+
|
| 1550 |
+
# Full flow: inference logs + MotionGPT-style metrics + encoder metrics + visualizations
|
| 1551 |
+
run_inference_on_all_samples(final_model, tokenizer, evaluation_data, device)
|
| 1552 |
+
metrics_token = evaluate_metrics_motiongpt_style(final_model, tokenizer, evaluation_data, all_motion_tokens, device)
|
| 1553 |
+
# Also compute encoder-based 3 metrics
|
| 1554 |
+
metrics_enc = evaluate_metrics_encoder_style(
|
| 1555 |
+
final_model, tokenizer, evaluation_data, device, sample_limit=EVAL_SAMPLE_LIMIT
|
| 1556 |
+
)
|
| 1557 |
+
# Visualizations (skip if metrics-only)
|
| 1558 |
+
viz_dir = os.path.join(OUTPUT_DIR, "html_visualizations")
|
| 1559 |
+
save_side_by_side_visualizations(metrics_token["pairs"], viz_dir, limit=4)
|
| 1560 |
+
|
| 1561 |
+
if __name__ == "__main__":
|
| 1562 |
+
main()
|
train.py
ADDED
|
@@ -0,0 +1,744 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training utilities and functions
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import time
|
| 8 |
+
import json
|
| 9 |
+
import shutil
|
| 10 |
+
import torch
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from typing import Optional, Dict, Any, List, Tuple
|
| 13 |
+
from torch.optim import AdamW
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
|
| 16 |
+
from transformers import TrainingArguments, Trainer, AutoModelForCausalLM, AutoTokenizer
|
| 17 |
+
from transformers.trainer_callback import TrainerCallback
|
| 18 |
+
from huggingface_hub import HfApi, upload_folder, snapshot_download, hf_hub_download
|
| 19 |
+
|
| 20 |
+
from config import (
|
| 21 |
+
BATCH_TRAIN, BATCH_EVAL, GRAD_ACCUM, LR, WARMUP,
|
| 22 |
+
LOG_STEPS, EVAL_STEPS, SAVE_STEPS, SEED, DTYPE,
|
| 23 |
+
HUB_REPO_S1, HUB_REPO_S2, HUB_REPO_S3, HF_TOKEN,
|
| 24 |
+
CHECKPOINTS_DIR, HF_USE_HUB, CHECKPOINT_UPLOAD_INTERVAL_EPOCHS,
|
| 25 |
+
S1_BATCH_SIZE, S1_LR, S1_EPOCHS, S2_BATCH_SIZE, S2_LR, S2_EPOCHS,
|
| 26 |
+
PAD_TOKEN, M_START, M_END
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# ======================================================================================
|
| 30 |
+
# Logic from test_overfit.py (Raw Training Loops & HF Utils)
|
| 31 |
+
# ======================================================================================
|
| 32 |
+
|
| 33 |
+
def _format_seconds(seconds: float) -> str:
|
| 34 |
+
"""Formats seconds into H:MM:SS or M:SS."""
|
| 35 |
+
seconds = int(max(0, seconds))
|
| 36 |
+
h = seconds // 3600
|
| 37 |
+
m = (seconds % 3600) // 60
|
| 38 |
+
s = seconds % 60
|
| 39 |
+
if h > 0:
|
| 40 |
+
return f"{h:d}:{m:02d}:{s:02d}"
|
| 41 |
+
return f"{m:d}:{s:02d}"
|
| 42 |
+
|
| 43 |
+
def _ensure_dir(path: str) -> None:
|
| 44 |
+
os.makedirs(path, exist_ok=True)
|
| 45 |
+
|
| 46 |
+
def resolve_and_ensure_repo(repo_id: str, hf_auth_token: Optional[str] = None) -> Optional[str]:
|
| 47 |
+
"""
|
| 48 |
+
Ensures the HF repo exists. Returns the fully-qualified repo_id (namespace/repo)
|
| 49 |
+
when token is available; otherwise returns the input repo_id.
|
| 50 |
+
"""
|
| 51 |
+
if not HF_USE_HUB:
|
| 52 |
+
return None
|
| 53 |
+
token = hf_auth_token or HF_TOKEN
|
| 54 |
+
if not token:
|
| 55 |
+
print("⚠️ HF token not found. Set HUGGINGFACE_HUB_TOKEN to enable Hub sync.")
|
| 56 |
+
return None
|
| 57 |
+
api = HfApi()
|
| 58 |
+
try:
|
| 59 |
+
who = api.whoami(token=token)
|
| 60 |
+
namespace = who.get("name") or (who.get("orgs", [None])[0] if isinstance(who.get("orgs"), list) else None)
|
| 61 |
+
except Exception as exc:
|
| 62 |
+
print(f"⚠️ Unable to resolve HF namespace: {exc}")
|
| 63 |
+
namespace = None
|
| 64 |
+
if "/" not in repo_id and namespace:
|
| 65 |
+
full_repo_id = f"{namespace}/{repo_id}"
|
| 66 |
+
else:
|
| 67 |
+
full_repo_id = repo_id
|
| 68 |
+
try:
|
| 69 |
+
api.create_repo(
|
| 70 |
+
repo_id=full_repo_id,
|
| 71 |
+
token=token,
|
| 72 |
+
repo_type="model",
|
| 73 |
+
private=True, # Default to private as in test_overfit config if not specified
|
| 74 |
+
exist_ok=True,
|
| 75 |
+
)
|
| 76 |
+
except Exception as exc:
|
| 77 |
+
print(f"⚠️ create_repo failed (may already exist): {exc}")
|
| 78 |
+
return full_repo_id
|
| 79 |
+
|
| 80 |
+
def repo_has_stage_latest(repo_id: str, stage: str, hf_auth_token: Optional[str] = None) -> bool:
|
| 81 |
+
"""Checks if a stage/latest checkpoint exists in the HF repo."""
|
| 82 |
+
token = hf_auth_token or HF_TOKEN
|
| 83 |
+
if not HF_USE_HUB or not token:
|
| 84 |
+
return False
|
| 85 |
+
api = HfApi()
|
| 86 |
+
try:
|
| 87 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token)
|
| 88 |
+
return any(path.startswith(f"{stage}/latest/") and path.endswith("config.json") for path in files)
|
| 89 |
+
except Exception as exc:
|
| 90 |
+
print(f"⚠️ Could not list files for {repo_id}: {exc}")
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
def repo_list_epoch_numbers(repo_id: str, stage: str, hf_auth_token: Optional[str] = None) -> List[int]:
|
| 94 |
+
"""
|
| 95 |
+
Returns sorted list of epoch numbers available under {stage}/epoch-XXX/ by scanning files.
|
| 96 |
+
"""
|
| 97 |
+
token = hf_auth_token or HF_TOKEN
|
| 98 |
+
if not HF_USE_HUB or not token:
|
| 99 |
+
return []
|
| 100 |
+
api = HfApi()
|
| 101 |
+
try:
|
| 102 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token)
|
| 103 |
+
except Exception as exc:
|
| 104 |
+
print(f"⚠️ Could not list files for {repo_id}: {exc}")
|
| 105 |
+
return []
|
| 106 |
+
epoch_numbers: List[int] = []
|
| 107 |
+
pattern = re.compile(rf"^{re.escape(stage)}/epoch-(\d+)/config\.json$")
|
| 108 |
+
for path in files:
|
| 109 |
+
m = pattern.match(path)
|
| 110 |
+
if m:
|
| 111 |
+
try:
|
| 112 |
+
epoch_numbers.append(int(m.group(1)))
|
| 113 |
+
except ValueError:
|
| 114 |
+
pass
|
| 115 |
+
return sorted(set(epoch_numbers))
|
| 116 |
+
|
| 117 |
+
def repo_get_latest_epoch_subfolder(repo_id: str, stage: str, hf_auth_token: Optional[str] = None) -> Optional[str]:
|
| 118 |
+
"""
|
| 119 |
+
Returns subfolder path like '{stage}/epoch-XXX' for the highest available epoch, or None.
|
| 120 |
+
"""
|
| 121 |
+
epochs = repo_list_epoch_numbers(repo_id, stage, hf_auth_token)
|
| 122 |
+
if not epochs:
|
| 123 |
+
return None
|
| 124 |
+
latest = max(epochs)
|
| 125 |
+
return f"{stage}/epoch-{latest:03d}"
|
| 126 |
+
|
| 127 |
+
def load_model_and_tokenizer_from_hf_subfolder(repo_id: str, subfolder: str, hf_auth_token: Optional[str] = None) -> Optional[Tuple[AutoModelForCausalLM, AutoTokenizer]]:
|
| 128 |
+
"""
|
| 129 |
+
Loads model and tokenizer from HF under a specific subfolder.
|
| 130 |
+
"""
|
| 131 |
+
if not HF_USE_HUB or (not hf_auth_token and not HF_TOKEN):
|
| 132 |
+
return None
|
| 133 |
+
print(f"\n---> Loading checkpoint from Hugging Face: {repo_id} (subfolder='{subfolder}')")
|
| 134 |
+
try:
|
| 135 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder, trust_remote_code=True)
|
| 136 |
+
model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=subfolder, trust_remote_code=True)
|
| 137 |
+
except Exception as exc:
|
| 138 |
+
print(f"⚠️ Failed to load model/tokenizer from subfolder '{subfolder}': {exc}")
|
| 139 |
+
return None
|
| 140 |
+
if tokenizer.pad_token is None:
|
| 141 |
+
tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
|
| 142 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 143 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 144 |
+
return model, tokenizer
|
| 145 |
+
|
| 146 |
+
def download_training_state_from_subfolder(repo_id: str, subfolder: str, hf_auth_token: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
| 147 |
+
"""
|
| 148 |
+
Downloads training_state.json from a specific subfolder.
|
| 149 |
+
"""
|
| 150 |
+
token = hf_auth_token or HF_TOKEN
|
| 151 |
+
if not HF_USE_HUB or not token:
|
| 152 |
+
return None
|
| 153 |
+
try:
|
| 154 |
+
state_path = hf_hub_download(
|
| 155 |
+
repo_id=repo_id,
|
| 156 |
+
filename=f"{subfolder}/training_state.json",
|
| 157 |
+
repo_type="model",
|
| 158 |
+
token=token,
|
| 159 |
+
)
|
| 160 |
+
with open(state_path, "r", encoding="utf-8") as f:
|
| 161 |
+
return json.load(f)
|
| 162 |
+
except Exception:
|
| 163 |
+
return None
|
| 164 |
+
|
| 165 |
+
def download_training_state(repo_id: str, stage: str, hf_auth_token: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
| 166 |
+
"""Downloads training_state.json from HF if present."""
|
| 167 |
+
token = hf_auth_token or HF_TOKEN
|
| 168 |
+
if not HF_USE_HUB or not token:
|
| 169 |
+
return None
|
| 170 |
+
try:
|
| 171 |
+
state_path = hf_hub_download(
|
| 172 |
+
repo_id=repo_id,
|
| 173 |
+
filename=f"{stage}/latest/training_state.json",
|
| 174 |
+
repo_type="model",
|
| 175 |
+
token=token,
|
| 176 |
+
)
|
| 177 |
+
with open(state_path, "r", encoding="utf-8") as f:
|
| 178 |
+
return json.load(f)
|
| 179 |
+
except Exception:
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
def download_optimizer_state(repo_id: str, stage: str, hf_auth_token: Optional[str] = None) -> Optional[str]:
|
| 183 |
+
"""Downloads optimizer.pt for resuming optimizer state."""
|
| 184 |
+
token = hf_auth_token or HF_TOKEN
|
| 185 |
+
if not HF_USE_HUB or not token:
|
| 186 |
+
return None
|
| 187 |
+
try:
|
| 188 |
+
opt_path = hf_hub_download(
|
| 189 |
+
repo_id=repo_id,
|
| 190 |
+
filename=f"{stage}/latest/optimizer.pt",
|
| 191 |
+
repo_type="model",
|
| 192 |
+
token=token,
|
| 193 |
+
)
|
| 194 |
+
return opt_path
|
| 195 |
+
except Exception:
|
| 196 |
+
return None
|
| 197 |
+
|
| 198 |
+
def load_model_and_tokenizer_from_hf(repo_id: str, stage: str, hf_auth_token: Optional[str] = None) -> Optional[Tuple[AutoModelForCausalLM, AutoTokenizer]]:
|
| 199 |
+
"""
|
| 200 |
+
Loads model and tokenizer from HF under subfolder {stage}/latest if available.
|
| 201 |
+
"""
|
| 202 |
+
if not repo_has_stage_latest(repo_id, stage, hf_auth_token):
|
| 203 |
+
return None
|
| 204 |
+
print(f"\n---> Loading checkpoint from Hugging Face: {repo_id} (subfolder='{stage}/latest')")
|
| 205 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=f"{stage}/latest", trust_remote_code=True)
|
| 206 |
+
model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=f"{stage}/latest", trust_remote_code=True)
|
| 207 |
+
if tokenizer.pad_token is None:
|
| 208 |
+
tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
|
| 209 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 210 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 211 |
+
return model, tokenizer
|
| 212 |
+
|
| 213 |
+
def save_and_push_checkpoint(
|
| 214 |
+
stage: str,
|
| 215 |
+
epoch_index_zero_based: int,
|
| 216 |
+
model: AutoModelForCausalLM,
|
| 217 |
+
tokenizer: AutoTokenizer,
|
| 218 |
+
optimizer: AdamW,
|
| 219 |
+
avg_loss: float,
|
| 220 |
+
dataloader_len: int,
|
| 221 |
+
batch_size: int,
|
| 222 |
+
total_epochs: int,
|
| 223 |
+
repo_id: Optional[str],
|
| 224 |
+
hf_auth_token: Optional[str] = None
|
| 225 |
+
) -> None:
|
| 226 |
+
"""
|
| 227 |
+
Saves checkpoint locally and pushes to HF.
|
| 228 |
+
"""
|
| 229 |
+
token = hf_auth_token or HF_TOKEN
|
| 230 |
+
epoch_number = epoch_index_zero_based + 1
|
| 231 |
+
stage_dir = os.path.join(CHECKPOINTS_DIR, stage)
|
| 232 |
+
epoch_dir_name = f"epoch-{epoch_number:03d}"
|
| 233 |
+
epoch_dir = os.path.join(stage_dir, epoch_dir_name)
|
| 234 |
+
latest_dir = os.path.join(stage_dir, "latest")
|
| 235 |
+
_ensure_dir(epoch_dir)
|
| 236 |
+
_ensure_dir(stage_dir)
|
| 237 |
+
|
| 238 |
+
# Save model + tokenizer
|
| 239 |
+
model.save_pretrained(epoch_dir)
|
| 240 |
+
tokenizer.save_pretrained(epoch_dir)
|
| 241 |
+
|
| 242 |
+
# Save optimizer state
|
| 243 |
+
torch.save(optimizer.state_dict(), os.path.join(epoch_dir, "optimizer.pt"))
|
| 244 |
+
|
| 245 |
+
# Save training state
|
| 246 |
+
training_state = {
|
| 247 |
+
"stage": stage,
|
| 248 |
+
"epoch_completed": epoch_number,
|
| 249 |
+
"total_epochs_for_stage": total_epochs,
|
| 250 |
+
"global_step": epoch_number * dataloader_len,
|
| 251 |
+
"avg_loss": float(avg_loss),
|
| 252 |
+
"batch_size": batch_size,
|
| 253 |
+
"saved_at": datetime.utcnow().isoformat() + "Z",
|
| 254 |
+
}
|
| 255 |
+
with open(os.path.join(epoch_dir, "training_state.json"), "w", encoding="utf-8") as f:
|
| 256 |
+
json.dump(training_state, f, ensure_ascii=False, indent=2)
|
| 257 |
+
|
| 258 |
+
# Update "latest"
|
| 259 |
+
if os.path.exists(latest_dir):
|
| 260 |
+
shutil.rmtree(latest_dir)
|
| 261 |
+
shutil.copytree(epoch_dir, latest_dir)
|
| 262 |
+
|
| 263 |
+
# Push to Hugging Face
|
| 264 |
+
if HF_USE_HUB and repo_id and token:
|
| 265 |
+
try:
|
| 266 |
+
upload_folder(
|
| 267 |
+
repo_id=repo_id,
|
| 268 |
+
folder_path=epoch_dir,
|
| 269 |
+
path_in_repo=f"{stage}/{epoch_dir_name}",
|
| 270 |
+
repo_type="model",
|
| 271 |
+
token=token,
|
| 272 |
+
commit_message=f"{stage}: save {epoch_dir_name}",
|
| 273 |
+
)
|
| 274 |
+
upload_folder(
|
| 275 |
+
repo_id=repo_id,
|
| 276 |
+
folder_path=latest_dir,
|
| 277 |
+
path_in_repo=f"{stage}/latest",
|
| 278 |
+
repo_type="model",
|
| 279 |
+
token=token,
|
| 280 |
+
commit_message=f"{stage}: update latest -> {epoch_dir_name}",
|
| 281 |
+
)
|
| 282 |
+
print(f"☁️ Pushed checkpoint to HF: {repo_id} ({stage}/{epoch_dir_name} and {stage}/latest)")
|
| 283 |
+
except Exception as exc:
|
| 284 |
+
print(f"⚠️ Failed to push checkpoint to HF: {exc}")
|
| 285 |
+
else:
|
| 286 |
+
print("ℹ️ Skipped HF push (Hub disabled or token/repo missing).")
|
| 287 |
+
|
| 288 |
+
def train_stage1_raw(
|
| 289 |
+
model,
|
| 290 |
+
tokenizer,
|
| 291 |
+
data: List[Dict[str, Any]],
|
| 292 |
+
device,
|
| 293 |
+
start_epoch: int = 0,
|
| 294 |
+
hf_repo_id: Optional[str] = None,
|
| 295 |
+
):
|
| 296 |
+
"""Trains the model on motion sequences only to learn the 'language of motion'."""
|
| 297 |
+
from data import MotionDataset # Import here to avoid circular imports
|
| 298 |
+
|
| 299 |
+
print("\n" + "="*80)
|
| 300 |
+
print(" STAGE 1: MOTION LANGUAGE MODELING (PRE-TRAINING)")
|
| 301 |
+
print(f" Training on {len(data)} samples.")
|
| 302 |
+
print("="*80)
|
| 303 |
+
|
| 304 |
+
dataset = MotionDataset(data, tokenizer)
|
| 305 |
+
dataloader = DataLoader(dataset, batch_size=S1_BATCH_SIZE, shuffle=True)
|
| 306 |
+
|
| 307 |
+
optimizer = AdamW(model.parameters(), lr=S1_LR)
|
| 308 |
+
model.to(device)
|
| 309 |
+
model.train()
|
| 310 |
+
|
| 311 |
+
# Try to resume optimizer if we resumed from HF
|
| 312 |
+
token = HF_TOKEN
|
| 313 |
+
if hf_repo_id and start_epoch > 0 and HF_USE_HUB and token:
|
| 314 |
+
opt_path = download_optimizer_state(hf_repo_id, "stage1", token)
|
| 315 |
+
if opt_path is not None:
|
| 316 |
+
try:
|
| 317 |
+
optimizer.load_state_dict(torch.load(opt_path, map_location=device))
|
| 318 |
+
print("↩️ Resumed optimizer state for Stage 1 from HF.")
|
| 319 |
+
except Exception as exc:
|
| 320 |
+
print(f"⚠️ Failed to load optimizer state for Stage 1: {exc}")
|
| 321 |
+
|
| 322 |
+
for epoch in range(start_epoch, S1_EPOCHS):
|
| 323 |
+
total_loss = 0
|
| 324 |
+
total_batches = len(dataloader)
|
| 325 |
+
epoch_start_time = time.time()
|
| 326 |
+
step_interval = max(1, total_batches // 50) # ~2% progress updates
|
| 327 |
+
for i, batch in enumerate(dataloader, 1):
|
| 328 |
+
optimizer.zero_grad()
|
| 329 |
+
|
| 330 |
+
input_ids = batch['input_ids'].squeeze(1).to(device)
|
| 331 |
+
attention_mask = batch['attention_mask'].squeeze(1).to(device)
|
| 332 |
+
|
| 333 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
| 334 |
+
|
| 335 |
+
loss = outputs.loss
|
| 336 |
+
loss.backward()
|
| 337 |
+
optimizer.step()
|
| 338 |
+
total_loss += loss.item()
|
| 339 |
+
|
| 340 |
+
# Progress with ETA
|
| 341 |
+
if i == 1 or (i % step_interval == 0) or (i == total_batches):
|
| 342 |
+
elapsed = time.time() - epoch_start_time
|
| 343 |
+
est_total = (elapsed / i) * total_batches
|
| 344 |
+
eta = est_total - elapsed
|
| 345 |
+
pct = (i / total_batches) * 100.0
|
| 346 |
+
print(
|
| 347 |
+
f"\r[Stage 1] Epoch {epoch+1}/{S1_EPOCHS} - "
|
| 348 |
+
f"{i}/{total_batches} ({pct:.1f}%) - ETA {_format_seconds(eta)}",
|
| 349 |
+
end="",
|
| 350 |
+
flush=True,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Finish the progress line
|
| 354 |
+
print()
|
| 355 |
+
avg_loss = total_loss / len(dataloader)
|
| 356 |
+
print(f"--- End of Epoch {epoch+1}/{S1_EPOCHS}, Average Loss: {avg_loss:.4f} ---")
|
| 357 |
+
# Save checkpoint locally every epoch; push to HF only at interval or final epoch
|
| 358 |
+
push_this_epoch = ((epoch + 1) % CHECKPOINT_UPLOAD_INTERVAL_EPOCHS == 0) or ((epoch + 1) == S1_EPOCHS)
|
| 359 |
+
repo_for_epoch = hf_repo_id if push_this_epoch else None
|
| 360 |
+
save_and_push_checkpoint(
|
| 361 |
+
stage="stage1",
|
| 362 |
+
epoch_index_zero_based=epoch,
|
| 363 |
+
model=model,
|
| 364 |
+
tokenizer=tokenizer,
|
| 365 |
+
optimizer=optimizer,
|
| 366 |
+
avg_loss=avg_loss,
|
| 367 |
+
dataloader_len=len(dataloader),
|
| 368 |
+
batch_size=S1_BATCH_SIZE,
|
| 369 |
+
total_epochs=S1_EPOCHS,
|
| 370 |
+
repo_id=repo_for_epoch,
|
| 371 |
+
hf_auth_token=token
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
print("\n✅ Stage 1 Training Complete.")
|
| 375 |
+
return model
|
| 376 |
+
|
| 377 |
+
def train_stage2_raw(
|
| 378 |
+
model,
|
| 379 |
+
tokenizer,
|
| 380 |
+
data: List[Dict[str, Any]],
|
| 381 |
+
device,
|
| 382 |
+
start_epoch: int = 0,
|
| 383 |
+
hf_repo_id: Optional[str] = None,
|
| 384 |
+
hf_stage_subdir: str = "stage2",
|
| 385 |
+
):
|
| 386 |
+
"""Fine-tunes the motion-aware model to connect text prompts to motions."""
|
| 387 |
+
from data import TextMotionDataset # Import here to avoid circular imports
|
| 388 |
+
|
| 389 |
+
print("\n" + "="*80)
|
| 390 |
+
print(" STAGE 2: TEXT-TO-MOTION FINE-TUNING")
|
| 391 |
+
print(f" Training on {len(data)} samples.")
|
| 392 |
+
print("="*80)
|
| 393 |
+
|
| 394 |
+
dataset = TextMotionDataset(data, tokenizer)
|
| 395 |
+
dataloader = DataLoader(dataset, batch_size=S2_BATCH_SIZE, shuffle=True)
|
| 396 |
+
|
| 397 |
+
optimizer = AdamW(model.parameters(), lr=S2_LR)
|
| 398 |
+
model.to(device)
|
| 399 |
+
model.train()
|
| 400 |
+
|
| 401 |
+
# Try to resume optimizer if we resumed from HF
|
| 402 |
+
token = HF_TOKEN
|
| 403 |
+
if hf_repo_id and start_epoch > 0 and HF_USE_HUB and token:
|
| 404 |
+
opt_path = download_optimizer_state(hf_repo_id, hf_stage_subdir, token)
|
| 405 |
+
if opt_path is not None:
|
| 406 |
+
try:
|
| 407 |
+
optimizer.load_state_dict(torch.load(opt_path, map_location=device))
|
| 408 |
+
print("↩️ Resumed optimizer state for Stage 2 from HF.")
|
| 409 |
+
except Exception as exc:
|
| 410 |
+
print(f"⚠️ Failed to load optimizer state for Stage 2: {exc}")
|
| 411 |
+
|
| 412 |
+
for epoch in range(start_epoch, S2_EPOCHS):
|
| 413 |
+
total_loss = 0
|
| 414 |
+
total_batches = len(dataloader)
|
| 415 |
+
epoch_start_time = time.time()
|
| 416 |
+
step_interval = max(1, total_batches // 50) # ~2% progress updates
|
| 417 |
+
for i, batch in enumerate(dataloader, 1):
|
| 418 |
+
optimizer.zero_grad()
|
| 419 |
+
|
| 420 |
+
input_ids = batch['input_ids'].to(device)
|
| 421 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 422 |
+
labels = batch['labels'].to(device)
|
| 423 |
+
|
| 424 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 425 |
+
|
| 426 |
+
loss = outputs.loss
|
| 427 |
+
loss.backward()
|
| 428 |
+
optimizer.step()
|
| 429 |
+
total_loss += loss.item()
|
| 430 |
+
|
| 431 |
+
# Progress with ETA
|
| 432 |
+
if i == 1 or (i % step_interval == 0) or (i == total_batches):
|
| 433 |
+
elapsed = time.time() - epoch_start_time
|
| 434 |
+
est_total = (elapsed / i) * total_batches
|
| 435 |
+
eta = est_total - elapsed
|
| 436 |
+
pct = (i / total_batches) * 100.0
|
| 437 |
+
print(
|
| 438 |
+
f"\r[Stage 2] Epoch {epoch+1}/{S2_EPOCHS} - "
|
| 439 |
+
f"{i}/{total_batches} ({pct:.1f}%) - ETA {_format_seconds(eta)}",
|
| 440 |
+
end="",
|
| 441 |
+
flush=True,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Finish the progress line
|
| 445 |
+
print()
|
| 446 |
+
avg_loss = total_loss / len(dataloader)
|
| 447 |
+
print(f"--- End of Epoch {epoch+1}/{S2_EPOCHS}, Average Loss: {avg_loss:.4f} ---")
|
| 448 |
+
# Save checkpoint locally every epoch; push to HF only at interval or final epoch
|
| 449 |
+
push_this_epoch = ((epoch + 1) % CHECKPOINT_UPLOAD_INTERVAL_EPOCHS == 0) or ((epoch + 1) == S2_EPOCHS)
|
| 450 |
+
repo_for_epoch = hf_repo_id if push_this_epoch else None
|
| 451 |
+
save_and_push_checkpoint(
|
| 452 |
+
stage=hf_stage_subdir,
|
| 453 |
+
epoch_index_zero_based=epoch,
|
| 454 |
+
model=model,
|
| 455 |
+
tokenizer=tokenizer,
|
| 456 |
+
optimizer=optimizer,
|
| 457 |
+
avg_loss=avg_loss,
|
| 458 |
+
dataloader_len=len(dataloader),
|
| 459 |
+
batch_size=S2_BATCH_SIZE,
|
| 460 |
+
total_epochs=S2_EPOCHS,
|
| 461 |
+
repo_id=repo_for_epoch,
|
| 462 |
+
hf_auth_token=token
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
print("\n✅ Stage 2 Training Complete.")
|
| 466 |
+
return model
|
| 467 |
+
|
| 468 |
+
# ======================================================================================
|
| 469 |
+
# Existing Utilities
|
| 470 |
+
# ======================================================================================
|
| 471 |
+
|
| 472 |
+
def make_training_args(out_dir: str, epochs: int, two_point_hub: bool = False) -> TrainingArguments:
|
| 473 |
+
"""
|
| 474 |
+
Create TrainingArguments for a training stage
|
| 475 |
+
"""
|
| 476 |
+
return TrainingArguments(
|
| 477 |
+
output_dir=out_dir,
|
| 478 |
+
per_device_train_batch_size=BATCH_TRAIN,
|
| 479 |
+
per_device_eval_batch_size=BATCH_EVAL,
|
| 480 |
+
gradient_accumulation_steps=GRAD_ACCUM,
|
| 481 |
+
learning_rate=LR,
|
| 482 |
+
num_train_epochs=epochs,
|
| 483 |
+
logging_steps=LOG_STEPS,
|
| 484 |
+
eval_strategy="steps",
|
| 485 |
+
eval_steps=EVAL_STEPS,
|
| 486 |
+
# When using two-point hub checkpointing, disable periodic local saves and rely on forced saves
|
| 487 |
+
save_steps=(10**12 if two_point_hub else SAVE_STEPS),
|
| 488 |
+
save_total_limit=2,
|
| 489 |
+
warmup_ratio=WARMUP,
|
| 490 |
+
bf16=(DTYPE == torch.bfloat16),
|
| 491 |
+
fp16=(DTYPE == torch.float16),
|
| 492 |
+
lr_scheduler_type="cosine",
|
| 493 |
+
optim="adamw_torch",
|
| 494 |
+
report_to="none",
|
| 495 |
+
seed=SEED,
|
| 496 |
+
remove_unused_columns=False,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def latest_hub_checkpoint(repo_id: str) -> Optional[str]:
|
| 501 |
+
"""
|
| 502 |
+
Download and return the local path to the latest checkpoint folder from the Hub.
|
| 503 |
+
Returns None if no checkpoint exists or on failure.
|
| 504 |
+
"""
|
| 505 |
+
api = HfApi()
|
| 506 |
+
try:
|
| 507 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="model")
|
| 508 |
+
except Exception as e:
|
| 509 |
+
print(f"Hub list failed for {repo_id}: {e}")
|
| 510 |
+
return None
|
| 511 |
+
|
| 512 |
+
def _step_key(dirname: str) -> int:
|
| 513 |
+
nums = re.findall(r"\d+", dirname)
|
| 514 |
+
return int(nums[-1]) if nums else -1
|
| 515 |
+
|
| 516 |
+
ckpt_dirs = sorted(
|
| 517 |
+
{p.split('/')[0] for p in files if p.startswith("checkpoint-")},
|
| 518 |
+
key=_step_key,
|
| 519 |
+
)
|
| 520 |
+
if not ckpt_dirs:
|
| 521 |
+
return None
|
| 522 |
+
latest = ckpt_dirs[-1]
|
| 523 |
+
local_root = snapshot_download(
|
| 524 |
+
repo_id=repo_id,
|
| 525 |
+
repo_type="model",
|
| 526 |
+
allow_patterns=[f"{latest}/**", "trainer_state.json"],
|
| 527 |
+
local_dir_use_symlinks=False,
|
| 528 |
+
)
|
| 529 |
+
return os.path.join(local_root, latest)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class TwoPointHubCheckpointCallback(TrainerCallback):
|
| 533 |
+
"""
|
| 534 |
+
Save to Hugging Face Hub exactly twice per training: halfway and at final step.
|
| 535 |
+
Keeps only the most recent N checkpoints on Hub.
|
| 536 |
+
"""
|
| 537 |
+
|
| 538 |
+
def __init__(self, repo_id: str, keep_last: int = 2, token: Optional[str] = None):
|
| 539 |
+
self.repo_id = repo_id
|
| 540 |
+
self.keep_last = keep_last
|
| 541 |
+
self.api = HfApi()
|
| 542 |
+
self.token = token or os.environ.get("HUGGINGFACE_HUB_TOKEN")
|
| 543 |
+
self._half_step: Optional[int] = None
|
| 544 |
+
self._final_step: Optional[int] = None
|
| 545 |
+
self._saved_steps = set()
|
| 546 |
+
self._pending_push_for_step: Optional[int] = None
|
| 547 |
+
try:
|
| 548 |
+
self.api.create_repo(repo_id=self.repo_id, private=True, exist_ok=True, token=self.token)
|
| 549 |
+
except Exception as e:
|
| 550 |
+
print(f"Could not ensure repo exists: {e}")
|
| 551 |
+
|
| 552 |
+
def _enforce_keep_last(self) -> None:
|
| 553 |
+
try:
|
| 554 |
+
files = self.api.list_repo_files(repo_id=self.repo_id, repo_type="model", token=self.token)
|
| 555 |
+
|
| 556 |
+
def _step_key(dirname: str) -> int:
|
| 557 |
+
nums = re.findall(r"\d+", dirname)
|
| 558 |
+
return int(nums[-1]) if nums else -1
|
| 559 |
+
|
| 560 |
+
ckpt_dirs = sorted(
|
| 561 |
+
{p.split('/')[0] for p in files if p.startswith("checkpoint-")},
|
| 562 |
+
key=_step_key,
|
| 563 |
+
)
|
| 564 |
+
if len(ckpt_dirs) <= self.keep_last:
|
| 565 |
+
return
|
| 566 |
+
to_delete = ckpt_dirs[:-self.keep_last]
|
| 567 |
+
for d in to_delete:
|
| 568 |
+
for f in [p for p in files if p.startswith(f"{d}/")]:
|
| 569 |
+
try:
|
| 570 |
+
self.api.delete_file(path=f, repo_id=self.repo_id, repo_type="model", token=self.token)
|
| 571 |
+
except Exception as e:
|
| 572 |
+
print(f"Failed deleting {f}: {e}")
|
| 573 |
+
except Exception as e:
|
| 574 |
+
print(f"Keep-last enforcement failed: {e}")
|
| 575 |
+
|
| 576 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 577 |
+
# Prefer Trainer-computed max_steps
|
| 578 |
+
if state.max_steps and state.max_steps > 0:
|
| 579 |
+
self._half_step = max(1, state.max_steps // 2)
|
| 580 |
+
self._final_step = state.max_steps
|
| 581 |
+
print(f"Two-point checkpointing: half={self._half_step}, final={self._final_step}")
|
| 582 |
+
else:
|
| 583 |
+
# Best-effort fallback using dataloader length and grad accumulation if available
|
| 584 |
+
td = kwargs.get("train_dataloader")
|
| 585 |
+
if td is not None and args.gradient_accumulation_steps > 0:
|
| 586 |
+
steps_per_epoch = math.ceil(len(td) / args.gradient_accumulation_steps)
|
| 587 |
+
self._final_step = steps_per_epoch * int(args.num_train_epochs)
|
| 588 |
+
self._half_step = max(1, self._final_step // 2)
|
| 589 |
+
print(f"Two-point checkpointing (approx): half={self._half_step}, final={self._final_step}")
|
| 590 |
+
|
| 591 |
+
def on_step_end(self, args, state, control, **kwargs):
|
| 592 |
+
if not self._final_step:
|
| 593 |
+
return control
|
| 594 |
+
gs = state.global_step
|
| 595 |
+
if gs == self._half_step and gs not in self._saved_steps:
|
| 596 |
+
control.should_save = True
|
| 597 |
+
self._pending_push_for_step = gs
|
| 598 |
+
if gs == self._final_step and gs not in self._saved_steps:
|
| 599 |
+
control.should_save = True
|
| 600 |
+
self._pending_push_for_step = gs
|
| 601 |
+
return control
|
| 602 |
+
|
| 603 |
+
def on_save(self, args, state, control, **kwargs):
|
| 604 |
+
# Push only when we triggered this save
|
| 605 |
+
if self._pending_push_for_step is None:
|
| 606 |
+
return control
|
| 607 |
+
step = self._pending_push_for_step
|
| 608 |
+
self._pending_push_for_step = None
|
| 609 |
+
self._saved_steps.add(step)
|
| 610 |
+
|
| 611 |
+
ckpt_dirname = f"checkpoint-{step}"
|
| 612 |
+
try:
|
| 613 |
+
upload_folder(
|
| 614 |
+
repo_id=self.repo_id,
|
| 615 |
+
folder_path=args.output_dir,
|
| 616 |
+
repo_type="model",
|
| 617 |
+
token=self.token,
|
| 618 |
+
commit_message=f"upload {ckpt_dirname}",
|
| 619 |
+
allow_patterns=[f"{ckpt_dirname}/**", "trainer_state.json"],
|
| 620 |
+
)
|
| 621 |
+
self._enforce_keep_last()
|
| 622 |
+
print(f"Pushed {ckpt_dirname} to {self.repo_id}")
|
| 623 |
+
except Exception as e:
|
| 624 |
+
print(f"Hub upload failed for {ckpt_dirname}: {e}")
|
| 625 |
+
return control
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def train_stage(
|
| 629 |
+
stage_name: str,
|
| 630 |
+
model,
|
| 631 |
+
tokenizer,
|
| 632 |
+
train_dataset,
|
| 633 |
+
eval_dataset,
|
| 634 |
+
data_collator,
|
| 635 |
+
out_dir: str,
|
| 636 |
+
epochs: int,
|
| 637 |
+
hub_repo: Optional[str] = None,
|
| 638 |
+
):
|
| 639 |
+
"""
|
| 640 |
+
Train a single stage
|
| 641 |
+
"""
|
| 642 |
+
print(f"\n{'='*60}")
|
| 643 |
+
print(f"Training {stage_name}")
|
| 644 |
+
print(f"{'='*60}")
|
| 645 |
+
|
| 646 |
+
# Auto-select Hub repo by stage if not provided
|
| 647 |
+
if hub_repo is None:
|
| 648 |
+
s = (stage_name or "").lower()
|
| 649 |
+
if s.startswith("stage1"):
|
| 650 |
+
hub_repo = HUB_REPO_S1
|
| 651 |
+
elif s.startswith("stage2"):
|
| 652 |
+
hub_repo = HUB_REPO_S2
|
| 653 |
+
elif s.startswith("stage3"):
|
| 654 |
+
hub_repo = HUB_REPO_S3
|
| 655 |
+
|
| 656 |
+
args = make_training_args(out_dir, epochs, two_point_hub=(hub_repo is not None))
|
| 657 |
+
|
| 658 |
+
trainer = Trainer(
|
| 659 |
+
model=model,
|
| 660 |
+
tokenizer=tokenizer,
|
| 661 |
+
train_dataset=train_dataset,
|
| 662 |
+
eval_dataset=eval_dataset,
|
| 663 |
+
args=args,
|
| 664 |
+
data_collator=data_collator,
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
# Train-loss early stop (match test_overfit behavior)
|
| 668 |
+
class TrainLossStopCallback(TrainerCallback):
|
| 669 |
+
def __init__(self, threshold: float = 1.0):
|
| 670 |
+
self.threshold = float(threshold)
|
| 671 |
+
self.triggered = False
|
| 672 |
+
|
| 673 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 674 |
+
if logs is None:
|
| 675 |
+
return control
|
| 676 |
+
loss = logs.get("loss")
|
| 677 |
+
if loss is not None and loss < self.threshold and state.global_step > 0 and not self.triggered:
|
| 678 |
+
self.triggered = True
|
| 679 |
+
print(f"\nTrain-loss early stop: loss={loss:.4f} < {self.threshold}")
|
| 680 |
+
control.should_training_stop = True
|
| 681 |
+
return control
|
| 682 |
+
|
| 683 |
+
trainer.add_callback(TrainLossStopCallback(threshold=1.0))
|
| 684 |
+
|
| 685 |
+
# Add two-point Hub checkpoint uploader if configured
|
| 686 |
+
if hub_repo:
|
| 687 |
+
# Pass token if available to avoid auth prompts in Kaggle/Colab
|
| 688 |
+
token = HF_TOKEN if isinstance(HF_TOKEN, str) and len(HF_TOKEN) > 0 else None
|
| 689 |
+
trainer.add_callback(TwoPointHubCheckpointCallback(hub_repo, token=token))
|
| 690 |
+
|
| 691 |
+
# Train (with auto-resume from Hub if available)
|
| 692 |
+
resume_path = latest_hub_checkpoint(hub_repo) if hub_repo else None
|
| 693 |
+
if resume_path:
|
| 694 |
+
print(f"Resuming from Hub checkpoint: {resume_path}")
|
| 695 |
+
trainer.train(resume_from_checkpoint=resume_path)
|
| 696 |
+
else:
|
| 697 |
+
print(f"Starting training for {stage_name}...")
|
| 698 |
+
trainer.train()
|
| 699 |
+
|
| 700 |
+
# Evaluate
|
| 701 |
+
print(f"Evaluating {stage_name}...")
|
| 702 |
+
metrics = trainer.evaluate()
|
| 703 |
+
|
| 704 |
+
# Compute perplexity
|
| 705 |
+
eval_loss = metrics.get("eval_loss", float("nan"))
|
| 706 |
+
ppl = math.exp(eval_loss) if not math.isnan(eval_loss) else float("nan")
|
| 707 |
+
|
| 708 |
+
print(f"\n{stage_name} Results:")
|
| 709 |
+
print(f" eval_loss: {eval_loss:.4f}")
|
| 710 |
+
print(f" perplexity: {ppl:.3f}")
|
| 711 |
+
|
| 712 |
+
# Save model (optional - can be commented out to save space)
|
| 713 |
+
# trainer.save_model(out_dir)
|
| 714 |
+
# print(f"Model saved to {out_dir}")
|
| 715 |
+
|
| 716 |
+
return metrics
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
def save_model_to_hub(model, tokenizer, repo_id: str, stage_name: str):
|
| 720 |
+
"""
|
| 721 |
+
Save model and tokenizer to HuggingFace Hub
|
| 722 |
+
"""
|
| 723 |
+
print(f"\nSaving {stage_name} to HuggingFace Hub: {repo_id}")
|
| 724 |
+
model.push_to_hub(repo_id, commit_message=f"Upload {stage_name}")
|
| 725 |
+
tokenizer.push_to_hub(repo_id, commit_message=f"Upload {stage_name}")
|
| 726 |
+
print(f"Successfully saved {stage_name}")
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def load_model_from_hub(repo_id: str):
|
| 730 |
+
"""
|
| 731 |
+
Load model and tokenizer from HuggingFace Hub
|
| 732 |
+
"""
|
| 733 |
+
from unsloth import FastLanguageModel
|
| 734 |
+
from config import MAX_SEQ_LEN, DTYPE
|
| 735 |
+
|
| 736 |
+
print(f"\nLoading model from HuggingFace Hub: {repo_id}")
|
| 737 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 738 |
+
model_name=repo_id,
|
| 739 |
+
max_seq_length=MAX_SEQ_LEN,
|
| 740 |
+
dtype=DTYPE,
|
| 741 |
+
load_in_4bit=True,
|
| 742 |
+
)
|
| 743 |
+
print(f"Successfully loaded model from {repo_id}")
|
| 744 |
+
return model, tokenizer
|
train_mgpt_vqvae.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import zipfile
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
import glob
|
| 10 |
+
import warnings
|
| 11 |
+
import json
|
| 12 |
+
import time
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
import random
|
| 15 |
+
import math
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
# Add the mGPT directory to the path
|
| 20 |
+
sys.path.append('/kaggle/working')
|
| 21 |
+
|
| 22 |
+
from mGPT.archs.mgpt_vq import VQVae
|
| 23 |
+
|
| 24 |
+
warnings.filterwarnings("ignore")
|
| 25 |
+
|
| 26 |
+
# Configuration
|
| 27 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 28 |
+
DATA_ROOT = '/kaggle/working/extracted_files'
|
| 29 |
+
CHECKPOINT_DIR = '/kaggle/working/checkpoints_mgpt'
|
| 30 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
| 31 |
+
print("Device:", DEVICE)
|
| 32 |
+
|
| 33 |
+
# ──────────────────────────────────────────────────────────
|
| 34 |
+
# Enhanced Dataset with File Tracking and Batching (UNCHANGED)
|
| 35 |
+
# ──────────────────────────────────────────────────────────
|
| 36 |
+
|
| 37 |
+
def load_smplx_from_folder(folder_path):
|
| 38 |
+
all_frame_dicts = []
|
| 39 |
+
for pkl_file in sorted(glob.glob(os.path.join(folder_path, '*.pkl'))):
|
| 40 |
+
try:
|
| 41 |
+
with open(pkl_file, 'rb') as f:
|
| 42 |
+
data = pickle.load(f)
|
| 43 |
+
if isinstance(data, list):
|
| 44 |
+
all_frame_dicts.extend(data)
|
| 45 |
+
elif isinstance(data, dict):
|
| 46 |
+
all_frame_dicts.append(data)
|
| 47 |
+
except Exception:
|
| 48 |
+
continue
|
| 49 |
+
if not all_frame_dicts:
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
param_keys = ['shape','body_pose','lhand_pose','rhand_pose','jaw_pose',
|
| 53 |
+
'expression','root_pose','cam_trans']
|
| 54 |
+
param_dims = [10,63,45,45,3,10,3,3]
|
| 55 |
+
sequences = []
|
| 56 |
+
for frame in all_frame_dicts:
|
| 57 |
+
vec = []
|
| 58 |
+
for key, dim in zip(param_keys, param_dims):
|
| 59 |
+
arr = np.zeros(dim)
|
| 60 |
+
if key in frame and frame[key] is not None:
|
| 61 |
+
v = np.array(frame[key]).flatten()
|
| 62 |
+
arr[:min(len(v), dim)] = v[:dim]
|
| 63 |
+
vec.append(arr)
|
| 64 |
+
sequences.append(np.concatenate(vec))
|
| 65 |
+
return torch.tensor(np.stack(sequences), dtype=torch.float32)
|
| 66 |
+
|
| 67 |
+
class EnhancedMotionDataset(Dataset):
|
| 68 |
+
def __init__(self, root_dir, processed_files_path, batch_folders=1000):
|
| 69 |
+
self.root_dir = root_dir
|
| 70 |
+
self.processed_files_path = processed_files_path
|
| 71 |
+
self.batch_folders = batch_folders
|
| 72 |
+
|
| 73 |
+
print(f"\n[DEBUG] Initializing Dataset.")
|
| 74 |
+
print(f"[DEBUG] Root directory: '{self.root_dir}'")
|
| 75 |
+
|
| 76 |
+
if not os.path.exists(self.root_dir):
|
| 77 |
+
print(f"[DEBUG] ERROR: The root directory '{self.root_dir}' does not exist!")
|
| 78 |
+
self.all_folders = []
|
| 79 |
+
else:
|
| 80 |
+
print(f"[DEBUG] Root directory exists.")
|
| 81 |
+
glob_path = os.path.join(root_dir, '*')
|
| 82 |
+
print(f"[DEBUG] Using glob pattern: '{glob_path}'")
|
| 83 |
+
all_paths = glob.glob(glob_path)
|
| 84 |
+
print(f"[DEBUG] Glob found {len(all_paths)} total paths.")
|
| 85 |
+
self.all_folders = [d for d in all_paths if os.path.isdir(d)]
|
| 86 |
+
print(f"[DEBUG] Found {len(self.all_folders)} directories.")
|
| 87 |
+
|
| 88 |
+
self.processed = self._load_processed()
|
| 89 |
+
print(f"[DEBUG] Loaded {len(self.processed)} processed folder paths.")
|
| 90 |
+
|
| 91 |
+
self.unprocessed = [f for f in self.all_folders if f not in self.processed]
|
| 92 |
+
print(f"[DEBUG] Found {len(self.unprocessed)} unprocessed folders.")
|
| 93 |
+
|
| 94 |
+
self._prep_batch()
|
| 95 |
+
|
| 96 |
+
def _load_processed(self):
|
| 97 |
+
if os.path.exists(self.processed_files_path):
|
| 98 |
+
with open(self.processed_files_path, 'r') as f:
|
| 99 |
+
return json.load(f)
|
| 100 |
+
return []
|
| 101 |
+
|
| 102 |
+
def _save_processed(self):
|
| 103 |
+
with open(self.processed_files_path, 'w') as f:
|
| 104 |
+
json.dump(self.processed, f)
|
| 105 |
+
|
| 106 |
+
def _prep_batch(self):
|
| 107 |
+
self.current = self.unprocessed[:self.batch_folders]
|
| 108 |
+
self.samples = self.current.copy()
|
| 109 |
+
print(f"→ Loading {len(self.samples)} folders this batch")
|
| 110 |
+
|
| 111 |
+
def mark_batch_as_processed(self):
|
| 112 |
+
self.processed += self.current
|
| 113 |
+
self._save_processed()
|
| 114 |
+
|
| 115 |
+
def get_next_batch(self):
|
| 116 |
+
all_folders = [d for d in glob.glob(os.path.join(self.root_dir, '*')) if os.path.isdir(d)]
|
| 117 |
+
self.processed = self._load_processed()
|
| 118 |
+
self.unprocessed = [f for f in all_folders if f not in self.processed]
|
| 119 |
+
|
| 120 |
+
if not self.unprocessed:
|
| 121 |
+
print("✅ All data processed")
|
| 122 |
+
return False
|
| 123 |
+
self._prep_batch()
|
| 124 |
+
return True
|
| 125 |
+
|
| 126 |
+
def __len__(self):
|
| 127 |
+
return len(self.samples)
|
| 128 |
+
|
| 129 |
+
def __getitem__(self, idx):
|
| 130 |
+
seq = load_smplx_from_folder(self.samples[idx])
|
| 131 |
+
if seq is None or seq.shape[0] < 64:
|
| 132 |
+
return None
|
| 133 |
+
return seq
|
| 134 |
+
|
| 135 |
+
# ───────────────────────────────────────────��──────────────
|
| 136 |
+
# Checkpoint Management (UNCHANGED)
|
| 137 |
+
# ──────────────────────────────────────────────────────────
|
| 138 |
+
|
| 139 |
+
class CheckpointManager:
|
| 140 |
+
def __init__(self, checkpoint_dir, max_checkpoints=2):
|
| 141 |
+
self.checkpoint_dir = checkpoint_dir
|
| 142 |
+
self.max_checkpoints = max_checkpoints
|
| 143 |
+
|
| 144 |
+
def save_checkpoint(self, model, optimizer, epoch, batch_idx, loss, metadata=None):
|
| 145 |
+
checkpoint = {
|
| 146 |
+
'epoch': epoch,
|
| 147 |
+
'batch_idx': batch_idx,
|
| 148 |
+
'model_state_dict': model.state_dict(),
|
| 149 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 150 |
+
'loss': loss,
|
| 151 |
+
'timestamp': datetime.now().isoformat(),
|
| 152 |
+
'metadata': metadata or {}
|
| 153 |
+
}
|
| 154 |
+
checkpoint_path = os.path.join(
|
| 155 |
+
self.checkpoint_dir,
|
| 156 |
+
f'mgpt_vqvae_epoch_{epoch:03d}_batch_{batch_idx:04d}.pt'
|
| 157 |
+
)
|
| 158 |
+
torch.save(checkpoint, checkpoint_path)
|
| 159 |
+
print(f"Saved checkpoint: {checkpoint_path}")
|
| 160 |
+
self.cleanup_old_checkpoints()
|
| 161 |
+
return checkpoint_path
|
| 162 |
+
|
| 163 |
+
def cleanup_old_checkpoints(self):
|
| 164 |
+
checkpoints = glob.glob(os.path.join(self.checkpoint_dir, 'mgpt_vqvae_epoch_*.pt'))
|
| 165 |
+
checkpoints.sort(key=os.path.getmtime, reverse=True)
|
| 166 |
+
if len(checkpoints) > self.max_checkpoints:
|
| 167 |
+
for checkpoint in checkpoints[self.max_checkpoints:]:
|
| 168 |
+
os.remove(checkpoint)
|
| 169 |
+
print(f"Removed old checkpoint: {checkpoint}")
|
| 170 |
+
|
| 171 |
+
def load_latest_checkpoint(self):
|
| 172 |
+
checkpoints = glob.glob(os.path.join(self.checkpoint_dir, 'mgpt_vqvae_epoch_*.pt'))
|
| 173 |
+
if not checkpoints:
|
| 174 |
+
return None
|
| 175 |
+
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
| 176 |
+
print(f"Loading checkpoint: {latest_checkpoint}")
|
| 177 |
+
return torch.load(latest_checkpoint, map_location=DEVICE)
|
| 178 |
+
|
| 179 |
+
def get_checkpoint_info(self):
|
| 180 |
+
checkpoints = glob.glob(os.path.join(self.checkpoint_dir, 'mgpt_vqvae_epoch_*.pt'))
|
| 181 |
+
return len(checkpoints), checkpoints
|
| 182 |
+
|
| 183 |
+
# ──────────────────────────────────────────────────────────
|
| 184 |
+
# Enhanced Training Function with MotionGPT VQ-VAE
|
| 185 |
+
# ──────────────────────────────────────────────────────────
|
| 186 |
+
|
| 187 |
+
def train_mgpt_vqvae(vq_model, dataset, epochs_per_batch=20, batch_size=16, lr=1e-4):
|
| 188 |
+
print("\n" + "="*70)
|
| 189 |
+
print(" STARTING MGPT VQ-VAE TRAINING WITH CHECKPOINTING ")
|
| 190 |
+
print("="*70)
|
| 191 |
+
|
| 192 |
+
optimizer = torch.optim.AdamW(vq_model.parameters(), lr=lr, weight_decay=1e-4)
|
| 193 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
|
| 194 |
+
loss_fn = nn.SmoothL1Loss(reduction='none')
|
| 195 |
+
checkpoint_manager = CheckpointManager(CHECKPOINT_DIR)
|
| 196 |
+
|
| 197 |
+
checkpoint = checkpoint_manager.load_latest_checkpoint()
|
| 198 |
+
global_epoch = 1
|
| 199 |
+
if checkpoint:
|
| 200 |
+
vq_model.load_state_dict(checkpoint['model_state_dict'])
|
| 201 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 202 |
+
global_epoch = checkpoint.get('metadata', {}).get('global_epoch', checkpoint['epoch'])
|
| 203 |
+
print(f"Resumed from GLOBAL epoch {global_epoch}")
|
| 204 |
+
|
| 205 |
+
vq_model.to(DEVICE).train()
|
| 206 |
+
|
| 207 |
+
# Define loss weights for SMPL parameters
|
| 208 |
+
param_dims = [10, 63, 45, 45, 3, 10, 3, 3]
|
| 209 |
+
param_starts = np.cumsum([0] + param_dims[:-1]).tolist()
|
| 210 |
+
smpl_dim = sum(param_dims)
|
| 211 |
+
loss_weights = torch.ones(smpl_dim, device=DEVICE)
|
| 212 |
+
loss_weights[param_starts[1]:param_starts[5]] = 10.0 # pose parameters
|
| 213 |
+
loss_weights[param_starts[0]:param_starts[1]] = 5.0 # shape parameters
|
| 214 |
+
loss_weights[param_starts[5]:param_starts[6]] = 8.0 # expression parameters
|
| 215 |
+
|
| 216 |
+
def log_codebook_analysis(x_recon, loss, perplexity, epoch, batch_idx):
|
| 217 |
+
# Extract encoded indices for analysis
|
| 218 |
+
with torch.no_grad():
|
| 219 |
+
x_in = vq_model.preprocess(x_recon[:1]) # Use reconstructed sample for analysis
|
| 220 |
+
x_encoder = vq_model.encoder(x_in)
|
| 221 |
+
x_flat = vq_model.quantizer.preprocess(x_encoder)
|
| 222 |
+
indices = vq_model.quantizer.quantize(x_flat)
|
| 223 |
+
|
| 224 |
+
unique_codes = torch.unique(indices)
|
| 225 |
+
usage_percentage = (len(unique_codes) / vq_model.quantizer.nb_code) * 100
|
| 226 |
+
|
| 227 |
+
print(f"[ANALYSIS] Epoch {epoch}, Batch {batch_idx}")
|
| 228 |
+
print(f"Unique codes used: {len(unique_codes)}/{vq_model.quantizer.nb_code} ({usage_percentage:.1f}%)")
|
| 229 |
+
print(f"Perplexity: {perplexity:.2f}")
|
| 230 |
+
return usage_percentage, indices
|
| 231 |
+
|
| 232 |
+
def save_reconstruction_sample(x, x_recon, lengths, epoch):
|
| 233 |
+
original_seq = x[0, :lengths[0]].cpu().numpy()
|
| 234 |
+
recon_seq = x_recon[0, :lengths[0]].cpu().numpy()
|
| 235 |
+
filename = os.path.join(CHECKPOINT_DIR, f'mgpt_recon_epoch_{epoch}.npz')
|
| 236 |
+
np.savez(filename, original=original_seq, reconstructed=recon_seq)
|
| 237 |
+
print(f"Saved reconstruction sample to {filename}")
|
| 238 |
+
mse = ((original_seq - recon_seq) ** 2).mean()
|
| 239 |
+
print(f"Reconstruction MSE: {mse:.6f}")
|
| 240 |
+
return mse
|
| 241 |
+
|
| 242 |
+
def collate_fn_enhanced(batch):
|
| 243 |
+
batch = [item for item in batch if item is not None]
|
| 244 |
+
if not batch:
|
| 245 |
+
return None
|
| 246 |
+
batch.sort(key=lambda x: x.shape[0], reverse=True)
|
| 247 |
+
max_len = batch[0].shape[0]
|
| 248 |
+
max_len = min(max_len, 256)
|
| 249 |
+
downsampling_factor = 8
|
| 250 |
+
padded_max_len = math.ceil(max_len / downsampling_factor) * downsampling_factor
|
| 251 |
+
padded_batch = torch.zeros(len(batch), padded_max_len, batch[0].shape[1])
|
| 252 |
+
lengths = []
|
| 253 |
+
for i, x in enumerate(batch):
|
| 254 |
+
length = min(x.shape[0], padded_max_len)
|
| 255 |
+
padded_batch[i, :length, :] = x[:length, :]
|
| 256 |
+
lengths.append(length)
|
| 257 |
+
return padded_batch, torch.tensor(lengths)
|
| 258 |
+
|
| 259 |
+
while True:
|
| 260 |
+
print(f"\n{'='*50}")
|
| 261 |
+
print(f"Processing file batch with {len(dataset)} files")
|
| 262 |
+
print(f"{'='*50}")
|
| 263 |
+
|
| 264 |
+
if len(dataset) == 0:
|
| 265 |
+
if not dataset.get_next_batch():
|
| 266 |
+
print("✅ All data processed! Training complete.")
|
| 267 |
+
break
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
dataloader = DataLoader(
|
| 271 |
+
dataset, batch_size=batch_size, shuffle=True,
|
| 272 |
+
num_workers=0, collate_fn=collate_fn_enhanced, drop_last=True
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
for epoch in range(global_epoch, global_epoch + epochs_per_batch):
|
| 276 |
+
epoch_losses, epoch_vq_losses, epoch_rec_losses = [], [], []
|
| 277 |
+
codebook_usage_history = []
|
| 278 |
+
epoch_indices = []
|
| 279 |
+
|
| 280 |
+
for batch_idx, batch_data in enumerate(dataloader):
|
| 281 |
+
if batch_data is None:
|
| 282 |
+
continue
|
| 283 |
+
|
| 284 |
+
motion_batch, lengths = batch_data
|
| 285 |
+
x = motion_batch.to(DEVICE)
|
| 286 |
+
|
| 287 |
+
# Forward pass through MotionGPT VQ-VAE
|
| 288 |
+
x_recon, vq_loss, perplexity = vq_model(x)
|
| 289 |
+
|
| 290 |
+
if batch_idx % 50 == 0:
|
| 291 |
+
usage_pct, indices = log_codebook_analysis(x_recon, vq_loss, perplexity, epoch, batch_idx)
|
| 292 |
+
epoch_indices.append(indices.cpu().numpy().flatten())
|
| 293 |
+
|
| 294 |
+
# Calculate reconstruction loss with weighted parameters
|
| 295 |
+
rec_loss_unreduced = loss_fn(x_recon, x) * loss_weights.unsqueeze(0).unsqueeze(0)
|
| 296 |
+
mask = torch.zeros_like(x[:, :, 0])
|
| 297 |
+
for i, length in enumerate(lengths):
|
| 298 |
+
mask[i, :length] = 1.0
|
| 299 |
+
mask = mask.unsqueeze(-1).expand_as(rec_loss_unreduced)
|
| 300 |
+
rec_loss = (rec_loss_unreduced * mask).sum() / mask.sum()
|
| 301 |
+
|
| 302 |
+
vq_weight = 1.0
|
| 303 |
+
total_loss = rec_loss + vq_weight * vq_loss
|
| 304 |
+
|
| 305 |
+
optimizer.zero_grad()
|
| 306 |
+
total_loss.backward()
|
| 307 |
+
torch.nn.utils.clip_grad_norm_(vq_model.parameters(), max_norm=1.0)
|
| 308 |
+
optimizer.step()
|
| 309 |
+
scheduler.step()
|
| 310 |
+
|
| 311 |
+
epoch_losses.append(total_loss.item())
|
| 312 |
+
epoch_vq_losses.append(vq_loss.item())
|
| 313 |
+
epoch_rec_losses.append(rec_loss.item())
|
| 314 |
+
|
| 315 |
+
if batch_idx % 20 == 0:
|
| 316 |
+
current_lr = optimizer.param_groups[0]['lr']
|
| 317 |
+
print(f"[E:{epoch:03d}] B:{batch_idx:03d} | "
|
| 318 |
+
f"Loss: {total_loss.item():.4f} "
|
| 319 |
+
f"(Rec: {rec_loss.item():.4f}, VQ: {vq_loss.item():.4f}) | "
|
| 320 |
+
f"Perplexity: {perplexity:.2f} | "
|
| 321 |
+
f"LR: {current_lr:.2e}")
|
| 322 |
+
|
| 323 |
+
if epoch_losses:
|
| 324 |
+
avg_loss = np.mean(epoch_losses)
|
| 325 |
+
avg_vq_loss = np.mean(epoch_vq_losses)
|
| 326 |
+
avg_rec_loss = np.mean(epoch_rec_losses)
|
| 327 |
+
|
| 328 |
+
print(f"\n[EPOCH {epoch:03d} SUMMARY]")
|
| 329 |
+
print(f"Avg Loss: {avg_loss:.4f} (Rec: {avg_rec_loss:.4f}, VQ: {avg_vq_loss:.4f})")
|
| 330 |
+
|
| 331 |
+
# Create histogram if we collected indices
|
| 332 |
+
if epoch_indices:
|
| 333 |
+
all_epoch_indices = np.concatenate(epoch_indices)
|
| 334 |
+
plt.figure(figsize=(12, 6))
|
| 335 |
+
plt.hist(all_epoch_indices, bins=vq_model.quantizer.nb_code,
|
| 336 |
+
range=(0, vq_model.quantizer.nb_code-1))
|
| 337 |
+
plt.title(f'MotionGPT Codebook Usage Distribution - Epoch {epoch}')
|
| 338 |
+
plt.xlabel('Codebook Index')
|
| 339 |
+
plt.ylabel('Frequency')
|
| 340 |
+
hist_path = os.path.join(CHECKPOINT_DIR, f'mgpt_codebook_usage_epoch_{epoch:03d}.png')
|
| 341 |
+
plt.savefig(hist_path)
|
| 342 |
+
plt.close()
|
| 343 |
+
print(f"Saved codebook usage histogram to {hist_path}")
|
| 344 |
+
|
| 345 |
+
if epoch > 0 and epoch % 5 == 0:
|
| 346 |
+
vq_model.eval()
|
| 347 |
+
with torch.no_grad():
|
| 348 |
+
for val_data in dataloader:
|
| 349 |
+
if val_data is not None:
|
| 350 |
+
motion_batch, lengths = val_data
|
| 351 |
+
x = motion_batch.to(DEVICE)
|
| 352 |
+
x_recon, _, _ = vq_model(x)
|
| 353 |
+
save_reconstruction_sample(x, x_recon, lengths, epoch)
|
| 354 |
+
break
|
| 355 |
+
vq_model.train()
|
| 356 |
+
|
| 357 |
+
if epoch > 0 and epoch % 10 == 0:
|
| 358 |
+
checkpoint_manager.save_checkpoint(
|
| 359 |
+
vq_model, optimizer, epoch, -1, np.mean(epoch_losses),
|
| 360 |
+
metadata={'global_epoch': epoch}
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
global_epoch += epochs_per_batch
|
| 364 |
+
|
| 365 |
+
dataset.mark_batch_as_processed()
|
| 366 |
+
|
| 367 |
+
if not dataset.get_next_batch():
|
| 368 |
+
print("✅ All data processed! Training complete.")
|
| 369 |
+
break
|
| 370 |
+
|
| 371 |
+
return vq_model
|
| 372 |
+
|
| 373 |
+
# ──────────────────────────────────────────────────────────
|
| 374 |
+
# Main Training Script
|
| 375 |
+
# ──────────────────────────────────────────────────────────
|
| 376 |
+
|
| 377 |
+
def main():
|
| 378 |
+
print("Starting MotionGPT VQ-VAE Training System")
|
| 379 |
+
print(f"Checkpoint directory: {CHECKPOINT_DIR}")
|
| 380 |
+
|
| 381 |
+
smpl_dim = 182
|
| 382 |
+
codebook_size = 512
|
| 383 |
+
code_dim = 512
|
| 384 |
+
|
| 385 |
+
# Initialize MotionGPT VQ-VAE
|
| 386 |
+
vq_model = VQVae(
|
| 387 |
+
nfeats=smpl_dim,
|
| 388 |
+
quantizer="ema_reset", # Options: "ema_reset", "orig", "ema", "reset"
|
| 389 |
+
code_num=codebook_size,
|
| 390 |
+
code_dim=code_dim,
|
| 391 |
+
output_emb_width=code_dim,
|
| 392 |
+
down_t=3,
|
| 393 |
+
stride_t=2,
|
| 394 |
+
width=512,
|
| 395 |
+
depth=3,
|
| 396 |
+
dilation_growth_rate=3,
|
| 397 |
+
norm=None,
|
| 398 |
+
activation="relu"
|
| 399 |
+
).to(DEVICE)
|
| 400 |
+
|
| 401 |
+
total_params = sum(p.numel() for p in vq_model.parameters())
|
| 402 |
+
trainable_params = sum(p.numel() for p in vq_model.parameters() if p.requires_grad)
|
| 403 |
+
print(f"Total parameters: {total_params:,}")
|
| 404 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
| 405 |
+
|
| 406 |
+
motion_dataset = EnhancedMotionDataset(
|
| 407 |
+
root_dir=DATA_ROOT,
|
| 408 |
+
processed_files_path=os.path.join(CHECKPOINT_DIR, 'processed_folders_mgpt.json'),
|
| 409 |
+
batch_folders=800
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
vq_model = train_mgpt_vqvae(
|
| 413 |
+
vq_model,
|
| 414 |
+
motion_dataset,
|
| 415 |
+
epochs_per_batch=15,
|
| 416 |
+
batch_size=12,
|
| 417 |
+
lr=2e-4
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
print("\n" + "="*70)
|
| 421 |
+
print("MGPT VQ-VAE TRAINING COMPLETED SUCCESSFULLY!")
|
| 422 |
+
print("="*70)
|
| 423 |
+
|
| 424 |
+
final_model_path = os.path.join(CHECKPOINT_DIR, 'final_mgpt_vqvae_model.pt')
|
| 425 |
+
torch.save({
|
| 426 |
+
'model_state_dict': vq_model.state_dict(),
|
| 427 |
+
'model_config': {
|
| 428 |
+
'nfeats': smpl_dim,
|
| 429 |
+
'code_num': codebook_size,
|
| 430 |
+
'code_dim': code_dim,
|
| 431 |
+
'quantizer': "ema_reset"
|
| 432 |
+
},
|
| 433 |
+
'training_completed': True
|
| 434 |
+
}, final_model_path)
|
| 435 |
+
print(f"Final model saved to: {final_model_path}")
|
| 436 |
+
|
| 437 |
+
if __name__ == "__main__":
|
| 438 |
+
main()
|
train_pipeline.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main training pipeline for Motion LLM (Matched to test_overfit.py logic)
|
| 3 |
+
Run this script to execute the full training process matching the reference implementation.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import torch
|
| 8 |
+
import json
|
| 9 |
+
import argparse
|
| 10 |
+
from types import SimpleNamespace
|
| 11 |
+
import warnings
|
| 12 |
+
|
| 13 |
+
# Import updated modules
|
| 14 |
+
from config import (
|
| 15 |
+
SEED, DATA_JSON_PATH, MODEL_NAME, PIPELINE_OUTPUT_DIR,
|
| 16 |
+
HF_STAGE1_REPO_ID, HF_STAGE2_REPO_ID, HF_STAGE2_SAVE_SUBDIR,
|
| 17 |
+
FORCE_STAGE2_FROM_STAGE1, HF_USE_HUB, HF_TOKEN,
|
| 18 |
+
EVALUATION_WORDS, EVAL_SAMPLE_LIMIT, RUN_EVALS_ONLY,
|
| 19 |
+
TEST_EVAL_OUTPUT_DIR, TEST_EVAL_DOWNLOAD_DIR, TEST_EVAL_EXTRACT_DIR,
|
| 20 |
+
TEST_EVAL_SAMPLE_LIMIT, TEST_EVAL_MAX_ZIPS, TEST_EVAL_HF_REPO, TEST_EVAL_HF_SUBFOLDER
|
| 21 |
+
)
|
| 22 |
+
from data import read_json_data, deduplicate_and_prepare_data, build_motion_vocab
|
| 23 |
+
from model import setup_model_and_tokenizer_raw, ensure_tokenizer_has_motion_tokens
|
| 24 |
+
from train import (
|
| 25 |
+
train_stage1_raw, train_stage2_raw, resolve_and_ensure_repo,
|
| 26 |
+
repo_has_stage_latest, load_model_and_tokenizer_from_hf,
|
| 27 |
+
download_training_state, repo_get_latest_epoch_subfolder,
|
| 28 |
+
load_model_and_tokenizer_from_hf_subfolder, download_training_state_from_subfolder
|
| 29 |
+
)
|
| 30 |
+
from metrics import (
|
| 31 |
+
evaluate_metrics_encoder_style, run_inference_on_all_samples,
|
| 32 |
+
evaluate_metrics_motiongpt_style, save_side_by_side_visualizations
|
| 33 |
+
)
|
| 34 |
+
import test_dataset_eval
|
| 35 |
+
|
| 36 |
+
# Suppress warnings
|
| 37 |
+
warnings.filterwarnings("ignore")
|
| 38 |
+
|
| 39 |
+
def main():
|
| 40 |
+
"""Main function to run the entire pipeline matching test_overfit.py."""
|
| 41 |
+
print("="*80)
|
| 42 |
+
print(" Motion LLM Training Pipeline (Matches test_overfit.py)")
|
| 43 |
+
print("="*80)
|
| 44 |
+
|
| 45 |
+
# Set seeds
|
| 46 |
+
random.seed(SEED)
|
| 47 |
+
torch.manual_seed(SEED)
|
| 48 |
+
if torch.cuda.is_available():
|
| 49 |
+
torch.cuda.manual_seed_all(SEED)
|
| 50 |
+
|
| 51 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 52 |
+
print(f"Using device: {device}")
|
| 53 |
+
|
| 54 |
+
# 1. Load ALL data
|
| 55 |
+
print(f"\n[1/6] Loading dataset from {DATA_JSON_PATH}...")
|
| 56 |
+
all_entries = read_json_data(DATA_JSON_PATH)
|
| 57 |
+
|
| 58 |
+
# 2. Clean the ENTIRE dataset and get all tokens
|
| 59 |
+
print("\n[2/6] Cleaning dataset...")
|
| 60 |
+
cleaned_data, all_motion_tokens = deduplicate_and_prepare_data(all_entries)
|
| 61 |
+
|
| 62 |
+
# 3. Stage 1: Initialize or resume from HF, then train
|
| 63 |
+
print("\n[3/6] Stage 1 Setup & Training...")
|
| 64 |
+
resolved_stage1_repo = resolve_and_ensure_repo(HF_STAGE1_REPO_ID, HF_TOKEN) if HF_USE_HUB else None
|
| 65 |
+
start_epoch_s1 = 0
|
| 66 |
+
stage1_loaded = None
|
| 67 |
+
if resolved_stage1_repo:
|
| 68 |
+
if repo_has_stage_latest(resolved_stage1_repo, "stage1", HF_TOKEN):
|
| 69 |
+
stage1_loaded = load_model_and_tokenizer_from_hf(resolved_stage1_repo, "stage1", HF_TOKEN)
|
| 70 |
+
state_s1 = download_training_state(resolved_stage1_repo, "stage1", HF_TOKEN)
|
| 71 |
+
if state_s1 and isinstance(state_s1.get("epoch_completed"), int):
|
| 72 |
+
start_epoch_s1 = state_s1["epoch_completed"]
|
| 73 |
+
else:
|
| 74 |
+
# Fallback: no 'latest' folder; select highest epoch-XXX
|
| 75 |
+
latest_s1_sub = repo_get_latest_epoch_subfolder(resolved_stage1_repo, "stage1", HF_TOKEN)
|
| 76 |
+
if latest_s1_sub:
|
| 77 |
+
stage1_loaded = load_model_and_tokenizer_from_hf_subfolder(resolved_stage1_repo, latest_s1_sub, HF_TOKEN)
|
| 78 |
+
state_s1 = download_training_state_from_subfolder(resolved_stage1_repo, latest_s1_sub, HF_TOKEN)
|
| 79 |
+
if state_s1 and isinstance(state_s1.get("epoch_completed"), int):
|
| 80 |
+
start_epoch_s1 = state_s1["epoch_completed"]
|
| 81 |
+
|
| 82 |
+
if stage1_loaded:
|
| 83 |
+
base_model, tokenizer = stage1_loaded
|
| 84 |
+
# Ensure tokenizer contains all motion tokens (add missing if dataset expanded)
|
| 85 |
+
added = ensure_tokenizer_has_motion_tokens(tokenizer, all_motion_tokens)
|
| 86 |
+
if added > 0:
|
| 87 |
+
base_model.resize_token_embeddings(len(tokenizer))
|
| 88 |
+
else:
|
| 89 |
+
base_model, tokenizer = setup_model_and_tokenizer_raw(MODEL_NAME, all_motion_tokens)
|
| 90 |
+
|
| 91 |
+
print(f"\nStarting Stage 1 training on {len(cleaned_data)} samples (resume from epoch {start_epoch_s1}).")
|
| 92 |
+
motion_model = train_stage1_raw(
|
| 93 |
+
base_model,
|
| 94 |
+
tokenizer,
|
| 95 |
+
cleaned_data,
|
| 96 |
+
device,
|
| 97 |
+
start_epoch=start_epoch_s1,
|
| 98 |
+
hf_repo_id=resolved_stage1_repo,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# 4. Stage 2: Initialize or resume from HF, then train
|
| 102 |
+
print("\n[4/6] Stage 2 Setup & Training...")
|
| 103 |
+
resolved_stage2_repo = resolve_and_ensure_repo(HF_STAGE2_REPO_ID, HF_TOKEN) if HF_USE_HUB else None
|
| 104 |
+
start_epoch_s2 = 0
|
| 105 |
+
stage2_loaded = None
|
| 106 |
+
print(f"Stage 2 resume policy: FORCE_STAGE2_FROM_STAGE1={FORCE_STAGE2_FROM_STAGE1}, save_subdir='{HF_STAGE2_SAVE_SUBDIR}'")
|
| 107 |
+
|
| 108 |
+
if not FORCE_STAGE2_FROM_STAGE1 and resolved_stage2_repo:
|
| 109 |
+
# Prefer loading from the configured Stage 2 save subdir (e.g., 'stage2_v2')
|
| 110 |
+
if repo_has_stage_latest(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR, HF_TOKEN):
|
| 111 |
+
stage2_loaded = load_model_and_tokenizer_from_hf(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR, HF_TOKEN)
|
| 112 |
+
state_s2 = download_training_state(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR, HF_TOKEN)
|
| 113 |
+
if state_s2 and isinstance(state_s2.get("epoch_completed"), int):
|
| 114 |
+
start_epoch_s2 = state_s2["epoch_completed"]
|
| 115 |
+
print(f"Resuming Stage 2 from HF subfolder: {HF_STAGE2_SAVE_SUBDIR}/latest (epoch_completed={start_epoch_s2})")
|
| 116 |
+
else:
|
| 117 |
+
latest_s2_sub = repo_get_latest_epoch_subfolder(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR, HF_TOKEN)
|
| 118 |
+
if latest_s2_sub:
|
| 119 |
+
stage2_loaded = load_model_and_tokenizer_from_hf_subfolder(resolved_stage2_repo, latest_s2_sub, HF_TOKEN)
|
| 120 |
+
state_s2 = download_training_state_from_subfolder(resolved_stage2_repo, latest_s2_sub, HF_TOKEN)
|
| 121 |
+
if state_s2 and isinstance(state_s2.get("epoch_completed"), int):
|
| 122 |
+
start_epoch_s2 = state_s2["epoch_completed"]
|
| 123 |
+
print(f"Resuming Stage 2 from HF subfolder: {latest_s2_sub} (epoch_completed={start_epoch_s2})")
|
| 124 |
+
|
| 125 |
+
if stage2_loaded:
|
| 126 |
+
stage2_model, tokenizer = stage2_loaded
|
| 127 |
+
added2 = ensure_tokenizer_has_motion_tokens(tokenizer, all_motion_tokens)
|
| 128 |
+
if added2 > 0:
|
| 129 |
+
stage2_model.resize_token_embeddings(len(tokenizer))
|
| 130 |
+
else:
|
| 131 |
+
stage2_model = motion_model # Start Stage 2 from Stage 1 model
|
| 132 |
+
|
| 133 |
+
print(f"\nStarting Stage 2 training on {len(cleaned_data)} samples (resume from epoch {start_epoch_s2}).")
|
| 134 |
+
final_model = train_stage2_raw(
|
| 135 |
+
stage2_model,
|
| 136 |
+
tokenizer,
|
| 137 |
+
cleaned_data,
|
| 138 |
+
device,
|
| 139 |
+
start_epoch=start_epoch_s2,
|
| 140 |
+
hf_repo_id=resolved_stage2_repo,
|
| 141 |
+
hf_stage_subdir=HF_STAGE2_SAVE_SUBDIR,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Save final model locally
|
| 145 |
+
if not os.path.exists(PIPELINE_OUTPUT_DIR):
|
| 146 |
+
os.makedirs(PIPELINE_OUTPUT_DIR)
|
| 147 |
+
final_model.save_pretrained(PIPELINE_OUTPUT_DIR)
|
| 148 |
+
tokenizer.save_pretrained(PIPELINE_OUTPUT_DIR)
|
| 149 |
+
print(f"Model saved to {PIPELINE_OUTPUT_DIR}")
|
| 150 |
+
|
| 151 |
+
# 5. Evaluation on Specific Words
|
| 152 |
+
print("\n[5/6] Evaluation on Specific Words...")
|
| 153 |
+
print("--- Filtering data for evaluation on specific words ---")
|
| 154 |
+
evaluation_data = [item for item in cleaned_data if item['word'].lower() in EVALUATION_WORDS]
|
| 155 |
+
print(f"Found {len(evaluation_data)} samples for evaluation words: {EVALUATION_WORDS}")
|
| 156 |
+
|
| 157 |
+
metrics_json_path = os.path.join(PIPELINE_OUTPUT_DIR, "metrics.json")
|
| 158 |
+
|
| 159 |
+
# 6. Metrics-only mode or full flow
|
| 160 |
+
if RUN_EVALS_ONLY:
|
| 161 |
+
# Compute the 3 metrics using VQ-VAE encoder features and save to JSON
|
| 162 |
+
metrics_enc = evaluate_metrics_encoder_style(
|
| 163 |
+
final_model, tokenizer, evaluation_data, device, sample_limit=EVAL_SAMPLE_LIMIT
|
| 164 |
+
)
|
| 165 |
+
os.makedirs(PIPELINE_OUTPUT_DIR, exist_ok=True)
|
| 166 |
+
metrics_payload = {
|
| 167 |
+
"source": "vqvae_encoder",
|
| 168 |
+
"fid": metrics_enc.get("fid"),
|
| 169 |
+
"diversity": {
|
| 170 |
+
"ground_truth": metrics_enc.get("diversity_gt"),
|
| 171 |
+
"model": metrics_enc.get("diversity_gen"),
|
| 172 |
+
},
|
| 173 |
+
"multimodality": {
|
| 174 |
+
"ground_truth": metrics_enc.get("mim_gt"),
|
| 175 |
+
"model": metrics_enc.get("mim_gen"),
|
| 176 |
+
},
|
| 177 |
+
"num_pairs": len(metrics_enc.get("pairs", [])),
|
| 178 |
+
}
|
| 179 |
+
with open(metrics_json_path, "w", encoding="utf-8") as f:
|
| 180 |
+
json.dump(metrics_payload, f, ensure_ascii=False, indent=2)
|
| 181 |
+
print(f"\n✅ Saved metrics to {metrics_json_path}")
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
# Full flow: inference logs + MotionGPT-style metrics + encoder metrics + visualizations
|
| 185 |
+
run_inference_on_all_samples(final_model, tokenizer, evaluation_data, device)
|
| 186 |
+
metrics_token = evaluate_metrics_motiongpt_style(final_model, tokenizer, evaluation_data, all_motion_tokens, device)
|
| 187 |
+
# Also compute encoder-based 3 metrics
|
| 188 |
+
metrics_enc = evaluate_metrics_encoder_style(
|
| 189 |
+
final_model, tokenizer, evaluation_data, device, sample_limit=EVAL_SAMPLE_LIMIT
|
| 190 |
+
)
|
| 191 |
+
# Visualizations (skip if metrics-only)
|
| 192 |
+
viz_dir = os.path.join(PIPELINE_OUTPUT_DIR, "html_visualizations")
|
| 193 |
+
save_side_by_side_visualizations(metrics_token["pairs"], viz_dir, limit=4)
|
| 194 |
+
|
| 195 |
+
# 7. Run Test Dataset Evaluation (test_dataset_eval.py)
|
| 196 |
+
print("\n[6/6] Running Evaluation on Held-out Test Dataset...")
|
| 197 |
+
try:
|
| 198 |
+
# Construct args matching test_dataset_eval.parse_args
|
| 199 |
+
eval_args = SimpleNamespace(
|
| 200 |
+
drive_url=None,
|
| 201 |
+
drive_id=None,
|
| 202 |
+
local_extracted_dir=None, # Will assume user needs to configure this or it uses defaults if not provided
|
| 203 |
+
# Note: test_dataset_eval requires one of drive/local. We can try to rely on defaults or skip if not configured.
|
| 204 |
+
# We will set download_dir and extract_dir from config.
|
| 205 |
+
download_dir=TEST_EVAL_DOWNLOAD_DIR,
|
| 206 |
+
extract_dir=TEST_EVAL_EXTRACT_DIR,
|
| 207 |
+
max_zips=TEST_EVAL_MAX_ZIPS,
|
| 208 |
+
hf_repo_id=TEST_EVAL_HF_REPO,
|
| 209 |
+
hf_subfolder=TEST_EVAL_HF_SUBFOLDER,
|
| 210 |
+
vqvae_ckpt=None,
|
| 211 |
+
stats_path=None,
|
| 212 |
+
output_dir=TEST_EVAL_OUTPUT_DIR,
|
| 213 |
+
sample_limit=TEST_EVAL_SAMPLE_LIMIT,
|
| 214 |
+
seed=SEED
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# For this pipeline, we might want to pass the *currently loaded* model instead of reloading from HF?
|
| 218 |
+
# test_dataset_eval.run_evaluation loads from HF.
|
| 219 |
+
# The prompt asked to "incorporate... code of test_dataset_eval.py".
|
| 220 |
+
# Ideally we pass the model object, but run_evaluation is written to load from HF.
|
| 221 |
+
# Given we just saved and pushed (if enabled), loading from HF is fine.
|
| 222 |
+
# If we haven't pushed (HF_USE_HUB=False), run_evaluation might fail if it tries to load from HF.
|
| 223 |
+
# However, the prompt implies using test_overfit.py training setup which pushes to HF.
|
| 224 |
+
|
| 225 |
+
# Critical fix: If we want to use the *local* model we just trained, we should modify test_dataset_eval or pass it.
|
| 226 |
+
# But test_dataset_eval.run_evaluation doesn't accept model arg.
|
| 227 |
+
# For now, we'll attempt to run it as designed (loading from HF).
|
| 228 |
+
# If HF_USE_HUB is False, this step might fail.
|
| 229 |
+
|
| 230 |
+
# Let's check if we can use local_extracted_dir if it exists, otherwise drive download.
|
| 231 |
+
# We will use a try-except block.
|
| 232 |
+
|
| 233 |
+
print("Calling test_dataset_eval.run_evaluation...")
|
| 234 |
+
# We need to provide either drive-url/id or local-extracted.
|
| 235 |
+
# We'll try to use the extracted dir if it has content, otherwise default to download if URL known?
|
| 236 |
+
# Actually, since we don't have a drive URL in config (it was an arg), we might skip this if not set up?
|
| 237 |
+
# But the user said "include the code".
|
| 238 |
+
|
| 239 |
+
# We'll default to using the extract dir if it exists, otherwise we might need to ask or skip.
|
| 240 |
+
# Let's assume the user has data or we use the default drive-id if known (it wasn't in the provided file).
|
| 241 |
+
# Wait, test_dataset_eval.py has mutually exclusive required group.
|
| 242 |
+
# I'll add a fallback: if TEST_EVAL_EXTRACT_DIR exists and has files, use it.
|
| 243 |
+
|
| 244 |
+
if os.path.exists(TEST_EVAL_EXTRACT_DIR) and os.listdir(TEST_EVAL_EXTRACT_DIR):
|
| 245 |
+
eval_args.local_extracted_dir = TEST_EVAL_EXTRACT_DIR
|
| 246 |
+
else:
|
| 247 |
+
# We don't have a drive URL hardcoded.
|
| 248 |
+
# We will mock the arg to fail gracefully or print a message.
|
| 249 |
+
print("⚠️ Skipping test_dataset_eval: No local data found and no Drive URL configured.")
|
| 250 |
+
eval_args = None
|
| 251 |
+
|
| 252 |
+
if eval_args:
|
| 253 |
+
test_dataset_eval.run_evaluation(eval_args)
|
| 254 |
+
|
| 255 |
+
except Exception as e:
|
| 256 |
+
print(f"⚠️ Test dataset evaluation failed: {e}")
|
| 257 |
+
|
| 258 |
+
print("\n" + "="*60)
|
| 259 |
+
print("Training pipeline complete!")
|
| 260 |
+
print("="*60)
|
| 261 |
+
print(f"Models saved to: {PIPELINE_OUTPUT_DIR}")
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
main()
|
train_vqvae.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
import glob
|
| 8 |
+
import warnings
|
| 9 |
+
import json
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
import math
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import sys
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
# ==============================================================================
|
| 18 |
+
# 0) SETUP: Architecture files
|
| 19 |
+
# ==============================================================================
|
| 20 |
+
# Make sure your mGPT folder is in the Python path
|
| 21 |
+
# sys.path.append('/path/to/your/mGPT_folder')
|
| 22 |
+
from mGPT.archs.mgpt_vq import VQVae
|
| 23 |
+
from mGPT.archs.tools import quantize_cnn
|
| 24 |
+
|
| 25 |
+
warnings.filterwarnings("ignore")
|
| 26 |
+
|
| 27 |
+
# ==============================================================================
|
| 28 |
+
# 1) CONFIGURATION
|
| 29 |
+
# ==============================================================================
|
| 30 |
+
SANITY_CHECK_ENABLED = True
|
| 31 |
+
sanity_check_counter = 0
|
| 32 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 33 |
+
print("Device:", DEVICE)
|
| 34 |
+
print(f"Sanity checks are {'ENABLED' if SANITY_CHECK_ENABLED else 'DISABLED'}.")
|
| 35 |
+
|
| 36 |
+
# ==============================================================================
|
| 37 |
+
# 2) VQ-VAE MODEL (Your instrumented classes are fine)
|
| 38 |
+
# ==============================================================================
|
| 39 |
+
class QuantizeEMAReset_Sanity(quantize_cnn.QuantizeEMAReset):
|
| 40 |
+
def forward(self, x, current_batch_idx=0):
|
| 41 |
+
global sanity_check_counter
|
| 42 |
+
N, width, T = x.shape
|
| 43 |
+
x_proc = self.preprocess(x)
|
| 44 |
+
if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0:
|
| 45 |
+
print("[Quantizer.forward] Input shape `x`: ", x.shape)
|
| 46 |
+
print("[Quantizer.forward] Shape after preprocess `x_proc`: ", x_proc.shape)
|
| 47 |
+
print(f"[Quantizer.forward] Codebook shape: {self.codebook.shape}")
|
| 48 |
+
if self.training and not self.init: print("[Quantizer.forward] Codebook is UNINITIALIZED.")
|
| 49 |
+
else: print(f"[Quantizer.forward] Codebook stats: min={self.codebook.min():.3f}, max={self.codebook.max():.3f}, mean={self.codebook.mean():.3f}")
|
| 50 |
+
if self.training and not self.init: self.init_codebook(x_proc)
|
| 51 |
+
code_idx = self.quantize(x_proc)
|
| 52 |
+
x_d = self.dequantize(code_idx)
|
| 53 |
+
if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0:
|
| 54 |
+
print(f"[Quantizer.forward] Code index range: min={code_idx.min()}, max={code_idx.max()}")
|
| 55 |
+
assert code_idx.max() < self.nb_code, "A code index is out of bounds!"
|
| 56 |
+
if self.training: perplexity = self.update_codebook(x_proc, code_idx)
|
| 57 |
+
else: perplexity = self.compute_perplexity(code_idx)
|
| 58 |
+
commit_loss = F.mse_loss(x_proc, x_d.detach())
|
| 59 |
+
x_d = x_proc + (x_d - x_proc).detach()
|
| 60 |
+
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous()
|
| 61 |
+
return x_d, commit_loss, perplexity
|
| 62 |
+
|
| 63 |
+
class VQVae_Sanity(VQVae):
|
| 64 |
+
def __init__(self, *args, **kwargs):
|
| 65 |
+
super().__init__(*args, **kwargs)
|
| 66 |
+
if isinstance(self.quantizer, quantize_cnn.QuantizeEMAReset):
|
| 67 |
+
self.quantizer = QuantizeEMAReset_Sanity(
|
| 68 |
+
self.quantizer.nb_code, self.quantizer.code_dim, self.quantizer.mu
|
| 69 |
+
)
|
| 70 |
+
def forward(self, features, current_batch_idx=0):
|
| 71 |
+
global sanity_check_counter
|
| 72 |
+
x_in = self.preprocess(features)
|
| 73 |
+
if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0: print("[VQVae.forward] Shape after preprocess (permute): ", x_in.shape)
|
| 74 |
+
x_encoder = self.encoder(x_in)
|
| 75 |
+
if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0:
|
| 76 |
+
print("[VQVae.forward] Shape after encoder `x_encoder`: ", x_encoder.shape)
|
| 77 |
+
total_downsample_factor = 2**3
|
| 78 |
+
expected_len = math.ceil(features.shape[1] / total_downsample_factor)
|
| 79 |
+
print(f"[VQVae.forward] Calculated expected quantized length: ~{expected_len}")
|
| 80 |
+
assert abs(x_encoder.shape[2] - expected_len) <= 1, "Temporal downsampling seems incorrect."
|
| 81 |
+
x_quantized, loss, perplexity = self.quantizer(x_encoder, current_batch_idx)
|
| 82 |
+
if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0: print("[VQVae.forward] Shape after quantizer `x_quantized`: ", x_quantized.shape)
|
| 83 |
+
x_decoder = self.decoder(x_quantized)
|
| 84 |
+
if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0:
|
| 85 |
+
print("[VQVae.forward] Shape after decoder `x_decoder`: ", x_decoder.shape)
|
| 86 |
+
assert x_decoder.shape[2] == features.shape[1], "Decoder output temporal dim mismatch!"
|
| 87 |
+
x_out = self.postprocess(x_decoder)
|
| 88 |
+
return x_out, loss, perplexity
|
| 89 |
+
|
| 90 |
+
# Monkey-patching
|
| 91 |
+
sys.modules['mGPT.archs.mgpt_vq'].VQVae = VQVae_Sanity
|
| 92 |
+
sys.modules['mGPT.archs.mgpt_vq'].QuantizeEMAReset = QuantizeEMAReset_Sanity
|
| 93 |
+
|
| 94 |
+
class MotionGPT_VQVAE_Wrapper(nn.Module):
|
| 95 |
+
def __init__(self, smpl_dim, codebook_size=512, code_dim=512, **kwargs):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.smpl_dim = smpl_dim
|
| 98 |
+
self.vqvae = VQVae(
|
| 99 |
+
nfeats=smpl_dim, code_num=codebook_size, code_dim=code_dim,
|
| 100 |
+
output_emb_width=code_dim, **kwargs
|
| 101 |
+
)
|
| 102 |
+
param_dims = [10, 63, 45, 45, 3, 10, 3, 3]
|
| 103 |
+
param_starts = np.cumsum([0] + param_dims[:-1]).tolist()
|
| 104 |
+
loss_weights = torch.ones(smpl_dim)
|
| 105 |
+
loss_weights[param_starts[1]:param_starts[5]] = 10.0
|
| 106 |
+
loss_weights[param_starts[0]:param_starts[1]] = 5.0
|
| 107 |
+
loss_weights[param_starts[5]:param_starts[6]] = 8.0
|
| 108 |
+
self.register_buffer('loss_weights', loss_weights)
|
| 109 |
+
print(f"Initialized MotionGPT VQ-VAE with {codebook_size} codebook size")
|
| 110 |
+
def forward(self, x, current_batch_idx=0):
|
| 111 |
+
global sanity_check_counter
|
| 112 |
+
if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0:
|
| 113 |
+
print("\n" + "="*50)
|
| 114 |
+
print("--- VQ-VAE WRAPPER SANITY CHECK (Batch 0) ---")
|
| 115 |
+
print(f"[Input] Shape of input features `x`: {x.shape}")
|
| 116 |
+
print("-"*50)
|
| 117 |
+
x_recon, vq_loss, perplexity = self.vqvae(x, current_batch_idx)
|
| 118 |
+
if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0:
|
| 119 |
+
print("[Output] Shape of reconstructed features `x_recon`: ", x_recon.shape)
|
| 120 |
+
assert x.shape == x_recon.shape, "Shape mismatch!"
|
| 121 |
+
print(f"[Output] vq_loss: {vq_loss.item():.6f}, perplexity: {perplexity.item():.2f}")
|
| 122 |
+
print("--- VQ-VAE WRAPPER SANITY CHECK COMPLETE ---")
|
| 123 |
+
print("="*50 + "\n")
|
| 124 |
+
indices, _ = self.vqvae.encode(x)
|
| 125 |
+
return x_recon, vq_loss, indices, perplexity
|
| 126 |
+
|
| 127 |
+
# ==============================================================================
|
| 128 |
+
# 3) DATA LOADING
|
| 129 |
+
# ==============================================================================
|
| 130 |
+
def load_motion_from_npz(file_path):
|
| 131 |
+
try:
|
| 132 |
+
with np.load(file_path) as data:
|
| 133 |
+
motion_data = data['motion']
|
| 134 |
+
return torch.tensor(motion_data, dtype=torch.float32)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"Warning: Could not load {os.path.basename(file_path)}. Skipping. Error: {e}")
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
class NpzMotionDataset(Dataset):
|
| 140 |
+
def __init__(self, root_dir, stats_path=None, min_seq_len=64):
|
| 141 |
+
self.min_seq_len = min_seq_len
|
| 142 |
+
print(f"\n[Dataset] Initializing from NPZ files in: '{root_dir}'")
|
| 143 |
+
glob_pattern = os.path.join(root_dir, '**', '*.npz')
|
| 144 |
+
self.files = glob.glob(glob_pattern, recursive=True)
|
| 145 |
+
if not self.files:
|
| 146 |
+
raise FileNotFoundError(f"FATAL: No .npz files found at '{glob_pattern}'.")
|
| 147 |
+
print(f"[Dataset] Found {len(self.files)} total .npz files.")
|
| 148 |
+
|
| 149 |
+
if stats_path and os.path.exists(stats_path):
|
| 150 |
+
stats = torch.load(stats_path, map_location='cpu')
|
| 151 |
+
self.mean = stats['mean']
|
| 152 |
+
self.std = stats['std']
|
| 153 |
+
print("[Dataset] Successfully loaded normalization stats to CPU.")
|
| 154 |
+
else:
|
| 155 |
+
print("❗ [Dataset] WARNING: Stats file not found. Proceeding without normalization. This will affect loss values and model performance.")
|
| 156 |
+
self.mean = 0
|
| 157 |
+
self.std = 1
|
| 158 |
+
|
| 159 |
+
def __len__(self):
|
| 160 |
+
return len(self.files)
|
| 161 |
+
|
| 162 |
+
def __getitem__(self, idx):
|
| 163 |
+
file_path = self.files[idx]
|
| 164 |
+
seq = load_motion_from_npz(file_path)
|
| 165 |
+
if seq is None or seq.shape[0] < self.min_seq_len:
|
| 166 |
+
return None
|
| 167 |
+
normalized_seq = (seq - self.mean) / self.std
|
| 168 |
+
return normalized_seq
|
| 169 |
+
|
| 170 |
+
# ==============================================================================
|
| 171 |
+
# 4) CHECKPOINT & CODEBOOK INITIALIZATION
|
| 172 |
+
# ==============================================================================
|
| 173 |
+
class CheckpointManager:
|
| 174 |
+
# (Your CheckpointManager code is fine, no changes needed here)
|
| 175 |
+
def __init__(self, checkpoint_dir, max_checkpoints=3):
|
| 176 |
+
self.checkpoint_dir = checkpoint_dir
|
| 177 |
+
self.max_checkpoints = max_checkpoints
|
| 178 |
+
def save_checkpoint(self, model, optimizer, epoch, loss, metadata=None):
|
| 179 |
+
checkpoint_path = os.path.join(self.checkpoint_dir, f'vqvae_epoch_{epoch:03d}.pt')
|
| 180 |
+
torch.save({
|
| 181 |
+
'epoch': epoch,
|
| 182 |
+
'model_state_dict': model.state_dict(),
|
| 183 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 184 |
+
'loss': loss,
|
| 185 |
+
'timestamp': datetime.now().isoformat(),
|
| 186 |
+
'metadata': metadata or {}
|
| 187 |
+
}, checkpoint_path)
|
| 188 |
+
print(f"✅ Saved checkpoint: {checkpoint_path}")
|
| 189 |
+
self.cleanup_old_checkpoints()
|
| 190 |
+
def cleanup_old_checkpoints(self):
|
| 191 |
+
checkpoints = glob.glob(os.path.join(self.checkpoint_dir, 'vqvae_epoch_*.pt'))
|
| 192 |
+
if len(checkpoints) > self.max_checkpoints:
|
| 193 |
+
checkpoints.sort(key=os.path.getmtime)
|
| 194 |
+
for old_checkpoint in checkpoints[:-self.max_checkpoints]:
|
| 195 |
+
os.remove(old_checkpoint)
|
| 196 |
+
print(f"🗑️ Removed old checkpoint: {old_checkpoint}")
|
| 197 |
+
def load_latest_checkpoint(self):
|
| 198 |
+
checkpoints = glob.glob(os.path.join(self.checkpoint_dir, 'vqvae_epoch_*.pt'))
|
| 199 |
+
if not checkpoints: return None
|
| 200 |
+
latest_checkpoint_path = max(checkpoints, key=os.path.getmtime)
|
| 201 |
+
print(f"🔄 Loading latest checkpoint: {latest_checkpoint_path}")
|
| 202 |
+
return torch.load(latest_checkpoint_path, map_location=DEVICE, weights_only=False)
|
| 203 |
+
|
| 204 |
+
def initialize_codebook_from_dataset(model, dataloader, num_batches=100):
|
| 205 |
+
print(f"⚙️ Collecting data from {num_batches} batches for codebook initialization...")
|
| 206 |
+
all_latents = []
|
| 207 |
+
model.eval()
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
for i, batch_data in enumerate(dataloader):
|
| 210 |
+
if i >= num_batches: break
|
| 211 |
+
if batch_data and batch_data[0] is not None:
|
| 212 |
+
motion_batch, _ = batch_data
|
| 213 |
+
x = motion_batch.to(DEVICE)
|
| 214 |
+
z_e = model.vqvae.encoder(model.vqvae.preprocess(x))
|
| 215 |
+
z_e_flat = z_e.permute(0, 2, 1).reshape(-1, z_e.shape[1])
|
| 216 |
+
all_latents.append(z_e_flat.cpu())
|
| 217 |
+
if not all_latents: raise ValueError("Could not collect any latents for initialization.")
|
| 218 |
+
all_latents = torch.cat(all_latents, dim=0)
|
| 219 |
+
print(f"Collected {all_latents.shape[0]} latent vectors.")
|
| 220 |
+
codebook_size = model.vqvae.quantizer.nb_code
|
| 221 |
+
indices = torch.randperm(all_latents.shape[0])[:codebook_size]
|
| 222 |
+
initial_codebook = all_latents[indices].to(DEVICE)
|
| 223 |
+
model.vqvae.quantizer.init_codebook(initial_codebook)
|
| 224 |
+
print("✅ Codebook initialized successfully from a diverse data sample.")
|
| 225 |
+
model.train()
|
| 226 |
+
|
| 227 |
+
# ==============================================================================
|
| 228 |
+
# 5) CORRECTED & COMPLETE TRAINING FUNCTION (No Globals)
|
| 229 |
+
# ==============================================================================
|
| 230 |
+
def train_vqvae_colab(vq_model, dataset, checkpoint_dir, num_epochs=300, batch_size=32, lr=2e-4):
|
| 231 |
+
"""
|
| 232 |
+
The complete, updated training function for Colab using .npz files.
|
| 233 |
+
This version avoids global variables by accepting checkpoint_dir as an argument.
|
| 234 |
+
"""
|
| 235 |
+
global sanity_check_counter
|
| 236 |
+
print("\n" + "="*70 + "\n STARTING VQ-VAE TRAINING ON COLAB \n" + "="*70)
|
| 237 |
+
|
| 238 |
+
optimizer = torch.optim.AdamW(vq_model.parameters(), lr=lr, weight_decay=1e-4)
|
| 239 |
+
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=15, T_mult=2)
|
| 240 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
| 241 |
+
loss_fn = nn.SmoothL1Loss(reduction='none')
|
| 242 |
+
# Use the passed-in checkpoint_dir
|
| 243 |
+
checkpoint_manager = CheckpointManager(checkpoint_dir)
|
| 244 |
+
|
| 245 |
+
start_epoch = 1
|
| 246 |
+
checkpoint = checkpoint_manager.load_latest_checkpoint()
|
| 247 |
+
if checkpoint:
|
| 248 |
+
vq_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 249 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 250 |
+
start_epoch = checkpoint.get('epoch', 1) + 1
|
| 251 |
+
print(f"✅ Resumed training from epoch {start_epoch}")
|
| 252 |
+
else: print("No CheckPoint Found")
|
| 253 |
+
vq_model.to(DEVICE).train()
|
| 254 |
+
codebook_size = vq_model.vqvae.quantizer.nb_code
|
| 255 |
+
|
| 256 |
+
def collate_fn_enhanced(batch):
|
| 257 |
+
batch = [item for item in batch if item is not None]
|
| 258 |
+
if not batch: return None, None
|
| 259 |
+
batch.sort(key=lambda x: x.shape[0], reverse=True)
|
| 260 |
+
max_len = min(batch[0].shape[0], 256)
|
| 261 |
+
padded_max_len = math.ceil(max_len / 8) * 8
|
| 262 |
+
padded_batch = torch.zeros(len(batch), padded_max_len, batch[0].shape[1])
|
| 263 |
+
lengths = [min(x.shape[0], padded_max_len) for x in batch]
|
| 264 |
+
for i, x_item in enumerate(batch):
|
| 265 |
+
padded_batch[i, :lengths[i], :] = x_item[:lengths[i], :]
|
| 266 |
+
return padded_batch, torch.tensor(lengths)
|
| 267 |
+
|
| 268 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2,
|
| 269 |
+
collate_fn=collate_fn_enhanced, drop_last=True, pin_memory=True)
|
| 270 |
+
|
| 271 |
+
if start_epoch == 1 and not getattr(vq_model.vqvae.quantizer, 'init', False):
|
| 272 |
+
initialize_codebook_from_dataset(vq_model, dataloader, num_batches=100)
|
| 273 |
+
|
| 274 |
+
for epoch in range(start_epoch, num_epochs + 1):
|
| 275 |
+
print(f"\n{'='*30} EPOCH {epoch}/{num_epochs} {'='*30}")
|
| 276 |
+
epoch_losses, epoch_vq_losses, epoch_rec_losses, epoch_perplexity = [], [], [], []
|
| 277 |
+
epoch_indices = []
|
| 278 |
+
|
| 279 |
+
for batch_idx, batch_data in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}")):
|
| 280 |
+
if not batch_data or batch_data[0] is None: continue
|
| 281 |
+
|
| 282 |
+
motion_batch, lengths = batch_data
|
| 283 |
+
x = motion_batch.to(DEVICE)
|
| 284 |
+
x_recon, vq_loss, indices, perplexity = vq_model(x, batch_idx)
|
| 285 |
+
|
| 286 |
+
rec_loss_unreduced = loss_fn(x_recon, x) * vq_model.loss_weights
|
| 287 |
+
mask = torch.zeros_like(x[:, :, 0], device=DEVICE)
|
| 288 |
+
for i, length in enumerate(lengths): mask[i, :length] = 1.0
|
| 289 |
+
mask = mask.unsqueeze(-1).expand_as(rec_loss_unreduced)
|
| 290 |
+
rec_loss = (rec_loss_unreduced * mask).sum() / mask.sum()
|
| 291 |
+
|
| 292 |
+
# vq_weight = max(150.0 * (0.97 ** max(0, epoch - 3)), 1.0)
|
| 293 |
+
beta = 0.25 # This is a standard and effective value.
|
| 294 |
+
total_loss = rec_loss + (beta * vq_loss)
|
| 295 |
+
# total_loss = rec_loss + (vq_weight * vq_loss)
|
| 296 |
+
|
| 297 |
+
optimizer.zero_grad(set_to_none=True)
|
| 298 |
+
total_loss.backward()
|
| 299 |
+
torch.nn.utils.clip_grad_norm_(vq_model.parameters(), max_norm=1.0)
|
| 300 |
+
optimizer.step()
|
| 301 |
+
scheduler.step()
|
| 302 |
+
|
| 303 |
+
epoch_losses.append(total_loss.item())
|
| 304 |
+
epoch_vq_losses.append(vq_loss.item())
|
| 305 |
+
epoch_rec_losses.append(rec_loss.item())
|
| 306 |
+
epoch_perplexity.append(perplexity.item())
|
| 307 |
+
epoch_indices.append(indices.cpu().numpy().flatten())
|
| 308 |
+
|
| 309 |
+
if batch_idx % 50 == 0 and batch_idx > 0:
|
| 310 |
+
print(f"\n[E:{epoch:03d}] B:{batch_idx:03d} | Loss: {total_loss.item():.4f} (Rec: {rec_loss.item():.4f}, VQ: {vq_loss.item():.6f}) | Perplexity: {perplexity.item():.2f}")
|
| 311 |
+
|
| 312 |
+
if SANITY_CHECK_ENABLED and batch_idx == 0 and sanity_check_counter == 0:
|
| 313 |
+
sanity_check_counter += 1
|
| 314 |
+
|
| 315 |
+
if not epoch_losses: continue
|
| 316 |
+
|
| 317 |
+
all_epoch_indices_flat = np.concatenate(epoch_indices)
|
| 318 |
+
counts = np.bincount(all_epoch_indices_flat, minlength=codebook_size)
|
| 319 |
+
avg_usage = (counts > 0).sum()
|
| 320 |
+
with torch.no_grad(): code_variance = vq_model.vqvae.quantizer.codebook.var(dim=0).mean().item()
|
| 321 |
+
|
| 322 |
+
print(f"\n[EPOCH {epoch:03d} SUMMARY]")
|
| 323 |
+
print(f" Avg Loss: {np.mean(epoch_losses):.4f} (Rec: {np.mean(epoch_rec_losses):.4f}, VQ: {np.mean(epoch_vq_losses):.6f})")
|
| 324 |
+
print(f" Avg Perplexity: {np.mean(epoch_perplexity):.2f}")
|
| 325 |
+
print(f" Codebook Usage: {avg_usage}/{codebook_size} ({(avg_usage/codebook_size)*100:.1f}%) | Variance: {code_variance:.6f}")
|
| 326 |
+
|
| 327 |
+
# Use the passed-in checkpoint_dir for saving plots
|
| 328 |
+
hist_path = os.path.join(checkpoint_dir, f'codebook_usage_epoch_{epoch:03d}.png')
|
| 329 |
+
plt.figure(figsize=(12, 6)); plt.hist(all_epoch_indices_flat, bins=codebook_size); plt.title(f'Codebook Usage - Epoch {epoch}'); plt.savefig(hist_path); plt.close()
|
| 330 |
+
|
| 331 |
+
if epoch > 0 and epoch % 5 == 0:
|
| 332 |
+
print("\n--- Performing End-of-Epoch Tasks ---")
|
| 333 |
+
vq_model.eval()
|
| 334 |
+
with torch.no_grad():
|
| 335 |
+
val_data = next(iter(dataloader))
|
| 336 |
+
if val_data and val_data[0] is not None:
|
| 337 |
+
motion_batch, lengths = val_data
|
| 338 |
+
x_val = motion_batch.to(DEVICE)
|
| 339 |
+
x_recon_val, _, _, _ = vq_model(x_val, -1)
|
| 340 |
+
orig = x_val[0, :lengths[0]].cpu().numpy()
|
| 341 |
+
recon = x_recon_val[0, :lengths[0]].cpu().numpy()
|
| 342 |
+
mse = ((orig - recon) ** 2).mean()
|
| 343 |
+
print(f"Reconstruction MSE on sample: {mse:.6f}")
|
| 344 |
+
|
| 345 |
+
with torch.no_grad():
|
| 346 |
+
usage_threshold = 10
|
| 347 |
+
underutilized_indices = torch.from_numpy(np.where(counts < usage_threshold)[0]).to(DEVICE)
|
| 348 |
+
num_to_reset = len(underutilized_indices)
|
| 349 |
+
if num_to_reset > 0:
|
| 350 |
+
print(f"[CODEBOOK MGMT] Resetting {num_to_reset} underutilized codes.")
|
| 351 |
+
reset_data = next(iter(dataloader))
|
| 352 |
+
if reset_data and reset_data[0] is not None:
|
| 353 |
+
motion_batch, _ = reset_data
|
| 354 |
+
x_reset = motion_batch.to(DEVICE)
|
| 355 |
+
z_e = vq_model.vqvae.encoder(vq_model.vqvae.preprocess(x_reset))
|
| 356 |
+
z_e_flat = z_e.permute(0, 2, 1).reshape(-1, z_e.shape[1])
|
| 357 |
+
if z_e_flat.shape[0] >= num_to_reset:
|
| 358 |
+
indices = torch.randperm(z_e_flat.size(0))[:num_to_reset]
|
| 359 |
+
vq_model.vqvae.quantizer.codebook.data[underutilized_indices] = z_e_flat[indices]
|
| 360 |
+
vq_model.train()
|
| 361 |
+
|
| 362 |
+
if epoch > 0 and epoch % 5 == 0:
|
| 363 |
+
checkpoint_manager.save_checkpoint(vq_model, optimizer, epoch, np.mean(epoch_losses))
|
| 364 |
+
|
| 365 |
+
print("\n✅ Training loop finished.")
|
| 366 |
+
return vq_model
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
# ==============================================================================
|
| 370 |
+
# 6) MAIN EXECUTION SCRIPT (No Globals)
|
| 371 |
+
# ==============================================================================
|
| 372 |
+
def main_colab():
|
| 373 |
+
from google.colab import drive
|
| 374 |
+
drive.mount('/content/drive')
|
| 375 |
+
print("✅ Google Drive mounted successfully.")
|
| 376 |
+
|
| 377 |
+
GDRIVE_ROOT = '/content/drive/MyDrive'
|
| 378 |
+
|
| 379 |
+
# Define all paths locally within the main function
|
| 380 |
+
STATS_PATH = f'/content/dataset_stats.pt'
|
| 381 |
+
DATA_ROOT = f'{GDRIVE_ROOT}/kaggle_upload/npz_data/batch_1'
|
| 382 |
+
CHECKPOINT_DIR = f'{GDRIVE_ROOT}/Colab_Checkpoints/MotionGPT_VQVAE_Final'
|
| 383 |
+
|
| 384 |
+
# The 'global' keyword is no longer needed
|
| 385 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
| 386 |
+
print(f"Data Root: {DATA_ROOT}")
|
| 387 |
+
print(f"Stats Path: {STATS_PATH}")
|
| 388 |
+
print(f"Checkpoint Dir: {CHECKPOINT_DIR}")
|
| 389 |
+
|
| 390 |
+
smpl_dim = 182
|
| 391 |
+
codebook_size = 512
|
| 392 |
+
code_dim = 512
|
| 393 |
+
vq_model = MotionGPT_VQVAE_Wrapper(
|
| 394 |
+
smpl_dim=smpl_dim, codebook_size=codebook_size, code_dim=code_dim,
|
| 395 |
+
quantizer="ema_reset", width=512, depth=3, down_t=3, stride_t=2,
|
| 396 |
+
dilation_growth_rate=3, activation='relu', norm=None
|
| 397 |
+
).to(DEVICE)
|
| 398 |
+
|
| 399 |
+
motion_dataset = NpzMotionDataset(
|
| 400 |
+
root_dir=DATA_ROOT,
|
| 401 |
+
stats_path=STATS_PATH,
|
| 402 |
+
min_seq_len=64
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Pass CHECKPOINT_DIR as an argument to the training function
|
| 406 |
+
vq_model = train_vqvae_colab(
|
| 407 |
+
vq_model,
|
| 408 |
+
motion_dataset,
|
| 409 |
+
checkpoint_dir=CHECKPOINT_DIR, # Pass the path here
|
| 410 |
+
num_epochs=1000,
|
| 411 |
+
batch_size=32,
|
| 412 |
+
lr=2e-4
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
print("\n" + "="*70 + "\nVQ-VAE TRAINING COMPLETED SUCCESSFULLY!\n" + "="*70)
|
| 416 |
+
final_model_path = os.path.join(CHECKPOINT_DIR, 'final_vqvae_model.pt')
|
| 417 |
+
torch.save({'model_state_dict': vq_model.state_dict()}, final_model_path)
|
| 418 |
+
print(f"Final model saved to: {final_model_path}")
|
| 419 |
+
|
| 420 |
+
if __name__ == "__main__":
|
| 421 |
+
main_colab()
|
visualize.py
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visualization script to convert motion tokens to SMPL-X 3D animation.
|
| 3 |
+
Requires VQ-VAE checkpoint, dataset stats, and SMPL-X model files.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
# Visualize from LLM output string
|
| 7 |
+
python visualize.py --tokens "<MOT_BEGIN><motion_177><motion_135>...<MOT_END>"
|
| 8 |
+
|
| 9 |
+
# Visualize from saved file
|
| 10 |
+
python visualize.py --input motion_output.txt
|
| 11 |
+
|
| 12 |
+
# Generate and visualize in one go
|
| 13 |
+
python visualize.py --prompt "walking" --stage 3
|
| 14 |
+
|
| 15 |
+
# Custom paths
|
| 16 |
+
python visualize.py --tokens "..." --vqvae-ckpt /path/to/vqvae.pt --smplx-dir /path/to/smplx
|
| 17 |
+
"""
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import re
|
| 21 |
+
import argparse
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
|
| 28 |
+
from config import WORK_DIR, DATA_DIR
|
| 29 |
+
|
| 30 |
+
# Try importing visualization dependencies
|
| 31 |
+
try:
|
| 32 |
+
import plotly.graph_objects as go
|
| 33 |
+
except ImportError:
|
| 34 |
+
print("Installing plotly...")
|
| 35 |
+
os.system("pip install -q plotly")
|
| 36 |
+
import plotly.graph_objects as go
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
import smplx
|
| 40 |
+
except ImportError:
|
| 41 |
+
print("Installing smplx...")
|
| 42 |
+
os.system("pip install -q smplx==0.1.28")
|
| 43 |
+
import smplx
|
| 44 |
+
|
| 45 |
+
# =====================================================================
|
| 46 |
+
# Configuration - can be overridden via command-line or environment
|
| 47 |
+
# =====================================================================
|
| 48 |
+
# VQ-VAE checkpoint path (trained motion encoder/decoder)
|
| 49 |
+
VQVAE_CHECKPOINT = os.environ.get(
|
| 50 |
+
"VQVAE_CHECKPOINT",
|
| 51 |
+
os.path.join(DATA_DIR, "vqvae_model.pt")
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Dataset normalization stats (mean/std used during VQ-VAE training)
|
| 55 |
+
STATS_PATH = os.environ.get(
|
| 56 |
+
"VQVAE_STATS_PATH",
|
| 57 |
+
os.path.join(DATA_DIR, "vqvae_stats.pt")
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# SMPL-X model directory (contains SMPLX_NEUTRAL.npz, etc.)
|
| 61 |
+
SMPLX_MODEL_DIR = os.environ.get(
|
| 62 |
+
"SMPLX_MODEL_DIR",
|
| 63 |
+
os.path.join(DATA_DIR, "smplx_models")
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Output directory for HTML animations
|
| 67 |
+
OUTPUT_DIR = os.environ.get("VIS_OUTPUT_DIR", WORK_DIR)
|
| 68 |
+
|
| 69 |
+
# Device
|
| 70 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 71 |
+
|
| 72 |
+
# VQ-VAE architecture params (must match training config)
|
| 73 |
+
SMPL_DIM = 182
|
| 74 |
+
CODEBOOK_SIZE = 512
|
| 75 |
+
CODE_DIM = 512
|
| 76 |
+
VQ_ARGS = dict(
|
| 77 |
+
width=512,
|
| 78 |
+
depth=3,
|
| 79 |
+
down_t=2,
|
| 80 |
+
stride_t=2,
|
| 81 |
+
dilation_growth_rate=3,
|
| 82 |
+
activation='relu',
|
| 83 |
+
norm=None,
|
| 84 |
+
quantizer="ema_reset"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# SMPL-X parameter layout (must match VQ-VAE training)
|
| 88 |
+
PARAM_DIMS = [10, 63, 45, 45, 3, 10, 3, 3]
|
| 89 |
+
PARAM_NAMES = ["betas", "body_pose", "left_hand_pose", "right_hand_pose",
|
| 90 |
+
"trans", "expression", "jaw_pose", "eye_pose"]
|
| 91 |
+
|
| 92 |
+
# =====================================================================
|
| 93 |
+
# Import VQ-VAE architecture
|
| 94 |
+
# =====================================================================
|
| 95 |
+
try:
|
| 96 |
+
# Add SignMotionGPT to path if not already
|
| 97 |
+
sign_mgpt_dir = os.path.join(os.path.dirname(__file__))
|
| 98 |
+
if sign_mgpt_dir not in sys.path:
|
| 99 |
+
sys.path.insert(0, sign_mgpt_dir)
|
| 100 |
+
|
| 101 |
+
from mGPT.archs.mgpt_vq import VQVae
|
| 102 |
+
except ImportError as e:
|
| 103 |
+
print(f"❌ Could not import VQVae: {e}")
|
| 104 |
+
print("Make sure mGPT/archs/mgpt_vq.py exists in the project.")
|
| 105 |
+
sys.exit(1)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# =====================================================================
|
| 109 |
+
# VQ-VAE Wrapper
|
| 110 |
+
# =====================================================================
|
| 111 |
+
class MotionGPT_VQVAE_Wrapper(nn.Module):
|
| 112 |
+
"""Wrapper matching the VQ-VAE training setup"""
|
| 113 |
+
def __init__(self, smpl_dim=SMPL_DIM, codebook_size=CODEBOOK_SIZE,
|
| 114 |
+
code_dim=CODE_DIM, **kwargs):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.vqvae = VQVae(
|
| 117 |
+
nfeats=smpl_dim,
|
| 118 |
+
code_num=codebook_size,
|
| 119 |
+
code_dim=code_dim,
|
| 120 |
+
output_emb_width=code_dim,
|
| 121 |
+
**kwargs
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# =====================================================================
|
| 126 |
+
# Token Parsing
|
| 127 |
+
# =====================================================================
|
| 128 |
+
def parse_motion_tokens(token_str):
|
| 129 |
+
"""
|
| 130 |
+
Parse motion tokens from LLM output string.
|
| 131 |
+
Accepts:
|
| 132 |
+
- "<MOT_BEGIN><motion_177><motion_135>...<MOT_END>"
|
| 133 |
+
- "177 135 152 200 46..."
|
| 134 |
+
- List/array of ints
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
List of token integers
|
| 138 |
+
"""
|
| 139 |
+
if isinstance(token_str, (list, tuple, np.ndarray)):
|
| 140 |
+
return [int(x) for x in token_str]
|
| 141 |
+
|
| 142 |
+
if not isinstance(token_str, str):
|
| 143 |
+
raise ValueError("Tokens must be string or list-like")
|
| 144 |
+
|
| 145 |
+
# Try extracting <motion_ID> tokens
|
| 146 |
+
matches = re.findall(r'<motion_(\d+)>', token_str)
|
| 147 |
+
if matches:
|
| 148 |
+
return [int(x) for x in matches]
|
| 149 |
+
|
| 150 |
+
# Try space-separated numbers
|
| 151 |
+
token_str = token_str.strip()
|
| 152 |
+
if token_str:
|
| 153 |
+
try:
|
| 154 |
+
return [int(x) for x in token_str.split()]
|
| 155 |
+
except ValueError:
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
raise ValueError(f"Could not parse motion tokens from: {token_str[:100]}...")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# =====================================================================
|
| 162 |
+
# Model Loading
|
| 163 |
+
# =====================================================================
|
| 164 |
+
def load_vqvae(checkpoint_path, device=DEVICE, vq_args=VQ_ARGS):
|
| 165 |
+
"""Load trained VQ-VAE model from checkpoint"""
|
| 166 |
+
if not os.path.exists(checkpoint_path):
|
| 167 |
+
raise FileNotFoundError(
|
| 168 |
+
f"VQ-VAE checkpoint not found: {checkpoint_path}\n"
|
| 169 |
+
f"Please download it and set VQVAE_CHECKPOINT environment variable "
|
| 170 |
+
f"or use --vqvae-ckpt argument."
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
print(f"Loading VQ-VAE from: {checkpoint_path}")
|
| 174 |
+
model = MotionGPT_VQVAE_Wrapper(
|
| 175 |
+
smpl_dim=SMPL_DIM,
|
| 176 |
+
codebook_size=CODEBOOK_SIZE,
|
| 177 |
+
code_dim=CODE_DIM,
|
| 178 |
+
**vq_args
|
| 179 |
+
).to(device)
|
| 180 |
+
|
| 181 |
+
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 182 |
+
state_dict = ckpt.get('model_state_dict', ckpt)
|
| 183 |
+
model.load_state_dict(state_dict, strict=False)
|
| 184 |
+
model.eval()
|
| 185 |
+
|
| 186 |
+
print(f"✅ VQ-VAE loaded (codebook size: {CODEBOOK_SIZE})")
|
| 187 |
+
return model
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def load_stats(stats_path):
|
| 191 |
+
"""Load normalization statistics (mean/std) used during VQ-VAE training"""
|
| 192 |
+
if not stats_path or not os.path.exists(stats_path):
|
| 193 |
+
print(f"⚠️ Stats file not found: {stats_path}")
|
| 194 |
+
print(" Will skip denormalization (may affect quality)")
|
| 195 |
+
return None, None
|
| 196 |
+
|
| 197 |
+
print(f"Loading stats from: {stats_path}")
|
| 198 |
+
st = torch.load(stats_path, map_location='cpu', weights_only=False)
|
| 199 |
+
mean = st.get('mean', 0)
|
| 200 |
+
std = st.get('std', 1)
|
| 201 |
+
|
| 202 |
+
# Convert to numpy
|
| 203 |
+
if torch.is_tensor(mean):
|
| 204 |
+
mean = mean.cpu().numpy()
|
| 205 |
+
if torch.is_tensor(std):
|
| 206 |
+
std = std.cpu().numpy()
|
| 207 |
+
|
| 208 |
+
print(f"✅ Stats loaded (mean shape: {np.array(mean).shape})")
|
| 209 |
+
return mean, std
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def load_smplx_model(model_dir, device=DEVICE):
|
| 213 |
+
"""Load SMPL-X body model"""
|
| 214 |
+
if not os.path.exists(model_dir):
|
| 215 |
+
raise FileNotFoundError(
|
| 216 |
+
f"SMPL-X model directory not found: {model_dir}\n"
|
| 217 |
+
f"Please download SMPL-X models and set SMPLX_MODEL_DIR environment variable "
|
| 218 |
+
f"or use --smplx-dir argument."
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
print(f"Loading SMPL-X from: {model_dir}")
|
| 222 |
+
model = smplx.SMPLX(
|
| 223 |
+
model_path=model_dir,
|
| 224 |
+
model_type='smplx',
|
| 225 |
+
gender='neutral',
|
| 226 |
+
use_pca=False,
|
| 227 |
+
create_global_orient=True,
|
| 228 |
+
create_body_pose=True,
|
| 229 |
+
create_betas=True,
|
| 230 |
+
create_expression=True,
|
| 231 |
+
create_jaw_pose=True,
|
| 232 |
+
create_left_hand_pose=True,
|
| 233 |
+
create_right_hand_pose=True,
|
| 234 |
+
create_transl=True
|
| 235 |
+
).to(device)
|
| 236 |
+
|
| 237 |
+
print(f"✅ SMPL-X loaded")
|
| 238 |
+
return model
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# =====================================================================
|
| 242 |
+
# Token Decoding
|
| 243 |
+
# =====================================================================
|
| 244 |
+
def decode_tokens_to_params(tokens, vqvae_model, mean=None, std=None, device=DEVICE):
|
| 245 |
+
"""
|
| 246 |
+
Decode motion tokens to SMPL-X parameters.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
tokens: List of motion token IDs
|
| 250 |
+
vqvae_model: Trained VQ-VAE model
|
| 251 |
+
mean: Optional normalization mean
|
| 252 |
+
std: Optional normalization std
|
| 253 |
+
device: Device to run on
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
numpy array of shape (T, SMPL_DIM) with SMPL-X parameters
|
| 257 |
+
"""
|
| 258 |
+
if not tokens:
|
| 259 |
+
return np.zeros((0, SMPL_DIM), dtype=np.float32)
|
| 260 |
+
|
| 261 |
+
# Prepare token indices
|
| 262 |
+
idx = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) # (1, T_q)
|
| 263 |
+
T_q = idx.shape[1]
|
| 264 |
+
|
| 265 |
+
quantizer = vqvae_model.vqvae.quantizer
|
| 266 |
+
|
| 267 |
+
# Get code dimension
|
| 268 |
+
if hasattr(quantizer, "codebook"):
|
| 269 |
+
codebook = quantizer.codebook.to(device)
|
| 270 |
+
code_dim = codebook.shape[1]
|
| 271 |
+
else:
|
| 272 |
+
code_dim = CODE_DIM
|
| 273 |
+
|
| 274 |
+
# Dequantize tokens
|
| 275 |
+
x_quantized = None
|
| 276 |
+
if hasattr(quantizer, "dequantize"):
|
| 277 |
+
try:
|
| 278 |
+
with torch.no_grad():
|
| 279 |
+
dq = quantizer.dequantize(idx)
|
| 280 |
+
if dq is not None:
|
| 281 |
+
dq = dq.contiguous()
|
| 282 |
+
# Ensure shape is (N, code_dim, T_q)
|
| 283 |
+
if dq.ndim == 3 and dq.shape[1] == code_dim:
|
| 284 |
+
x_quantized = dq
|
| 285 |
+
elif dq.ndim == 3 and dq.shape[1] == T_q:
|
| 286 |
+
x_quantized = dq.permute(0, 2, 1).contiguous()
|
| 287 |
+
else:
|
| 288 |
+
x_quantized = None
|
| 289 |
+
except Exception:
|
| 290 |
+
x_quantized = None
|
| 291 |
+
|
| 292 |
+
# Fallback: manual codebook lookup
|
| 293 |
+
if x_quantized is None:
|
| 294 |
+
if not hasattr(quantizer, "codebook"):
|
| 295 |
+
raise RuntimeError("No dequantize method and no codebook available")
|
| 296 |
+
with torch.no_grad():
|
| 297 |
+
emb = codebook[idx] # (1, T_q, code_dim)
|
| 298 |
+
x_quantized = emb.permute(0, 2, 1).contiguous() # (1, code_dim, T_q)
|
| 299 |
+
|
| 300 |
+
# Decode through VQ-VAE decoder
|
| 301 |
+
with torch.no_grad():
|
| 302 |
+
x_dec = vqvae_model.vqvae.decoder(x_quantized)
|
| 303 |
+
smpl_out = vqvae_model.vqvae.postprocess(x_dec) # (1, T_out, SMPL_DIM)
|
| 304 |
+
params_np = smpl_out.squeeze(0).cpu().numpy() # (T_out, SMPL_DIM)
|
| 305 |
+
|
| 306 |
+
# Denormalize if stats provided
|
| 307 |
+
if (mean is not None) and (std is not None):
|
| 308 |
+
mean_arr = np.array(mean).reshape(1, -1)
|
| 309 |
+
std_arr = np.array(std).reshape(1, -1)
|
| 310 |
+
params_np = (params_np * std_arr) + mean_arr
|
| 311 |
+
|
| 312 |
+
return params_np
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# =====================================================================
|
| 316 |
+
# SMPL-X Parameter to Vertices
|
| 317 |
+
# =====================================================================
|
| 318 |
+
def params_to_vertices(params_seq, smplx_model, batch_size=32):
|
| 319 |
+
"""
|
| 320 |
+
Convert SMPL-X parameters to 3D vertices.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
params_seq: numpy array (T, SMPL_DIM)
|
| 324 |
+
smplx_model: loaded SMPL-X model
|
| 325 |
+
batch_size: batch size for processing
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
verts: numpy array (T, V, 3)
|
| 329 |
+
faces: numpy array (F, 3)
|
| 330 |
+
"""
|
| 331 |
+
# Compute parameter slicing indices
|
| 332 |
+
starts = np.cumsum([0] + PARAM_DIMS[:-1])
|
| 333 |
+
ends = starts + np.array(PARAM_DIMS)
|
| 334 |
+
|
| 335 |
+
T = params_seq.shape[0]
|
| 336 |
+
all_verts = []
|
| 337 |
+
|
| 338 |
+
# Infer number of body joints
|
| 339 |
+
num_body_joints = getattr(smplx_model, "NUM_BODY_JOINTS", 21)
|
| 340 |
+
|
| 341 |
+
with torch.no_grad():
|
| 342 |
+
for s in range(0, T, batch_size):
|
| 343 |
+
batch = params_seq[s:s+batch_size] # (B, SMPL_DIM)
|
| 344 |
+
B = batch.shape[0]
|
| 345 |
+
|
| 346 |
+
# Extract parameters
|
| 347 |
+
np_parts = {}
|
| 348 |
+
for name, st, ed in zip(PARAM_NAMES, starts, ends):
|
| 349 |
+
np_parts[name] = batch[:, st:ed].astype(np.float32)
|
| 350 |
+
|
| 351 |
+
# Convert to tensors
|
| 352 |
+
tensor_parts = {
|
| 353 |
+
name: torch.from_numpy(arr).to(DEVICE)
|
| 354 |
+
for name, arr in np_parts.items()
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
# Handle body pose (may or may not include global orient)
|
| 358 |
+
body_t = tensor_parts['body_pose']
|
| 359 |
+
L_body = body_t.shape[1]
|
| 360 |
+
expected_no_go = num_body_joints * 3
|
| 361 |
+
expected_with_go = (num_body_joints + 1) * 3
|
| 362 |
+
|
| 363 |
+
if L_body == expected_with_go:
|
| 364 |
+
global_orient = body_t[:, :3].contiguous()
|
| 365 |
+
body_pose_only = body_t[:, 3:].contiguous()
|
| 366 |
+
elif L_body == expected_no_go:
|
| 367 |
+
global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
|
| 368 |
+
body_pose_only = body_t
|
| 369 |
+
else:
|
| 370 |
+
# Best-effort fallback
|
| 371 |
+
if L_body > expected_no_go:
|
| 372 |
+
global_orient = body_t[:, :3].contiguous()
|
| 373 |
+
body_pose_only = body_t[:, 3:].contiguous()
|
| 374 |
+
else:
|
| 375 |
+
pad_len = max(0, expected_no_go - L_body)
|
| 376 |
+
body_pose_only = F.pad(body_t, (0, pad_len))
|
| 377 |
+
global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
|
| 378 |
+
|
| 379 |
+
# Call SMPL-X
|
| 380 |
+
out = smplx_model(
|
| 381 |
+
betas=tensor_parts['betas'],
|
| 382 |
+
global_orient=global_orient,
|
| 383 |
+
body_pose=body_pose_only,
|
| 384 |
+
left_hand_pose=tensor_parts['left_hand_pose'],
|
| 385 |
+
right_hand_pose=tensor_parts['right_hand_pose'],
|
| 386 |
+
expression=tensor_parts['expression'],
|
| 387 |
+
jaw_pose=tensor_parts['jaw_pose'],
|
| 388 |
+
leye_pose=tensor_parts['eye_pose'],
|
| 389 |
+
reye_pose=tensor_parts['eye_pose'],
|
| 390 |
+
transl=tensor_parts['trans'],
|
| 391 |
+
return_verts=True
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
verts = out.vertices.detach().cpu().numpy() # (B, V, 3)
|
| 395 |
+
all_verts.append(verts)
|
| 396 |
+
|
| 397 |
+
verts_all = np.concatenate(all_verts, axis=0) # (T, V, 3)
|
| 398 |
+
faces = smplx_model.faces.astype(np.int32)
|
| 399 |
+
|
| 400 |
+
return verts_all, faces
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# =====================================================================
|
| 404 |
+
# Visualization
|
| 405 |
+
# =====================================================================
|
| 406 |
+
def animate_motion(verts, faces, title="Generated Motion", output_path=None, fps=20):
|
| 407 |
+
"""
|
| 408 |
+
Create interactive 3D animation using Plotly.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
verts: numpy array (T, V, 3)
|
| 412 |
+
faces: numpy array (F, 3)
|
| 413 |
+
title: Plot title
|
| 414 |
+
output_path: Path to save HTML file
|
| 415 |
+
fps: Frames per second for animation
|
| 416 |
+
|
| 417 |
+
Returns:
|
| 418 |
+
Plotly figure object
|
| 419 |
+
"""
|
| 420 |
+
T, V, _ = verts.shape
|
| 421 |
+
i, j, k = faces.T.tolist()
|
| 422 |
+
|
| 423 |
+
# Initial mesh
|
| 424 |
+
mesh = go.Mesh3d(
|
| 425 |
+
x=verts[0, :, 0],
|
| 426 |
+
y=verts[0, :, 1],
|
| 427 |
+
z=verts[0, :, 2],
|
| 428 |
+
i=i, j=j, k=k,
|
| 429 |
+
name=title,
|
| 430 |
+
flatshading=True,
|
| 431 |
+
opacity=0.7
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Create frames
|
| 435 |
+
frames = [
|
| 436 |
+
go.Frame(
|
| 437 |
+
data=[go.Mesh3d(
|
| 438 |
+
x=verts[t, :, 0],
|
| 439 |
+
y=verts[t, :, 1],
|
| 440 |
+
z=verts[t, :, 2],
|
| 441 |
+
i=i, j=j, k=k,
|
| 442 |
+
flatshading=True,
|
| 443 |
+
opacity=0.7
|
| 444 |
+
)],
|
| 445 |
+
name=str(t)
|
| 446 |
+
)
|
| 447 |
+
for t in range(T)
|
| 448 |
+
]
|
| 449 |
+
|
| 450 |
+
# Create figure
|
| 451 |
+
fig = go.Figure(data=[mesh], frames=frames)
|
| 452 |
+
|
| 453 |
+
fig.update_layout(
|
| 454 |
+
title_text=title,
|
| 455 |
+
scene=dict(
|
| 456 |
+
aspectmode='data',
|
| 457 |
+
xaxis=dict(visible=False),
|
| 458 |
+
yaxis=dict(visible=False),
|
| 459 |
+
zaxis=dict(visible=False),
|
| 460 |
+
camera=dict(eye=dict(x=0, y=-2, z=0.7))
|
| 461 |
+
),
|
| 462 |
+
updatemenus=[dict(
|
| 463 |
+
type="buttons",
|
| 464 |
+
buttons=[
|
| 465 |
+
dict(
|
| 466 |
+
label="Play",
|
| 467 |
+
method="animate",
|
| 468 |
+
args=[None, {
|
| 469 |
+
"frame": {"duration": 1000//fps, "redraw": True},
|
| 470 |
+
"fromcurrent": True
|
| 471 |
+
}]
|
| 472 |
+
),
|
| 473 |
+
dict(
|
| 474 |
+
label="Pause",
|
| 475 |
+
method="animate",
|
| 476 |
+
args=[[None], {
|
| 477 |
+
"frame": {"duration": 0, "redraw": False}
|
| 478 |
+
}]
|
| 479 |
+
)
|
| 480 |
+
]
|
| 481 |
+
)]
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
# Save HTML
|
| 485 |
+
if output_path:
|
| 486 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 487 |
+
fig.write_html(output_path)
|
| 488 |
+
print(f"✅ Animation saved to: {output_path}")
|
| 489 |
+
|
| 490 |
+
return fig
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
# =====================================================================
|
| 494 |
+
# Main Visualization Pipeline
|
| 495 |
+
# =====================================================================
|
| 496 |
+
def visualize(
|
| 497 |
+
tokens,
|
| 498 |
+
vqvae_ckpt=VQVAE_CHECKPOINT,
|
| 499 |
+
stats_path=STATS_PATH,
|
| 500 |
+
smplx_dir=SMPLX_MODEL_DIR,
|
| 501 |
+
output_html=None,
|
| 502 |
+
title="Generated Motion",
|
| 503 |
+
fps=20
|
| 504 |
+
):
|
| 505 |
+
"""
|
| 506 |
+
Complete visualization pipeline: tokens -> vertices -> animation.
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
tokens: Motion tokens (string or list of ints)
|
| 510 |
+
vqvae_ckpt: Path to VQ-VAE checkpoint
|
| 511 |
+
stats_path: Path to normalization stats
|
| 512 |
+
smplx_dir: Path to SMPL-X model directory
|
| 513 |
+
output_html: Path to save HTML animation
|
| 514 |
+
title: Animation title
|
| 515 |
+
fps: Frames per second
|
| 516 |
+
|
| 517 |
+
Returns:
|
| 518 |
+
Plotly figure object
|
| 519 |
+
"""
|
| 520 |
+
print("="*60)
|
| 521 |
+
print("Motion Visualization Pipeline")
|
| 522 |
+
print("="*60)
|
| 523 |
+
|
| 524 |
+
# Parse tokens
|
| 525 |
+
print("\n[1/5] Parsing tokens...")
|
| 526 |
+
token_list = parse_motion_tokens(tokens)
|
| 527 |
+
print(f" Parsed {len(token_list)} tokens")
|
| 528 |
+
if not token_list:
|
| 529 |
+
print("❌ No tokens to visualize")
|
| 530 |
+
return None
|
| 531 |
+
|
| 532 |
+
# Load models
|
| 533 |
+
print("\n[2/5] Loading VQ-VAE...")
|
| 534 |
+
vq_model = load_vqvae(vqvae_ckpt, device=DEVICE)
|
| 535 |
+
|
| 536 |
+
print("\n[3/5] Loading normalization stats...")
|
| 537 |
+
mean, std = load_stats(stats_path)
|
| 538 |
+
|
| 539 |
+
print("\n[4/5] Loading SMPL-X model...")
|
| 540 |
+
smplx_model = load_smplx_model(smplx_dir, device=DEVICE)
|
| 541 |
+
|
| 542 |
+
# Decode tokens
|
| 543 |
+
print("\n[5/5] Decoding and rendering...")
|
| 544 |
+
print(" Decoding tokens to SMPL-X parameters...")
|
| 545 |
+
params = decode_tokens_to_params(token_list, vq_model, mean, std, device=DEVICE)
|
| 546 |
+
print(f" Decoded params shape: {params.shape}")
|
| 547 |
+
|
| 548 |
+
if params.shape[0] == 0:
|
| 549 |
+
print("❌ No frames produced from decoder")
|
| 550 |
+
return None
|
| 551 |
+
|
| 552 |
+
# Convert to vertices
|
| 553 |
+
print(" Converting parameters to vertices...")
|
| 554 |
+
verts, faces = params_to_vertices(params, smplx_model, batch_size=32)
|
| 555 |
+
print(f" Vertices shape: {verts.shape}, Faces: {faces.shape}")
|
| 556 |
+
|
| 557 |
+
# Create animation
|
| 558 |
+
print(" Creating animation...")
|
| 559 |
+
if output_html is None:
|
| 560 |
+
output_html = os.path.join(OUTPUT_DIR, "motion_animation.html")
|
| 561 |
+
|
| 562 |
+
fig = animate_motion(verts, faces, title=title, output_path=output_html, fps=fps)
|
| 563 |
+
|
| 564 |
+
print("\n" + "="*60)
|
| 565 |
+
print("✅ Visualization complete!")
|
| 566 |
+
print("="*60)
|
| 567 |
+
|
| 568 |
+
return fig
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
# =====================================================================
|
| 572 |
+
# CLI
|
| 573 |
+
# =====================================================================
|
| 574 |
+
def main():
|
| 575 |
+
parser = argparse.ArgumentParser(
|
| 576 |
+
description="Visualize motion tokens as 3D SMPL-X animation"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# Input options (mutually exclusive)
|
| 580 |
+
input_group = parser.add_mutually_exclusive_group(required=True)
|
| 581 |
+
input_group.add_argument(
|
| 582 |
+
"--tokens",
|
| 583 |
+
type=str,
|
| 584 |
+
help="Motion tokens string (e.g., '<MOT_BEGIN><motion_177>...<MOT_END>' or '177 135 152...')"
|
| 585 |
+
)
|
| 586 |
+
input_group.add_argument(
|
| 587 |
+
"--input",
|
| 588 |
+
type=str,
|
| 589 |
+
help="Path to file containing motion tokens"
|
| 590 |
+
)
|
| 591 |
+
input_group.add_argument(
|
| 592 |
+
"--prompt",
|
| 593 |
+
type=str,
|
| 594 |
+
help="Generate tokens from text prompt first (requires --stage)"
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# Generation options (if using --prompt)
|
| 598 |
+
parser.add_argument(
|
| 599 |
+
"--stage",
|
| 600 |
+
type=int,
|
| 601 |
+
default=3,
|
| 602 |
+
choices=[1, 2, 3],
|
| 603 |
+
help="Stage model to use for generation (default: 3)"
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
# Model paths
|
| 607 |
+
parser.add_argument(
|
| 608 |
+
"--vqvae-ckpt",
|
| 609 |
+
type=str,
|
| 610 |
+
default=VQVAE_CHECKPOINT,
|
| 611 |
+
help=f"Path to VQ-VAE checkpoint (default: {VQVAE_CHECKPOINT})"
|
| 612 |
+
)
|
| 613 |
+
parser.add_argument(
|
| 614 |
+
"--stats",
|
| 615 |
+
type=str,
|
| 616 |
+
default=STATS_PATH,
|
| 617 |
+
help=f"Path to normalization stats (default: {STATS_PATH})"
|
| 618 |
+
)
|
| 619 |
+
parser.add_argument(
|
| 620 |
+
"--smplx-dir",
|
| 621 |
+
type=str,
|
| 622 |
+
default=SMPLX_MODEL_DIR,
|
| 623 |
+
help=f"Path to SMPL-X model directory (default: {SMPLX_MODEL_DIR})"
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
# Output options
|
| 627 |
+
parser.add_argument(
|
| 628 |
+
"--output",
|
| 629 |
+
type=str,
|
| 630 |
+
default=None,
|
| 631 |
+
help="Path to save HTML animation (default: motion_animation.html)"
|
| 632 |
+
)
|
| 633 |
+
parser.add_argument(
|
| 634 |
+
"--title",
|
| 635 |
+
type=str,
|
| 636 |
+
default="Generated Motion",
|
| 637 |
+
help="Animation title"
|
| 638 |
+
)
|
| 639 |
+
parser.add_argument(
|
| 640 |
+
"--fps",
|
| 641 |
+
type=int,
|
| 642 |
+
default=20,
|
| 643 |
+
help="Frames per second for animation (default: 20)"
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
args = parser.parse_args()
|
| 647 |
+
|
| 648 |
+
# Get tokens
|
| 649 |
+
if args.prompt:
|
| 650 |
+
# Generate tokens first using inference.py
|
| 651 |
+
print("Generating motion tokens from prompt...")
|
| 652 |
+
from inference import inference
|
| 653 |
+
tokens = inference(
|
| 654 |
+
prompt=args.prompt,
|
| 655 |
+
stage=args.stage,
|
| 656 |
+
output_file=None,
|
| 657 |
+
per_prompt_vocab=True
|
| 658 |
+
)
|
| 659 |
+
elif args.input:
|
| 660 |
+
# Read from file
|
| 661 |
+
with open(args.input, 'r') as f:
|
| 662 |
+
tokens = f.read().strip()
|
| 663 |
+
else:
|
| 664 |
+
# Direct token string
|
| 665 |
+
tokens = args.tokens
|
| 666 |
+
|
| 667 |
+
# Visualize
|
| 668 |
+
visualize(
|
| 669 |
+
tokens=tokens,
|
| 670 |
+
vqvae_ckpt=args.vqvae_ckpt,
|
| 671 |
+
stats_path=args.stats,
|
| 672 |
+
smplx_dir=args.smplx_dir,
|
| 673 |
+
output_html=args.output,
|
| 674 |
+
title=args.title,
|
| 675 |
+
fps=args.fps
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
if __name__ == "__main__":
|
| 680 |
+
main()
|
| 681 |
+
|