rdz-falcon commited on
Commit
4bd136e
·
1 Parent(s): bf06606

Deploy SignMotionGPT Demo with LFS

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/motion_llm_dataset.json filter=lfs diff=lfs merge=lfs -text
37
+ /content/SignMotionGPT/data/vqvae_model.pt filter=lfs diff=lfs merge=lfs -text
38
+ /content/SignMotionGPT/data/smplx_models/SMPLX_NEUTRAL.npz filter=lfs diff=lfs merge=lfs -text
39
+ /content/SignMotionGPT/data/vqvae_stats.pt filter=lfs diff=lfs merge=lfs -text
INFERENCE_AND_VIS.md ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference & Visualization Quick Reference
2
+
3
+ ## Overview
4
+ After training your 3-stage SignMotionGPT model, use these scripts to generate and visualize motions.
5
+
6
+ ---
7
+
8
+ ## 1. Inference (Generate Motion Tokens)
9
+
10
+ ### Basic Usage
11
+ ```bash
12
+ # Generate from Stage 3 model (recommended)
13
+ python inference.py --prompt "walking forward"
14
+
15
+ # Try different stages
16
+ python inference.py --prompt "dancing" --stage 1 # Motion-only LM
17
+ python inference.py --prompt "dancing" --stage 2 # Multi-task
18
+ python inference.py --prompt "dancing" --stage 3 # T2M SFT (best quality)
19
+ ```
20
+
21
+ ### Save Output
22
+ ```bash
23
+ python inference.py --prompt "jumping" --output my_motion.txt
24
+ ```
25
+
26
+ ### With Participant ID
27
+ ```bash
28
+ python inference.py --prompt "yoga pose" --pid P40
29
+ ```
30
+
31
+ ### Expected Output
32
+ ```
33
+ ============================================================
34
+ Motion Generation Inference - Stage 3
35
+ ============================================================
36
+ Prompt: 'walking forward'
37
+ Device: cuda
38
+
39
+ Loading Stage 3 model from: /kaggle/working/SignMotionGPT/stage3_t2m_sft
40
+ ✅ Stage 3 model loaded successfully
41
+
42
+ Generating motion for: 'walking forward'
43
+
44
+ ============================================================
45
+ Generated Motion:
46
+ ============================================================
47
+ <MOT_BEGIN><motion_224><motion_39><motion_76>...<MOT_END>
48
+ ============================================================
49
+ ```
50
+
51
+ ---
52
+
53
+ ## 2. Visualization (Motion Tokens → 3D Animation)
54
+
55
+ ### Prerequisites
56
+
57
+ #### Option A: Use Google Drive (Colab/Kaggle)
58
+ Edit `setup_env.sh` and add your Google Drive file IDs:
59
+ ```bash
60
+ VQVAE_MODEL_ID="1AbCdEfGhIj" # VQ-VAE checkpoint (.pt)
61
+ VQVAE_STATS_ID="2KlMnOpQrSt" # Normalization stats (.pt)
62
+ SMPLX_MODELS_ID="3UvWxYzAbCd" # SMPL-X models (.zip)
63
+ ```
64
+
65
+ Then run:
66
+ ```bash
67
+ bash setup_env.sh
68
+ ```
69
+
70
+ #### Option B: Manual Setup (Local)
71
+ ```bash
72
+ export VQVAE_CHECKPOINT=/path/to/vqvae_model.pt
73
+ export VQVAE_STATS_PATH=/path/to/vqvae_stats.pt
74
+ export SMPLX_MODEL_DIR=/path/to/smplx_models
75
+ ```
76
+
77
+ ### Basic Usage
78
+
79
+ ```bash
80
+ # Visualize token string
81
+ python visualize.py --tokens "<MOT_BEGIN><motion_177><motion_135>...<MOT_END>"
82
+
83
+ # Visualize from file
84
+ python visualize.py --input my_motion.txt
85
+
86
+ # Generate + visualize in one command
87
+ python visualize.py --prompt "walking" --stage 3
88
+ ```
89
+
90
+ ### Custom Output
91
+ ```bash
92
+ python visualize.py \
93
+ --input motion_tokens.txt \
94
+ --output walk_animation.html \
95
+ --title "Walking Forward" \
96
+ --fps 30
97
+ ```
98
+
99
+ ### With Custom Paths
100
+ ```bash
101
+ python visualize.py \
102
+ --tokens "<MOT_BEGIN>..." \
103
+ --vqvae-ckpt /custom/vqvae.pt \
104
+ --stats /custom/stats.pt \
105
+ --smplx-dir /custom/smplx_models \
106
+ --output animation.html
107
+ ```
108
+
109
+ ### Expected Output
110
+ ```
111
+ ============================================================
112
+ Motion Visualization Pipeline
113
+ ============================================================
114
+
115
+ [1/5] Parsing tokens...
116
+ Parsed 15 tokens
117
+
118
+ [2/5] Loading VQ-VAE...
119
+ ✅ VQ-VAE loaded (codebook size: 512)
120
+
121
+ [3/5] Loading normalization stats...
122
+ ✅ Stats loaded (mean shape: (182,))
123
+
124
+ [4/5] Loading SMPL-X model...
125
+ ✅ SMPL-X loaded
126
+
127
+ [5/5] Decoding and rendering...
128
+ Decoding tokens to SMPL-X parameters...
129
+ Decoded params shape: (16, 182)
130
+ Converting parameters to vertices...
131
+ Vertices shape: (16, 10475, 3), Faces: (20908, 3)
132
+ Creating animation...
133
+ ✅ Animation saved to: motion_animation.html
134
+
135
+ ============================================================
136
+ ✅ Visualization complete!
137
+ ============================================================
138
+ ```
139
+
140
+ ---
141
+
142
+ ## 3. Complete Workflow Example
143
+
144
+ ### A. Train (already done)
145
+ ```bash
146
+ python train_pipeline.py
147
+ ```
148
+
149
+ ### B. Generate Motion Tokens
150
+ ```bash
151
+ python inference.py --prompt "college" --stage 3 --output college_motion.txt
152
+ ```
153
+
154
+ ### C. Visualize
155
+ ```bash
156
+ python visualize.py --input college_motion.txt --output college_animation.html
157
+ ```
158
+
159
+ ### D. View Animation
160
+ Open `college_animation.html` in a browser. You'll see an interactive 3D SMPL-X character performing the motion. Use mouse to rotate/zoom, and click Play/Pause buttons.
161
+
162
+ ---
163
+
164
+ ## 4. Troubleshooting
165
+
166
+ ### Inference Issues
167
+
168
+ **"Checkpoint not found"**
169
+ - Ensure you've trained all stages first: `python train_pipeline.py`
170
+ - Check that `OUT_S1`, `OUT_S2`, `OUT_S3` directories exist in `WORK_DIR`
171
+
172
+ **"Dataset not found"**
173
+ - Inference needs the dataset to build vocabulary
174
+ - Set `DATA_JSON_PATH` in `config.py` or via environment variable
175
+
176
+ ### Visualization Issues
177
+
178
+ **"VQ-VAE checkpoint not found"**
179
+ - Download VQ-VAE model or set `VQVAE_CHECKPOINT` path
180
+ - The VQ-VAE is separate from LLM training (used to decode tokens to SMPL-X params)
181
+
182
+ **"SMPL-X models not found"**
183
+ - Download SMPL-X models from https://smpl-x.is.tue.mpg.de/
184
+ - Extract to a directory and set `SMPLX_MODEL_DIR`
185
+
186
+ **"No tokens to visualize"**
187
+ - Check token format: should contain `<motion_ID>` tags or space-separated numbers
188
+ - Example valid formats:
189
+ - `<MOT_BEGIN><motion_177><motion_135><MOT_END>`
190
+ - `177 135 152 200 46 142`
191
+
192
+ **"Shape mismatch" or "Decoding errors"**
193
+ - Ensure VQ-VAE checkpoint matches the codebook size used in LLM training
194
+ - Check `CODEBOOK_SIZE`, `CODE_DIM`, `SMPL_DIM` in `visualize.py` match training
195
+
196
+ ---
197
+
198
+ ## 5. Configuration
199
+
200
+ ### Key Environment Variables
201
+
202
+ | Variable | Purpose | Default |
203
+ |----------|---------|---------|
204
+ | `VQVAE_CHECKPOINT` | VQ-VAE model path | `./data/vqvae_model.pt` |
205
+ | `VQVAE_STATS_PATH` | Normalization stats | `./data/vqvae_stats.pt` |
206
+ | `SMPLX_MODEL_DIR` | SMPL-X models directory | `./data/smplx_models` |
207
+ | `VIS_OUTPUT_DIR` | Output directory for animations | `WORK_DIR` |
208
+
209
+ ### VQ-VAE Architecture (must match training)
210
+ In `visualize.py`:
211
+ ```python
212
+ SMPL_DIM = 182 # SMPL-X parameter dimension
213
+ CODEBOOK_SIZE = 512 # Motion vocabulary size
214
+ CODE_DIM = 512 # Latent code dimension
215
+ VQ_ARGS = dict(
216
+ width=512,
217
+ depth=3,
218
+ down_t=2,
219
+ stride_t=2,
220
+ ...
221
+ )
222
+ ```
223
+
224
+ ---
225
+
226
+ ## 6. Tips
227
+
228
+ ### Inference
229
+ - **Stage 3** generally produces best quality for text-to-motion
230
+ - **Stage 2** can handle M2T and denoising (but inference.py only does T2M)
231
+ - **Stage 1** generates motion without text conditioning (still needs prompt for length)
232
+ - Use `--no-per-prompt-vocab` to allow novel combinations (less constrained)
233
+
234
+ ### Visualization
235
+ - **FPS 20-30** works well for most motions
236
+ - Longer sequences may take a few seconds to render
237
+ - The HTML file is self-contained and can be shared
238
+ - 3D mesh has ~10K vertices; animations can be large for long sequences
239
+
240
+ ### Performance
241
+ - Inference: ~1-2 seconds per generation (depends on length)
242
+ - Visualization: ~3-10 seconds (depends on sequence length and batch size)
243
+ - Both run on GPU if available, fall back to CPU otherwise
244
+
245
+ ---
246
+
247
+ ## 7. Next Steps
248
+
249
+ - **Batch Inference**: Loop over multiple prompts and save outputs
250
+ - **Evaluate Quality**: Compare generated tokens to ground truth using edit distance
251
+ - **Fine-tune Generation**: Adjust `GEN_TEMPERATURE`, `GEN_TOP_P` in `config.py`
252
+ - **Export to Other Formats**: Extend `visualize.py` to export BVH, FBX, or USD
253
+
README.md CHANGED
@@ -1,12 +1,116 @@
1
- ---
2
- title: SignMotionGPT
3
- emoji: 🌍
4
- colorFrom: yellow
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 6.0.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ### 1) Configure setup script (one time)
3
+
4
+ Run the setup:
5
+
6
+ ```bash
7
+ bash setup_env.sh
8
+ ```
9
+
10
+ After setup, defaults are:
11
+ - `WORK_DIR` = current directory
12
+ - `DATA_JSON_PATH` = `./data/motion_llm_dataset.json`
13
+
14
+ You can override via environment variables if needed:
15
+
16
+ ```bash
17
+ export WORK_DIR=/path/to/workdir
18
+ export DATA_JSON_PATH=/path/to/motion_llm_dataset.json
19
+ ```
20
+
21
+ ## Overview
22
+
23
+ This repository implements a robust 2-stage training pipeline for motion generation, replicating the high-performance "overfit" test setup:
24
+ - **Stage 1**: Motion-only Language Model (MLM) - Pre-training on motion token sequences to learn the "language of motion".
25
+ - **Stage 2**: Text-to-Motion Fine-Tuning (T2M) - Supervised fine-tuning to align text prompts with motion sequences.
26
+
27
+ Key features:
28
+ - **Integrated Evaluation**: Automatically computes FID, Diversity, and Multimodality (MIM) metrics.
29
+ - **Side-by-Side Visualization**: Generates HTML comparisons of Ground Truth vs Generated motions.
30
+ - **Test Set Evaluation**: Can optionally run evaluation on a held-out test set (SMPL-X data).
31
+ - **Hugging Face Integration**: Automatic checkpointing and resuming from the Hub.
32
+
33
+ ## Installation
34
+
35
+ ```bash
36
+ # Clone the repository
37
+ git clone https://github.com/rajvizala/SignMotionGPT.git
38
+ cd SignMotionGPT
39
+
40
+ # Setup Everything
41
+ bash setup_env.sh
42
+ ```
43
+
44
+ ## Dataset Format
45
+
46
+ Your dataset should be a JSON file with the following structure:
47
+
48
+ ```json
49
+ [
50
+ {
51
+ "text_query": "a person walks forward",
52
+ "motion_tokens": "42 18 91 ...",
53
+ "participant_id": "P001" // Optional
54
+ },
55
+ ...
56
+ ]
57
+ ```
58
+
59
+ ## Quick Start
60
+
61
+ ### 1. Configure Training
62
+
63
+ Edit `config.py` to set your paths and hyperparameters. Key settings include:
64
+ - `DATA_JSON_PATH`: Path to your dataset.
65
+ - `MODEL_NAME`: Base model (e.g., "Qwen/Qwen3-0.6B").
66
+ - `PIPELINE_OUTPUT_DIR`: Directory for checkpoints and results.
67
+ - `HF_TOKEN`: Your Hugging Face token (or set via env var).
68
+
69
+ ### 2. Run Full Pipeline
70
+
71
+ ```bash
72
+ python train_pipeline.py
73
+ ```
74
+
75
+ This script orchestrates the entire process:
76
+ 1. **Data Loading & Cleaning**: Deduplicates samples and builds vocabulary.
77
+ 2. **Stage 1 Training**: Motion Language Modeling (Pre-training).
78
+ 3. **Stage 2 Training**: Text-to-Motion Fine-Tuning.
79
+ 4. **Evaluation**: Runs inference on specific words, computes metrics (FID, Diversity, MIM), and generates visualizations.
80
+ 5. **Test Set Evaluation**: (Optional) Runs evaluation on held-out test data if configured.
81
+
82
+ ### 3. Environment Variables
83
+
84
+ You can control many aspects via environment variables without editing code:
85
+
86
+ ```bash
87
+ # Training Config
88
+ export PIPELINE_S1_EPOCHS=20
89
+ export PIPELINE_S2_EPOCHS=20
90
+ export PIPELINE_S1_BATCH=8
91
+ export PIPELINE_S2_BATCH=8
92
+
93
+ # Hugging Face
94
+ export HUGGINGFACE_HUB_TOKEN="your_token"
95
+ export HF_UPLOAD_INTERVAL_EPOCHS=2
96
+
97
+ # Evaluation
98
+ export EVALUATION_WORDS="passport,send,library"
99
+ export TEST_EVAL_SAMPLE_LIMIT=100
100
+ ```
101
+
102
+ ## Held-out Test Dataset Evaluation
103
+
104
+ The pipeline includes integration with `test_dataset_eval.py` to measure performance on an unseen SMPL-X test dataset.
105
+
106
+ To enable this, ensure `TEST_EVAL_DOWNLOAD_DIR` or `TEST_EVAL_EXTRACT_DIR` are configured in `config.py` or via env vars. The pipeline will attempt to run this after training if data is available.
107
+
108
+ ## Visualization
109
+
110
+ The pipeline automatically generates side-by-side HTML visualizations in the output directory (`html_visualizations` folder). You can open these in any browser to compare Ground Truth motions with the model's generations.
111
+
112
+ To manually visualize tokens:
113
+
114
+ ```bash
115
+ python visualize.py --tokens "<MOT_BEGIN><motion_177>...<MOT_END>" --output my_anim.html
116
+ ```
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import sys
5
+ import warnings
6
+ from pathlib import Path
7
+
8
+ # Add root to path to allow imports from project root when running from demo-code/
9
+ # or when running from root
10
+ current_dir = os.path.dirname(os.path.abspath(__file__))
11
+ parent_dir = os.path.dirname(current_dir)
12
+ sys.path.append(current_dir)
13
+ sys.path.append(parent_dir)
14
+
15
+ # Import project modules
16
+ try:
17
+ from inference import load_trained_model, inference as run_inference_cmd
18
+ from visualize import visualize
19
+ from model import setup_model_and_tokenizer, get_motion_token_info
20
+ from generate import generate_t2m
21
+ from data import compute_length_stats, build_prompt_vocab, check_has_participant_id, load_dataset
22
+ import config
23
+ except ImportError as e:
24
+ print(f"Error importing project modules: {e}")
25
+ print("Make sure you are running this from the project root or have the project structure intact.")
26
+
27
+ # Constants
28
+ HF_REPO_ID = "rdz-falcon/SignMotionGPT"
29
+ EPOCH_SUBFOLDER = "stage2/epoch-030"
30
+
31
+ def load_model_from_hf(repo_id, subfolder, token=None):
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer
33
+ print(f"Loading model from HF: {repo_id}/{subfolder}")
34
+ try:
35
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder, token=token, trust_remote_code=True)
36
+ model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=subfolder, token=token, trust_remote_code=True)
37
+ return model, tokenizer
38
+ except Exception as e:
39
+ print(f"Error loading model: {e}")
40
+ return None, None
41
+
42
+ # Global model cache
43
+ MODEL = None
44
+ TOKENIZER = None
45
+ MOTION_TOKEN_IDS = None
46
+ MOT_BEGIN_ID = None
47
+ MOT_END_ID = None
48
+ CODEBOOK_SIZE = 512
49
+
50
+ def init_model():
51
+ global MODEL, TOKENIZER, MOTION_TOKEN_IDS, MOT_BEGIN_ID, MOT_END_ID
52
+ if MODEL is not None:
53
+ return
54
+
55
+ token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
56
+
57
+ # Load model/tokenizer
58
+ MODEL, TOKENIZER = load_model_from_hf(HF_REPO_ID, EPOCH_SUBFOLDER, token)
59
+
60
+ if MODEL is None:
61
+ raise RuntimeError(f"Failed to load model from {HF_REPO_ID}/{EPOCH_SUBFOLDER}")
62
+
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ MODEL.to(device)
65
+ MODEL.eval()
66
+
67
+ # Setup token info
68
+ motion_token_ids = []
69
+ for i in range(CODEBOOK_SIZE):
70
+ t = f"<motion_{i}>"
71
+ if t in TOKENIZER.get_vocab():
72
+ motion_token_ids.append(TOKENIZER.convert_tokens_to_ids(t))
73
+
74
+ MOTION_TOKEN_IDS = motion_token_ids
75
+ MOT_BEGIN_ID = TOKENIZER.convert_tokens_to_ids("<MOT_BEGIN>") if "<MOT_BEGIN>" in TOKENIZER.get_vocab() else None
76
+ MOT_END_ID = TOKENIZER.convert_tokens_to_ids("<MOT_END>") if "<MOT_END>" in TOKENIZER.get_vocab() else None
77
+
78
+ print("Model initialized.")
79
+
80
+ def generate_motion_app(text_prompt):
81
+ if not text_prompt:
82
+ return None, "Please enter a prompt."
83
+
84
+ if MODEL is None:
85
+ try:
86
+ init_model()
87
+ except Exception as e:
88
+ return None, f"Model Initialization Failed: {e}"
89
+
90
+ device = MODEL.device
91
+ print(f"Generating for: {text_prompt}")
92
+
93
+ try:
94
+ generated_tokens = generate_t2m(
95
+ model=MODEL,
96
+ tokenizer=TOKENIZER,
97
+ prompt_text=text_prompt,
98
+ mot_begin_id=MOT_BEGIN_ID,
99
+ mot_end_id=MOT_END_ID,
100
+ motion_token_ids=MOTION_TOKEN_IDS,
101
+ length_stats_by_text={}, # Fallback to global_median_len
102
+ global_median_len=100, # Reasonable default
103
+ prompt_vocab=None,
104
+ has_pid=False,
105
+ per_prompt_vocab=False # Allow all tokens
106
+ )
107
+ except Exception as e:
108
+ return None, f"Generation Error: {e}"
109
+
110
+ # Visualization
111
+ try:
112
+ # Ensure paths for VQ-VAE and SMPL-X
113
+ # In HF Spaces, we assume these are in the repo (e.g., ./data)
114
+ data_dir = os.environ.get("DATA_DIR", "data")
115
+ vqvae_ckpt = os.path.join(data_dir, "vqvae_model.pt")
116
+ stats_path = os.path.join(data_dir, "vqvae_stats.pt")
117
+ smplx_dir = os.path.join(data_dir, "smplx_models")
118
+
119
+ # Check existence
120
+ missing = []
121
+ if not os.path.exists(vqvae_ckpt): missing.append(vqvae_ckpt)
122
+ if not os.path.exists(stats_path): missing.append(stats_path)
123
+ if not os.path.exists(smplx_dir): missing.append(smplx_dir)
124
+
125
+ if missing:
126
+ return None, f"Missing visualization files in {data_dir}: {missing}. Please ensure they are uploaded to the Space."
127
+
128
+ # Output to a temporary file
129
+ # Gradio needs a file path or HTML string. visualize returns a Figure.
130
+ output_html = "temp_viz.html"
131
+
132
+ fig = visualize(
133
+ tokens=generated_tokens,
134
+ vqvae_ckpt=vqvae_ckpt,
135
+ stats_path=stats_path,
136
+ smplx_dir=smplx_dir,
137
+ output_html=output_html,
138
+ title=f"Motion: {text_prompt}",
139
+ fps=20
140
+ )
141
+
142
+ if fig is None:
143
+ return None, "Visualization failed (no frames produced)."
144
+
145
+ return fig, f"Success! Generated tokens length: {len(generated_tokens.split())}"
146
+
147
+ except Exception as e:
148
+ return None, f"Visualization Error: {e}"
149
+
150
+
151
+ # Gradio UI
152
+ with gr.Interface(
153
+ fn=generate_motion_app,
154
+ inputs=gr.Textbox(label="Enter Motion Prompt", placeholder="e.g. walking forward"),
155
+ outputs=[
156
+ gr.Plot(label="Motion Visualization"),
157
+ gr.Textbox(label="Status/Output")
158
+ ],
159
+ title="SignMotionGPT Demo",
160
+ description="Generate Sign Language/Motion Avatars from Text. Using model checkpoint: epoch 30."
161
+ ) as demo:
162
+ pass
163
+
164
+ if __name__ == "__main__":
165
+ demo.launch()
166
+
collators.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data collators with label masking for training
3
+ """
4
+ import torch
5
+
6
+
7
+ class AssistantSpanCollator:
8
+ """
9
+ Collator that masks labels to only train on assistant responses.
10
+
11
+ For where=="mot": labels only inside <MOT_BEGIN>...<MOT_END> in assistant
12
+ For where=="text": labels entire assistant span (for M2T tasks)
13
+ """
14
+
15
+ def __init__(self, tokenizer, max_length):
16
+ self.tok = tokenizer
17
+ self.max_len = max_length
18
+
19
+ # Get special token IDs
20
+ self.im_start = self.tok.convert_tokens_to_ids("<|im_start|>")
21
+ self.im_end = self.tok.convert_tokens_to_ids("<|im_end|>")
22
+ self.mot_beg = self.tok.convert_tokens_to_ids("<MOT_BEGIN>")
23
+ self.mot_end = self.tok.convert_tokens_to_ids("<MOT_END>")
24
+
25
+ def __call__(self, examples):
26
+ texts = [e["text"] for e in examples]
27
+ wheres = [e["where"] for e in examples]
28
+
29
+ # Tokenize
30
+ enc = self.tok(
31
+ texts,
32
+ return_tensors="pt",
33
+ padding=True,
34
+ truncation=True,
35
+ max_length=self.max_len
36
+ )
37
+
38
+ input_ids = enc["input_ids"]
39
+ labels = input_ids.clone().fill_(-100)
40
+
41
+ # Apply label masking per example
42
+ for i, w in enumerate(wheres):
43
+ seq = input_ids[i]
44
+
45
+ # Find last <|im_start|> (start of assistant)
46
+ starts = (seq == self.im_start).nonzero(as_tuple=True)[0]
47
+ if starts.numel() == 0:
48
+ continue
49
+
50
+ a_start = int(starts[-1].item())
51
+
52
+ # Find corresponding <|im_end|>
53
+ sub = seq[a_start+1:]
54
+ ends = (sub == self.im_end).nonzero(as_tuple=True)[0]
55
+ a_end = (a_start + 1 + int(ends[0].item())) if ends.numel() > 0 else (seq.size(0) - 1)
56
+
57
+ if w == "text":
58
+ # Label entire assistant span
59
+ labels[i, a_start+1:a_end] = seq[a_start+1:a_end]
60
+ else:
61
+ # Label only motion tokens between <MOT_BEGIN> and <MOT_END>
62
+ asst = seq[a_start+1:a_end]
63
+ bpos = (asst == self.mot_beg).nonzero(as_tuple=True)[0]
64
+ epos = (asst == self.mot_end).nonzero(as_tuple=True)[0]
65
+
66
+ if bpos.numel() > 0 and epos.numel() > 0 and epos[0] >= bpos[0]:
67
+ b = a_start + 1 + int(bpos[0].item())
68
+ e = a_start + 1 + int(epos[0].item())
69
+ labels[i, b:e+1] = seq[b:e+1]
70
+
71
+ return {
72
+ "input_ids": input_ids,
73
+ "attention_mask": enc["attention_mask"],
74
+ "labels": labels
75
+ }
config.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration file for Motion LLM training
3
+ """
4
+ import os
5
+ import torch
6
+
7
+ # Random seed
8
+ SEED = 42
9
+
10
+ # Paths
11
+ # WORK_DIR defaults to current working directory if not explicitly set
12
+ WORK_DIR = os.environ.get("WORK_DIR", os.getcwd())
13
+ DATA_DIR = os.environ.get("DATA_DIR", os.path.join(WORK_DIR, "data"))
14
+ os.makedirs(DATA_DIR, exist_ok=True)
15
+
16
+ # Single-file JSON dataset path (can be overridden via env)
17
+ DATA_JSON_PATH = os.environ.get(
18
+ "DATA_JSON_PATH",
19
+ os.path.join(DATA_DIR, "motion_llm_dataset.json"),
20
+ )
21
+
22
+ # Directory Configuration
23
+ # PIPELINE_OUTPUT_DIR matches test_overfit's default "./motion_gpt_full_model"
24
+ PIPELINE_OUTPUT_DIR = os.environ.get("PIPELINE_OUTPUT_DIR", "./motion_gpt_full_model")
25
+ METRICS_JSON_PATH = os.path.join(PIPELINE_OUTPUT_DIR, "metrics.json")
26
+ CHECKPOINTS_DIR = os.path.join(PIPELINE_OUTPUT_DIR, "checkpoints")
27
+
28
+ # Model configuration
29
+ MODEL_NAME = "Qwen/Qwen3-0.6B" # Matches test_overfit.py
30
+ MAX_SEQ_LEN = 512 # Kept from previous config, though test_overfit uses 256 in datasets
31
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16
32
+
33
+ # Evaluation Words (matches test_overfit.py)
34
+ EVALUATION_WORDS = ["passport", "send", "library", "push"]
35
+
36
+ # Training Hyperparameters (matches test_overfit.py)
37
+ # Stage 1
38
+ S1_EPOCHS = 20
39
+ S1_LR = 5e-5
40
+ S1_BATCH_SIZE = 8
41
+
42
+ # Stage 2
43
+ S2_EPOCHS = 20
44
+ S2_LR = 2e-5
45
+ S2_BATCH_SIZE = 8
46
+
47
+ # Inference Hyperparameters (matches test_overfit.py)
48
+ INFERENCE_REPETITION_PENALTY = 1.2
49
+ INFERENCE_TEMPERATURE = 0.7
50
+ INFERENCE_TOP_K = 50
51
+
52
+ # Special Tokens (matches test_overfit.py)
53
+ M_START = "<M_START>"
54
+ M_END = "<M_END>"
55
+ PAD_TOKEN = "<PAD>"
56
+
57
+ # Hugging Face Hub Configuration
58
+ HF_USE_HUB = True
59
+ HF_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("hf_auth_token")
60
+ HF_USER = os.environ.get("HF_USER", "rdz-falcon") # Derived from test_overfit.py repo ids
61
+ HF_STAGE1_REPO_ID = "rdz-falcon/SignMotionGPTfit-archive"
62
+ HF_STAGE2_REPO_ID = "rdz-falcon/SignMotionGPTfit-archive"
63
+ HF_PRIVATE_REPO = os.environ.get("HF_PRIVATE", "true").lower() != "false"
64
+ FORCE_STAGE2_FROM_STAGE1_RAW = os.environ.get("FORCE_STAGE2_FROM_STAGE1", "false")
65
+ FORCE_STAGE2_FROM_STAGE1 = str(FORCE_STAGE2_FROM_STAGE1_RAW).strip().lower() not in ("0", "false", "no", "off")
66
+ HF_STAGE2_SAVE_SUBDIR = os.environ.get("HF_STAGE2_SAVE_SUBDIR", "stage2_v2")
67
+ CHECKPOINT_UPLOAD_INTERVAL_EPOCHS = int(os.environ.get("HF_UPLOAD_INTERVAL_EPOCHS", "2"))
68
+ HF_DISABLE_PROGRESS = os.environ.get("HF_DISABLE_PROGRESS", "true").lower() != "false"
69
+
70
+ # Evaluation controls
71
+ RUN_EVALS_ONLY = False
72
+ EVAL_SAMPLE_LIMIT = 100
73
+
74
+ # Test Eval Config (from test_dataset_eval.py defaults)
75
+ TEST_EVAL_OUTPUT_DIR = os.environ.get("TEST_EVAL_OUTPUT_DIR", PIPELINE_OUTPUT_DIR)
76
+ TEST_EVAL_DOWNLOAD_DIR = os.environ.get(
77
+ "TEST_EVAL_DOWNLOAD_DIR", os.path.join(WORK_DIR, "test_data", "downloads")
78
+ )
79
+ TEST_EVAL_EXTRACT_DIR = os.environ.get(
80
+ "TEST_EVAL_EXTRACT_DIR", os.path.join(WORK_DIR, "test_data", "extracted")
81
+ )
82
+ TEST_EVAL_SAMPLE_LIMIT = int(os.environ.get("TEST_EVAL_SAMPLE_LIMIT", "300"))
83
+ TEST_EVAL_MAX_ZIPS = int(os.environ.get("TEST_EVAL_MAX_ZIPS", "500"))
84
+ TEST_EVAL_HF_REPO = os.environ.get("TEST_EVAL_HF_REPO", "rdz-falcon/SignMotionGPTfit-archive")
85
+ TEST_EVAL_HF_SUBFOLDER = os.environ.get(
86
+ "TEST_EVAL_HF_SUBFOLDER", f"{HF_STAGE2_SAVE_SUBDIR}/latest"
87
+ )
data.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset loading and vocabulary building utilities
3
+ """
4
+ import json
5
+ import os
6
+ import random
7
+ from typing import List, Dict, Tuple, Any
8
+ from collections import defaultdict
9
+ import torch
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from transformers import AutoTokenizer
12
+ from config import M_START, M_END, PAD_TOKEN
13
+
14
+ # ======================================================================================
15
+ # Logic from test_overfit.py
16
+ # ======================================================================================
17
+
18
+ def read_json_data(json_path: str) -> List[Dict[str, Any]]:
19
+ """Loads the dataset from the specified JSON file."""
20
+ if not os.path.exists(json_path):
21
+ raise FileNotFoundError(f"Dataset not found at: {json_path}")
22
+ with open(json_path, "r", encoding="utf-8") as f:
23
+ return json.load(f)
24
+
25
+ def deduplicate_and_prepare_data(entries: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[str]]:
26
+ """
27
+ Cleans the entire dataset by ensuring each (word, participant_id) pair is unique.
28
+ If a conflict is found (same pair, different motion), it keeps only the first one encountered.
29
+ Then, it prepares the full list of motion tokens from the cleaned data.
30
+ """
31
+ print("\n---> Cleaning dataset by removing ambiguous (word, participant_id) pairs...")
32
+
33
+ unique_samples = {}
34
+ conflicts_found = 0
35
+
36
+ for entry in entries:
37
+ word = entry.get("word", "").lower()
38
+ pid = entry.get("participant_id", "")
39
+ key = (word, pid)
40
+
41
+ if key not in unique_samples:
42
+ unique_samples[key] = entry
43
+ else:
44
+ # A sample for this key already exists. We only care if it's a conflict.
45
+ existing_tokens = unique_samples[key].get("motion_tokens")
46
+ current_tokens = entry.get("motion_tokens")
47
+ if existing_tokens != current_tokens:
48
+ conflicts_found += 1
49
+ # We do nothing, effectively discarding this new conflicting sample.
50
+
51
+ cleaned_data = list(unique_samples.values())
52
+
53
+ print(f"Original samples: {len(entries)}")
54
+ print(f"Cleaned samples (unique (word, pid) pairs): {len(cleaned_data)}")
55
+ print(f"Removed {len(entries) - len(cleaned_data)} total samples. ({conflicts_found} were direct conflicts).")
56
+
57
+ print("\n---> Extracting motion tokens from the full cleaned dataset...")
58
+ all_motion_tokens = set()
59
+ for entry in cleaned_data:
60
+ motion_tokens = entry.get("motion_tokens", "").strip().split()
61
+ for token in motion_tokens:
62
+ all_motion_tokens.add(f"<M{token}>")
63
+
64
+ unique_tokens = sorted(list(all_motion_tokens))
65
+ print(f"Found {len(unique_tokens)} unique motion tokens in the entire dataset.")
66
+
67
+ return cleaned_data, unique_tokens
68
+
69
+ class MotionDataset(Dataset):
70
+ """Dataset for Stage 1: Contains only motion token sequences."""
71
+ def __init__(self, data: List[Dict[str, Any]], tokenizer: AutoTokenizer, max_length: int = 256):
72
+ self.tokenizer = tokenizer
73
+ self.max_length = max_length
74
+ self.sequences = []
75
+
76
+ for item in data:
77
+ tokens_str = item.get("motion_tokens", "")
78
+ wrapped_tokens = " ".join([f"<M{t}>" for t in tokens_str.split()])
79
+ full_sequence = f"{M_START} {wrapped_tokens} {M_END}"
80
+ self.sequences.append(full_sequence)
81
+
82
+ def __len__(self):
83
+ return len(self.sequences)
84
+
85
+ def __getitem__(self, idx):
86
+ return self.tokenizer(
87
+ self.sequences[idx],
88
+ truncation=True,
89
+ max_length=self.max_length,
90
+ padding="max_length",
91
+ return_tensors="pt"
92
+ )
93
+
94
+ class TextMotionDataset(Dataset):
95
+ """Dataset for Stage 2: Contains (prompt, motion_sequence) pairs."""
96
+ def __init__(self, data: List[Dict[str, Any]], tokenizer: AutoTokenizer, max_length: int = 256):
97
+ self.tokenizer = tokenizer
98
+ self.max_length = max_length
99
+ self.items = []
100
+
101
+ for item in data:
102
+ prompt = f"Instruction: Generate motion for word '{item['word']}' with variant '{item['participant_id']}'.\nMotion: "
103
+
104
+ tokens_str = item.get("motion_tokens", "")
105
+ wrapped_tokens = " ".join([f"<M{t}>" for t in tokens_str.split()])
106
+ target_sequence = f"{M_START} {wrapped_tokens} {M_END}"
107
+
108
+ full_text = prompt + target_sequence
109
+
110
+ tokenized = self.tokenizer(
111
+ full_text,
112
+ truncation=True,
113
+ max_length=self.max_length,
114
+ padding="max_length",
115
+ return_tensors="pt"
116
+ )
117
+
118
+ prompt_tokenized = self.tokenizer(prompt, return_tensors="pt")
119
+ prompt_len = prompt_tokenized.input_ids.shape[1]
120
+
121
+ labels = tokenized['input_ids'].clone()
122
+ labels[0, :prompt_len] = -100
123
+
124
+ self.items.append({
125
+ "input_ids": tokenized['input_ids'].squeeze(0),
126
+ "attention_mask": tokenized['attention_mask'].squeeze(0),
127
+ "labels": labels.squeeze(0)
128
+ })
129
+
130
+ def __len__(self):
131
+ return len(self.items)
132
+
133
+ def __getitem__(self, idx):
134
+ return self.items[idx]
135
+
136
+ # ======================================================================================
137
+ # Legacy utilities (kept for compatibility if needed, but mostly superseded)
138
+ # ======================================================================================
139
+
140
+ def build_motion_vocab(dataset):
141
+ """
142
+ Build motion vocabulary by finding max token ID
143
+ Returns: (codebook_size, max_token_id)
144
+ """
145
+ def max_token_in_example(ex):
146
+ return max(int(x) for x in ex["motion_tokens"].split())
147
+
148
+ global_max_id = 0
149
+ for ex in dataset:
150
+ global_max_id = max(global_max_id, max_token_in_example(ex))
151
+
152
+ codebook_size = global_max_id + 1
153
+ return codebook_size, global_max_id
154
+
155
+ def motion_specials_to_ids(s: str) -> List[int]:
156
+ """Extract motion IDs from special tokens"""
157
+ toks = s.strip().split()
158
+ ids = []
159
+ for t in toks:
160
+ if t.startswith("<motion_") or (t.startswith("<M") and t.endswith(">") and t[2:-1].isdigit()):
161
+ # Handle both <motion_ID> and <MID> formats
162
+ try:
163
+ if t.startswith("<motion_"):
164
+ ids.append(int(t[8:-1]))
165
+ else:
166
+ ids.append(int(t[2:-1]))
167
+ except:
168
+ pass
169
+ return ids
data/motion_llm_dataset.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba9a0521241d7c72d0759c739ea323eee47e04cf41a5a7b756b9e083b40bc4e1
3
+ size 16798494
data/smplx_models/SMPLX_NEUTRAL.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:376021446ddc86e99acacd795182bbef903e61d33b76b9d8b359c2b0865bd992
3
+ size 108752058
data/vqvae_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fadbf3fb4ded1c6fe7752e7e381b627a46fa37787d051d969b73d97f81b278fb
3
+ size 231392924
data/vqvae_stats.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa86de891dd702ca71f0006cfbf68839c5eba35fb728891ab9f1890949dca943
3
+ size 2876
generate.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generation and inference utilities with constrained decoding
3
+ """
4
+ import torch
5
+ from transformers import LogitsProcessor, LogitsProcessorList
6
+ from typing import Dict
7
+ from config import (
8
+ SYSTEM_MSG, GEN_MAX_NEW_TOKENS, GEN_TEMPERATURE,
9
+ GEN_TOP_P, GEN_TOP_K, GEN_NO_REPEAT_NGRAM_SIZE,
10
+ GEN_REPETITION_PENALTY, GEN_END_LOGIT_SLOPE
11
+ )
12
+
13
+
14
+ class LengthAwareMotionLogitsProcessor(LogitsProcessor):
15
+ """
16
+ Constrained decoding processor that:
17
+ 1. Enforces motion token vocabulary
18
+ 2. Controls sequence length (min/soft_target/max)
19
+ 3. Biases toward ending at soft_target length
20
+ """
21
+
22
+ def __init__(self, prompt_len, mot_begin_id, mot_end_id, motion_ids,
23
+ hard_min, soft_target, hard_max, end_logit_slope=0.25):
24
+ super().__init__()
25
+ self.prompt_len = int(prompt_len)
26
+ self.mot_begin_id = int(mot_begin_id)
27
+ self.mot_end_id = int(mot_end_id)
28
+ self.motion_ids = torch.tensor(sorted(set(int(x) for x in motion_ids)))
29
+ self.motion_plus_end = torch.tensor(
30
+ sorted(set(list(self.motion_ids.tolist()) + [self.mot_end_id]))
31
+ )
32
+ self.hard_min = int(hard_min)
33
+ self.soft_target = int(soft_target)
34
+ self.hard_max = int(hard_max)
35
+ self.end_logit_slope = float(end_logit_slope)
36
+
37
+ def __call__(self, input_ids, scores):
38
+ device = scores.device
39
+ bs = scores.size(0)
40
+ mask = torch.full_like(scores, float("-inf"))
41
+
42
+ for b in range(bs):
43
+ gen = input_ids[b, self.prompt_len:]
44
+
45
+ # No tokens generated yet - must start with MOT_BEGIN
46
+ if gen.numel() == 0:
47
+ allowed = torch.tensor([self.mot_begin_id], device=device)
48
+ mask[b].index_fill_(0, allowed, 0.0)
49
+ continue
50
+
51
+ # Find MOT_BEGIN position
52
+ begin_pos = (gen == self.mot_begin_id).nonzero(as_tuple=True)[0]
53
+ if begin_pos.numel() == 0:
54
+ allowed = torch.tensor([self.mot_begin_id], device=device)
55
+ mask[b].index_fill_(0, allowed, 0.0)
56
+ continue
57
+
58
+ # Already generated MOT_END - force EOS
59
+ if (gen == self.mot_end_id).any():
60
+ allowed = torch.tensor([self.mot_end_id], device=device)
61
+ mask[b].index_fill_(0, allowed, 0.0)
62
+ continue
63
+
64
+ # Count motion tokens after MOT_BEGIN
65
+ after_begin = gen[begin_pos[0].item() + 1:]
66
+ cur_len = after_begin.numel()
67
+
68
+ # Before minimum length - only allow motion tokens
69
+ if cur_len < self.hard_min:
70
+ allowed = self.motion_ids.to(device)
71
+ mask[b].index_fill_(0, allowed, 0.0)
72
+
73
+ # After maximum length - force end
74
+ elif cur_len >= self.hard_max:
75
+ allowed = torch.tensor([self.mot_end_id], device=device)
76
+ mask[b].index_fill_(0, allowed, 0.0)
77
+
78
+ # Between min and max - allow motion tokens or end
79
+ else:
80
+ allowed = self.motion_plus_end.to(device)
81
+ mask[b].index_fill_(0, allowed, 0.0)
82
+
83
+ # Bias toward ending at soft_target
84
+ distance = max(0, cur_len - self.soft_target)
85
+ bias = self.end_logit_slope * float(distance)
86
+ scores[b, self.mot_end_id] = scores[b, self.mot_end_id] + bias
87
+
88
+ return scores + mask
89
+
90
+
91
+ def get_len_controls(prompt_text: str, length_stats_by_text: Dict, global_median_len: int):
92
+ """
93
+ Get length controls (min/soft_target/max) for a given prompt
94
+ """
95
+ s = length_stats_by_text.get(prompt_text)
96
+ if s is None:
97
+ med = global_median_len
98
+ else:
99
+ med = s["median"]
100
+
101
+ hard_min = max(1, int(0.6 * med))
102
+ soft_tgt = med
103
+ hard_max = max(hard_min + 4, int(1.4 * med))
104
+
105
+ return hard_min, soft_tgt, hard_max
106
+
107
+
108
+ def generate_t2m(
109
+ model,
110
+ tokenizer,
111
+ prompt_text: str,
112
+ mot_begin_id: int,
113
+ mot_end_id: int,
114
+ motion_token_ids: list,
115
+ length_stats_by_text: Dict,
116
+ global_median_len: int,
117
+ prompt_vocab: Dict = None,
118
+ pid: str = None,
119
+ has_pid: bool = False,
120
+ max_new_tokens: int = None,
121
+ per_prompt_vocab: bool = True
122
+ ):
123
+ """
124
+ Generate motion sequence from text prompt with constrained decoding
125
+ """
126
+ model.eval()
127
+ device = next(model.parameters()).device
128
+
129
+ if max_new_tokens is None:
130
+ max_new_tokens = GEN_MAX_NEW_TOKENS
131
+
132
+ # Build prompt
133
+ pid_tok = ""
134
+ if has_pid and pid is not None:
135
+ pid_tok = f"<PID_{pid}>"
136
+
137
+ user_text = f"<T2M>{pid_tok}\n\n" + prompt_text
138
+ prompt = (
139
+ "<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
140
+ + "<|im_start|>user\n" + user_text + "\n<|im_end|>\n"
141
+ + "<|im_start|>assistant\n"
142
+ )
143
+
144
+ # Tokenize
145
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
146
+ prompt_len = inputs["input_ids"].size(1)
147
+
148
+ # Get length controls
149
+ hard_min, soft_tgt, hard_max = get_len_controls(
150
+ prompt_text, length_stats_by_text, global_median_len
151
+ )
152
+
153
+ # Get allowed motion tokens
154
+ if per_prompt_vocab and prompt_vocab:
155
+ allowed_motion_ids = prompt_vocab.get(prompt_text, motion_token_ids)
156
+ else:
157
+ allowed_motion_ids = motion_token_ids
158
+
159
+ # Setup constrained decoding
160
+ processors = LogitsProcessorList([
161
+ LengthAwareMotionLogitsProcessor(
162
+ prompt_len=prompt_len,
163
+ mot_begin_id=mot_begin_id,
164
+ mot_end_id=mot_end_id,
165
+ motion_ids=allowed_motion_ids,
166
+ hard_min=hard_min,
167
+ soft_target=soft_tgt,
168
+ hard_max=hard_max,
169
+ end_logit_slope=GEN_END_LOGIT_SLOPE,
170
+ )
171
+ ])
172
+
173
+ # Generate
174
+ with torch.no_grad():
175
+ out = model.generate(
176
+ input_ids=inputs["input_ids"],
177
+ attention_mask=inputs.get("attention_mask"),
178
+ max_new_tokens=min(max_new_tokens, hard_max + 4),
179
+ do_sample=True,
180
+ temperature=GEN_TEMPERATURE,
181
+ top_p=GEN_TOP_P,
182
+ top_k=GEN_TOP_K,
183
+ no_repeat_ngram_size=GEN_NO_REPEAT_NGRAM_SIZE,
184
+ repetition_penalty=GEN_REPETITION_PENALTY,
185
+ logits_processor=processors,
186
+ eos_token_id=mot_end_id,
187
+ pad_token_id=tokenizer.eos_token_id,
188
+ )
189
+
190
+ # Decode
191
+ decoded = tokenizer.decode(out[0], skip_special_tokens=False)
192
+ reply = decoded.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0]
193
+
194
+ return reply
inference.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for generating motion tokens from text prompts.
3
+ Run after training to generate motion sequences from any text description.
4
+
5
+ Usage:
6
+ python inference.py --prompt "walking forward" --stage 3
7
+ python inference.py --prompt "dancing" --stage 2 --output motion_output.txt
8
+ """
9
+ import os
10
+ import argparse
11
+ import torch
12
+ from pathlib import Path
13
+
14
+ from config import (
15
+ OUT_S1, OUT_S2, OUT_S3, MAX_SEQ_LEN, DATA_JSON_PATH,
16
+ WORK_DIR
17
+ )
18
+ from data import (
19
+ load_dataset, compute_length_stats, build_prompt_vocab,
20
+ check_has_participant_id
21
+ )
22
+ from model import setup_model_and_tokenizer, get_motion_token_info
23
+ from generate import generate_t2m
24
+
25
+
26
+ def load_trained_model(stage: int, device: torch.device):
27
+ """
28
+ Load a trained model from a specific stage checkpoint.
29
+
30
+ Args:
31
+ stage: Stage number (1, 2, or 3)
32
+ device: Device to load model on
33
+
34
+ Returns:
35
+ model, tokenizer, motion_token_ids, mot_begin_id, mot_end_id
36
+ """
37
+ stage_dirs = {1: OUT_S1, 2: OUT_S2, 3: OUT_S3}
38
+ stage_dir = stage_dirs.get(stage)
39
+
40
+ if not stage_dir or not os.path.exists(stage_dir):
41
+ raise FileNotFoundError(
42
+ f"Stage {stage} checkpoint not found at {stage_dir}. "
43
+ f"Train stage {stage} first."
44
+ )
45
+
46
+ print(f"\nLoading Stage {stage} model from: {stage_dir}")
47
+
48
+ # Load dataset to build vocab (needed for model setup)
49
+ if not os.path.exists(DATA_JSON_PATH):
50
+ raise FileNotFoundError(f"Dataset not found: {DATA_JSON_PATH}")
51
+
52
+ raw_ds = load_dataset(DATA_JSON_PATH)
53
+
54
+ # Build motion vocab
55
+ def max_token_in_example(ex):
56
+ return max(int(x) for x in ex["motion_tokens"].split())
57
+
58
+ global_max_id = max(max_token_in_example(ex) for ex in raw_ds)
59
+ codebook_size = global_max_id + 1
60
+
61
+ # Check for participant IDs
62
+ has_pid = check_has_participant_id(raw_ds)
63
+ unique_pids = None
64
+ if has_pid:
65
+ unique_pids = sorted({str(ex["participant_id"]) for ex in raw_ds})
66
+
67
+ # Setup model and tokenizer with same config as training
68
+ model, tokenizer, _ = setup_model_and_tokenizer(codebook_size, unique_pids)
69
+
70
+ # Load trained weights from checkpoint
71
+ # Try different checkpoint naming patterns
72
+ possible_ckpts = [
73
+ os.path.join(stage_dir, "pytorch_model.bin"),
74
+ os.path.join(stage_dir, "model.safetensors"),
75
+ os.path.join(stage_dir, "adapter_model.bin"),
76
+ ]
77
+
78
+ loaded = False
79
+ for ckpt_path in possible_ckpts:
80
+ if os.path.exists(ckpt_path):
81
+ print(f"Loading checkpoint: {ckpt_path}")
82
+ # Unsloth/PEFT models save adapters separately
83
+ # The model will auto-load from the directory
84
+ loaded = True
85
+ break
86
+
87
+ if not loaded:
88
+ print(f"⚠️ No explicit checkpoint file found, using model directory: {stage_dir}")
89
+
90
+ # Move model to device
91
+ model.to(device)
92
+ model.eval()
93
+
94
+ # Get motion token info
95
+ motion_token_ids, mot_begin_id, mot_end_id = get_motion_token_info(
96
+ tokenizer, codebook_size
97
+ )
98
+
99
+ print(f"✅ Stage {stage} model loaded successfully")
100
+ print(f" Vocabulary size: {len(tokenizer)}")
101
+ print(f" Motion tokens: {len(motion_token_ids)}")
102
+
103
+ return model, tokenizer, motion_token_ids, mot_begin_id, mot_end_id, raw_ds
104
+
105
+
106
+ def inference(
107
+ prompt: str,
108
+ stage: int = 3,
109
+ pid: str = None,
110
+ output_file: str = None,
111
+ per_prompt_vocab: bool = True,
112
+ device: torch.device = None
113
+ ):
114
+ """
115
+ Generate motion tokens from a text prompt.
116
+
117
+ Args:
118
+ prompt: Text description of desired motion
119
+ stage: Which training stage model to use (1, 2, or 3)
120
+ pid: Optional participant ID for personalization
121
+ output_file: Optional file to save output tokens
122
+ per_prompt_vocab: Whether to use per-prompt vocabulary constraints
123
+ device: Device to run inference on
124
+
125
+ Returns:
126
+ Generated motion token string
127
+ """
128
+ if device is None:
129
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130
+
131
+ print("="*60)
132
+ print(f"Motion Generation Inference - Stage {stage}")
133
+ print("="*60)
134
+ print(f"Prompt: '{prompt}'")
135
+ print(f"Device: {device}")
136
+
137
+ # Load model and dataset
138
+ model, tokenizer, motion_token_ids, mot_begin_id, mot_end_id, raw_ds = load_trained_model(stage, device)
139
+
140
+ # Compute length stats and prompt vocab
141
+ print("\nComputing dataset statistics...")
142
+ length_stats_by_text, global_median_len = compute_length_stats(raw_ds)
143
+ prompt_vocab = build_prompt_vocab(raw_ds)
144
+ has_pid = check_has_participant_id(raw_ds)
145
+
146
+ # Generate motion tokens
147
+ print(f"\nGenerating motion for: '{prompt}'")
148
+ print(f"Per-prompt vocabulary: {per_prompt_vocab}")
149
+
150
+ generated = generate_t2m(
151
+ model=model,
152
+ tokenizer=tokenizer,
153
+ prompt_text=prompt,
154
+ mot_begin_id=mot_begin_id,
155
+ mot_end_id=mot_end_id,
156
+ motion_token_ids=motion_token_ids,
157
+ length_stats_by_text=length_stats_by_text,
158
+ global_median_len=global_median_len,
159
+ prompt_vocab=prompt_vocab,
160
+ has_pid=has_pid,
161
+ per_prompt_vocab=per_prompt_vocab,
162
+ pid=pid
163
+ )
164
+
165
+ print("\n" + "="*60)
166
+ print("Generated Motion:")
167
+ print("="*60)
168
+ print(generated)
169
+ print("="*60)
170
+
171
+ # Optionally save to file
172
+ if output_file:
173
+ output_path = Path(output_file)
174
+ output_path.parent.mkdir(parents=True, exist_ok=True)
175
+ with open(output_path, 'w') as f:
176
+ f.write(generated)
177
+ print(f"\n✅ Output saved to: {output_file}")
178
+
179
+ return generated
180
+
181
+
182
+ def main():
183
+ parser = argparse.ArgumentParser(
184
+ description="Generate motion tokens from text prompts using trained SignMotionGPT model"
185
+ )
186
+ parser.add_argument(
187
+ "--prompt",
188
+ type=str,
189
+ required=True,
190
+ help="Text description of the desired motion (e.g., 'walking forward', 'dancing')"
191
+ )
192
+ parser.add_argument(
193
+ "--stage",
194
+ type=int,
195
+ default=3,
196
+ choices=[1, 2, 3],
197
+ help="Which training stage model to use (1=motion-only, 2=multi-task, 3=T2M SFT, default=3)"
198
+ )
199
+ parser.add_argument(
200
+ "--pid",
201
+ type=str,
202
+ default=None,
203
+ help="Optional participant ID for personalized generation (e.g., 'P40')"
204
+ )
205
+ parser.add_argument(
206
+ "--output",
207
+ type=str,
208
+ default=None,
209
+ help="Optional output file to save generated tokens"
210
+ )
211
+ parser.add_argument(
212
+ "--no-per-prompt-vocab",
213
+ action="store_true",
214
+ help="Disable per-prompt vocabulary constraints (allows all motion tokens)"
215
+ )
216
+ parser.add_argument(
217
+ "--device",
218
+ type=str,
219
+ default=None,
220
+ choices=["cpu", "cuda", "cuda:0", "cuda:1"],
221
+ help="Device to run inference on (default: auto-detect)"
222
+ )
223
+
224
+ args = parser.parse_args()
225
+
226
+ # Setup device
227
+ if args.device:
228
+ device = torch.device(args.device)
229
+ else:
230
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
231
+
232
+ # Run inference
233
+ inference(
234
+ prompt=args.prompt,
235
+ stage=args.stage,
236
+ pid=args.pid,
237
+ output_file=args.output,
238
+ per_prompt_vocab=not args.no_per_prompt_vocab,
239
+ device=device
240
+ )
241
+
242
+
243
+ if __name__ == "__main__":
244
+ main()
mGPT/__init__.py ADDED
File without changes
mGPT/archs/__init__.py ADDED
File without changes
mGPT/archs/mgpt_vq.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Optional, Union
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor, nn
6
+ from torch.distributions.distribution import Distribution
7
+ from .tools.resnet import Resnet1D
8
+ from .tools.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset
9
+ from collections import OrderedDict
10
+
11
+
12
+ class VQVae(nn.Module):
13
+
14
+ def __init__(self,
15
+ nfeats: int,
16
+ quantizer: str = "ema_reset",
17
+ code_num=512,
18
+ code_dim=512,
19
+ output_emb_width=512,
20
+ down_t=3,
21
+ stride_t=2,
22
+ width=512,
23
+ depth=3,
24
+ dilation_growth_rate=3,
25
+ norm=None,
26
+ activation: str = "relu",
27
+ **kwargs) -> None:
28
+
29
+ super().__init__()
30
+
31
+ self.code_dim = code_dim
32
+
33
+ self.encoder = Encoder(nfeats,
34
+ output_emb_width,
35
+ down_t,
36
+ stride_t,
37
+ width,
38
+ depth,
39
+ dilation_growth_rate,
40
+ activation=activation,
41
+ norm=norm)
42
+
43
+ self.decoder = Decoder(nfeats,
44
+ output_emb_width,
45
+ down_t,
46
+ stride_t,
47
+ width,
48
+ depth,
49
+ dilation_growth_rate,
50
+ activation=activation,
51
+ norm=norm)
52
+
53
+ if quantizer == "ema_reset":
54
+ self.quantizer = QuantizeEMAReset(code_num, code_dim, mu=0.99)
55
+ elif quantizer == "orig":
56
+ self.quantizer = Quantizer(code_num, code_dim, beta=1.0)
57
+ elif quantizer == "ema":
58
+ self.quantizer = QuantizeEMA(code_num, code_dim, mu=0.99)
59
+ elif quantizer == "reset":
60
+ self.quantizer = QuantizeReset(code_num, code_dim)
61
+
62
+ def preprocess(self, x):
63
+ # (bs, T, Jx3) -> (bs, Jx3, T)
64
+ x = x.permute(0, 2, 1)
65
+ return x
66
+
67
+ def postprocess(self, x):
68
+ # (bs, Jx3, T) -> (bs, T, Jx3)
69
+ x = x.permute(0, 2, 1)
70
+ return x
71
+
72
+ def forward(self, features: Tensor):
73
+ # Preprocess
74
+ x_in = self.preprocess(features)
75
+
76
+ # Encode
77
+ x_encoder = self.encoder(x_in)
78
+
79
+ # quantization
80
+ x_quantized, loss, perplexity = self.quantizer(x_encoder)
81
+
82
+ # decoder
83
+ x_decoder = self.decoder(x_quantized)
84
+ x_out = self.postprocess(x_decoder)
85
+
86
+ return x_out, loss, perplexity
87
+
88
+ def encode(
89
+ self,
90
+ features: Tensor,
91
+ ) -> Union[Tensor, Distribution]:
92
+
93
+ N, T, _ = features.shape
94
+ x_in = self.preprocess(features)
95
+ x_encoder = self.encoder(x_in)
96
+ x_encoder = self.postprocess(x_encoder)
97
+ x_encoder = x_encoder.contiguous().view(-1,
98
+ x_encoder.shape[-1]) # (NT, C)
99
+ code_idx = self.quantizer.quantize(x_encoder)
100
+ code_idx = code_idx.view(N, -1)
101
+
102
+ # latent, dist
103
+ return code_idx, None
104
+
105
+ def decode(self, z: Tensor):
106
+
107
+ x_d = self.quantizer.dequantize(z)
108
+ x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous()
109
+
110
+ # decoder
111
+ x_decoder = self.decoder(x_d)
112
+ x_out = self.postprocess(x_decoder)
113
+ return x_out
114
+
115
+
116
+ class Encoder(nn.Module):
117
+
118
+ def __init__(self,
119
+ input_emb_width=3,
120
+ output_emb_width=512,
121
+ down_t=3,
122
+ stride_t=2,
123
+ width=512,
124
+ depth=3,
125
+ dilation_growth_rate=3,
126
+ activation='relu',
127
+ norm=None):
128
+ super().__init__()
129
+
130
+ blocks = []
131
+ filter_t, pad_t = stride_t * 2, stride_t // 2
132
+ blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1))
133
+ blocks.append(nn.ReLU())
134
+
135
+ for i in range(down_t):
136
+ input_dim = width
137
+ block = nn.Sequential(
138
+ nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t),
139
+ Resnet1D(width,
140
+ depth,
141
+ dilation_growth_rate,
142
+ activation=activation,
143
+ norm=norm),
144
+ )
145
+ blocks.append(block)
146
+ blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1))
147
+ self.model = nn.Sequential(*blocks)
148
+
149
+ def forward(self, x):
150
+ return self.model(x)
151
+
152
+
153
+ class Decoder(nn.Module):
154
+
155
+ def __init__(self,
156
+ input_emb_width=3,
157
+ output_emb_width=512,
158
+ down_t=3,
159
+ stride_t=2,
160
+ width=512,
161
+ depth=3,
162
+ dilation_growth_rate=3,
163
+ activation='relu',
164
+ norm=None):
165
+ super().__init__()
166
+ blocks = []
167
+
168
+ filter_t, pad_t = stride_t * 2, stride_t // 2
169
+ blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1))
170
+ blocks.append(nn.ReLU())
171
+ for i in range(down_t):
172
+ out_dim = width
173
+ block = nn.Sequential(
174
+ Resnet1D(width,
175
+ depth,
176
+ dilation_growth_rate,
177
+ reverse_dilation=True,
178
+ activation=activation,
179
+ norm=norm), nn.Upsample(scale_factor=2,
180
+ mode='nearest'),
181
+ nn.Conv1d(width, out_dim, 3, 1, 1))
182
+ blocks.append(block)
183
+ blocks.append(nn.Conv1d(width, width, 3, 1, 1))
184
+ blocks.append(nn.ReLU())
185
+ blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1))
186
+ self.model = nn.Sequential(*blocks)
187
+
188
+ def forward(self, x):
189
+ return self.model(x)
mGPT/archs/tools/__init__.py ADDED
File without changes
mGPT/archs/tools/quantize_cnn.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ class QuantizeEMAReset(nn.Module):
8
+ def __init__(self, nb_code, code_dim, mu):
9
+ super().__init__()
10
+ self.nb_code = nb_code
11
+ self.code_dim = code_dim
12
+ self.mu = mu
13
+ self.reset_codebook()
14
+
15
+ def reset_codebook(self):
16
+ self.init = False
17
+ self.code_sum = None
18
+ self.code_count = None
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).to(device))
21
+
22
+ def _tile(self, x):
23
+ nb_code_x, code_dim = x.shape
24
+ if nb_code_x < self.nb_code:
25
+ n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
26
+ std = 0.01 / np.sqrt(code_dim)
27
+ out = x.repeat(n_repeats, 1)
28
+ out = out + torch.randn_like(out) * std
29
+ else :
30
+ out = x
31
+ return out
32
+
33
+ def init_codebook(self, x):
34
+ out = self._tile(x)
35
+ self.codebook = out[:self.nb_code]
36
+ self.code_sum = self.codebook.clone()
37
+ self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
38
+ self.init = True
39
+
40
+ @torch.no_grad()
41
+ def compute_perplexity(self, code_idx) :
42
+ # Calculate new centres
43
+ code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
44
+ code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
45
+
46
+ code_count = code_onehot.sum(dim=-1) # nb_code
47
+ prob = code_count / torch.sum(code_count)
48
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
49
+ return perplexity
50
+
51
+ @torch.no_grad()
52
+ def update_codebook(self, x, code_idx):
53
+
54
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
55
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
56
+
57
+ code_sum = torch.matmul(code_onehot, x) # nb_code, w
58
+ code_count = code_onehot.sum(dim=-1) # nb_code
59
+
60
+ out = self._tile(x)
61
+ code_rand = out[:self.nb_code]
62
+
63
+ # Update centres
64
+ self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
65
+ self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
66
+
67
+ usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
68
+ code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
69
+
70
+ self.codebook = usage * code_update + (1 - usage) * code_rand
71
+ prob = code_count / torch.sum(code_count)
72
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
73
+
74
+
75
+ return perplexity
76
+
77
+ def preprocess(self, x):
78
+ # NCT -> NTC -> [NT, C]
79
+ x = x.permute(0, 2, 1).contiguous()
80
+ x = x.view(-1, x.shape[-1])
81
+ return x
82
+
83
+ def quantize(self, x):
84
+ # Calculate latent code x_l
85
+ k_w = self.codebook.t()
86
+ distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
87
+ keepdim=True) # (N * L, b)
88
+ _, code_idx = torch.min(distance, dim=-1)
89
+ return code_idx
90
+
91
+ def dequantize(self, code_idx):
92
+ x = F.embedding(code_idx, self.codebook)
93
+ return x
94
+
95
+
96
+ def forward(self, x):
97
+ N, width, T = x.shape
98
+
99
+ # Preprocess
100
+ x = self.preprocess(x)
101
+
102
+ # Init codebook if not inited
103
+ if self.training and not self.init:
104
+ self.init_codebook(x)
105
+
106
+ # quantize and dequantize through bottleneck
107
+ code_idx = self.quantize(x)
108
+ x_d = self.dequantize(code_idx)
109
+
110
+ # Update embeddings
111
+ if self.training:
112
+ perplexity = self.update_codebook(x, code_idx)
113
+ else :
114
+ perplexity = self.compute_perplexity(code_idx)
115
+
116
+ # Loss
117
+ commit_loss = F.mse_loss(x, x_d.detach())
118
+
119
+ # Passthrough
120
+ x_d = x + (x_d - x).detach()
121
+
122
+ # Postprocess
123
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
124
+
125
+ return x_d, commit_loss, perplexity
126
+
127
+
128
+
129
+ class Quantizer(nn.Module):
130
+ def __init__(self, n_e, e_dim, beta):
131
+ super(Quantizer, self).__init__()
132
+
133
+ self.e_dim = e_dim
134
+ self.n_e = n_e
135
+ self.beta = beta
136
+
137
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
138
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
139
+
140
+ def forward(self, z):
141
+
142
+ N, width, T = z.shape
143
+ z = self.preprocess(z)
144
+ assert z.shape[-1] == self.e_dim
145
+ z_flattened = z.contiguous().view(-1, self.e_dim)
146
+
147
+ # B x V
148
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
149
+ # B x 1
150
+ min_encoding_indices = torch.argmin(d, dim=1)
151
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
152
+
153
+ # compute loss for embedding
154
+ loss = torch.mean((z_q - z.detach())**2) + self.beta * torch.mean((z_q.detach() - z)**2)
155
+
156
+ # preserve gradients
157
+ z_q = z + (z_q - z).detach()
158
+ z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
159
+
160
+ min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype)
161
+ e_mean = torch.mean(min_encodings, dim=0)
162
+ perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10)))
163
+ return z_q, loss, perplexity
164
+
165
+ def quantize(self, z):
166
+
167
+ assert z.shape[-1] == self.e_dim
168
+
169
+ # B x V
170
+ d = torch.sum(z ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * torch.matmul(z, self.embedding.weight.t())
171
+ # B x 1
172
+ min_encoding_indices = torch.argmin(d, dim=1)
173
+ return min_encoding_indices
174
+
175
+ def dequantize(self, indices):
176
+
177
+ index_flattened = indices.view(-1)
178
+ z_q = self.embedding(index_flattened)
179
+ z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous()
180
+ return z_q
181
+
182
+ def preprocess(self, x):
183
+ # NCT -> NTC -> [NT, C]
184
+ x = x.permute(0, 2, 1).contiguous()
185
+ x = x.view(-1, x.shape[-1])
186
+ return x
187
+
188
+
189
+
190
+ class QuantizeReset(nn.Module):
191
+ def __init__(self, nb_code, code_dim):
192
+ super().__init__()
193
+ self.nb_code = nb_code
194
+ self.code_dim = code_dim
195
+ self.reset_codebook()
196
+ self.codebook = nn.Parameter(torch.randn(nb_code, code_dim))
197
+
198
+ def reset_codebook(self):
199
+ self.init = False
200
+ self.code_count = None
201
+
202
+ def _tile(self, x):
203
+ nb_code_x, code_dim = x.shape
204
+ if nb_code_x < self.nb_code:
205
+ n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
206
+ std = 0.01 / np.sqrt(code_dim)
207
+ out = x.repeat(n_repeats, 1)
208
+ out = out + torch.randn_like(out) * std
209
+ else :
210
+ out = x
211
+ return out
212
+
213
+ def init_codebook(self, x):
214
+ out = self._tile(x)
215
+ self.codebook = nn.Parameter(out[:self.nb_code])
216
+ self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
217
+ self.init = True
218
+
219
+ @torch.no_grad()
220
+ def compute_perplexity(self, code_idx) :
221
+ # Calculate new centres
222
+ code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
223
+ code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
224
+
225
+ code_count = code_onehot.sum(dim=-1) # nb_code
226
+ prob = code_count / torch.sum(code_count)
227
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
228
+ return perplexity
229
+
230
+ def update_codebook(self, x, code_idx):
231
+
232
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
233
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
234
+
235
+ code_count = code_onehot.sum(dim=-1) # nb_code
236
+
237
+ out = self._tile(x)
238
+ code_rand = out[:self.nb_code]
239
+
240
+ # Update centres
241
+ self.code_count = code_count # nb_code
242
+ usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
243
+
244
+ self.codebook.data = usage * self.codebook.data + (1 - usage) * code_rand
245
+ prob = code_count / torch.sum(code_count)
246
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
247
+
248
+
249
+ return perplexity
250
+
251
+ def preprocess(self, x):
252
+ # NCT -> NTC -> [NT, C]
253
+ x = x.permute(0, 2, 1).contiguous()
254
+ x = x.view(-1, x.shape[-1])
255
+ return x
256
+
257
+ def quantize(self, x):
258
+ # Calculate latent code x_l
259
+ k_w = self.codebook.t()
260
+ distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
261
+ keepdim=True) # (N * L, b)
262
+ _, code_idx = torch.min(distance, dim=-1)
263
+ return code_idx
264
+
265
+ def dequantize(self, code_idx):
266
+ x = F.embedding(code_idx, self.codebook)
267
+ return x
268
+
269
+
270
+ def forward(self, x):
271
+ N, width, T = x.shape
272
+ # Preprocess
273
+ x = self.preprocess(x)
274
+ # Init codebook if not inited
275
+ if self.training and not self.init:
276
+ self.init_codebook(x)
277
+ # quantize and dequantize through bottleneck
278
+ code_idx = self.quantize(x)
279
+ x_d = self.dequantize(code_idx)
280
+ # Update embeddings
281
+ if self.training:
282
+ perplexity = self.update_codebook(x, code_idx)
283
+ else :
284
+ perplexity = self.compute_perplexity(code_idx)
285
+
286
+ # Loss
287
+ commit_loss = F.mse_loss(x, x_d.detach())
288
+
289
+ # Passthrough
290
+ x_d = x + (x_d - x).detach()
291
+
292
+ # Postprocess
293
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
294
+
295
+ return x_d, commit_loss, perplexity
296
+
297
+
298
+ class QuantizeEMA(nn.Module):
299
+ def __init__(self, nb_code, code_dim, mu):
300
+ super().__init__()
301
+ self.nb_code = nb_code
302
+ self.code_dim = code_dim
303
+ self.mu = mu
304
+ self.reset_codebook()
305
+
306
+ def reset_codebook(self):
307
+ self.init = False
308
+ self.code_sum = None
309
+ self.code_count = None
310
+ self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda())
311
+
312
+ def _tile(self, x):
313
+ nb_code_x, code_dim = x.shape
314
+ if nb_code_x < self.nb_code:
315
+ n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
316
+ std = 0.01 / np.sqrt(code_dim)
317
+ out = x.repeat(n_repeats, 1)
318
+ out = out + torch.randn_like(out) * std
319
+ else :
320
+ out = x
321
+ return out
322
+
323
+ def init_codebook(self, x):
324
+ out = self._tile(x)
325
+ self.codebook = out[:self.nb_code]
326
+ self.code_sum = self.codebook.clone()
327
+ self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
328
+ self.init = True
329
+
330
+ @torch.no_grad()
331
+ def compute_perplexity(self, code_idx) :
332
+ # Calculate new centres
333
+ code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
334
+ code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
335
+
336
+ code_count = code_onehot.sum(dim=-1) # nb_code
337
+ prob = code_count / torch.sum(code_count)
338
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
339
+ return perplexity
340
+
341
+ @torch.no_grad()
342
+ def update_codebook(self, x, code_idx):
343
+
344
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
345
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
346
+
347
+ code_sum = torch.matmul(code_onehot, x) # nb_code, w
348
+ code_count = code_onehot.sum(dim=-1) # nb_code
349
+
350
+ # Update centres
351
+ self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
352
+ self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
353
+
354
+ code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
355
+
356
+ self.codebook = code_update
357
+ prob = code_count / torch.sum(code_count)
358
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
359
+
360
+ return perplexity
361
+
362
+ def preprocess(self, x):
363
+ # NCT -> NTC -> [NT, C]
364
+ x = x.permute(0, 2, 1).contiguous()
365
+ x = x.view(-1, x.shape[-1])
366
+ return x
367
+
368
+ def quantize(self, x):
369
+ # Calculate latent code x_l
370
+ k_w = self.codebook.t()
371
+ distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
372
+ keepdim=True) # (N * L, b)
373
+ _, code_idx = torch.min(distance, dim=-1)
374
+ return code_idx
375
+
376
+ def dequantize(self, code_idx):
377
+ x = F.embedding(code_idx, self.codebook)
378
+ return x
379
+
380
+
381
+ def forward(self, x):
382
+ N, width, T = x.shape
383
+
384
+ # Preprocess
385
+ x = self.preprocess(x)
386
+
387
+ # Init codebook if not inited
388
+ if self.training and not self.init:
389
+ self.init_codebook(x)
390
+
391
+ # quantize and dequantize through bottleneck
392
+ code_idx = self.quantize(x)
393
+ x_d = self.dequantize(code_idx)
394
+
395
+ # Update embeddings
396
+ if self.training:
397
+ perplexity = self.update_codebook(x, code_idx)
398
+ else :
399
+ perplexity = self.compute_perplexity(code_idx)
400
+
401
+ # Loss
402
+ commit_loss = F.mse_loss(x, x_d.detach())
403
+
404
+ # Passthrough
405
+ x_d = x + (x_d - x).detach()
406
+
407
+ # Postprocess
408
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
409
+
410
+ return x_d, commit_loss, perplexity
mGPT/archs/tools/resnet.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ class nonlinearity(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x):
10
+ # swish
11
+ return x * torch.sigmoid(x)
12
+
13
+ class ResConv1DBlock(nn.Module):
14
+ def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None):
15
+ super().__init__()
16
+ padding = dilation
17
+ self.norm = norm
18
+ if norm == "LN":
19
+ self.norm1 = nn.LayerNorm(n_in)
20
+ self.norm2 = nn.LayerNorm(n_in)
21
+ elif norm == "GN":
22
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
23
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
24
+ elif norm == "BN":
25
+ self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
26
+ self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
27
+
28
+ else:
29
+ self.norm1 = nn.Identity()
30
+ self.norm2 = nn.Identity()
31
+
32
+ if activation == "relu":
33
+ self.activation1 = nn.ReLU()
34
+ self.activation2 = nn.ReLU()
35
+
36
+ elif activation == "silu":
37
+ self.activation1 = nonlinearity()
38
+ self.activation2 = nonlinearity()
39
+
40
+ elif activation == "gelu":
41
+ self.activation1 = nn.GELU()
42
+ self.activation2 = nn.GELU()
43
+
44
+ self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation)
45
+ self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0,)
46
+
47
+
48
+ def forward(self, x):
49
+ x_orig = x
50
+ if self.norm == "LN":
51
+ x = self.norm1(x.transpose(-2, -1))
52
+ x = self.activation1(x.transpose(-2, -1))
53
+ else:
54
+ x = self.norm1(x)
55
+ x = self.activation1(x)
56
+
57
+ x = self.conv1(x)
58
+
59
+ if self.norm == "LN":
60
+ x = self.norm2(x.transpose(-2, -1))
61
+ x = self.activation2(x.transpose(-2, -1))
62
+ else:
63
+ x = self.norm2(x)
64
+ x = self.activation2(x)
65
+
66
+ x = self.conv2(x)
67
+ x = x + x_orig
68
+ return x
69
+
70
+ class Resnet1D(nn.Module):
71
+ def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None):
72
+ super().__init__()
73
+
74
+ blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) for depth in range(n_depth)]
75
+ if reverse_dilation:
76
+ blocks = blocks[::-1]
77
+
78
+ self.model = nn.Sequential(*blocks)
79
+
80
+ def forward(self, x):
81
+ return self.model(x)
metrics.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation metrics for motion generation
3
+ """
4
+ import random
5
+ import os
6
+ import re
7
+ import json
8
+ import numpy as np
9
+ import scipy.linalg
10
+ import torch
11
+ from typing import List, Tuple, Dict, Optional, Any
12
+ from rapidfuzz.distance import Levenshtein
13
+ from collections import defaultdict
14
+ from data import motion_specials_to_ids
15
+ from config import (
16
+ SEED, PIPELINE_OUTPUT_DIR, M_START, M_END,
17
+ INFERENCE_TEMPERATURE, INFERENCE_TOP_K, INFERENCE_REPETITION_PENALTY
18
+ )
19
+
20
+ random.seed(SEED)
21
+
22
+ # ======================================================================================
23
+ # Logic from test_overfit.py (Metrics & Visualization)
24
+ # ======================================================================================
25
+
26
+ def calculate_activation_statistics_np(activations: np.ndarray):
27
+ """
28
+ Params:
29
+ -- activations: num_samples x dim_feat (numpy)
30
+ Returns:
31
+ -- mu: dim_feat
32
+ -- sigma: dim_feat x dim_feat
33
+ """
34
+ mu = np.mean(activations, axis=0)
35
+ cov = np.cov(activations, rowvar=False)
36
+ return mu, cov
37
+
38
+ def calculate_frechet_distance_np(mu1, sigma1, mu2, sigma2, eps=1e-6):
39
+ """Numpy implementation of the Frechet Distance."""
40
+ mu1 = np.atleast_1d(mu1)
41
+ mu2 = np.atleast_1d(mu2)
42
+ sigma1 = np.atleast_2d(sigma1)
43
+ sigma2 = np.atleast_2d(sigma2)
44
+ assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
45
+ assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
46
+ diff = mu1 - mu2
47
+ covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
48
+ if not np.isfinite(covmean).all():
49
+ offset = np.eye(sigma1.shape[0]) * eps
50
+ covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
51
+ if np.iscomplexobj(covmean):
52
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
53
+ m = np.max(np.abs(covmean.imag))
54
+ raise ValueError(f"Imaginary component {m}")
55
+ covmean = covmean.real
56
+ tr_covmean = np.trace(covmean)
57
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
58
+
59
+ def calculate_diversity_np(activation: np.ndarray, diversity_times: int = 200) -> float:
60
+ """Mean pairwise L2 distance across random pairs."""
61
+ assert len(activation.shape) == 2
62
+ if activation.shape[0] < 2:
63
+ return 0.0
64
+ num_samples = activation.shape[0]
65
+ effective_times = min(diversity_times, max(1, num_samples - 1))
66
+ first_indices = np.random.choice(num_samples, effective_times, replace=False)
67
+ second_indices = np.random.choice(num_samples, effective_times, replace=False)
68
+ diffs = activation[first_indices] - activation[second_indices]
69
+ dist = np.linalg.norm(diffs, axis=1)
70
+ return float(dist.mean())
71
+
72
+ def calculate_multimodality_np(activation: np.ndarray, multimodality_times: int = 20) -> float:
73
+ """
74
+ activation: [num_labels, num_per_label, D]
75
+ Returns mean pairwise within-label diversity (higher = more multimodal).
76
+ """
77
+ assert len(activation.shape) == 3
78
+ num_labels, num_per_label, _ = activation.shape
79
+ if num_per_label < 2:
80
+ return float("nan")
81
+ effective_times = min(multimodality_times, max(1, num_per_label - 1))
82
+ first_dices = np.random.choice(num_per_label, effective_times, replace=False)
83
+ second_dices = np.random.choice(num_per_label, effective_times, replace=False)
84
+ diffs = activation[:, first_dices] - activation[:, second_dices]
85
+ dist = np.linalg.norm(diffs, axis=2)
86
+ return float(dist.mean())
87
+
88
+ # --------------------------------------------------------------------------------------
89
+ # Token sequence → activation (bag-of-motion-tokens) helpers
90
+ # --------------------------------------------------------------------------------------
91
+ def _extract_motion_tokens_from_sequence(seq: str) -> list[str]:
92
+ # Expect tokens like <M123>, within M_START/M_END fences; keep only <M...>
93
+ return [tok for tok in seq.split() if tok.startswith("<M") and tok.endswith(">")]
94
+
95
+ def _extract_ids_from_sequence(seq: str) -> list[int]:
96
+ return [int(t[2:-1]) for t in _extract_motion_tokens_from_sequence(seq) if t[2:-1].isdigit()]
97
+
98
+ def _build_token_index(tokens_vocab: list[str]) -> Dict[str, int]:
99
+ return {tok: idx for idx, tok in enumerate(tokens_vocab)}
100
+
101
+ def _sequence_to_activation(seq: str, token_to_index: Dict[str, int]) -> np.ndarray:
102
+ vec = np.zeros((len(token_to_index),), dtype=np.float32)
103
+ for tok in _extract_motion_tokens_from_sequence(seq):
104
+ idx = token_to_index.get(tok)
105
+ if idx is not None:
106
+ vec[idx] += 1.0
107
+ # Normalize to unit length to reduce length bias
108
+ norm = np.linalg.norm(vec)
109
+ if norm > 0:
110
+ vec = vec / norm
111
+ return vec
112
+
113
+ def generate_motion(model, tokenizer, prompt, device):
114
+ """Generates a motion sequence from a prompt using sampling."""
115
+ model.eval()
116
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
117
+
118
+ with torch.no_grad():
119
+ output = model.generate(
120
+ **inputs,
121
+ max_new_tokens=100,
122
+ do_sample=True,
123
+ temperature=INFERENCE_TEMPERATURE,
124
+ top_k=INFERENCE_TOP_K,
125
+ repetition_penalty=INFERENCE_REPETITION_PENALTY,
126
+ pad_token_id=tokenizer.pad_token_id,
127
+ eos_token_id=tokenizer.convert_tokens_to_ids(M_END),
128
+ early_stopping=True
129
+ )
130
+
131
+ decoded = tokenizer.decode(output[0], skip_special_tokens=False)
132
+ if "Motion: " in decoded:
133
+ motion_part = decoded.split("Motion: ")[-1]
134
+ else:
135
+ motion_part = decoded
136
+ return motion_part.strip()
137
+
138
+ def _collect_eval_pairs(model, tokenizer, data, device) -> list[Tuple[str, str, str]]:
139
+ """
140
+ Returns list of (word, participant_id, gt_sequence, generated_sequence) for each sample in data.
141
+ """
142
+ results = []
143
+ for sample in data:
144
+ gt_tokens_str = sample.get("motion_tokens", "")
145
+ gt_wrapped = " ".join([f"<M{t}>" for t in gt_tokens_str.split()])
146
+ gt_sequence = f"{M_START} {gt_wrapped} {M_END}"
147
+ prompt = f"Instruction: Generate motion for word '{sample['word']}' with variant '{sample['participant_id']}'.\nMotion: "
148
+ generated_sequence = generate_motion(model, tokenizer, prompt, device)
149
+ pid = str(sample.get("participant_id", ""))
150
+ results.append((sample["word"], pid, gt_sequence, generated_sequence))
151
+ return results
152
+
153
+ def _activations_from_pairs(pairs: list[Tuple[str, str, str]], vocab_tokens: list[str]):
154
+ """
155
+ Build numpy activations and labels arrays from sequences.
156
+ Returns:
157
+ gt_acts: (N, D)
158
+ gen_acts: (N, D)
159
+ labels: list[str] length N (word labels)
160
+ """
161
+ token_to_index = _build_token_index(vocab_tokens)
162
+ gt_vecs = []
163
+ gen_vecs = []
164
+ labels = []
165
+ for pair in pairs:
166
+ # Support both legacy 3-tuple (word, gt, gen) and new 4-tuple (word, pid, gt, gen)
167
+ if len(pair) == 4:
168
+ word, _pid, gt_seq, gen_seq = pair
169
+ else:
170
+ word, gt_seq, gen_seq = pair
171
+ gt_vecs.append(_sequence_to_activation(gt_seq, token_to_index))
172
+ gen_vecs.append(_sequence_to_activation(gen_seq, token_to_index))
173
+ labels.append(word)
174
+ return np.stack(gt_vecs, axis=0), np.stack(gen_vecs, axis=0), labels
175
+
176
+ def _to_label_tensor3(acts: np.ndarray, labels: list[str]) -> np.ndarray:
177
+ """
178
+ Convert N x D activations with string labels to [L, K, D] by truncating each label
179
+ to the minimum count across labels.
180
+ """
181
+ label_to_indices: Dict[str, list[int]] = {}
182
+ for i, lbl in enumerate(labels):
183
+ label_to_indices.setdefault(lbl, []).append(i)
184
+ per_label_counts = [len(idxs) for idxs in label_to_indices.values()]
185
+ if len(per_label_counts) == 0:
186
+ raise ValueError("No labels found for multimodality computation.")
187
+ min_count = max(2, min(per_label_counts))
188
+ label_names = sorted(label_to_indices.keys())
189
+ stacked = []
190
+ for lbl in label_names:
191
+ idxs = label_to_indices[lbl][:min_count]
192
+ stacked.append(acts[idxs])
193
+ return np.stack(stacked, axis=0) # [L, K, D]
194
+
195
+ def evaluate_metrics_motiongpt_style(model, tokenizer, eval_data, all_motion_tokens, device):
196
+ """
197
+ Computes:
198
+ - Diversity: GT vs GEN (pair)
199
+ - Multimodality (MIM): GT vs GEN (pair)
200
+ - FID: between GT and GEN
201
+ """
202
+ print("\n" + "="*80)
203
+ print(" METRICS EVALUATION (FID, Diversity, Multimodality)")
204
+ print("="*80)
205
+ pairs = _collect_eval_pairs(model, tokenizer, eval_data, device)
206
+ gt_acts, gen_acts, labels = _activations_from_pairs(pairs, all_motion_tokens)
207
+ # Diversity
208
+ diversity_times = min(200, max(4, gt_acts.shape[0] - 1))
209
+ diversity_gt = calculate_diversity_np(gt_acts, diversity_times=diversity_times)
210
+ diversity_gen = calculate_diversity_np(gen_acts, diversity_times=diversity_times)
211
+ # Multimodality (MIM)
212
+ try:
213
+ gt_lbl_tensor = _to_label_tensor3(gt_acts, labels)
214
+ gen_lbl_tensor = _to_label_tensor3(gen_acts, labels)
215
+ multimodality_times = min(20, max(3, gt_lbl_tensor.shape[1] - 1))
216
+ mim_gt = calculate_multimodality_np(gt_lbl_tensor, multimodality_times=multimodality_times)
217
+ mim_gen = calculate_multimodality_np(gen_lbl_tensor, multimodality_times=multimodality_times)
218
+ except Exception as exc:
219
+ print(f"⚠️ Multimodality could not be computed reliably: {exc}")
220
+ mim_gt = float("nan")
221
+ mim_gen = float("nan")
222
+ # FID
223
+ mu_gen, cov_gen = calculate_activation_statistics_np(gen_acts)
224
+ mu_gt, cov_gt = calculate_activation_statistics_np(gt_acts)
225
+ fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen)
226
+ print(f"Diversity: GT = {diversity_gt:.4f} | GEN = {diversity_gen:.4f}")
227
+ print(f"Multimodality (MIM): GT = {mim_gt:.4f} | GEN = {mim_gen:.4f}")
228
+ print(f"FID (GT vs GEN): {fid:.4f}")
229
+ return {
230
+ "diversity_gt": diversity_gt,
231
+ "diversity_gen": diversity_gen,
232
+ "mim_gt": mim_gt,
233
+ "mim_gen": mim_gen,
234
+ "fid": fid,
235
+ "pairs": pairs, # for visualization usage
236
+ }
237
+
238
+ def _encode_params_to_feature(params: np.ndarray, vq_model, mean, std, device) -> np.ndarray:
239
+ """
240
+ Convert SMPL-X parameter sequence (T, D) into a single clip feature using
241
+ the VQ-VAE encoder output BEFORE quantization. Average-pool over time to get (D_embed,).
242
+ """
243
+ if params.size == 0:
244
+ return np.zeros((getattr(vq_model.vqvae, "output_emb_width", 512),), dtype=np.float32)
245
+ x = torch.from_numpy(params.astype(np.float32)).to(device) # [T, D]
246
+ x = x.unsqueeze(0) # [1, T, D]
247
+ with torch.no_grad():
248
+ # Normalize / preprocess
249
+ x_pre = None
250
+ if hasattr(vq_model.vqvae, "preprocess"):
251
+ try:
252
+ x_pre = vq_model.vqvae.preprocess(x) # expected to return tensor ready for encoder
253
+ except Exception:
254
+ x_pre = None
255
+ if x_pre is None:
256
+ # Manual normalization with provided mean/std
257
+ if mean is not None and std is not None:
258
+ mean_t = torch.from_numpy(np.array(mean, dtype=np.float32)).to(device).view(1, 1, -1)
259
+ std_t = torch.from_numpy(np.array(std, dtype=np.float32)).to(device).view(1, 1, -1)
260
+ x_norm = (x - mean_t) / (std_t + 1e-8)
261
+ else:
262
+ x_norm = x
263
+ # Some encoders expect [N, D, T]
264
+ x_pre = x_norm.transpose(1, 2).contiguous() # [1, D, T]
265
+ # Encode to get pre-quant latent
266
+ z_e = vq_model.vqvae.encoder(x_pre)
267
+ # z_e could be [N, D_embed, T_q] or [N, T_q, D_embed]
268
+ if z_e.dim() == 3:
269
+ embed_dim_known = getattr(vq_model.vqvae, "output_emb_width", None)
270
+ if embed_dim_known is not None:
271
+ if z_e.shape[1] == embed_dim_known:
272
+ time_axis = 2 # [N, D_embed, T_q]
273
+ elif z_e.shape[2] == embed_dim_known:
274
+ time_axis = 1 # [N, T_q, D_embed]
275
+ else:
276
+ time_axis = 2 if z_e.shape[2] < z_e.shape[1] else 1
277
+ else:
278
+ time_axis = 2 if z_e.shape[2] < z_e.shape[1] else 1
279
+ feat = z_e.mean(dim=time_axis).squeeze(0)
280
+ elif z_e.dim() == 2:
281
+ feat = z_e.squeeze(0)
282
+ else:
283
+ feat = z_e.view(1, -1).mean(dim=0)
284
+ feat_np = feat.detach().cpu().numpy().astype(np.float32)
285
+ # L2 normalize
286
+ norm = np.linalg.norm(feat_np)
287
+ if norm > 0:
288
+ feat_np = feat_np / norm
289
+ return feat_np
290
+
291
+ def evaluate_metrics_encoder_style(
292
+ model,
293
+ tokenizer,
294
+ eval_data,
295
+ device,
296
+ vqvae_ckpt: Optional[str] = None,
297
+ stats_path: Optional[str] = None,
298
+ sample_limit: int = 100,
299
+ ):
300
+ """
301
+ Computes FID, Diversity, and MIM using VQ-VAE encoder pre-quantization features.
302
+ """
303
+ print("\n" + "="*80)
304
+ print(" METRICS EVALUATION (VQ-VAE Encoder Features)")
305
+ print("="*80)
306
+ # Lazy import to reuse your visualization utilities and stats
307
+ try:
308
+ from visualize import load_vqvae, load_stats, VQVAE_CHECKPOINT as DEFAULT_VQ, STATS_PATH as DEFAULT_STATS
309
+ vq_ckpt = vqvae_ckpt or os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ)
310
+ stats_p = stats_path or os.getenv("VQVAE_STATS_PATH", DEFAULT_STATS)
311
+ vq_model = load_vqvae(vq_ckpt, device=device)
312
+ mean, std = load_stats(stats_p)
313
+ from visualize import decode_tokens_to_params
314
+ except Exception as exc:
315
+ print(f"⚠️ Could not set up VQ-VAE encoder metrics: {exc}")
316
+ return {}
317
+ # Collect GT/GEN token sequences for pairs (limit to speed-up)
318
+ pairs = _collect_eval_pairs(model, tokenizer, eval_data[:sample_limit], device)
319
+ # Build features
320
+ gt_feats = []
321
+ gen_feats = []
322
+ labels = []
323
+ for pair in pairs:
324
+ if len(pair) == 4:
325
+ word, _pid, gt_seq, gen_seq = pair
326
+ else:
327
+ word, gt_seq, gen_seq = pair
328
+ # Decode to SMPL-X
329
+ tokens_gt = _extract_ids_from_sequence(gt_seq)
330
+ tokens_gen = _extract_ids_from_sequence(gen_seq)
331
+ try:
332
+ params_gt = decode_tokens_to_params(tokens_gt, vq_model, mean, std, device=device) # (T, D) denorm
333
+ except Exception:
334
+ params_gt = np.zeros((0, 182), dtype=np.float32)
335
+ try:
336
+ params_gen = decode_tokens_to_params(tokens_gen, vq_model, mean, std, device=device) # (T, D) denorm
337
+ except Exception:
338
+ params_gen = np.zeros((0, 182), dtype=np.float32)
339
+ # Encode (pre-quant) -> pooled feature
340
+ feat_gt = _encode_params_to_feature(params_gt, vq_model, mean, std, device)
341
+ feat_gen = _encode_params_to_feature(params_gen, vq_model, mean, std, device)
342
+ gt_feats.append(feat_gt)
343
+ gen_feats.append(feat_gen)
344
+ labels.append(word)
345
+ gt_feats = np.stack(gt_feats, axis=0)
346
+ gen_feats = np.stack(gen_feats, axis=0)
347
+ # Diversity
348
+ diversity_times = min(200, max(4, gt_feats.shape[0] - 1))
349
+ diversity_gt = calculate_diversity_np(gt_feats, diversity_times=diversity_times)
350
+ diversity_gen = calculate_diversity_np(gen_feats, diversity_times=diversity_times)
351
+ # Multimodality (MIM)
352
+ try:
353
+ gt_lbl_tensor = _to_label_tensor3(gt_feats, labels)
354
+ gen_lbl_tensor = _to_label_tensor3(gen_feats, labels)
355
+ multimodality_times = min(20, max(3, gt_lbl_tensor.shape[1] - 1))
356
+ mim_gt = calculate_multimodality_np(gt_lbl_tensor, multimodality_times=multimodality_times)
357
+ mim_gen = calculate_multimodality_np(gen_lbl_tensor, multimodality_times=multimodality_times)
358
+ except Exception as exc:
359
+ print(f"⚠️ Multimodality could not be computed reliably: {exc}")
360
+ mim_gt = float("nan")
361
+ mim_gen = float("nan")
362
+ # FID (on encoder features)
363
+ mu_gen, cov_gen = calculate_activation_statistics_np(gen_feats)
364
+ mu_gt, cov_gt = calculate_activation_statistics_np(gt_feats)
365
+ fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen)
366
+ print(f"Diversity (encoder feats): GT = {diversity_gt:.4f} | GEN = {diversity_gen:.4f}")
367
+ print(f"Multimodality (MIM, encoder): GT = {mim_gt:.4f} | GEN = {mim_gen:.4f}")
368
+ print(f"FID (encoder feats, GT vs GEN): {fid:.4f}")
369
+ return {
370
+ "diversity_gt": diversity_gt,
371
+ "diversity_gen": diversity_gen,
372
+ "mim_gt": mim_gt,
373
+ "mim_gen": mim_gen,
374
+ "fid": fid,
375
+ "pairs": pairs,
376
+ }
377
+
378
+ def save_side_by_side_visualizations(pairs: list[Tuple[str, str, str]], output_dir: str, limit: int = 4):
379
+ """
380
+ Generate side-by-side 3D animations for GT vs GEN.
381
+ """
382
+ try:
383
+ from visualize import (
384
+ load_vqvae, load_stats, load_smplx_model,
385
+ decode_tokens_to_params, params_to_vertices,
386
+ VQVAE_CHECKPOINT as DEFAULT_VQ, STATS_PATH as DEFAULT_STATS, SMPLX_MODEL_DIR as DEFAULT_SMPLX
387
+ )
388
+ import plotly.graph_objects as go
389
+ from plotly.subplots import make_subplots
390
+ except Exception as exc:
391
+ print(f"⚠️ Visualization skipped (missing dependencies): {exc}")
392
+ return
393
+
394
+ os.makedirs(output_dir, exist_ok=True)
395
+ vqvae_ckpt = os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ)
396
+ stats_path = os.getenv("VQVAE_STATS_PATH", DEFAULT_STATS)
397
+ smplx_dir = os.getenv("SMPLX_MODEL_DIR", DEFAULT_SMPLX)
398
+
399
+ print("Loading VQ-VAE, stats, SMPL-X ...")
400
+ vq_model = load_vqvae(vqvae_ckpt)
401
+ mean, std = load_stats(stats_path)
402
+ smplx_model = load_smplx_model(smplx_dir)
403
+
404
+ def animate_side_by_side(verts_left, faces, verts_right, fps=20, titles=("Ground Truth", "LLM Generated"), output_html=None):
405
+ T = min(verts_left.shape[0], verts_right.shape[0])
406
+ verts_left, verts_right = verts_left[:T], verts_right[:T]
407
+ i, j, k = faces.T.tolist()
408
+ fig = make_subplots(
409
+ rows=1, cols=2,
410
+ specs=[[{'type': 'scene'}, {'type': 'scene'}]],
411
+ horizontal_spacing=0.05,
412
+ subplot_titles=list(titles)
413
+ )
414
+ left_mesh = go.Mesh3d(x=verts_left[0,:,0], y=verts_left[0,:,1], z=verts_left[0,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False)
415
+ right_mesh = go.Mesh3d(x=verts_right[0,:,0], y=verts_right[0,:,1], z=verts_right[0,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False)
416
+ fig.add_trace(left_mesh, row=1, col=1)
417
+ fig.add_trace(right_mesh, row=1, col=2)
418
+ frames = []
419
+ for t in range(T):
420
+ frames.append(go.Frame(
421
+ name=str(t),
422
+ data=[
423
+ go.Mesh3d(x=verts_left[t,:,0], y=verts_left[t,:,1], z=verts_left[t,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False,scene="scene"),
424
+ go.Mesh3d(x=verts_right[t,:,0], y=verts_right[t,:,1], z=verts_right[t,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False,scene="scene2")
425
+ ]
426
+ ))
427
+ fig.frames = frames
428
+ fig.update_layout(
429
+ showlegend=False,
430
+ margin=dict(l=10, r=10, t=50, b=10),
431
+ scene=dict(aspectmode='data',xaxis=dict(visible=False),yaxis=dict(visible=False),zaxis=dict(visible=False),
432
+ camera=dict(eye=dict(x=0,y=-2,z=0.7))),
433
+ scene2=dict(aspectmode='data',xaxis=dict(visible=False),yaxis=dict(visible=False),zaxis=dict(visible=False),
434
+ camera=dict(eye=dict(x=0,y=-2,z=0.7))),
435
+ updatemenus=[dict(
436
+ type="buttons", x=0.5, xanchor="center", y=1.15, yanchor="top",
437
+ buttons=[
438
+ dict(label="Play", method="animate", args=[None, {"frame": {"duration": max(1,1000//fps), "redraw": True}, "fromcurrent": True}]),
439
+ dict(label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}}])
440
+ ]
441
+ )]
442
+ )
443
+ if output_html:
444
+ fig.write_html(output_html)
445
+ print(f"✅ Saved: {output_html}")
446
+ return fig
447
+
448
+ # Determine which words to include (up to `limit` distinct words)
449
+ allowed_words = None
450
+ if isinstance(limit, int) and limit > 0:
451
+ ordered_unique_words = []
452
+ for pair in pairs:
453
+ word = pair[0]
454
+ if word not in ordered_unique_words:
455
+ ordered_unique_words.append(word)
456
+ if len(ordered_unique_words) >= limit:
457
+ break
458
+ allowed_words = set(ordered_unique_words)
459
+
460
+ for pair in pairs:
461
+ try:
462
+ if len(pair) == 4:
463
+ word, pid, gt_seq, gen_seq = pair
464
+ else:
465
+ word, gt_seq, gen_seq = pair
466
+ pid = "unknown"
467
+ if allowed_words is not None and word not in allowed_words:
468
+ continue
469
+ tokens_gt = _extract_ids_from_sequence(gt_seq)
470
+ tokens_gen = _extract_ids_from_sequence(gen_seq)
471
+ params_gt = decode_tokens_to_params(tokens_gt, vq_model, mean, std)
472
+ params_gen = decode_tokens_to_params(tokens_gen, vq_model, mean, std)
473
+ verts_gt, faces = params_to_vertices(params_gt, smplx_model)
474
+ verts_gen, _ = params_to_vertices(params_gen, smplx_model)
475
+ out_dir = os.path.join(output_dir)
476
+ os.makedirs(out_dir, exist_ok=True)
477
+ # Sanitize for filesystem safety
478
+ safe_word = re.sub(r'[^A-Za-z0-9_-]+', '_', str(word))
479
+ safe_pid = re.sub(r'[^A-Za-z0-9_-]+', '_', str(pid))
480
+ output_html = os.path.join(out_dir, f"word_{safe_word}_{safe_pid}_side_by_side.html")
481
+ animate_side_by_side(
482
+ verts_left=verts_gt,
483
+ faces=faces,
484
+ verts_right=verts_gen,
485
+ fps=20,
486
+ titles=("Ground Truth", "LLM Generated"),
487
+ output_html=output_html
488
+ )
489
+ except Exception as exc:
490
+ print(f"⚠️ Error creating visualization for word '{pair[0]}': {exc}")
491
+
492
+ def run_inference_on_all_samples(model, tokenizer, data, device):
493
+ """
494
+ Runs inference on ALL available samples for the trained words and compares
495
+ each one to its specific ground truth.
496
+ """
497
+ print("\n" + "="*80)
498
+ print(" INFERENCE AND EVALUATION (ALL SAMPLES)")
499
+ print(" Goal: Test the model's performance on every variant.")
500
+ print("="*80)
501
+
502
+ def compare_sequences(gt: str, gen: str):
503
+ """Provides a simple visual diff of two sequences without external libraries."""
504
+ gt_tokens = gt.split()
505
+ gen_tokens = gen.split()
506
+
507
+ print("\nDetailed Comparison (✅ = Match, ❌ = Mismatch/Missing/Added):")
508
+
509
+ gt_str = " GT: "
510
+ gen_str = " GEN: "
511
+ diff_str = " "
512
+
513
+ max_len = max(len(gt_tokens), len(gen_tokens))
514
+
515
+ for i in range(max_len):
516
+ gt_tok = gt_tokens[i] if i < len(gt_tokens) else "___"
517
+ gen_tok = gen_tokens[i] if i < len(gen_tokens) else "___"
518
+
519
+ max_tok_len = max(len(gt_tok), len(gen_tok))
520
+ gt_tok_padded = gt_tok.ljust(max_tok_len)
521
+ gen_tok_padded = gen_tok.ljust(max_tok_len)
522
+
523
+ gt_str += gt_tok_padded + " "
524
+ gen_str += gen_tok_padded + " "
525
+
526
+ if gt_tok == gen_tok:
527
+ diff_str += "✅".ljust(max_tok_len) + " "
528
+ else:
529
+ diff_str += "❌".ljust(max_tok_len) + " "
530
+
531
+ print(gt_str)
532
+ print(gen_str)
533
+ print(diff_str)
534
+
535
+ data_by_word = {}
536
+ for item in data:
537
+ word = item['word']
538
+ if word not in data_by_word:
539
+ data_by_word[word] = []
540
+ data_by_word[word].append(item)
541
+
542
+ for word, samples in data_by_word.items():
543
+ print(f"\n\n{'='*25} TESTING WORD: '{word}' {'='*25}")
544
+ num_correct = 0
545
+
546
+ for i, sample in enumerate(samples):
547
+ print(f"\n--- Testing Variant {i+1}/{len(samples)}: '{sample['participant_id']}' ---")
548
+
549
+ gt_tokens_str = sample.get("motion_tokens", "")
550
+ gt_wrapped = " ".join([f"<M{t}>" for t in gt_tokens_str.split()])
551
+ gt_sequence = f"{M_START} {gt_wrapped} {M_END}"
552
+ print(f"Ground Truth:\n{gt_sequence}")
553
+
554
+ prompt = f"Instruction: Generate motion for word '{sample['word']}' with variant '{sample['participant_id']}'.\nMotion: "
555
+ generated_sequence = generate_motion(model, tokenizer, prompt, device)
556
+ print(f"\nLLM Generated:\n{generated_sequence}")
557
+
558
+ compare_sequences(gt_sequence, generated_sequence)
559
+
560
+ if gt_sequence.strip() == generated_sequence.strip():
561
+ num_correct += 1
562
+
563
+ print("-" * 80)
564
+
565
+ accuracy = (num_correct / len(samples)) * 100
566
+ print(f"\nSUMMARY FOR '{word}': {num_correct}/{len(samples)} correct ({accuracy:.1f}%)")
567
+
568
+
569
+ # ======================================================================================
570
+ # Existing Utilities (Compatibility)
571
+ # ======================================================================================
572
+ def seq_edit_distance(a_ids: List[int], b_ids: List[int]) -> int:
573
+ """Token-level Levenshtein distance"""
574
+ return Levenshtein.distance(a_ids, b_ids)
575
+
576
+ def best_ref_distance(pred_ids: List[int], refs: List[List[int]]) -> int:
577
+ """Find minimum edit distance to any reference"""
578
+ if not refs:
579
+ return len(pred_ids)
580
+ return min(seq_edit_distance(pred_ids, r) for r in refs)
581
+
582
+ def build_text_to_refs(dataset):
583
+ """
584
+ Build mapping from text prompts to list of reference motion sequences
585
+ """
586
+ text_to_refs = defaultdict(list)
587
+ for ex in dataset:
588
+ text_to_refs[ex["text_query"]].append(
589
+ [int(x) for x in ex["motion_tokens"].split()]
590
+ )
591
+ return text_to_refs
592
+
593
+ def _concat(ids_list: List[List[int]]) -> List[int]:
594
+ out = []
595
+ for s in ids_list:
596
+ out.extend(s)
597
+ return out
598
+
599
+ def _distinct_n(ids_list: List[List[int]], n: int) -> float:
600
+ if n <= 0:
601
+ return 0.0
602
+ total = 0
603
+ uniq = set()
604
+ for seq in ids_list:
605
+ if len(seq) < n:
606
+ continue
607
+ total += (len(seq) - n + 1)
608
+ for i in range(len(seq) - n + 1):
609
+ uniq.add(tuple(seq[i:i+n]))
610
+ if total == 0:
611
+ return 0.0
612
+ return len(uniq) / float(total)
613
+
614
+ def token_fid_diag(gens: List[List[int]], refs: List[List[int]], codebook_size: int) -> float:
615
+ """
616
+ Diagonal-covariance Fréchet distance between histograms of token usage.
617
+ This is a lightweight proxy for FID using token distributions.
618
+ """
619
+ if len(gens) == 0 or len(refs) == 0:
620
+ return float("nan")
621
+
622
+ def feats(batch: List[List[int]]) -> np.ndarray:
623
+ mats = []
624
+ for seq in batch:
625
+ hist = np.bincount([x for x in seq if 0 <= x < codebook_size], minlength=codebook_size).astype(np.float64)
626
+ s = hist.sum()
627
+ if s > 0:
628
+ hist /= s
629
+ mats.append(hist)
630
+ return np.stack(mats, axis=0)
631
+
632
+ G = feats(gens)
633
+ R = feats(refs)
634
+ mu_g = G.mean(axis=0)
635
+ mu_r = R.mean(axis=0)
636
+ var_g = G.var(axis=0)
637
+ var_r = R.var(axis=0)
638
+ mean_term = np.sum((mu_g - mu_r) ** 2)
639
+ # Diagonal covariance approximation
640
+ cov_term = np.sum(var_g + var_r - 2.0 * np.sqrt(np.clip(var_g * var_r, 0.0, None)))
641
+ return float(mean_term + cov_term)
642
+
643
+ def compute_token_metrics(
644
+ gen_by_text: Dict[str, List[int]],
645
+ text_to_refs: Dict[str, List[List[int]]],
646
+ codebook_size: int,
647
+ ) -> Dict[str, float]:
648
+ """
649
+ Compute token-level metrics:
650
+ - FID_diag: Fréchet distance between token histograms (diag cov)
651
+ - MIM: average min edit distance to references
652
+ - Diversity: distinct-1 and distinct-2
653
+ """
654
+ gens = list(gen_by_text.values())
655
+ refs_all = _concat([v for v in text_to_refs.values()])
656
+ # refs_all is concatenated list of ids; split sequences are needed
657
+ ref_seqs = [r for refs in text_to_refs.values() for r in refs]
658
+
659
+ fid_diag = token_fid_diag(gens, ref_seqs, codebook_size)
660
+
661
+ # MIM: average best edit distance per prompt (only over prompts we generated)
662
+ mim_dists = []
663
+ for text, gen_ids in gen_by_text.items():
664
+ refs = text_to_refs.get(text, [])
665
+ mim_dists.append(best_ref_distance(gen_ids, refs))
666
+ mim = float(sum(mim_dists) / len(mim_dists)) if mim_dists else float("nan")
667
+
668
+ div1 = _distinct_n(gens, 1)
669
+ div2 = _distinct_n(gens, 2)
670
+
671
+ return {
672
+ "FID_diag": fid_diag,
673
+ "MIM": mim,
674
+ "distinct_1": div1,
675
+ "distinct_2": div2,
676
+ }
677
+
678
+ def eval_t2m_set(
679
+ model,
680
+ tokenizer,
681
+ sample_pairs: List[Tuple[str, List[List[int]]]],
682
+ mot_begin_id: int,
683
+ mot_end_id: int,
684
+ motion_token_ids: list,
685
+ length_stats_by_text: dict,
686
+ global_median_len: int,
687
+ prompt_vocab: dict = None,
688
+ has_pid: bool = False,
689
+ per_prompt_vocab: bool = True,
690
+ n_eval: int = 100
691
+ ):
692
+ """
693
+ Evaluate text-to-motion generation on a set of samples
694
+ Returns a compact dict with avg_edit_dist & median_len; kept for pipeline compatibility.
695
+ """
696
+ random.shuffle(sample_pairs)
697
+ subset = sample_pairs[:min(n_eval, len(sample_pairs))]
698
+
699
+ dists = []
700
+ lens = []
701
+
702
+ for text, ref_list in subset:
703
+ gen = generate_t2m(
704
+ model=model,
705
+ tokenizer=tokenizer,
706
+ prompt_text=text,
707
+ mot_begin_id=mot_begin_id,
708
+ mot_end_id=mot_end_id,
709
+ motion_token_ids=motion_token_ids,
710
+ length_stats_by_text=length_stats_by_text,
711
+ global_median_len=global_median_len,
712
+ prompt_vocab=prompt_vocab,
713
+ pid=None,
714
+ has_pid=has_pid,
715
+ per_prompt_vocab=per_prompt_vocab
716
+ )
717
+ span = gen.split("<MOT_BEGIN>")[-1]
718
+ span = span.split("<MOT_END>")[0]
719
+ pred_ids = motion_specials_to_ids(span)
720
+ d = best_ref_distance(pred_ids, ref_list)
721
+ dists.append(d)
722
+ lens.append(len(pred_ids))
723
+
724
+ if dists:
725
+ avg_dist = sum(dists) / len(dists)
726
+ median_len = sorted(lens)[len(lens)//2] if lens else 0
727
+ print(f"Eval T2M: avg_edit_dist={avg_dist:.2f}, median_len={median_len}, n={len(dists)}")
728
+ return {"avg_edit_dist": avg_dist, "median_len": median_len, "n_samples": len(dists)}
729
+ else:
730
+ print("Eval T2M: no samples")
731
+ return {}
model.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model and tokenizer initialization
3
+ """
4
+ import torch
5
+ from typing import List, Set, Tuple
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from unsloth import FastLanguageModel
8
+ from config import (
9
+ MODEL_NAME, MAX_SEQ_LEN, DTYPE,
10
+ LORA_R, LORA_ALPHA, LORA_DROPOUT,
11
+ LORA_TARGET_MODULES, LORA_MODULES_TO_SAVE,
12
+ PAD_TOKEN, M_START, M_END
13
+ )
14
+
15
+ # ======================================================================================
16
+ # Logic from test_overfit.py (Standard Transformers)
17
+ # ======================================================================================
18
+
19
+ def setup_model_and_tokenizer_raw(model_name: str, motion_tokens: List[str]) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
20
+ """Loads the model and tokenizer, adding special and motion tokens (Standard Transformers)."""
21
+ print(f"\n---> Loading base model and tokenizer: {model_name}")
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
23
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
24
+
25
+ # Add special tokens (matches test_overfit.py)
26
+ tokenizer.add_special_tokens({"pad_token": PAD_TOKEN, "additional_special_tokens": [M_START, M_END]})
27
+
28
+ print(f"Adding {len(motion_tokens)} motion tokens to the tokenizer.")
29
+ tokenizer.add_tokens(motion_tokens, special_tokens=True)
30
+
31
+ model.resize_token_embeddings(len(tokenizer))
32
+ model.config.pad_token_id = tokenizer.pad_token_id
33
+
34
+ return model, tokenizer
35
+
36
+ def ensure_tokenizer_has_motion_tokens(tokenizer: AutoTokenizer, motion_tokens: List[str]) -> int:
37
+ """
38
+ Adds any missing motion tokens to the tokenizer. Returns number of tokens added.
39
+ """
40
+ tokenizer.add_special_tokens({"pad_token": PAD_TOKEN, "additional_special_tokens": [M_START, M_END]})
41
+ added = tokenizer.add_tokens(motion_tokens, special_tokens=True)
42
+ return added
43
+
44
+ # ======================================================================================
45
+ # Existing Logic (Unsloth / LoRA)
46
+ # ======================================================================================
47
+
48
+ def build_special_tokens(codebook_size: int, unique_pids: List[str] = None) -> List[str]:
49
+ """
50
+ Build all special tokens for motion vocabulary
51
+ """
52
+ # Motion tokens
53
+ motion_tokens = [f"<motion_{i}>" for i in range(codebook_size)]
54
+
55
+ # Boundary tokens
56
+ boundary_tokens = ["<MOT_BEGIN>", "<MOT_END>"]
57
+
58
+ # Task tokens
59
+ task_tokens = ["<T2M>", "<M2T>", "<DENOISE>", "<MOTION_MASK>"]
60
+
61
+ # Participant ID tokens
62
+ pid_tokens = []
63
+ if unique_pids:
64
+ pid_tokens = ["<PID_NULL>"] + [f"<PID_{pid}>" for pid in unique_pids]
65
+
66
+ return boundary_tokens + motion_tokens + task_tokens + pid_tokens
67
+
68
+
69
+ def setup_model_and_tokenizer(codebook_size: int, unique_pids: List[str] = None):
70
+ """
71
+ Initialize model and tokenizer with custom tokens (Unsloth LoRA)
72
+ Returns: (model, tokenizer, new_token_ids)
73
+ """
74
+ # Build special tokens
75
+ additional_special_tokens = build_special_tokens(codebook_size, unique_pids)
76
+
77
+ # Load base model
78
+ model, tokenizer = FastLanguageModel.from_pretrained(
79
+ model_name=MODEL_NAME,
80
+ max_seq_length=MAX_SEQ_LEN,
81
+ dtype=DTYPE,
82
+ load_in_4bit=False,
83
+ trust_remote_code=True,
84
+ )
85
+
86
+ # Configure tokenizer
87
+ tokenizer.padding_side = "right"
88
+
89
+ # Add special tokens
90
+ existing = set(tokenizer.special_tokens_map_extended.get("additional_special_tokens", []))
91
+ to_add = [t for t in additional_special_tokens if t not in existing]
92
+
93
+ if to_add:
94
+ tokenizer.add_special_tokens({"additional_special_tokens": to_add})
95
+
96
+ if tokenizer.pad_token is None:
97
+ tokenizer.pad_token = tokenizer.eos_token
98
+
99
+ # Resize embeddings
100
+ model.resize_token_embeddings(len(tokenizer))
101
+
102
+ # Apply LoRA
103
+ model = FastLanguageModel.get_peft_model(
104
+ model,
105
+ r=LORA_R,
106
+ lora_alpha=LORA_ALPHA,
107
+ lora_dropout=LORA_DROPOUT,
108
+ bias="none",
109
+ target_modules=LORA_TARGET_MODULES,
110
+ modules_to_save=LORA_MODULES_TO_SAVE,
111
+ use_gradient_checkpointing="unsloth",
112
+ )
113
+
114
+ # Get new token IDs for gradient masking
115
+ new_token_ids = set(tokenizer.convert_tokens_to_ids(additional_special_tokens))
116
+
117
+ # Apply gradient mask to prevent base vocab drift
118
+ apply_gradient_mask(model, new_token_ids)
119
+
120
+ return model, tokenizer, new_token_ids
121
+
122
+
123
+ def apply_gradient_mask(model, new_token_ids: Set[int]):
124
+ """
125
+ Apply gradient mask so only new token embeddings are updated
126
+ """
127
+ def mask_rows_hook(param, rows: set):
128
+ mask = torch.zeros(param.size(0), device=param.device, dtype=param.dtype)
129
+ idxs = sorted(list(rows))
130
+ if len(idxs) > 0:
131
+ mask[idxs] = 1.0
132
+ param.register_hook(lambda g: g * mask.unsqueeze(1))
133
+
134
+ with torch.no_grad():
135
+ emb = model.get_input_embeddings().weight
136
+ head = model.get_output_embeddings().weight
137
+
138
+ mask_rows_hook(emb, new_token_ids)
139
+ mask_rows_hook(head, new_token_ids)
140
+
141
+
142
+ def get_motion_token_info(tokenizer, codebook_size: int):
143
+ """
144
+ Get motion token IDs and boundary token IDs
145
+ Returns: (motion_token_ids, mot_begin_id, mot_end_id)
146
+ """
147
+ motion_token_strs = [f"<motion_{i}>" for i in range(codebook_size)]
148
+ motion_token_ids = tokenizer.convert_tokens_to_ids(motion_token_strs)
149
+ mot_begin_id = tokenizer.convert_tokens_to_ids("<MOT_BEGIN>")
150
+ mot_end_id = tokenizer.convert_tokens_to_ids("<MOT_END>")
151
+
152
+ return motion_token_ids, mot_begin_id, mot_end_id
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ accelerate
5
+ numpy
6
+ scipy
7
+ rapidfuzz
8
+ huggingface_hub
9
+ plotly
10
+ smplx
11
+ # Core dependencies
12
+ torch>=2.0.0
13
+ transformers>=4.40.0
14
+ datasets>=2.14.0
15
+ accelerate>=0.20.0
16
+
17
+ # Unsloth for efficient training
18
+ unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
19
+
20
+ # Training utilities
21
+ bitsandbytes>=0.41.0
22
+ peft>=0.4.0
23
+ trl>=0.4.7
24
+
25
+ # Evaluation
26
+ rapidfuzz>=3.0.0
27
+
28
+ # Utilities
29
+ numpy>=1.24.0
30
+ tqdm>=4.65.0
31
+ huggingface_hub>=0.22.0
32
+ gdown>=5.2.0
setup_env.sh ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -euo pipefail
4
+
5
+ # Usage:
6
+ # bash setup_env.sh
7
+ #
8
+ # - Installs Python dependencies from requirements.txt
9
+ # - Downloads a public Google Drive dataset file into ./data/motion_llm_dataset.json
10
+ # - Exports env vars for this session (optional) and prints instructions
11
+
12
+ THIS_DIR="$(pwd)"
13
+ DATA_DIR="$THIS_DIR/data"
14
+ mkdir -p "$DATA_DIR"
15
+
16
+ # --- Explicit placeholders (replace these later) ---
17
+ # Training dataset
18
+ GDRIVE_ID="11711RgTmzauXpYVFoqLF8DZXiZlZovfn"
19
+
20
+ # Visualization assets (optional - only needed for visualize.py)
21
+ VQVAE_MODEL_ID="1JEMKVZWFG4Ue7k3Nm7q1o7-uBVsVricY"
22
+ VQVAE_STATS_ID="1WTwP5DdBl4c-X5Kj7jXtlEHofOX2BifZ"
23
+ SMPLX_MODELS_ID="1tZEfqw9zHgOaBEw5X_oazAEnesRtE9ky"
24
+
25
+ # Hugging Face token
26
+ HF_TOKEN_IN=""
27
+ # ---------------------------------------------------
28
+
29
+ echo "Installing Python dependencies..."
30
+ python -m pip install --upgrade pip
31
+ pip install -r requirements.txt
32
+
33
+ if [[ -n "$GDRIVE_ID" ]] && [[ "$GDRIVE_ID" != "YOUR_GOOGLE_DRIVE_FILE_ID_HERE" ]]; then
34
+ echo "Downloading training dataset from Google Drive (file id: $GDRIVE_ID)..."
35
+ gdown --id "$GDRIVE_ID" -O "$DATA_DIR/motion_llm_dataset.json"
36
+ else
37
+ echo "No training dataset Google Drive ID provided. Skipping dataset download."
38
+ fi
39
+
40
+ # Download visualization assets if IDs are provided
41
+ if [[ -n "$VQVAE_MODEL_ID" ]] && [[ "$VQVAE_MODEL_ID" != "YOUR_VQVAE_CHECKPOINT_GDRIVE_ID_HERE" ]]; then
42
+ echo "Downloading VQ-VAE model from Google Drive (file id: $VQVAE_MODEL_ID)..."
43
+ gdown --id "$VQVAE_MODEL_ID" -O "$DATA_DIR/vqvae_model.pt"
44
+ fi
45
+
46
+ if [[ -n "$VQVAE_STATS_ID" ]] && [[ "$VQVAE_STATS_ID" != "YOUR_VQVAE_STATS_GDRIVE_ID_HERE" ]]; then
47
+ echo "Downloading VQ-VAE stats from Google Drive (file id: $VQVAE_STATS_ID)..."
48
+ gdown --id "$VQVAE_STATS_ID" -O "$DATA_DIR/vqvae_stats.pt"
49
+ fi
50
+
51
+ if [[ -n "$SMPLX_MODELS_ID" ]] && [[ "$SMPLX_MODELS_ID" != "YOUR_SMPLX_MODELS_GDRIVE_ID_HERE" ]]; then
52
+ echo "Downloading SMPL-X neutral model (.npz) from Google Drive (file id: $SMPLX_MODELS_ID)..."
53
+ mkdir -p "$DATA_DIR/smplx_models"
54
+ gdown --id "$SMPLX_MODELS_ID" -O "$DATA_DIR/smplx_models/SMPLX_NEUTRAL.npz"
55
+ echo "Saved SMPLX_NEUTRAL.npz to $DATA_DIR/smplx_models"
56
+ fi
57
+
58
+ if [[ -n "$HF_TOKEN_IN" ]]; then
59
+ echo "Exporting HUGGINGFACE_HUB_TOKEN for this shell session..."
60
+ export HUGGINGFACE_HUB_TOKEN="$HF_TOKEN_IN"
61
+ fi
62
+
63
+ echo
64
+ echo "Environment setup complete."
65
+ echo "- WORK_DIR defaults to: $THIS_DIR"
66
+ echo "- DATA_JSON_PATH defaults to: $DATA_DIR/motion_llm_dataset.json"
67
+ echo "- To persist HF token, set an environment variable before running:"
68
+ echo " export HUGGINGFACE_HUB_TOKEN=hf_..."
69
+ echo
70
+ echo "You can now run your training scripts."
templates.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt templates and mapping functions for different training stages
3
+ """
4
+ import random
5
+ from data import ids_to_motion_specials
6
+ from config import SYSTEM_MSG, SEED
7
+
8
+ random.seed(SEED)
9
+
10
+
11
+ def pid_token_from_example(ex, has_pid: bool):
12
+ """Get participant ID token from example"""
13
+ if not has_pid:
14
+ return ""
15
+
16
+ pid = ex.get("participant_id", None)
17
+ if pid is not None:
18
+ return f"<PID_{pid}>"
19
+ return "<PID_NULL>"
20
+
21
+
22
+ def map_stage1(ex, has_pid: bool):
23
+ """
24
+ Stage 1: Word + optional PID conditioning to learn motion language.
25
+ The user explicitly provides the word (+PID); assistant outputs motion span.
26
+ """
27
+ mot = ids_to_motion_specials(ex["motion_tokens"])
28
+ assistant = f"<MOT_BEGIN> {mot} <MOT_END>"
29
+ pid_tok = pid_token_from_example(ex, has_pid)
30
+ word = ex.get("word", ex.get("text_query", ""))
31
+
32
+ # Word + PID conditioning (no natural language chatter to keep it compact)
33
+ user = f"<T2M>{pid_tok}\nword: {word}"
34
+ text = (
35
+ "<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
36
+ + "<|im_start|>user\n" + user + "\n<|im_end|>\n"
37
+ + "<|im_start|>assistant\n" + assistant + "\n<|im_end|>\n"
38
+ )
39
+
40
+ return {"text": text, "where": "mot"}
41
+
42
+
43
+ def map_stage2(ex, has_pid: bool):
44
+ """
45
+ Stage 2: Multi-task (T2M/M2T/DENOISE)
46
+ Randomly choose between text-to-motion, motion-to-text, or denoising
47
+ """
48
+ t = ex["text_query"]
49
+ mot = ids_to_motion_specials(ex["motion_tokens"])
50
+ pid_tok = pid_token_from_example(ex, has_pid)
51
+
52
+ # Sample task type
53
+ task = random.choices(["t2m", "m2t", "denoise"], weights=[0.5, 0.3, 0.2], k=1)[0]
54
+
55
+ if task == "t2m":
56
+ # Text to motion
57
+ assistant = f"<MOT_BEGIN> {mot} <MOT_END>"
58
+ text = (
59
+ "<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
60
+ + "<|im_start|>user\n" + f"<T2M>{pid_tok}\n\n" + t + "\n<|im_end|>\n"
61
+ + "<|im_start|>assistant\n" + assistant + "\n<|im_end|>\n"
62
+ )
63
+ where = "mot"
64
+
65
+ elif task == "m2t":
66
+ # Motion to text
67
+ user = f"<M2T>{pid_tok}\n\n<MOT_BEGIN> {mot} <MOT_END>"
68
+ text = (
69
+ "<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
70
+ + "<|im_start|>user\n" + user + "\n<|im_end|>\n"
71
+ + "<|im_start|>assistant\n" + t + "\n<|im_end|>\n"
72
+ )
73
+ where = "text"
74
+
75
+ else:
76
+ # Denoising
77
+ toks = mot.split()
78
+ noisy = []
79
+ for tok in toks:
80
+ if random.random() < 0.15:
81
+ noisy.append("<MOTION_MASK>")
82
+ else:
83
+ noisy.append(tok)
84
+
85
+ user = f"<DENOISE>{pid_tok}\n\n<MOT_BEGIN> {' '.join(noisy)} <MOT_END>"
86
+ assistant = f"<MOT_BEGIN> {mot} <MOT_END>"
87
+ text = (
88
+ "<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
89
+ + "<|im_start|>user\n" + user + "\n<|im_end|>\n"
90
+ + "<|im_start|>assistant\n" + assistant + "\n<|im_end|>\n"
91
+ )
92
+ where = "mot"
93
+
94
+ return {"text": text, "where": where, "text_query": t}
95
+
96
+
97
+ def map_stage3(ex, has_pid: bool):
98
+ """
99
+ Stage 3 (Instruct): Word-only request, no participant ID.
100
+ The system prompt directs: "Output motion tokens for the given word".
101
+ """
102
+ t = ex["text_query"]
103
+ mot = ids_to_motion_specials(ex["motion_tokens"])
104
+ assistant = f"<MOT_BEGIN> {mot} <MOT_END>"
105
+
106
+ # Instruct-style, no PID
107
+ user = f"<T2M>\nword: {t}"
108
+ text = (
109
+ "<|im_start|>system\n" + SYSTEM_MSG + "<|im_end|>\n"
110
+ + "<|im_start|>user\n" + user + "\n<|im_end|>\n"
111
+ + "<|im_start|>assistant\n" + assistant + "\n<|im_end|>\n"
112
+ )
113
+
114
+ return {
115
+ "text": text,
116
+ "where": "mot",
117
+ "text_query": t,
118
+ "motion_tokens": ex["motion_tokens"]
119
+ }
120
+
121
+
122
+ def create_mapper(stage: int, has_pid: bool):
123
+ """
124
+ Create a mapper function for a specific stage
125
+ """
126
+ if stage == 1:
127
+ return lambda ex: map_stage1(ex, has_pid)
128
+ elif stage == 2:
129
+ return lambda ex: map_stage2(ex, has_pid)
130
+ elif stage == 3:
131
+ return lambda ex: map_stage3(ex, has_pid)
132
+ else:
133
+ raise ValueError(f"Unknown stage: {stage}")
test_dataset_eval.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate the SignMotionGPT model on a held-out SMPL-X test dataset.
3
+
4
+ The script can download Google Drive archives or consume an already extracted
5
+ directory of `video_data.pkl` files. Each sequence is converted into encoder
6
+ features via the project VQ-VAE utilities and compared against motions generated
7
+ by the LLM to compute FID/Diversity/Multimodality metrics.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import json
14
+ import os
15
+ import pickle
16
+ import random
17
+ import sys
18
+ import zipfile
19
+ from typing import Dict, List, Optional, Tuple
20
+
21
+ import numpy as np
22
+ import torch
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer
24
+
25
+ from config import (
26
+ TEST_EVAL_DOWNLOAD_DIR,
27
+ TEST_EVAL_EXTRACT_DIR,
28
+ TEST_EVAL_HF_REPO,
29
+ TEST_EVAL_HF_SUBFOLDER,
30
+ TEST_EVAL_MAX_ZIPS,
31
+ TEST_EVAL_OUTPUT_DIR,
32
+ TEST_EVAL_SAMPLE_LIMIT,
33
+ )
34
+
35
+ M_START = "<M_START>"
36
+ M_END = "<M_END>"
37
+ PAD_TOKEN = "<PAD>"
38
+
39
+ INFERENCE_REPETITION_PENALTY = 1.2
40
+ INFERENCE_TEMPERATURE = 0.7
41
+ INFERENCE_TOP_K = 50
42
+
43
+
44
+ # -----------------------------------------------------------------------------
45
+ # Download / extraction helpers
46
+ # -----------------------------------------------------------------------------
47
+ def try_import_gdown() -> bool:
48
+ try:
49
+ import gdown # noqa: F401
50
+
51
+ return True
52
+ except Exception:
53
+ return False
54
+
55
+
56
+ def download_drive_folder(folder_url_or_id: str, dest_dir: str) -> None:
57
+ os.makedirs(dest_dir, exist_ok=True)
58
+ if not try_import_gdown():
59
+ raise RuntimeError("gdown is required for Drive downloads. Install with `pip install gdown`.")
60
+ import gdown
61
+
62
+ if "drive.google.com" in folder_url_or_id:
63
+ url = folder_url_or_id
64
+ else:
65
+ url = f"https://drive.google.com/drive/folders/{folder_url_or_id}"
66
+ print(f"Downloading Drive folder to {dest_dir} ...")
67
+ gdown.download_folder(url=url, output=dest_dir, quiet=False, use_cookies=False)
68
+ print("Download complete.")
69
+
70
+
71
+ def list_zip_files(download_dir: str) -> List[str]:
72
+ matches: List[str] = []
73
+ for root, _dirs, files in os.walk(download_dir):
74
+ for name in files:
75
+ if name.lower().endswith(".zip"):
76
+ matches.append(os.path.join(root, name))
77
+ return sorted(matches)
78
+
79
+
80
+ def extract_zip_files(zip_paths: List[str], extract_dir: str, limit: Optional[int]) -> List[str]:
81
+ os.makedirs(extract_dir, exist_ok=True)
82
+ extracted_roots: List[str] = []
83
+ for idx, zp in enumerate(zip_paths):
84
+ if limit is not None and idx >= limit:
85
+ break
86
+ try:
87
+ with zipfile.ZipFile(zp, "r") as archive:
88
+ subdir = os.path.splitext(os.path.basename(zp))[0]
89
+ target = os.path.join(extract_dir, subdir)
90
+ os.makedirs(target, exist_ok=True)
91
+ archive.extractall(target)
92
+ extracted_roots.append(target)
93
+ except Exception as exc:
94
+ print(f"⚠️ Failed to extract {zp}: {exc}")
95
+ print(f"Extracted {len(extracted_roots)} archives.")
96
+ return extracted_roots
97
+
98
+
99
+ def find_video_pkl_paths(extracted_root: str) -> List[str]:
100
+ matches: List[str] = []
101
+ for root, _dirs, files in os.walk(extracted_root):
102
+ for name in files:
103
+ if name == "video_data.pkl":
104
+ matches.append(os.path.join(root, name))
105
+ return matches
106
+
107
+
108
+ def parse_word_from_path(path: str) -> str:
109
+ base = os.path.basename(os.path.dirname(path))
110
+ if "-" in base:
111
+ word = base.split("-", 1)[1]
112
+ else:
113
+ word = base
114
+ return word.strip().lower()
115
+
116
+
117
+ # -----------------------------------------------------------------------------
118
+ # SMPL-X helpers
119
+ # -----------------------------------------------------------------------------
120
+ def try_to_array(value) -> Optional[np.ndarray]:
121
+ if isinstance(value, np.ndarray):
122
+ return value
123
+ try:
124
+ return np.asarray(value)
125
+ except Exception:
126
+ return None
127
+
128
+
129
+ def load_smplx_params_from_pkl(pkl_path: str) -> Optional[np.ndarray]:
130
+ try:
131
+ with open(pkl_path, "rb") as handle:
132
+ payload = pickle.load(handle)
133
+ except Exception as exc:
134
+ print(f"⚠️ Could not read {pkl_path}: {exc}")
135
+ return None
136
+
137
+ if not isinstance(payload, (list, tuple)) or len(payload) == 0:
138
+ return None
139
+
140
+ def get_vec(frame: dict, key: str, expected: int, allow_trim: bool = True) -> np.ndarray:
141
+ val = frame.get(key)
142
+ arr = try_to_array(val)
143
+ if arr is None:
144
+ return np.zeros((expected,), dtype=np.float32)
145
+ arr = np.array(arr, dtype=np.float32).reshape(-1)
146
+ if arr.size == expected:
147
+ return arr
148
+ if allow_trim and arr.size > expected:
149
+ if key == "body_pose" and arr.size == 66 and expected == 63:
150
+ return arr[3:3 + 63]
151
+ return arr[:expected]
152
+ if arr.size < expected:
153
+ out = np.zeros((expected,), dtype=np.float32)
154
+ out[: arr.size] = arr
155
+ return out
156
+ return arr[:expected]
157
+
158
+ sequences: List[np.ndarray] = []
159
+ for frame in payload:
160
+ if not isinstance(frame, dict):
161
+ continue
162
+ vec = np.concatenate(
163
+ [
164
+ get_vec(frame, "shape", 10),
165
+ get_vec(frame, "body_pose", 63),
166
+ get_vec(frame, "lhand_pose", 45),
167
+ get_vec(frame, "rhand_pose", 45),
168
+ get_vec(frame, "cam_trans", 3),
169
+ get_vec(frame, "expression", 10),
170
+ get_vec(frame, "jaw_pose", 3),
171
+ np.zeros((3,), dtype=np.float32), # eye pose placeholder
172
+ ],
173
+ axis=0,
174
+ )
175
+ sequences.append(vec)
176
+ if not sequences:
177
+ return None
178
+ return np.stack(sequences, axis=0).astype(np.float32)
179
+
180
+
181
+ def import_visualize_helpers():
182
+ try:
183
+ from visualize import (
184
+ load_vqvae,
185
+ load_stats,
186
+ decode_tokens_to_params,
187
+ VQVAE_CHECKPOINT as DEFAULT_VQ,
188
+ STATS_PATH as DEFAULT_STATS,
189
+ )
190
+
191
+ return load_vqvae, load_stats, decode_tokens_to_params, DEFAULT_VQ, DEFAULT_STATS
192
+ except Exception as exc:
193
+ raise RuntimeError(f"Failed to import visualize helpers: {exc}") from exc
194
+
195
+
196
+ def _encode_params_to_feature(
197
+ params: np.ndarray,
198
+ vq_model,
199
+ mean,
200
+ std,
201
+ device: torch.device,
202
+ ) -> Optional[np.ndarray]:
203
+ if params is None or params.size == 0:
204
+ return None
205
+ clip = torch.from_numpy(params.astype(np.float32)).unsqueeze(0).to(device)
206
+ with torch.no_grad():
207
+ x_pre = None
208
+ if hasattr(vq_model.vqvae, "preprocess"):
209
+ try:
210
+ x_pre = vq_model.vqvae.preprocess(clip)
211
+ except Exception:
212
+ x_pre = None
213
+ if x_pre is None:
214
+ if mean is not None and std is not None:
215
+ mean_t = torch.from_numpy(np.array(mean, dtype=np.float32)).to(device).view(1, 1, -1)
216
+ std_t = torch.from_numpy(np.array(std, dtype=np.float32)).to(device).view(1, 1, -1)
217
+ clip = (clip - mean_t) / (std_t + 1e-8)
218
+ x_pre = clip.transpose(1, 2).contiguous()
219
+ latent = vq_model.vqvae.encoder(x_pre)
220
+ if latent.dim() == 3:
221
+ embed_dim = getattr(vq_model.vqvae, "output_emb_width", None)
222
+ if embed_dim is not None:
223
+ if latent.shape[1] == embed_dim:
224
+ axis = 2
225
+ elif latent.shape[2] == embed_dim:
226
+ axis = 1
227
+ else:
228
+ axis = 2 if latent.shape[2] < latent.shape[1] else 1
229
+ else:
230
+ axis = 2 if latent.shape[2] < latent.shape[1] else 1
231
+ feat = latent.mean(dim=axis).squeeze(0)
232
+ elif latent.dim() == 2:
233
+ feat = latent.squeeze(0)
234
+ else:
235
+ feat = latent.view(1, -1).mean(dim=0)
236
+ vec = feat.detach().cpu().numpy().astype(np.float32)
237
+ norm = np.linalg.norm(vec)
238
+ if norm > 0:
239
+ vec = vec / norm
240
+ return vec
241
+
242
+
243
+ # -----------------------------------------------------------------------------
244
+ # Metrics helpers
245
+ # -----------------------------------------------------------------------------
246
+ def calculate_activation_statistics_np(activations: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
247
+ mu = np.mean(activations, axis=0)
248
+ cov = np.cov(activations, rowvar=False)
249
+ return mu, cov
250
+
251
+
252
+ def calculate_frechet_distance_np(mu1, sigma1, mu2, sigma2, eps=1e-6) -> float:
253
+ from scipy.linalg import sqrtm
254
+
255
+ mu1 = np.atleast_1d(mu1)
256
+ mu2 = np.atleast_1d(mu2)
257
+ sigma1 = np.atleast_2d(sigma1)
258
+ sigma2 = np.atleast_2d(sigma2)
259
+ assert mu1.shape == mu2.shape, "Mean vectors must match"
260
+ assert sigma1.shape == sigma2.shape, "Covariance matrices must match"
261
+ diff = mu1 - mu2
262
+ covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
263
+ if not np.isfinite(covmean).all():
264
+ offset = np.eye(sigma1.shape[0]) * eps
265
+ covmean = sqrtm((sigma1 + offset).dot(sigma2 + offset))
266
+ if np.iscomplexobj(covmean):
267
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
268
+ raise ValueError("Covmean contains large imaginary components")
269
+ covmean = covmean.real
270
+ return float(diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean))
271
+
272
+
273
+ def calculate_diversity_np(activation: np.ndarray, diversity_times: int = 200) -> float:
274
+ assert activation.ndim == 2
275
+ n = activation.shape[0]
276
+ if n < 2:
277
+ return float("nan")
278
+ times = min(diversity_times, max(1, n - 1))
279
+ idx1 = np.random.choice(n, times, replace=False)
280
+ idx2 = np.random.choice(n, times, replace=False)
281
+ diffs = activation[idx1] - activation[idx2]
282
+ return float(np.linalg.norm(diffs, axis=1).mean())
283
+
284
+
285
+ def _to_label_tensor3(acts: np.ndarray, labels: List[str]) -> np.ndarray:
286
+ label_to_indices: Dict[str, List[int]] = {}
287
+ for idx, lbl in enumerate(labels):
288
+ label_to_indices.setdefault(lbl, []).append(idx)
289
+ counts = [len(v) for v in label_to_indices.values()]
290
+ if not counts:
291
+ raise ValueError("No labels available for multimodality computation.")
292
+ min_count = max(2, min(counts))
293
+ stacked = []
294
+ for lbl in sorted(label_to_indices.keys()):
295
+ stacked.append(acts[label_to_indices[lbl][:min_count]])
296
+ return np.stack(stacked, axis=0)
297
+
298
+
299
+ def calculate_multimodality_np(activation: np.ndarray, multimodality_times: int = 20) -> float:
300
+ assert activation.ndim == 3
301
+ _, per_label, _ = activation.shape
302
+ if per_label < 2:
303
+ return float("nan")
304
+ times = min(multimodality_times, max(1, per_label - 1))
305
+ first = np.random.choice(per_label, times, replace=False)
306
+ second = np.random.choice(per_label, times, replace=False)
307
+ diffs = activation[:, first] - activation[:, second]
308
+ return float(np.linalg.norm(diffs, axis=2).mean())
309
+
310
+
311
+ # -----------------------------------------------------------------------------
312
+ # Generation helpers
313
+ # -----------------------------------------------------------------------------
314
+ def extract_ids_from_sequence(seq: str) -> List[int]:
315
+ content = seq
316
+ if M_START in seq and M_END in seq:
317
+ content = seq.split(M_START, 1)[-1].split(M_END, 1)[0]
318
+ ids: List[int] = []
319
+ for tok in content.split():
320
+ if tok.startswith("<M") and tok.endswith(">"):
321
+ payload = tok[2:-1]
322
+ if payload.isdigit():
323
+ ids.append(int(payload))
324
+ return ids
325
+
326
+
327
+ def generate_motion_text(model, tokenizer, word: str, device: torch.device) -> str:
328
+ model.eval()
329
+ prompt = f"Instruction: Generate motion for word '{word}' with variant 'unknown'.\nMotion: "
330
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
331
+ with torch.no_grad():
332
+ output = model.generate(
333
+ **inputs,
334
+ max_new_tokens=100,
335
+ do_sample=True,
336
+ temperature=INFERENCE_TEMPERATURE,
337
+ top_k=INFERENCE_TOP_K,
338
+ repetition_penalty=INFERENCE_REPETITION_PENALTY,
339
+ pad_token_id=tokenizer.pad_token_id,
340
+ eos_token_id=tokenizer.convert_tokens_to_ids(M_END),
341
+ )
342
+ decoded = tokenizer.decode(output[0], skip_special_tokens=False)
343
+ if "Motion: " in decoded:
344
+ return decoded.split("Motion: ", 1)[-1].strip()
345
+ return decoded.strip()
346
+
347
+
348
+ # -----------------------------------------------------------------------------
349
+ # Core evaluation
350
+ # -----------------------------------------------------------------------------
351
+ def parse_args() -> argparse.Namespace:
352
+ parser = argparse.ArgumentParser(
353
+ "Evaluate the trained Stage 2 model on an unseen SMPL-X test dataset."
354
+ )
355
+ group = parser.add_mutually_exclusive_group(required=True)
356
+ group.add_argument("--drive-url", type=str, help="Google Drive folder URL to download archives from.")
357
+ group.add_argument("--drive-id", type=str, help="Google Drive folder ID to download archives from.")
358
+ group.add_argument(
359
+ "--local-extracted-dir",
360
+ type=str,
361
+ help="Use an existing directory that already contains extracted `video_data.pkl` files.",
362
+ )
363
+
364
+ parser.add_argument("--max-zips", type=int, default=TEST_EVAL_MAX_ZIPS, help="Maximum number of zip files to extract.")
365
+ parser.add_argument("--download-dir", type=str, default=TEST_EVAL_DOWNLOAD_DIR, help="Directory to store downloaded zips.")
366
+ parser.add_argument("--extract-dir", type=str, default=TEST_EVAL_EXTRACT_DIR, help="Directory to extract archives into.")
367
+
368
+ parser.add_argument("--hf-repo-id", type=str, default=TEST_EVAL_HF_REPO, help="Hugging Face repo containing the Stage 2 checkpoint.")
369
+ parser.add_argument(
370
+ "--hf-subfolder",
371
+ type=str,
372
+ default=TEST_EVAL_HF_SUBFOLDER,
373
+ help="Subfolder inside the repo that hosts the Stage 2 model (e.g., `stage2_v2/epoch-020`).",
374
+ )
375
+
376
+ parser.add_argument("--vqvae-ckpt", type=str, default=None, help="Optional override for VQ-VAE checkpoint path.")
377
+ parser.add_argument("--stats-path", type=str, default=None, help="Optional override for VQ-VAE stats file.")
378
+
379
+ parser.add_argument("--output-dir", type=str, default=TEST_EVAL_OUTPUT_DIR, help="Directory to write metrics JSON.")
380
+ parser.add_argument("--sample-limit", type=int, default=TEST_EVAL_SAMPLE_LIMIT, help="Maximum number of samples to evaluate.")
381
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
382
+ return parser.parse_args()
383
+
384
+
385
+ def run_evaluation(args: argparse.Namespace) -> Dict[str, object]:
386
+ random.seed(args.seed)
387
+ np.random.seed(args.seed)
388
+ torch.manual_seed(args.seed)
389
+
390
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
391
+ os.makedirs(args.output_dir, exist_ok=True)
392
+ metrics_path = os.path.join(args.output_dir, "metrics_test.json")
393
+
394
+ print(f"Loading Stage 2 model from HF: {args.hf_repo_id} (subfolder='{args.hf_subfolder}')")
395
+ tokenizer = AutoTokenizer.from_pretrained(args.hf_repo_id, subfolder=args.hf_subfolder, trust_remote_code=True)
396
+ model = AutoModelForCausalLM.from_pretrained(args.hf_repo_id, subfolder=args.hf_subfolder, trust_remote_code=True)
397
+ if tokenizer.pad_token is None:
398
+ tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
399
+ model.resize_token_embeddings(len(tokenizer))
400
+ model.config.pad_token_id = tokenizer.pad_token_id
401
+ model.to(device)
402
+
403
+ load_vqvae, load_stats, decode_tokens_to_params, DEFAULT_VQ, DEFAULT_STATS = import_visualize_helpers()
404
+ vq_ckpt = args.vqvae_ckpt if args.vqvae_ckpt else os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ)
405
+ stats_path = args.stats_path if args.stats_path else os.getenv("VQVAE_STATS_PATH", DEFAULT_STATS)
406
+ print(f"Loading VQ-VAE from: {vq_ckpt}")
407
+ vq_model = load_vqvae(vq_ckpt, device=device)
408
+ print(f"Loading stats from: {stats_path}")
409
+ mean, std = load_stats(stats_path)
410
+
411
+ extracted_dirs: List[str] = []
412
+ if args.local_extracted_dir:
413
+ if not os.path.isdir(args.local_extracted_dir):
414
+ raise FileNotFoundError(f"Local extracted dir not found: {args.local_extracted_dir}")
415
+ extracted_dirs = [args.local_extracted_dir]
416
+ else:
417
+ folder_ref = args.drive_url if args.drive_url else args.drive_id
418
+ download_drive_folder(folder_ref, args.download_dir)
419
+ zips = list_zip_files(args.download_dir)
420
+ if not zips:
421
+ raise RuntimeError("No zip files found after download.")
422
+ extracted_dirs = extract_zip_files(zips, args.extract_dir, limit=args.max_zips)
423
+
424
+ samples: List[Tuple[str, str]] = []
425
+ for root in extracted_dirs:
426
+ for pkl_path in find_video_pkl_paths(root):
427
+ samples.append((parse_word_from_path(pkl_path), pkl_path))
428
+ if not samples:
429
+ raise RuntimeError("No `video_data.pkl` files discovered in the extracted directories.")
430
+
431
+ random.shuffle(samples)
432
+ samples = samples[: args.sample_limit]
433
+ print(f"Found {len(samples)} samples to evaluate.")
434
+
435
+ gt_features: List[np.ndarray] = []
436
+ gen_features: List[np.ndarray] = []
437
+ labels: List[str] = []
438
+
439
+ for idx, (word, pkl_path) in enumerate(samples, 1):
440
+ params_gt = load_smplx_params_from_pkl(pkl_path)
441
+ if params_gt is None or params_gt.ndim != 2:
442
+ print(f"Skipping {pkl_path}: invalid SMPL-X payload.")
443
+ continue
444
+ try:
445
+ feat_gt = _encode_params_to_feature(params_gt, vq_model, mean, std, device)
446
+ except Exception as exc:
447
+ print(f"Skipping {pkl_path}: encoder failed ({exc}).")
448
+ continue
449
+ if feat_gt is None:
450
+ print(f"Skipping {pkl_path}: empty GT feature.")
451
+ continue
452
+
453
+ gen_text = generate_motion_text(model, tokenizer, word, device)
454
+ token_ids = extract_ids_from_sequence(gen_text)
455
+ if not token_ids:
456
+ print(f"Skipping GEN for '{word}': no motion tokens produced.")
457
+ continue
458
+ try:
459
+ params_gen = decode_tokens_to_params(token_ids, vq_model, mean, std, device=device)
460
+ except Exception as exc:
461
+ print(f"Skipping GEN for '{word}': decode failed ({exc}).")
462
+ continue
463
+ feat_gen = _encode_params_to_feature(params_gen, vq_model, mean, std, device)
464
+ if feat_gen is None:
465
+ print(f"Skipping GEN for '{word}': empty GEN feature.")
466
+ continue
467
+
468
+ gt_features.append(feat_gt)
469
+ gen_features.append(feat_gen)
470
+ labels.append(word)
471
+ if idx % 25 == 0:
472
+ print(f"Processed {idx} samples...")
473
+
474
+ if len(gt_features) < 5 or len(gen_features) < 5:
475
+ print("⚠️ Not enough samples to compute stable metrics; results may be noisy.")
476
+
477
+ gt_feats = np.stack(gt_features, axis=0)
478
+ gen_feats = np.stack(gen_features, axis=0)
479
+
480
+ diversity_gt = calculate_diversity_np(gt_feats, diversity_times=min(200, max(4, gt_feats.shape[0] - 1)))
481
+ diversity_gen = calculate_diversity_np(gen_feats, diversity_times=min(200, max(4, gen_feats.shape[0] - 1)))
482
+
483
+ try:
484
+ gt_lbl_tensor = _to_label_tensor3(gt_feats, labels)
485
+ gen_lbl_tensor = _to_label_tensor3(gen_feats, labels)
486
+ mim_gt = calculate_multimodality_np(
487
+ gt_lbl_tensor, multimodality_times=min(20, max(3, gt_lbl_tensor.shape[1] - 1))
488
+ )
489
+ mim_gen = calculate_multimodality_np(
490
+ gen_lbl_tensor, multimodality_times=min(20, max(3, gen_lbl_tensor.shape[1] - 1))
491
+ )
492
+ except Exception as exc:
493
+ print(f"⚠️ Multimodality could not be computed reliably: {exc}")
494
+ mim_gt = float("nan")
495
+ mim_gen = float("nan")
496
+
497
+ mu_gen, cov_gen = calculate_activation_statistics_np(gen_feats)
498
+ mu_gt, cov_gt = calculate_activation_statistics_np(gt_feats)
499
+ fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen)
500
+
501
+ metrics_payload = {
502
+ "source": "test_raw_smplx_encoder_features",
503
+ "counts": {
504
+ "samples_total": len(samples),
505
+ "samples_used": int(gt_feats.shape[0]),
506
+ },
507
+ "fid": fid,
508
+ "diversity": {
509
+ "ground_truth": diversity_gt,
510
+ "model": diversity_gen,
511
+ },
512
+ "multimodality": {
513
+ "ground_truth": mim_gt,
514
+ "model": mim_gen,
515
+ },
516
+ }
517
+ with open(metrics_path, "w", encoding="utf-8") as handle:
518
+ json.dump(metrics_payload, handle, ensure_ascii=False, indent=2)
519
+ print(f"\n✅ Saved test metrics to {metrics_path}")
520
+ return metrics_payload
521
+
522
+
523
+ def main() -> None:
524
+ args = parse_args()
525
+ try:
526
+ run_evaluation(args)
527
+ except Exception as exc:
528
+ print(f"Evaluation failed: {exc}")
529
+ sys.exit(1)
530
+
531
+
532
+ if __name__ == "__main__":
533
+ main()
534
+
test_overfit.py ADDED
@@ -0,0 +1,1562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import re
4
+ import json
5
+ import random
6
+ from typing import Dict, List, Tuple, Any, Optional
7
+ import shutil
8
+ from datetime import datetime
9
+ import time
10
+
11
+ import torch
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+ from torch.optim import AdamW
15
+ from huggingface_hub import HfApi, upload_folder, hf_hub_download
16
+
17
+ import numpy as np
18
+ import scipy.linalg
19
+ # ======================================================================================
20
+ # 0. Configuration
21
+ # ======================================================================================
22
+ # --- Paths and Words ---
23
+ DATASET_PATH = "/content/SignMotionGPT/data/motion_llm_dataset.json"
24
+ MODEL_NAME = "Qwen/Qwen3-0.6B"
25
+ # We will train on the full dataset, but use these words for our final evaluation
26
+ EVALUATION_WORDS = ["passport", "send", "library", "push"]
27
+ OUTPUT_DIR = "./motion_gpt_full_model"
28
+
29
+ # --- Evaluation controls ---
30
+ # If True: after training, only compute metrics (FID, Diversity, MIM) and save to JSON.
31
+ # Skip per-sample inference logs and HTML visualizations.
32
+ # If False: run the existing flow and also compute these 3 metrics.
33
+ RUN_EVALS_ONLY = False
34
+ EVAL_SAMPLE_LIMIT = 100
35
+ METRICS_JSON_PATH = ""
36
+
37
+ # --- Training Hyperparameters ---
38
+ # NOTE: Training on the full dataset will take longer.
39
+ # These epochs are a starting point.
40
+ S1_EPOCHS = 20
41
+ S1_LR = 5e-5
42
+ S1_BATCH_SIZE = 8 # Kept small for Colab VRAM
43
+
44
+ S2_EPOCHS = 20
45
+ S2_LR = 2e-5
46
+ S2_BATCH_SIZE = 8
47
+
48
+ # --- Inference Hyperparameters ---
49
+ INFERENCE_REPETITION_PENALTY = 1.2
50
+ INFERENCE_TEMPERATURE = 0.7
51
+ INFERENCE_TOP_K = 50
52
+
53
+ # --- Special Tokens ---
54
+ M_START = "<M_START>"
55
+ M_END = "<M_END>"
56
+ PAD_TOKEN = "<PAD>"
57
+
58
+ # --- Hugging Face Hub Configuration ---
59
+ # Provide HUGGINGFACE_HUB_TOKEN or hf_auth_token in environment for private repos.
60
+ HF_USE_HUB = True
61
+ hf_auth_token = os.getenv("hf_auth_token")
62
+ if hf_auth_token is None:
63
+ raise ValueError("hf_auth_token environment variable is not set")
64
+ HF_STAGE1_REPO_ID = "rdz-falcon/SignMotionGPTfit-archive"
65
+ HF_STAGE2_REPO_ID = "rdz-falcon/SignMotionGPTfit-archive"
66
+ HF_PRIVATE_REPO = os.environ.get("HF_PRIVATE", "true").lower() != "false"
67
+ FORCE_STAGE2_FROM_STAGE1_RAW = os.environ.get("FORCE_STAGE2_FROM_STAGE1", "false")
68
+ FORCE_STAGE2_FROM_STAGE1 = str(FORCE_STAGE2_FROM_STAGE1_RAW).strip().lower() not in ("0", "false", "no", "off")
69
+ # Save Stage 2 checkpoints to a new subfolder so old stage2 checkpoints remain intact
70
+ HF_STAGE2_SAVE_SUBDIR = os.environ.get("HF_STAGE2_SAVE_SUBDIR", "stage2_v2")
71
+
72
+ # --- Local Checkpoint Root ---
73
+ CHECKPOINTS_DIR = ""
74
+
75
+ # --- Upload frequency and progress control ---
76
+ # Push to Hugging Face only every N epochs (still save locally every epoch)
77
+ CHECKPOINT_UPLOAD_INTERVAL_EPOCHS = int(os.environ.get("HF_UPLOAD_INTERVAL_EPOCHS", "2"))
78
+ # Disable HF Hub progress bars to reduce noisy logs (set HF_DISABLE_PROGRESS=false to re-enable)
79
+ HF_DISABLE_PROGRESS = os.environ.get("HF_DISABLE_PROGRESS", "true").lower() != "false"
80
+
81
+
82
+ def _refresh_runtime_paths() -> None:
83
+ """Refresh derived paths when OUTPUT_DIR changes."""
84
+ global METRICS_JSON_PATH, CHECKPOINTS_DIR
85
+ METRICS_JSON_PATH = os.path.join(OUTPUT_DIR, "metrics.json")
86
+ CHECKPOINTS_DIR = os.path.join(OUTPUT_DIR, "checkpoints")
87
+
88
+
89
+ def _apply_progress_setting() -> None:
90
+ """Apply huggingface_hub progress bar preference."""
91
+ if HF_DISABLE_PROGRESS:
92
+ try:
93
+ # Also respected by huggingface_hub internal progress usage
94
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
95
+ from huggingface_hub.utils import disable_progress_bars # type: ignore
96
+
97
+ disable_progress_bars()
98
+ except Exception:
99
+ pass
100
+ else:
101
+ os.environ.pop("HF_HUB_DISABLE_PROGRESS_BARS", None)
102
+
103
+
104
+ def apply_config_overrides(overrides: Optional[Dict[str, Any]] = None) -> None:
105
+ """
106
+ Allow external callers to override module-level configuration prior to running main().
107
+ """
108
+ global hf_auth_token, HF_DISABLE_PROGRESS, OUTPUT_DIR
109
+ if not overrides:
110
+ return
111
+
112
+ updated_paths = False
113
+ progress_flag_updated = False
114
+ for key, value in overrides.items():
115
+ if key == "hf_auth_token":
116
+ hf_auth_token = value
117
+ continue
118
+ if key not in globals():
119
+ print(f"[config] Unknown override ignored: {key}")
120
+ continue
121
+ globals()[key] = value
122
+ if key == "OUTPUT_DIR":
123
+ updated_paths = True
124
+ if key == "HF_DISABLE_PROGRESS":
125
+ progress_flag_updated = True
126
+ if updated_paths:
127
+ _refresh_runtime_paths()
128
+ if progress_flag_updated:
129
+ _apply_progress_setting()
130
+
131
+
132
+ _refresh_runtime_paths()
133
+ _apply_progress_setting()
134
+
135
+
136
+ # ======================================================================================
137
+ # 1. Data Loading and Preparation (NEW & IMPROVED)
138
+ # ======================================================================================
139
+ def read_json_data(json_path: str) -> List[Dict[str, Any]]:
140
+ """Loads the dataset from the specified JSON file."""
141
+ if not os.path.exists(json_path):
142
+ raise FileNotFoundError(f"Dataset not found at: {json_path}")
143
+ with open(json_path, "r", encoding="utf-8") as f:
144
+ return json.load(f)
145
+
146
+ def deduplicate_and_prepare_data(entries: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[str]]:
147
+ """
148
+ Cleans the entire dataset by ensuring each (word, participant_id) pair is unique.
149
+ If a conflict is found (same pair, different motion), it keeps only the first one encountered.
150
+ Then, it prepares the full list of motion tokens from the cleaned data.
151
+ """
152
+ print("\n---> Cleaning dataset by removing ambiguous (word, participant_id) pairs...")
153
+
154
+ unique_samples = {}
155
+ conflicts_found = 0
156
+
157
+ for entry in entries:
158
+ word = entry.get("word", "").lower()
159
+ pid = entry.get("participant_id", "")
160
+ key = (word, pid)
161
+
162
+ if key not in unique_samples:
163
+ unique_samples[key] = entry
164
+ else:
165
+ # A sample for this key already exists. We only care if it's a conflict.
166
+ existing_tokens = unique_samples[key].get("motion_tokens")
167
+ current_tokens = entry.get("motion_tokens")
168
+ if existing_tokens != current_tokens:
169
+ conflicts_found += 1
170
+ # We do nothing, effectively discarding this new conflicting sample.
171
+
172
+ cleaned_data = list(unique_samples.values())
173
+
174
+ print(f"Original samples: {len(entries)}")
175
+ print(f"Cleaned samples (unique (word, pid) pairs): {len(cleaned_data)}")
176
+ print(f"Removed {len(entries) - len(cleaned_data)} total samples. ({conflicts_found} were direct conflicts).")
177
+
178
+ print("\n---> Extracting motion tokens from the full cleaned dataset...")
179
+ all_motion_tokens = set()
180
+ for entry in cleaned_data:
181
+ motion_tokens = entry.get("motion_tokens", "").strip().split()
182
+ for token in motion_tokens:
183
+ all_motion_tokens.add(f"<M{token}>")
184
+
185
+ unique_tokens = sorted(list(all_motion_tokens))
186
+ print(f"Found {len(unique_tokens)} unique motion tokens in the entire dataset.")
187
+
188
+ return cleaned_data, unique_tokens
189
+
190
+ # ======================================================================================
191
+ # 2. Model and Tokenizer Setup
192
+ # ======================================================================================
193
+ def setup_model_and_tokenizer(model_name: str, motion_tokens: List[str]) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
194
+ """Loads the model and tokenizer, adding special and motion tokens."""
195
+ print(f"\n---> Loading base model and tokenizer: {model_name}")
196
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
197
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
198
+
199
+ tokenizer.add_special_tokens({"pad_token": PAD_TOKEN, "additional_special_tokens": [M_START, M_END]})
200
+
201
+ print(f"Adding {len(motion_tokens)} motion tokens to the tokenizer.")
202
+ tokenizer.add_tokens(motion_tokens, special_tokens=True)
203
+
204
+ model.resize_token_embeddings(len(tokenizer))
205
+ model.config.pad_token_id = tokenizer.pad_token_id
206
+
207
+ return model, tokenizer
208
+
209
+ # ======================================================================================
210
+ # 2b. Hugging Face Hub Utilities and Checkpointing
211
+ # ======================================================================================
212
+ def _format_seconds(seconds: float) -> str:
213
+ """Formats seconds into H:MM:SS or M:SS."""
214
+ seconds = int(max(0, seconds))
215
+ h = seconds // 3600
216
+ m = (seconds % 3600) // 60
217
+ s = seconds % 60
218
+ if h > 0:
219
+ return f"{h:d}:{m:02d}:{s:02d}"
220
+ return f"{m:d}:{s:02d}"
221
+
222
+ def _ensure_dir(path: str) -> None:
223
+ os.makedirs(path, exist_ok=True)
224
+
225
+ def _resolve_and_ensure_repo(repo_id: str) -> Optional[str]:
226
+ """
227
+ Ensures the HF repo exists. Returns the fully-qualified repo_id (namespace/repo)
228
+ when token is available; otherwise returns the input repo_id.
229
+ """
230
+ if not HF_USE_HUB:
231
+ return None
232
+ if hf_auth_token is None:
233
+ print("⚠️ HF token not found. Set HUGGINGFACE_HUB_TOKEN or hf_auth_token to enable Hub sync.")
234
+ return None
235
+ api = HfApi()
236
+ try:
237
+ who = api.whoami(token=hf_auth_token)
238
+ namespace = who.get("name") or (who.get("orgs", [None])[0] if isinstance(who.get("orgs"), list) else None)
239
+ except Exception as exc:
240
+ print(f"⚠️ Unable to resolve HF namespace: {exc}")
241
+ namespace = None
242
+ if "/" not in repo_id and namespace:
243
+ full_repo_id = f"{namespace}/{repo_id}"
244
+ else:
245
+ full_repo_id = repo_id
246
+ try:
247
+ api.create_repo(
248
+ repo_id=full_repo_id,
249
+ token=hf_auth_token,
250
+ repo_type="model",
251
+ private=HF_PRIVATE_REPO,
252
+ exist_ok=True,
253
+ )
254
+ except Exception as exc:
255
+ print(f"⚠️ create_repo failed (may already exist): {exc}")
256
+ return full_repo_id
257
+
258
+ def _repo_has_stage_latest(repo_id: str, stage: str) -> bool:
259
+ """Checks if a stage/latest checkpoint exists in the HF repo."""
260
+ if not HF_USE_HUB or hf_auth_token is None:
261
+ return False
262
+ api = HfApi()
263
+ try:
264
+ files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=hf_auth_token)
265
+ return any(path.startswith(f"{stage}/latest/") and path.endswith("config.json") for path in files)
266
+ except Exception as exc:
267
+ print(f"⚠️ Could not list files for {repo_id}: {exc}")
268
+ return False
269
+
270
+ def _repo_list_epoch_numbers(repo_id: str, stage: str) -> List[int]:
271
+ """
272
+ Returns sorted list of epoch numbers available under {stage}/epoch-XXX/ by scanning files.
273
+ Works even if 'latest' does not exist.
274
+ """
275
+ if not HF_USE_HUB or hf_auth_token is None:
276
+ return []
277
+ api = HfApi()
278
+ try:
279
+ files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=hf_auth_token)
280
+ except Exception as exc:
281
+ print(f"⚠️ Could not list files for {repo_id}: {exc}")
282
+ return []
283
+ epoch_numbers: List[int] = []
284
+ pattern = re.compile(rf"^{re.escape(stage)}/epoch-(\d+)/config\.json$")
285
+ for path in files:
286
+ m = pattern.match(path)
287
+ if m:
288
+ try:
289
+ epoch_numbers.append(int(m.group(1)))
290
+ except ValueError:
291
+ pass
292
+ return sorted(set(epoch_numbers))
293
+
294
+ def _repo_get_latest_epoch_subfolder(repo_id: str, stage: str) -> Optional[str]:
295
+ """
296
+ Returns subfolder path like '{stage}/epoch-XXX' for the highest available epoch, or None.
297
+ """
298
+ epochs = _repo_list_epoch_numbers(repo_id, stage)
299
+ if not epochs:
300
+ return None
301
+ latest = max(epochs)
302
+ return f"{stage}/epoch-{latest:03d}"
303
+
304
+ def _load_model_and_tokenizer_from_hf_subfolder(repo_id: str, subfolder: str) -> Optional[Tuple[AutoModelForCausalLM, AutoTokenizer]]:
305
+ """
306
+ Loads model and tokenizer from HF under a specific subfolder (e.g., 'stage1/epoch-020').
307
+ """
308
+ if not HF_USE_HUB or hf_auth_token is None:
309
+ return None
310
+ print(f"\n---> Loading checkpoint from Hugging Face: {repo_id} (subfolder='{subfolder}')")
311
+ try:
312
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder, trust_remote_code=True)
313
+ model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=subfolder, trust_remote_code=True)
314
+ except Exception as exc:
315
+ print(f"⚠️ Failed to load model/tokenizer from subfolder '{subfolder}': {exc}")
316
+ return None
317
+ if tokenizer.pad_token is None:
318
+ tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
319
+ model.resize_token_embeddings(len(tokenizer))
320
+ model.config.pad_token_id = tokenizer.pad_token_id
321
+ return model, tokenizer
322
+
323
+ def _download_training_state_from_subfolder(repo_id: str, subfolder: str) -> Optional[Dict[str, Any]]:
324
+ """
325
+ Downloads training_state.json from a specific subfolder (e.g., 'stage1/epoch-020').
326
+ """
327
+ if not HF_USE_HUB or hf_auth_token is None:
328
+ return None
329
+ try:
330
+ state_path = hf_hub_download(
331
+ repo_id=repo_id,
332
+ filename=f"{subfolder}/training_state.json",
333
+ repo_type="model",
334
+ token=hf_auth_token,
335
+ )
336
+ with open(state_path, "r", encoding="utf-8") as f:
337
+ return json.load(f)
338
+ except Exception:
339
+ return None
340
+
341
+ def _download_training_state(repo_id: str, stage: str) -> Optional[Dict[str, Any]]:
342
+ """Downloads training_state.json from HF if present."""
343
+ if not HF_USE_HUB or hf_auth_token is None:
344
+ return None
345
+ try:
346
+ state_path = hf_hub_download(
347
+ repo_id=repo_id,
348
+ filename=f"{stage}/latest/training_state.json",
349
+ repo_type="model",
350
+ token=hf_auth_token,
351
+ )
352
+ with open(state_path, "r", encoding="utf-8") as f:
353
+ return json.load(f)
354
+ except Exception:
355
+ return None
356
+
357
+ def _download_optimizer_state(repo_id: str, stage: str) -> Optional[str]:
358
+ """Downloads optimizer.pt for resuming optimizer state."""
359
+ if not HF_USE_HUB or hf_auth_token is None:
360
+ return None
361
+ try:
362
+ opt_path = hf_hub_download(
363
+ repo_id=repo_id,
364
+ filename=f"{stage}/latest/optimizer.pt",
365
+ repo_type="model",
366
+ token=hf_auth_token,
367
+ )
368
+ return opt_path
369
+ except Exception:
370
+ return None
371
+
372
+ def _load_model_and_tokenizer_from_hf(repo_id: str, stage: str) -> Optional[Tuple[AutoModelForCausalLM, AutoTokenizer]]:
373
+ """
374
+ Loads model and tokenizer from HF under subfolder {stage}/latest if available.
375
+ """
376
+ if not _repo_has_stage_latest(repo_id, stage):
377
+ return None
378
+ print(f"\n---> Loading checkpoint from Hugging Face: {repo_id} (subfolder='{stage}/latest')")
379
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=f"{stage}/latest", trust_remote_code=True)
380
+ model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=f"{stage}/latest", trust_remote_code=True)
381
+ if tokenizer.pad_token is None:
382
+ tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
383
+ model.resize_token_embeddings(len(tokenizer))
384
+ model.config.pad_token_id = tokenizer.pad_token_id
385
+ return model, tokenizer
386
+
387
+ def _ensure_tokenizer_has_motion_tokens(tokenizer: AutoTokenizer, motion_tokens: List[str]) -> int:
388
+ """
389
+ Adds any missing motion tokens to the tokenizer. Returns number of tokens added.
390
+ """
391
+ tokenizer.add_special_tokens({"pad_token": PAD_TOKEN, "additional_special_tokens": [M_START, M_END]})
392
+ added = tokenizer.add_tokens(motion_tokens, special_tokens=True)
393
+ return added
394
+
395
+ def _save_and_push_checkpoint(
396
+ stage: str,
397
+ epoch_index_zero_based: int,
398
+ model: AutoModelForCausalLM,
399
+ tokenizer: AutoTokenizer,
400
+ optimizer: AdamW,
401
+ avg_loss: float,
402
+ dataloader_len: int,
403
+ batch_size: int,
404
+ total_epochs: int,
405
+ repo_id: Optional[str],
406
+ ) -> None:
407
+ """
408
+ Saves checkpoint locally (per-epoch and latest) and pushes to HF under:
409
+ - {stage}/epoch-XXX
410
+ - {stage}/latest
411
+ Also saves optimizer state and training_state.json to preserve resume info.
412
+ """
413
+ epoch_number = epoch_index_zero_based + 1
414
+ stage_dir = os.path.join(CHECKPOINTS_DIR, stage)
415
+ epoch_dir_name = f"epoch-{epoch_number:03d}"
416
+ epoch_dir = os.path.join(stage_dir, epoch_dir_name)
417
+ latest_dir = os.path.join(stage_dir, "latest")
418
+ _ensure_dir(epoch_dir)
419
+ _ensure_dir(stage_dir)
420
+
421
+ # Save model + tokenizer
422
+ model.save_pretrained(epoch_dir)
423
+ tokenizer.save_pretrained(epoch_dir)
424
+
425
+ # Save optimizer state
426
+ torch.save(optimizer.state_dict(), os.path.join(epoch_dir, "optimizer.pt"))
427
+
428
+ # Save training state
429
+ training_state = {
430
+ "stage": stage,
431
+ "epoch_completed": epoch_number,
432
+ "total_epochs_for_stage": total_epochs,
433
+ "global_step": epoch_number * dataloader_len,
434
+ "avg_loss": float(avg_loss),
435
+ "batch_size": batch_size,
436
+ "saved_at": datetime.utcnow().isoformat() + "Z",
437
+ }
438
+ with open(os.path.join(epoch_dir, "training_state.json"), "w", encoding="utf-8") as f:
439
+ json.dump(training_state, f, ensure_ascii=False, indent=2)
440
+
441
+ # Update "latest"
442
+ if os.path.exists(latest_dir):
443
+ shutil.rmtree(latest_dir)
444
+ shutil.copytree(epoch_dir, latest_dir)
445
+
446
+ # Push to Hugging Face
447
+ if HF_USE_HUB and repo_id and hf_auth_token:
448
+ try:
449
+ upload_folder(
450
+ repo_id=repo_id,
451
+ folder_path=epoch_dir,
452
+ path_in_repo=f"{stage}/{epoch_dir_name}",
453
+ repo_type="model",
454
+ token=hf_auth_token,
455
+ commit_message=f"{stage}: save {epoch_dir_name}",
456
+ )
457
+ upload_folder(
458
+ repo_id=repo_id,
459
+ folder_path=latest_dir,
460
+ path_in_repo=f"{stage}/latest",
461
+ repo_type="model",
462
+ token=hf_auth_token,
463
+ commit_message=f"{stage}: update latest -> {epoch_dir_name}",
464
+ )
465
+ print(f"☁️ Pushed checkpoint to HF: {repo_id} ({stage}/{epoch_dir_name} and {stage}/latest)")
466
+ except Exception as exc:
467
+ print(f"⚠️ Failed to push checkpoint to HF: {exc}")
468
+ else:
469
+ print("ℹ️ Skipped HF push (Hub disabled or token/repo missing).")
470
+
471
+ # ======================================================================================
472
+ # 3. Training Stage 1: Motion Language Modeling
473
+ # ======================================================================================
474
+ class MotionDataset(Dataset):
475
+ """Dataset for Stage 1: Contains only motion token sequences."""
476
+ def __init__(self, data: List[Dict[str, Any]], tokenizer: AutoTokenizer, max_length: int = 256):
477
+ self.tokenizer = tokenizer
478
+ self.max_length = max_length
479
+ self.sequences = []
480
+
481
+ for item in data:
482
+ tokens_str = item.get("motion_tokens", "")
483
+ wrapped_tokens = " ".join([f"<M{t}>" for t in tokens_str.split()])
484
+ full_sequence = f"{M_START} {wrapped_tokens} {M_END}"
485
+ self.sequences.append(full_sequence)
486
+
487
+ def __len__(self):
488
+ return len(self.sequences)
489
+
490
+ def __getitem__(self, idx):
491
+ return self.tokenizer(
492
+ self.sequences[idx],
493
+ truncation=True,
494
+ max_length=self.max_length,
495
+ padding="max_length",
496
+ return_tensors="pt"
497
+ )
498
+
499
+ def train_stage1(
500
+ model,
501
+ tokenizer,
502
+ data,
503
+ device,
504
+ start_epoch: int = 0,
505
+ hf_repo_id: Optional[str] = None,
506
+ ):
507
+ """Trains the model on motion sequences only to learn the 'language of motion'.
508
+ Resumes from Hugging Face if available (model/tokenizer/optimizer)."""
509
+ print("\n" + "="*80)
510
+ print(" STAGE 1: MOTION LANGUAGE MODELING (PRE-TRAINING)")
511
+ print(f" Training on {len(data)} samples.")
512
+ print("="*80)
513
+
514
+ dataset = MotionDataset(data, tokenizer)
515
+ dataloader = DataLoader(dataset, batch_size=S1_BATCH_SIZE, shuffle=True)
516
+
517
+ optimizer = AdamW(model.parameters(), lr=S1_LR)
518
+ model.to(device)
519
+ model.train()
520
+
521
+ # Try to resume optimizer if we resumed from HF
522
+ if hf_repo_id and start_epoch > 0 and HF_USE_HUB and hf_auth_token:
523
+ opt_path = _download_optimizer_state(hf_repo_id, "stage1")
524
+ if opt_path is not None:
525
+ try:
526
+ optimizer.load_state_dict(torch.load(opt_path, map_location=device))
527
+ print("↩️ Resumed optimizer state for Stage 1 from HF.")
528
+ except Exception as exc:
529
+ print(f"⚠️ Failed to load optimizer state for Stage 1: {exc}")
530
+
531
+ for epoch in range(start_epoch, S1_EPOCHS):
532
+ total_loss = 0
533
+ total_batches = len(dataloader)
534
+ epoch_start_time = time.time()
535
+ step_interval = max(1, total_batches // 50) # ~2% progress updates
536
+ for i, batch in enumerate(dataloader, 1):
537
+ optimizer.zero_grad()
538
+
539
+ input_ids = batch['input_ids'].squeeze(1).to(device)
540
+ attention_mask = batch['attention_mask'].squeeze(1).to(device)
541
+
542
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
543
+
544
+ loss = outputs.loss
545
+ loss.backward()
546
+ optimizer.step()
547
+ total_loss += loss.item()
548
+
549
+ # Progress with ETA
550
+ if i == 1 or (i % step_interval == 0) or (i == total_batches):
551
+ elapsed = time.time() - epoch_start_time
552
+ est_total = (elapsed / i) * total_batches
553
+ eta = est_total - elapsed
554
+ pct = (i / total_batches) * 100.0
555
+ print(
556
+ f"\r[Stage 1] Epoch {epoch+1}/{S1_EPOCHS} - "
557
+ f"{i}/{total_batches} ({pct:.1f}%) - ETA {_format_seconds(eta)}",
558
+ end="",
559
+ flush=True,
560
+ )
561
+
562
+ # Finish the progress line
563
+ print()
564
+ avg_loss = total_loss / len(dataloader)
565
+ print(f"--- End of Epoch {epoch+1}/{S1_EPOCHS}, Average Loss: {avg_loss:.4f} ---")
566
+ # Save checkpoint locally every epoch; push to HF only at interval or final epoch
567
+ push_this_epoch = ((epoch + 1) % CHECKPOINT_UPLOAD_INTERVAL_EPOCHS == 0) or ((epoch + 1) == S1_EPOCHS)
568
+ repo_for_epoch = hf_repo_id if push_this_epoch else None
569
+ _save_and_push_checkpoint(
570
+ stage="stage1",
571
+ epoch_index_zero_based=epoch,
572
+ model=model,
573
+ tokenizer=tokenizer,
574
+ optimizer=optimizer,
575
+ avg_loss=avg_loss,
576
+ dataloader_len=len(dataloader),
577
+ batch_size=S1_BATCH_SIZE,
578
+ total_epochs=S1_EPOCHS,
579
+ repo_id=repo_for_epoch,
580
+ )
581
+
582
+ print("\n✅ Stage 1 Training Complete.")
583
+ return model
584
+
585
+ # ======================================================================================
586
+ # 4. Training Stage 2: Text-to-Motion Fine-Tuning
587
+ # ======================================================================================
588
+ class TextMotionDataset(Dataset):
589
+ """Dataset for Stage 2: Contains (prompt, motion_sequence) pairs."""
590
+ def __init__(self, data: List[Dict[str, Any]], tokenizer: AutoTokenizer, max_length: int = 256):
591
+ self.tokenizer = tokenizer
592
+ self.max_length = max_length
593
+ self.items = []
594
+
595
+ for item in data:
596
+ prompt = f"Instruction: Generate motion for word '{item['word']}' with variant '{item['participant_id']}'.\nMotion: "
597
+
598
+ tokens_str = item.get("motion_tokens", "")
599
+ wrapped_tokens = " ".join([f"<M{t}>" for t in tokens_str.split()])
600
+ target_sequence = f"{M_START} {wrapped_tokens} {M_END}"
601
+
602
+ full_text = prompt + target_sequence
603
+
604
+ tokenized = self.tokenizer(
605
+ full_text,
606
+ truncation=True,
607
+ max_length=self.max_length,
608
+ padding="max_length",
609
+ return_tensors="pt"
610
+ )
611
+
612
+ prompt_tokenized = self.tokenizer(prompt, return_tensors="pt")
613
+ prompt_len = prompt_tokenized.input_ids.shape[1]
614
+
615
+ labels = tokenized['input_ids'].clone()
616
+ labels[0, :prompt_len] = -100
617
+
618
+ self.items.append({
619
+ "input_ids": tokenized['input_ids'].squeeze(0),
620
+ "attention_mask": tokenized['attention_mask'].squeeze(0),
621
+ "labels": labels.squeeze(0)
622
+ })
623
+
624
+ def __len__(self):
625
+ return len(self.items)
626
+
627
+ def __getitem__(self, idx):
628
+ return self.items[idx]
629
+
630
+ def train_stage2(
631
+ model,
632
+ tokenizer,
633
+ data,
634
+ device,
635
+ start_epoch: int = 0,
636
+ hf_repo_id: Optional[str] = None,
637
+ hf_stage_subdir: str = "stage2",
638
+ ):
639
+ """Fine-tunes the motion-aware model to connect text prompts to motions.
640
+ Resumes from Hugging Face if available (model/tokenizer/optimizer)."""
641
+ print("\n" + "="*80)
642
+ print(" STAGE 2: TEXT-TO-MOTION FINE-TUNING")
643
+ print(f" Training on {len(data)} samples.")
644
+ print("="*80)
645
+
646
+ dataset = TextMotionDataset(data, tokenizer)
647
+ dataloader = DataLoader(dataset, batch_size=S2_BATCH_SIZE, shuffle=True)
648
+
649
+ optimizer = AdamW(model.parameters(), lr=S2_LR)
650
+ model.to(device)
651
+ model.train()
652
+
653
+ # Try to resume optimizer if we resumed from HF
654
+ if hf_repo_id and start_epoch > 0 and HF_USE_HUB and hf_auth_token:
655
+ opt_path = _download_optimizer_state(hf_repo_id, hf_stage_subdir)
656
+ if opt_path is not None:
657
+ try:
658
+ optimizer.load_state_dict(torch.load(opt_path, map_location=device))
659
+ print("↩️ Resumed optimizer state for Stage 2 from HF.")
660
+ except Exception as exc:
661
+ print(f"⚠️ Failed to load optimizer state for Stage 2: {exc}")
662
+
663
+ for epoch in range(start_epoch, S2_EPOCHS):
664
+ total_loss = 0
665
+ total_batches = len(dataloader)
666
+ epoch_start_time = time.time()
667
+ step_interval = max(1, total_batches // 50) # ~2% progress updates
668
+ for i, batch in enumerate(dataloader, 1):
669
+ optimizer.zero_grad()
670
+
671
+ input_ids = batch['input_ids'].to(device)
672
+ attention_mask = batch['attention_mask'].to(device)
673
+ labels = batch['labels'].to(device)
674
+
675
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
676
+
677
+ loss = outputs.loss
678
+ loss.backward()
679
+ optimizer.step()
680
+ total_loss += loss.item()
681
+
682
+ # Progress with ETA
683
+ if i == 1 or (i % step_interval == 0) or (i == total_batches):
684
+ elapsed = time.time() - epoch_start_time
685
+ est_total = (elapsed / i) * total_batches
686
+ eta = est_total - elapsed
687
+ pct = (i / total_batches) * 100.0
688
+ print(
689
+ f"\r[Stage 2] Epoch {epoch+1}/{S2_EPOCHS} - "
690
+ f"{i}/{total_batches} ({pct:.1f}%) - ETA {_format_seconds(eta)}",
691
+ end="",
692
+ flush=True,
693
+ )
694
+
695
+ # Finish the progress line
696
+ print()
697
+ avg_loss = total_loss / len(dataloader)
698
+ print(f"--- End of Epoch {epoch+1}/{S2_EPOCHS}, Average Loss: {avg_loss:.4f} ---")
699
+ # Save checkpoint locally every epoch; push to HF only at interval or final epoch
700
+ push_this_epoch = ((epoch + 1) % CHECKPOINT_UPLOAD_INTERVAL_EPOCHS == 0) or ((epoch + 1) == S2_EPOCHS)
701
+ repo_for_epoch = hf_repo_id if push_this_epoch else None
702
+ _save_and_push_checkpoint(
703
+ stage=hf_stage_subdir,
704
+ epoch_index_zero_based=epoch,
705
+ model=model,
706
+ tokenizer=tokenizer,
707
+ optimizer=optimizer,
708
+ avg_loss=avg_loss,
709
+ dataloader_len=len(dataloader),
710
+ batch_size=S2_BATCH_SIZE,
711
+ total_epochs=S2_EPOCHS,
712
+ repo_id=repo_for_epoch,
713
+ )
714
+
715
+ print("\n✅ Stage 2 Training Complete.")
716
+ if not os.path.exists(OUTPUT_DIR):
717
+ os.makedirs(OUTPUT_DIR)
718
+ model.save_pretrained(OUTPUT_DIR)
719
+ tokenizer.save_pretrained(OUTPUT_DIR)
720
+ print(f"Model saved to {OUTPUT_DIR}")
721
+ return model
722
+
723
+ # ======================================================================================
724
+ # 5. Inference and Comparison
725
+ # ======================================================================================
726
+ def generate_motion(model, tokenizer, prompt, device):
727
+ """Generates a motion sequence from a prompt using sampling."""
728
+ model.eval()
729
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
730
+
731
+ with torch.no_grad():
732
+ output = model.generate(
733
+ **inputs,
734
+ max_new_tokens=100,
735
+ do_sample=True,
736
+ temperature=INFERENCE_TEMPERATURE,
737
+ top_k=INFERENCE_TOP_K,
738
+ repetition_penalty=INFERENCE_REPETITION_PENALTY,
739
+ pad_token_id=tokenizer.pad_token_id,
740
+ eos_token_id=tokenizer.convert_tokens_to_ids(M_END),
741
+ early_stopping=True
742
+ )
743
+
744
+ decoded = tokenizer.decode(output[0], skip_special_tokens=False)
745
+ motion_part = decoded.split("Motion: ")[-1]
746
+ return motion_part.strip()
747
+
748
+ def compare_sequences(gt: str, gen: str):
749
+ """Provides a simple visual diff of two sequences without external libraries."""
750
+ gt_tokens = gt.split()
751
+ gen_tokens = gen.split()
752
+
753
+ print("\nDetailed Comparison (✅ = Match, ❌ = Mismatch/Missing/Added):")
754
+
755
+ gt_str = " GT: "
756
+ gen_str = " GEN: "
757
+ diff_str = " "
758
+
759
+ max_len = max(len(gt_tokens), len(gen_tokens))
760
+
761
+ for i in range(max_len):
762
+ gt_tok = gt_tokens[i] if i < len(gt_tokens) else "___"
763
+ gen_tok = gen_tokens[i] if i < len(gen_tokens) else "___"
764
+
765
+ max_tok_len = max(len(gt_tok), len(gen_tok))
766
+ gt_tok_padded = gt_tok.ljust(max_tok_len)
767
+ gen_tok_padded = gen_tok.ljust(max_tok_len)
768
+
769
+ gt_str += gt_tok_padded + " "
770
+ gen_str += gen_tok_padded + " "
771
+
772
+ if gt_tok == gen_tok:
773
+ diff_str += "✅".ljust(max_tok_len) + " "
774
+ else:
775
+ diff_str += "❌".ljust(max_tok_len) + " "
776
+
777
+ print(gt_str)
778
+ print(gen_str)
779
+ print(diff_str)
780
+
781
+ def run_inference_on_all_samples(model, tokenizer, data, device):
782
+ """
783
+ Runs inference on ALL available samples for the trained words and compares
784
+ each one to its specific ground truth.
785
+ """
786
+ print("\n" + "="*80)
787
+ print(" INFERENCE AND EVALUATION (ALL SAMPLES)")
788
+ print(" Goal: Test the model's performance on every variant.")
789
+ print("="*80)
790
+
791
+ data_by_word = {}
792
+ for item in data:
793
+ word = item['word']
794
+ if word not in data_by_word:
795
+ data_by_word[word] = []
796
+ data_by_word[word].append(item)
797
+
798
+ for word, samples in data_by_word.items():
799
+ print(f"\n\n{'='*25} TESTING WORD: '{word}' {'='*25}")
800
+ num_correct = 0
801
+
802
+ for i, sample in enumerate(samples):
803
+ print(f"\n--- Testing Variant {i+1}/{len(samples)}: '{sample['participant_id']}' ---")
804
+
805
+ gt_tokens_str = sample.get("motion_tokens", "")
806
+ gt_wrapped = " ".join([f"<M{t}>" for t in gt_tokens_str.split()])
807
+ gt_sequence = f"{M_START} {gt_wrapped} {M_END}"
808
+ print(f"Ground Truth:\n{gt_sequence}")
809
+
810
+ prompt = f"Instruction: Generate motion for word '{sample['word']}' with variant '{sample['participant_id']}'.\nMotion: "
811
+ generated_sequence = generate_motion(model, tokenizer, prompt, device)
812
+ print(f"\nLLM Generated:\n{generated_sequence}")
813
+
814
+ compare_sequences(gt_sequence, generated_sequence)
815
+
816
+ if gt_sequence.strip() == generated_sequence.strip():
817
+ num_correct += 1
818
+
819
+ print("-" * 80)
820
+
821
+ accuracy = (num_correct / len(samples)) * 100
822
+ print(f"\nSUMMARY FOR '{word}': {num_correct}/{len(samples)} correct ({accuracy:.1f}%)")
823
+
824
+ # ======================================================================================
825
+ # 5b. Metrics: FID, Diversity, Multimodality (MIM) using MotionGPT-style utils
826
+ # ======================================================================================
827
+ def calculate_activation_statistics_np(activations: np.ndarray):
828
+ """
829
+ Params:
830
+ -- activations: num_samples x dim_feat (numpy)
831
+ Returns:
832
+ -- mu: dim_feat
833
+ -- sigma: dim_feat x dim_feat
834
+ """
835
+ mu = np.mean(activations, axis=0)
836
+ cov = np.cov(activations, rowvar=False)
837
+ return mu, cov
838
+
839
+ def calculate_frechet_distance_np(mu1, sigma1, mu2, sigma2, eps=1e-6):
840
+ """Numpy implementation of the Frechet Distance."""
841
+ mu1 = np.atleast_1d(mu1)
842
+ mu2 = np.atleast_1d(mu2)
843
+ sigma1 = np.atleast_2d(sigma1)
844
+ sigma2 = np.atleast_2d(sigma2)
845
+ assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
846
+ assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
847
+ diff = mu1 - mu2
848
+ covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
849
+ if not np.isfinite(covmean).all():
850
+ offset = np.eye(sigma1.shape[0]) * eps
851
+ covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
852
+ if np.iscomplexobj(covmean):
853
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
854
+ m = np.max(np.abs(covmean.imag))
855
+ raise ValueError(f"Imaginary component {m}")
856
+ covmean = covmean.real
857
+ tr_covmean = np.trace(covmean)
858
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
859
+
860
+ def calculate_diversity_np(activation: np.ndarray, diversity_times: int = 200) -> float:
861
+ """Mean pairwise L2 distance across random pairs."""
862
+ assert len(activation.shape) == 2
863
+ assert activation.shape[0] > max(2, diversity_times)
864
+ num_samples = activation.shape[0]
865
+ first_indices = np.random.choice(num_samples, diversity_times, replace=False)
866
+ second_indices = np.random.choice(num_samples, diversity_times, replace=False)
867
+ diffs = activation[first_indices] - activation[second_indices]
868
+ dist = np.linalg.norm(diffs, axis=1)
869
+ return float(dist.mean())
870
+
871
+ def calculate_multimodality_np(activation: np.ndarray, multimodality_times: int = 20) -> float:
872
+ """
873
+ activation: [num_labels, num_per_label, D]
874
+ Returns mean pairwise within-label diversity (higher = more multimodal).
875
+ """
876
+ assert len(activation.shape) == 3
877
+ num_labels, num_per_label, _ = activation.shape
878
+ assert num_per_label > multimodality_times
879
+ first_dices = np.random.choice(num_per_label, multimodality_times, replace=False)
880
+ second_dices = np.random.choice(num_per_label, multimodality_times, replace=False)
881
+ diffs = activation[:, first_dices] - activation[:, second_dices]
882
+ dist = np.linalg.norm(diffs, axis=2)
883
+ return float(dist.mean())
884
+
885
+ # --------------------------------------------------------------------------------------
886
+ # Token sequence → activation (bag-of-motion-tokens) helpers
887
+ # --------------------------------------------------------------------------------------
888
+ def _extract_motion_tokens_from_sequence(seq: str) -> list[str]:
889
+ # Expect tokens like <M123>, within M_START/M_END fences; keep only <M...>
890
+ return [tok for tok in seq.split() if tok.startswith("<M") and tok.endswith(">")]
891
+
892
+ def _build_token_index(tokens_vocab: list[str]) -> Dict[str, int]:
893
+ return {tok: idx for idx, tok in enumerate(tokens_vocab)}
894
+
895
+ def _sequence_to_activation(seq: str, token_to_index: Dict[str, int]) -> np.ndarray:
896
+ vec = np.zeros((len(token_to_index),), dtype=np.float32)
897
+ for tok in _extract_motion_tokens_from_sequence(seq):
898
+ idx = token_to_index.get(tok)
899
+ if idx is not None:
900
+ vec[idx] += 1.0
901
+ # Normalize to unit length to reduce length bias
902
+ norm = np.linalg.norm(vec)
903
+ if norm > 0:
904
+ vec = vec / norm
905
+ return vec
906
+
907
+ def _collect_eval_pairs(model, tokenizer, data, device) -> list[Tuple[str, str, str]]:
908
+ """
909
+ Returns list of (word, participant_id, gt_sequence, generated_sequence) for each sample in data.
910
+ """
911
+ results = []
912
+ for sample in data:
913
+ gt_tokens_str = sample.get("motion_tokens", "")
914
+ gt_wrapped = " ".join([f"<M{t}>" for t in gt_tokens_str.split()])
915
+ gt_sequence = f"{M_START} {gt_wrapped} {M_END}"
916
+ prompt = f"Instruction: Generate motion for word '{sample['word']}' with variant '{sample['participant_id']}'.\nMotion: "
917
+ generated_sequence = generate_motion(model, tokenizer, prompt, device)
918
+ pid = str(sample.get("participant_id", ""))
919
+ results.append((sample["word"], pid, gt_sequence, generated_sequence))
920
+ return results
921
+
922
+ def _activations_from_pairs(pairs: list[Tuple[str, str, str]], vocab_tokens: list[str]):
923
+ """
924
+ Build numpy activations and labels arrays from sequences.
925
+ Returns:
926
+ gt_acts: (N, D)
927
+ gen_acts: (N, D)
928
+ labels: list[str] length N (word labels)
929
+ """
930
+ token_to_index = _build_token_index(vocab_tokens)
931
+ gt_vecs = []
932
+ gen_vecs = []
933
+ labels = []
934
+ for pair in pairs:
935
+ # Support both legacy 3-tuple (word, gt, gen) and new 4-tuple (word, pid, gt, gen)
936
+ if len(pair) == 4:
937
+ word, _pid, gt_seq, gen_seq = pair
938
+ else:
939
+ word, gt_seq, gen_seq = pair
940
+ gt_vecs.append(_sequence_to_activation(gt_seq, token_to_index))
941
+ gen_vecs.append(_sequence_to_activation(gen_seq, token_to_index))
942
+ labels.append(word)
943
+ return np.stack(gt_vecs, axis=0), np.stack(gen_vecs, axis=0), labels
944
+
945
+ def _to_label_tensor3(acts: np.ndarray, labels: list[str]) -> np.ndarray:
946
+ """
947
+ Convert N x D activations with string labels to [L, K, D] by truncating each label
948
+ to the minimum count across labels.
949
+ """
950
+ label_to_indices: Dict[str, list[int]] = {}
951
+ for i, lbl in enumerate(labels):
952
+ label_to_indices.setdefault(lbl, []).append(i)
953
+ per_label_counts = [len(idxs) for idxs in label_to_indices.values()]
954
+ if len(per_label_counts) == 0:
955
+ raise ValueError("No labels found for multimodality computation.")
956
+ min_count = max(2, min(per_label_counts))
957
+ label_names = sorted(label_to_indices.keys())
958
+ stacked = []
959
+ for lbl in label_names:
960
+ idxs = label_to_indices[lbl][:min_count]
961
+ stacked.append(acts[idxs])
962
+ return np.stack(stacked, axis=0) # [L, K, D]
963
+
964
+ def evaluate_metrics_motiongpt_style(model, tokenizer, eval_data, all_motion_tokens, device):
965
+ """
966
+ Computes:
967
+ - Diversity: GT vs GEN (pair)
968
+ - Multimodality (MIM): GT vs GEN (pair)
969
+ - FID: between GT and GEN
970
+ """
971
+ print("\n" + "="*80)
972
+ print(" METRICS EVALUATION (FID, Diversity, Multimodality)")
973
+ print("="*80)
974
+ pairs = _collect_eval_pairs(model, tokenizer, eval_data, device)
975
+ gt_acts, gen_acts, labels = _activations_from_pairs(pairs, all_motion_tokens)
976
+ # Diversity
977
+ diversity_times = min(200, max(4, gt_acts.shape[0] - 1))
978
+ diversity_gt = calculate_diversity_np(gt_acts, diversity_times=diversity_times)
979
+ diversity_gen = calculate_diversity_np(gen_acts, diversity_times=diversity_times)
980
+ # Multimodality (MIM)
981
+ try:
982
+ gt_lbl_tensor = _to_label_tensor3(gt_acts, labels)
983
+ gen_lbl_tensor = _to_label_tensor3(gen_acts, labels)
984
+ multimodality_times = min(20, max(3, gt_lbl_tensor.shape[1] - 1))
985
+ mim_gt = calculate_multimodality_np(gt_lbl_tensor, multimodality_times=multimodality_times)
986
+ mim_gen = calculate_multimodality_np(gen_lbl_tensor, multimodality_times=multimodality_times)
987
+ except Exception as exc:
988
+ print(f"⚠️ Multimodality could not be computed reliably: {exc}")
989
+ mim_gt = float("nan")
990
+ mim_gen = float("nan")
991
+ # FID
992
+ mu_gen, cov_gen = calculate_activation_statistics_np(gen_acts)
993
+ mu_gt, cov_gt = calculate_activation_statistics_np(gt_acts)
994
+ fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen)
995
+ print(f"Diversity: GT = {diversity_gt:.4f} | GEN = {diversity_gen:.4f}")
996
+ print(f"Multimodality (MIM): GT = {mim_gt:.4f} | GEN = {mim_gen:.4f}")
997
+ print(f"FID (GT vs GEN): {fid:.4f}")
998
+ return {
999
+ "diversity_gt": diversity_gt,
1000
+ "diversity_gen": diversity_gen,
1001
+ "mim_gt": mim_gt,
1002
+ "mim_gen": mim_gen,
1003
+ "fid": fid,
1004
+ "pairs": pairs, # for visualization usage
1005
+ }
1006
+
1007
+ # ======================================================================================
1008
+ # 5b-ALT. Metrics using VQ-VAE codebook embeddings (near-standard activations)
1009
+ # ======================================================================================
1010
+ def _sequence_to_codebook_feature(seq: str, vq_model) -> np.ndarray:
1011
+ """
1012
+ Build a single clip feature by mean-pooling VQ-VAE codebook embeddings
1013
+ corresponding to the token ids in the sequence. L2-normalized.
1014
+ """
1015
+ token_ids = _extract_ids_from_sequence(seq)
1016
+ # Resolve code dimension and codebook availability
1017
+ quantizer = getattr(vq_model.vqvae, "quantizer", None)
1018
+ if quantizer is None:
1019
+ raise RuntimeError("VQ-VAE quantizer missing; cannot extract codebook embeddings.")
1020
+ # Try dequantize -> mean over time (preferred)
1021
+ feat_vec = None
1022
+ if hasattr(quantizer, "dequantize") and token_ids:
1023
+ try:
1024
+ idx = torch.tensor(token_ids, dtype=torch.long, device=next(vq_model.parameters()).device).unsqueeze(0)
1025
+ with torch.no_grad():
1026
+ dq = quantizer.dequantize(idx)
1027
+ if dq is not None:
1028
+ # Expect shape [N, code_dim, T]; average over T
1029
+ if dq.ndim == 3:
1030
+ if dq.shape[0] == 1:
1031
+ x = dq.squeeze(0) # [code_dim, T] or [T, code_dim]
1032
+ else:
1033
+ x = dq.mean(dim=0)
1034
+ if x.shape[0] < x.shape[1]:
1035
+ # [code_dim, T]
1036
+ feat = x.mean(dim=1)
1037
+ else:
1038
+ # [T, code_dim]
1039
+ feat = x.mean(dim=0)
1040
+ feat_vec = feat.detach().cpu().numpy().astype(np.float32)
1041
+ except Exception:
1042
+ feat_vec = None
1043
+ # Fallback: direct codebook lookup -> mean over token ids
1044
+ if feat_vec is None:
1045
+ codebook = getattr(quantizer, "codebook", None)
1046
+ if codebook is None:
1047
+ raise RuntimeError("Quantizer has neither dequantize() nor codebook.")
1048
+ code_np = codebook.detach().cpu().numpy() # [K, D]
1049
+ if not token_ids:
1050
+ feat_vec = np.zeros((code_np.shape[1],), dtype=np.float32)
1051
+ else:
1052
+ ids = np.asarray(token_ids, dtype=np.int64)
1053
+ ids = np.clip(ids, 0, code_np.shape[0] - 1)
1054
+ feat_vec = code_np[ids].mean(axis=0).astype(np.float32)
1055
+ # L2-normalize to reduce length/scale bias
1056
+ norm = np.linalg.norm(feat_vec)
1057
+ if norm > 0:
1058
+ feat_vec = feat_vec / norm
1059
+ return feat_vec
1060
+
1061
+
1062
+ def _activations_from_pairs_codebook(pairs: list[Tuple[str, str, str]], vq_model):
1063
+ """
1064
+ Produce codebook-embedding features for GT and GEN sequences and their labels.
1065
+ Returns:
1066
+ gt_feats: (N, D)
1067
+ gen_feats: (N, D)
1068
+ labels: list[str] of length N (word labels)
1069
+ """
1070
+ gt_feats = []
1071
+ gen_feats = []
1072
+ labels = []
1073
+ for pair in pairs:
1074
+ if len(pair) == 4:
1075
+ word, _pid, gt_seq, gen_seq = pair
1076
+ else:
1077
+ word, gt_seq, gen_seq = pair
1078
+ gt_feats.append(_sequence_to_codebook_feature(gt_seq, vq_model))
1079
+ gen_feats.append(_sequence_to_codebook_feature(gen_seq, vq_model))
1080
+ labels.append(word)
1081
+ return np.stack(gt_feats, axis=0), np.stack(gen_feats, axis=0), labels
1082
+
1083
+
1084
+ def evaluate_metrics_codebook_style(model, tokenizer, eval_data, device, vqvae_ckpt: Optional[str] = None):
1085
+ """
1086
+ Computes FID, Diversity, and MIM using features derived from the VQ-VAE codebook:
1087
+ - Feature per clip = mean-pooled codebook embeddings over token sequence, L2-normalized
1088
+ - Diversity/MIM computed exactly as in MotionGPT-style helpers but on these features
1089
+ - FID computed via full covariance Fréchet distance on these features
1090
+ Returns a dict mirroring evaluate_metrics_motiongpt_style.
1091
+ """
1092
+ print("\n" + "="*80)
1093
+ print(" METRICS EVALUATION (Codebook-Embedding Features)")
1094
+ print("="*80)
1095
+ # Lazy import to avoid hard dependency at module import time
1096
+ try:
1097
+ from visualize import load_vqvae, VQVAE_CHECKPOINT as DEFAULT_VQ
1098
+ vq_ckpt = vqvae_ckpt or os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ)
1099
+ vq_model = load_vqvae(vq_ckpt, device=device)
1100
+ except Exception as exc:
1101
+ print(f"⚠️ Could not load VQ-VAE for codebook metrics: {exc}")
1102
+ return {}
1103
+ # Collect pairs and build features
1104
+ pairs = _collect_eval_pairs(model, tokenizer, eval_data, device)
1105
+ gt_feats, gen_feats, labels = _activations_from_pairs_codebook(pairs, vq_model)
1106
+ # Diversity
1107
+ diversity_times = min(200, max(4, gt_feats.shape[0] - 1))
1108
+ diversity_gt = calculate_diversity_np(gt_feats, diversity_times=diversity_times)
1109
+ diversity_gen = calculate_diversity_np(gen_feats, diversity_times=diversity_times)
1110
+ # Multimodality (MIM)
1111
+ try:
1112
+ gt_lbl_tensor = _to_label_tensor3(gt_feats, labels)
1113
+ gen_lbl_tensor = _to_label_tensor3(gen_feats, labels)
1114
+ multimodality_times = min(20, max(3, gt_lbl_tensor.shape[1] - 1))
1115
+ mim_gt = calculate_multimodality_np(gt_lbl_tensor, multimodality_times=multimodality_times)
1116
+ mim_gen = calculate_multimodality_np(gen_lbl_tensor, multimodality_times=multimodality_times)
1117
+ except Exception as exc:
1118
+ print(f"⚠️ Multimodality could not be computed reliably: {exc}")
1119
+ mim_gt = float("nan")
1120
+ mim_gen = float("nan")
1121
+ # FID (on codebook features)
1122
+ mu_gen, cov_gen = calculate_activation_statistics_np(gen_feats)
1123
+ mu_gt, cov_gt = calculate_activation_statistics_np(gt_feats)
1124
+ fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen)
1125
+ print(f"Diversity (codebook feats): GT = {diversity_gt:.4f} | GEN = {diversity_gen:.4f}")
1126
+ print(f"Multimodality (MIM, codebook): GT = {mim_gt:.4f} | GEN = {mim_gen:.4f}")
1127
+ print(f"FID (codebook feats, GT vs GEN): {fid:.4f}")
1128
+ return {
1129
+ "diversity_gt": diversity_gt,
1130
+ "diversity_gen": diversity_gen,
1131
+ "mim_gt": mim_gt,
1132
+ "mim_gen": mim_gen,
1133
+ "fid": fid,
1134
+ "pairs": pairs,
1135
+ }
1136
+
1137
+ # ======================================================================================
1138
+ # 5b-ALT2. Metrics using VQ-VAE encoder pre-quantization features (as described)
1139
+ # ======================================================================================
1140
+ def _encode_params_to_feature(params: np.ndarray, vq_model, mean, std, device) -> np.ndarray:
1141
+ """
1142
+ Convert SMPL-X parameter sequence (T, D) into a single clip feature using
1143
+ the VQ-VAE encoder output BEFORE quantization. Average-pool over time to get (D_embed,).
1144
+ - Attempts to use vq_model.vqvae.preprocess; otherwise applies manual normalization with mean/std.
1145
+ - Handles encoder outputs shaped as [N, D, T] or [N, T, D_embed].
1146
+ """
1147
+ if params.size == 0:
1148
+ return np.zeros((getattr(vq_model.vqvae, "output_emb_width", 512),), dtype=np.float32)
1149
+ x = torch.from_numpy(params.astype(np.float32)).to(device) # [T, D]
1150
+ x = x.unsqueeze(0) # [1, T, D]
1151
+ with torch.no_grad():
1152
+ # Normalize / preprocess
1153
+ x_pre = None
1154
+ if hasattr(vq_model.vqvae, "preprocess"):
1155
+ try:
1156
+ x_pre = vq_model.vqvae.preprocess(x) # expected to return tensor ready for encoder
1157
+ except Exception:
1158
+ x_pre = None
1159
+ if x_pre is None:
1160
+ # Manual normalization with provided mean/std
1161
+ if mean is not None and std is not None:
1162
+ mean_t = torch.from_numpy(np.array(mean, dtype=np.float32)).to(device).view(1, 1, -1)
1163
+ std_t = torch.from_numpy(np.array(std, dtype=np.float32)).to(device).view(1, 1, -1)
1164
+ x_norm = (x - mean_t) / (std_t + 1e-8)
1165
+ else:
1166
+ x_norm = x
1167
+ # Some encoders expect [N, D, T]
1168
+ x_pre = x_norm.transpose(1, 2).contiguous() # [1, D, T]
1169
+ # Encode to get pre-quant latent
1170
+ z_e = vq_model.vqvae.encoder(x_pre)
1171
+ # z_e could be [N, D_embed, T_q] or [N, T_q, D_embed]
1172
+ if z_e.dim() == 3:
1173
+ # Determine which axis is time by comparing to known embed dim when available,
1174
+ # otherwise assume time is the smaller dimension (varies per clip).
1175
+ embed_dim_known = getattr(vq_model.vqvae, "output_emb_width", None)
1176
+ if embed_dim_known is not None:
1177
+ if z_e.shape[1] == embed_dim_known:
1178
+ time_axis = 2 # [N, D_embed, T_q]
1179
+ elif z_e.shape[2] == embed_dim_known:
1180
+ time_axis = 1 # [N, T_q, D_embed]
1181
+ else:
1182
+ time_axis = 2 if z_e.shape[2] < z_e.shape[1] else 1
1183
+ else:
1184
+ time_axis = 2 if z_e.shape[2] < z_e.shape[1] else 1
1185
+ feat = z_e.mean(dim=time_axis).squeeze(0)
1186
+ elif z_e.dim() == 2:
1187
+ feat = z_e.squeeze(0)
1188
+ else:
1189
+ # Fallback: flatten then reduce
1190
+ feat = z_e.view(1, -1).mean(dim=0)
1191
+ feat_np = feat.detach().cpu().numpy().astype(np.float32)
1192
+ # L2 normalize
1193
+ norm = np.linalg.norm(feat_np)
1194
+ if norm > 0:
1195
+ feat_np = feat_np / norm
1196
+ return feat_np
1197
+
1198
+
1199
+ def evaluate_metrics_encoder_style(
1200
+ model,
1201
+ tokenizer,
1202
+ eval_data,
1203
+ device,
1204
+ vqvae_ckpt: Optional[str] = None,
1205
+ stats_path: Optional[str] = None,
1206
+ sample_limit: int = 100,
1207
+ ):
1208
+ """
1209
+ Computes FID, Diversity, and MIM using VQ-VAE encoder pre-quantization features:
1210
+ - For each sample, decode tokens -> SMPL-X params, then run through VQ-VAE encoder,
1211
+ average-pool across time, L2-normalize to get a clip feature.
1212
+ - Diversity/MIM identical formulations but on these encoder features.
1213
+ - FID via full covariance Fréchet distance on these encoder features.
1214
+ Evaluates on up to 'sample_limit' samples for speed.
1215
+ """
1216
+ print("\n" + "="*80)
1217
+ print(" METRICS EVALUATION (VQ-VAE Encoder Features)")
1218
+ print("="*80)
1219
+ # Lazy import to reuse your visualization utilities and stats
1220
+ try:
1221
+ from visualize import load_vqvae, load_stats, VQVAE_CHECKPOINT as DEFAULT_VQ, STATS_PATH as DEFAULT_STATS
1222
+ vq_ckpt = vqvae_ckpt or os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ)
1223
+ stats_p = stats_path or os.getenv("VQVAE_STATS_PATH", DEFAULT_STATS)
1224
+ vq_model = load_vqvae(vq_ckpt, device=device)
1225
+ mean, std = load_stats(stats_p)
1226
+ from visualize import decode_tokens_to_params
1227
+ except Exception as exc:
1228
+ print(f"⚠️ Could not set up VQ-VAE encoder metrics: {exc}")
1229
+ return {}
1230
+ # Collect GT/GEN token sequences for pairs (limit to speed-up)
1231
+ pairs = _collect_eval_pairs(model, tokenizer, eval_data[:sample_limit], device)
1232
+ # Build features
1233
+ gt_feats = []
1234
+ gen_feats = []
1235
+ labels = []
1236
+ for pair in pairs:
1237
+ if len(pair) == 4:
1238
+ word, _pid, gt_seq, gen_seq = pair
1239
+ else:
1240
+ word, gt_seq, gen_seq = pair
1241
+ # Decode to SMPL-X
1242
+ tokens_gt = _extract_ids_from_sequence(gt_seq)
1243
+ tokens_gen = _extract_ids_from_sequence(gen_seq)
1244
+ try:
1245
+ params_gt = decode_tokens_to_params(tokens_gt, vq_model, mean, std, device=device) # (T, D) denorm
1246
+ except Exception:
1247
+ params_gt = np.zeros((0, 182), dtype=np.float32)
1248
+ try:
1249
+ params_gen = decode_tokens_to_params(tokens_gen, vq_model, mean, std, device=device) # (T, D) denorm
1250
+ except Exception:
1251
+ params_gen = np.zeros((0, 182), dtype=np.float32)
1252
+ # Encode (pre-quant) -> pooled feature
1253
+ feat_gt = _encode_params_to_feature(params_gt, vq_model, mean, std, device)
1254
+ feat_gen = _encode_params_to_feature(params_gen, vq_model, mean, std, device)
1255
+ gt_feats.append(feat_gt)
1256
+ gen_feats.append(feat_gen)
1257
+ labels.append(word)
1258
+ gt_feats = np.stack(gt_feats, axis=0)
1259
+ gen_feats = np.stack(gen_feats, axis=0)
1260
+ # Diversity
1261
+ diversity_times = min(200, max(4, gt_feats.shape[0] - 1))
1262
+ diversity_gt = calculate_diversity_np(gt_feats, diversity_times=diversity_times)
1263
+ diversity_gen = calculate_diversity_np(gen_feats, diversity_times=diversity_times)
1264
+ # Multimodality (MIM)
1265
+ try:
1266
+ gt_lbl_tensor = _to_label_tensor3(gt_feats, labels)
1267
+ gen_lbl_tensor = _to_label_tensor3(gen_feats, labels)
1268
+ multimodality_times = min(20, max(3, gt_lbl_tensor.shape[1] - 1))
1269
+ mim_gt = calculate_multimodality_np(gt_lbl_tensor, multimodality_times=multimodality_times)
1270
+ mim_gen = calculate_multimodality_np(gen_lbl_tensor, multimodality_times=multimodality_times)
1271
+ except Exception as exc:
1272
+ print(f"⚠️ Multimodality could not be computed reliably: {exc}")
1273
+ mim_gt = float("nan")
1274
+ mim_gen = float("nan")
1275
+ # FID (on encoder features)
1276
+ mu_gen, cov_gen = calculate_activation_statistics_np(gen_feats)
1277
+ mu_gt, cov_gt = calculate_activation_statistics_np(gt_feats)
1278
+ fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen)
1279
+ print(f"Diversity (encoder feats): GT = {diversity_gt:.4f} | GEN = {diversity_gen:.4f}")
1280
+ print(f"Multimodality (MIM, encoder): GT = {mim_gt:.4f} | GEN = {mim_gen:.4f}")
1281
+ print(f"FID (encoder feats, GT vs GEN): {fid:.4f}")
1282
+ return {
1283
+ "diversity_gt": diversity_gt,
1284
+ "diversity_gen": diversity_gen,
1285
+ "mim_gt": mim_gt,
1286
+ "mim_gen": mim_gen,
1287
+ "fid": fid,
1288
+ "pairs": pairs,
1289
+ }
1290
+
1291
+ # ======================================================================================
1292
+ # 5c. Side-by-side visualization (4 samples)
1293
+ # ======================================================================================
1294
+ def _extract_ids_from_sequence(seq: str) -> list[int]:
1295
+ return [int(t[2:-1]) for t in _extract_motion_tokens_from_sequence(seq) if t[2:-1].isdigit()]
1296
+
1297
+ def save_side_by_side_visualizations(pairs: list[Tuple[str, str, str]], output_dir: str, limit: int = 4):
1298
+ """
1299
+ Generate side-by-side 3D animations for GT vs GEN, saving one HTML per sample
1300
+ using filename scheme: word_PID_side_by_side.html.
1301
+ - Processes ALL samples for up to `limit` distinct words (if provided).
1302
+ - Requires visualize.py utilities and plotly.
1303
+ """
1304
+ try:
1305
+ from visualize import (
1306
+ load_vqvae, load_stats, load_smplx_model,
1307
+ decode_tokens_to_params, params_to_vertices,
1308
+ VQVAE_CHECKPOINT as DEFAULT_VQ, STATS_PATH as DEFAULT_STATS, SMPLX_MODEL_DIR as DEFAULT_SMPLX
1309
+ )
1310
+ import plotly.graph_objects as go
1311
+ from plotly.subplots import make_subplots
1312
+ except Exception as exc:
1313
+ print(f"⚠️ Visualization skipped (missing dependencies): {exc}")
1314
+ return
1315
+
1316
+ os.makedirs(output_dir, exist_ok=True)
1317
+ vqvae_ckpt = os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ)
1318
+ stats_path = os.getenv("VQVAE_STATS_PATH", DEFAULT_STATS)
1319
+ smplx_dir = os.getenv("SMPLX_MODEL_DIR", DEFAULT_SMPLX)
1320
+
1321
+ print("Loading VQ-VAE, stats, SMPL-X ...")
1322
+ vq_model = load_vqvae(vqvae_ckpt)
1323
+ mean, std = load_stats(stats_path)
1324
+ smplx_model = load_smplx_model(smplx_dir)
1325
+
1326
+ def animate_side_by_side(verts_left, faces, verts_right, fps=20, titles=("Ground Truth", "LLM Generated"), output_html=None):
1327
+ T = min(verts_left.shape[0], verts_right.shape[0])
1328
+ verts_left, verts_right = verts_left[:T], verts_right[:T]
1329
+ i, j, k = faces.T.tolist()
1330
+ fig = make_subplots(
1331
+ rows=1, cols=2,
1332
+ specs=[[{'type': 'scene'}, {'type': 'scene'}]],
1333
+ horizontal_spacing=0.05,
1334
+ subplot_titles=list(titles)
1335
+ )
1336
+ left_mesh = go.Mesh3d(x=verts_left[0,:,0], y=verts_left[0,:,1], z=verts_left[0,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False)
1337
+ right_mesh = go.Mesh3d(x=verts_right[0,:,0], y=verts_right[0,:,1], z=verts_right[0,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False)
1338
+ fig.add_trace(left_mesh, row=1, col=1)
1339
+ fig.add_trace(right_mesh, row=1, col=2)
1340
+ frames = []
1341
+ for t in range(T):
1342
+ frames.append(go.Frame(
1343
+ name=str(t),
1344
+ data=[
1345
+ go.Mesh3d(x=verts_left[t,:,0], y=verts_left[t,:,1], z=verts_left[t,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False,scene="scene"),
1346
+ go.Mesh3d(x=verts_right[t,:,0], y=verts_right[t,:,1], z=verts_right[t,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False,scene="scene2")
1347
+ ]
1348
+ ))
1349
+ fig.frames = frames
1350
+ fig.update_layout(
1351
+ showlegend=False,
1352
+ margin=dict(l=10, r=10, t=50, b=10),
1353
+ scene=dict(aspectmode='data',xaxis=dict(visible=False),yaxis=dict(visible=False),zaxis=dict(visible=False),
1354
+ camera=dict(eye=dict(x=0,y=-2,z=0.7))),
1355
+ scene2=dict(aspectmode='data',xaxis=dict(visible=False),yaxis=dict(visible=False),zaxis=dict(visible=False),
1356
+ camera=dict(eye=dict(x=0,y=-2,z=0.7))),
1357
+ updatemenus=[dict(
1358
+ type="buttons", x=0.5, xanchor="center", y=1.15, yanchor="top",
1359
+ buttons=[
1360
+ dict(label="Play", method="animate", args=[None, {"frame": {"duration": max(1,1000//fps), "redraw": True}, "fromcurrent": True}]),
1361
+ dict(label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}}])
1362
+ ]
1363
+ )]
1364
+ )
1365
+ if output_html:
1366
+ fig.write_html(output_html)
1367
+ print(f"✅ Saved: {output_html}")
1368
+ return fig
1369
+
1370
+ # Determine which words to include (up to `limit` distinct words)
1371
+ allowed_words = None
1372
+ if isinstance(limit, int) and limit > 0:
1373
+ ordered_unique_words = []
1374
+ for pair in pairs:
1375
+ word = pair[0]
1376
+ if word not in ordered_unique_words:
1377
+ ordered_unique_words.append(word)
1378
+ if len(ordered_unique_words) >= limit:
1379
+ break
1380
+ allowed_words = set(ordered_unique_words)
1381
+
1382
+ for pair in pairs:
1383
+ try:
1384
+ if len(pair) == 4:
1385
+ word, pid, gt_seq, gen_seq = pair
1386
+ else:
1387
+ word, gt_seq, gen_seq = pair
1388
+ pid = "unknown"
1389
+ if allowed_words is not None and word not in allowed_words:
1390
+ continue
1391
+ tokens_gt = _extract_ids_from_sequence(gt_seq)
1392
+ tokens_gen = _extract_ids_from_sequence(gen_seq)
1393
+ params_gt = decode_tokens_to_params(tokens_gt, vq_model, mean, std)
1394
+ params_gen = decode_tokens_to_params(tokens_gen, vq_model, mean, std)
1395
+ verts_gt, faces = params_to_vertices(params_gt, smplx_model)
1396
+ verts_gen, _ = params_to_vertices(params_gen, smplx_model)
1397
+ out_dir = os.path.join(output_dir)
1398
+ os.makedirs(out_dir, exist_ok=True)
1399
+ # Sanitize for filesystem safety
1400
+ safe_word = re.sub(r'[^A-Za-z0-9_-]+', '_', str(word))
1401
+ safe_pid = re.sub(r'[^A-Za-z0-9_-]+', '_', str(pid))
1402
+ output_html = os.path.join(out_dir, f"word_{safe_word}_{safe_pid}_side_by_side.html")
1403
+ animate_side_by_side(
1404
+ verts_left=verts_gt,
1405
+ faces=faces,
1406
+ verts_right=verts_gen,
1407
+ fps=20,
1408
+ titles=("Ground Truth", "LLM Generated"),
1409
+ output_html=output_html
1410
+ )
1411
+ except Exception as exc:
1412
+ print(f"⚠️ Error creating visualization for word '{pair[0]}': {exc}")
1413
+
1414
+ # ======================================================================================
1415
+ # 6. Main Execution Block (UPDATED)
1416
+ # ======================================================================================
1417
+ def main(config_overrides: Optional[Dict[str, Any]] = None):
1418
+ """Main function to run the entire pipeline."""
1419
+ apply_config_overrides(config_overrides)
1420
+ if config_overrides:
1421
+ printable = {k: v for k, v in config_overrides.items() if "token" not in k.lower()}
1422
+ if printable:
1423
+ print("\nApplied config overrides:")
1424
+ for key, value in printable.items():
1425
+ print(f" - {key} = {value}")
1426
+ random.seed(42)
1427
+ torch.manual_seed(42)
1428
+
1429
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1430
+ print(f"Using device: {device}")
1431
+
1432
+ # 1. Load ALL data
1433
+ all_entries = read_json_data(DATASET_PATH)
1434
+
1435
+ # 2. Clean the ENTIRE dataset and get all tokens
1436
+ cleaned_data, all_motion_tokens = deduplicate_and_prepare_data(all_entries)
1437
+
1438
+ # 3. Stage 1: Initialize or resume from HF, then train
1439
+ resolved_stage1_repo = _resolve_and_ensure_repo(HF_STAGE1_REPO_ID) if HF_USE_HUB else None
1440
+ start_epoch_s1 = 0
1441
+ stage1_loaded = None
1442
+ if resolved_stage1_repo:
1443
+ if _repo_has_stage_latest(resolved_stage1_repo, "stage1"):
1444
+ stage1_loaded = _load_model_and_tokenizer_from_hf(resolved_stage1_repo, "stage1")
1445
+ state_s1 = _download_training_state(resolved_stage1_repo, "stage1")
1446
+ if state_s1 and isinstance(state_s1.get("epoch_completed"), int):
1447
+ start_epoch_s1 = state_s1["epoch_completed"]
1448
+ else:
1449
+ # Fallback: no 'latest' folder; select highest epoch-XXX
1450
+ latest_s1_sub = _repo_get_latest_epoch_subfolder(resolved_stage1_repo, "stage1")
1451
+ if latest_s1_sub:
1452
+ stage1_loaded = _load_model_and_tokenizer_from_hf_subfolder(resolved_stage1_repo, latest_s1_sub)
1453
+ state_s1 = _download_training_state_from_subfolder(resolved_stage1_repo, latest_s1_sub)
1454
+ if state_s1 and isinstance(state_s1.get("epoch_completed"), int):
1455
+ start_epoch_s1 = state_s1["epoch_completed"]
1456
+
1457
+ if stage1_loaded:
1458
+ base_model, tokenizer = stage1_loaded
1459
+ # Ensure tokenizer contains all motion tokens (add missing if dataset expanded)
1460
+ added = _ensure_tokenizer_has_motion_tokens(tokenizer, all_motion_tokens)
1461
+ if added > 0:
1462
+ base_model.resize_token_embeddings(len(tokenizer))
1463
+ else:
1464
+ base_model, tokenizer = setup_model_and_tokenizer(MODEL_NAME, all_motion_tokens)
1465
+
1466
+ print(f"\nStarting Stage 1 training on {len(cleaned_data)} samples (resume from epoch {start_epoch_s1}).")
1467
+ motion_model = train_stage1(
1468
+ base_model,
1469
+ tokenizer,
1470
+ cleaned_data,
1471
+ device,
1472
+ start_epoch=start_epoch_s1,
1473
+ hf_repo_id=resolved_stage1_repo,
1474
+ )
1475
+
1476
+ # 4. Stage 2: Initialize or resume from HF, then train
1477
+ resolved_stage2_repo = _resolve_and_ensure_repo(HF_STAGE2_REPO_ID) if HF_USE_HUB else None
1478
+ start_epoch_s2 = 0
1479
+ stage2_loaded = None
1480
+ print(f"\nStage 2 resume policy: FORCE_STAGE2_FROM_STAGE1={FORCE_STAGE2_FROM_STAGE1}, save_subdir='{HF_STAGE2_SAVE_SUBDIR}'")
1481
+ # For this run we want Stage 2 to start from Stage 1 epoch-20 even if an old stage2 exists.
1482
+ # Only resume Stage 2 if explicitly allowed and if there is a checkpoint under the save subdir.
1483
+ if not FORCE_STAGE2_FROM_STAGE1 and resolved_stage2_repo:
1484
+ # Prefer loading from the configured Stage 2 save subdir (e.g., 'stage2_v2')
1485
+ if _repo_has_stage_latest(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR):
1486
+ stage2_loaded = _load_model_and_tokenizer_from_hf(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR)
1487
+ state_s2 = _download_training_state(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR)
1488
+ if state_s2 and isinstance(state_s2.get("epoch_completed"), int):
1489
+ start_epoch_s2 = state_s2["epoch_completed"]
1490
+ print(f"Resuming Stage 2 from HF subfolder: {HF_STAGE2_SAVE_SUBDIR}/latest (epoch_completed={start_epoch_s2})")
1491
+ else:
1492
+ latest_s2_sub = _repo_get_latest_epoch_subfolder(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR)
1493
+ if latest_s2_sub:
1494
+ stage2_loaded = _load_model_and_tokenizer_from_hf_subfolder(resolved_stage2_repo, latest_s2_sub)
1495
+ state_s2 = _download_training_state_from_subfolder(resolved_stage2_repo, latest_s2_sub)
1496
+ if state_s2 and isinstance(state_s2.get("epoch_completed"), int):
1497
+ start_epoch_s2 = state_s2["epoch_completed"]
1498
+ print(f"Resuming Stage 2 from HF subfolder: {latest_s2_sub} (epoch_completed={start_epoch_s2})")
1499
+
1500
+ if stage2_loaded:
1501
+ stage2_model, tokenizer = stage2_loaded
1502
+ added2 = _ensure_tokenizer_has_motion_tokens(tokenizer, all_motion_tokens)
1503
+ if added2 > 0:
1504
+ stage2_model.resize_token_embeddings(len(tokenizer))
1505
+ else:
1506
+ stage2_model = motion_model # Start Stage 2 from Stage 1 model
1507
+
1508
+ print(f"\nStarting Stage 2 training on {len(cleaned_data)} samples (resume from epoch {start_epoch_s2}).")
1509
+ final_model = train_stage2(
1510
+ stage2_model,
1511
+ tokenizer,
1512
+ cleaned_data,
1513
+ device,
1514
+ start_epoch=start_epoch_s2,
1515
+ hf_repo_id=resolved_stage2_repo,
1516
+ hf_stage_subdir=HF_STAGE2_SAVE_SUBDIR,
1517
+ )
1518
+
1519
+ # 5. Filter the cleaned data to get a smaller set for evaluation
1520
+ # This keeps the evaluation focused on our benchmark words and the logs readable
1521
+ print("\n--- Filtering data for evaluation on specific words ---")
1522
+ evaluation_data = [item for item in cleaned_data if item['word'].lower() in EVALUATION_WORDS]
1523
+ print(f"Found {len(evaluation_data)} samples for evaluation words: {EVALUATION_WORDS}")
1524
+
1525
+ # 6. Metrics-only mode or full flow
1526
+ if RUN_EVALS_ONLY:
1527
+ # Compute the 3 metrics using VQ-VAE encoder features and save to JSON
1528
+ metrics_enc = evaluate_metrics_encoder_style(
1529
+ final_model, tokenizer, evaluation_data, device, sample_limit=EVAL_SAMPLE_LIMIT
1530
+ )
1531
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
1532
+ metrics_payload = {
1533
+ "source": "vqvae_encoder",
1534
+ "fid": metrics_enc.get("fid"),
1535
+ "diversity": {
1536
+ "ground_truth": metrics_enc.get("diversity_gt"),
1537
+ "model": metrics_enc.get("diversity_gen"),
1538
+ },
1539
+ "multimodality": {
1540
+ "ground_truth": metrics_enc.get("mim_gt"),
1541
+ "model": metrics_enc.get("mim_gen"),
1542
+ },
1543
+ "num_pairs": len(metrics_enc.get("pairs", [])),
1544
+ }
1545
+ with open(METRICS_JSON_PATH, "w", encoding="utf-8") as f:
1546
+ json.dump(metrics_payload, f, ensure_ascii=False, indent=2)
1547
+ print(f"\n✅ Saved metrics to {METRICS_JSON_PATH}")
1548
+ return
1549
+
1550
+ # Full flow: inference logs + MotionGPT-style metrics + encoder metrics + visualizations
1551
+ run_inference_on_all_samples(final_model, tokenizer, evaluation_data, device)
1552
+ metrics_token = evaluate_metrics_motiongpt_style(final_model, tokenizer, evaluation_data, all_motion_tokens, device)
1553
+ # Also compute encoder-based 3 metrics
1554
+ metrics_enc = evaluate_metrics_encoder_style(
1555
+ final_model, tokenizer, evaluation_data, device, sample_limit=EVAL_SAMPLE_LIMIT
1556
+ )
1557
+ # Visualizations (skip if metrics-only)
1558
+ viz_dir = os.path.join(OUTPUT_DIR, "html_visualizations")
1559
+ save_side_by_side_visualizations(metrics_token["pairs"], viz_dir, limit=4)
1560
+
1561
+ if __name__ == "__main__":
1562
+ main()
train.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training utilities and functions
3
+ """
4
+ import math
5
+ import os
6
+ import re
7
+ import time
8
+ import json
9
+ import shutil
10
+ import torch
11
+ from datetime import datetime
12
+ from typing import Optional, Dict, Any, List, Tuple
13
+ from torch.optim import AdamW
14
+ from torch.utils.data import DataLoader
15
+
16
+ from transformers import TrainingArguments, Trainer, AutoModelForCausalLM, AutoTokenizer
17
+ from transformers.trainer_callback import TrainerCallback
18
+ from huggingface_hub import HfApi, upload_folder, snapshot_download, hf_hub_download
19
+
20
+ from config import (
21
+ BATCH_TRAIN, BATCH_EVAL, GRAD_ACCUM, LR, WARMUP,
22
+ LOG_STEPS, EVAL_STEPS, SAVE_STEPS, SEED, DTYPE,
23
+ HUB_REPO_S1, HUB_REPO_S2, HUB_REPO_S3, HF_TOKEN,
24
+ CHECKPOINTS_DIR, HF_USE_HUB, CHECKPOINT_UPLOAD_INTERVAL_EPOCHS,
25
+ S1_BATCH_SIZE, S1_LR, S1_EPOCHS, S2_BATCH_SIZE, S2_LR, S2_EPOCHS,
26
+ PAD_TOKEN, M_START, M_END
27
+ )
28
+
29
+ # ======================================================================================
30
+ # Logic from test_overfit.py (Raw Training Loops & HF Utils)
31
+ # ======================================================================================
32
+
33
+ def _format_seconds(seconds: float) -> str:
34
+ """Formats seconds into H:MM:SS or M:SS."""
35
+ seconds = int(max(0, seconds))
36
+ h = seconds // 3600
37
+ m = (seconds % 3600) // 60
38
+ s = seconds % 60
39
+ if h > 0:
40
+ return f"{h:d}:{m:02d}:{s:02d}"
41
+ return f"{m:d}:{s:02d}"
42
+
43
+ def _ensure_dir(path: str) -> None:
44
+ os.makedirs(path, exist_ok=True)
45
+
46
+ def resolve_and_ensure_repo(repo_id: str, hf_auth_token: Optional[str] = None) -> Optional[str]:
47
+ """
48
+ Ensures the HF repo exists. Returns the fully-qualified repo_id (namespace/repo)
49
+ when token is available; otherwise returns the input repo_id.
50
+ """
51
+ if not HF_USE_HUB:
52
+ return None
53
+ token = hf_auth_token or HF_TOKEN
54
+ if not token:
55
+ print("⚠️ HF token not found. Set HUGGINGFACE_HUB_TOKEN to enable Hub sync.")
56
+ return None
57
+ api = HfApi()
58
+ try:
59
+ who = api.whoami(token=token)
60
+ namespace = who.get("name") or (who.get("orgs", [None])[0] if isinstance(who.get("orgs"), list) else None)
61
+ except Exception as exc:
62
+ print(f"⚠️ Unable to resolve HF namespace: {exc}")
63
+ namespace = None
64
+ if "/" not in repo_id and namespace:
65
+ full_repo_id = f"{namespace}/{repo_id}"
66
+ else:
67
+ full_repo_id = repo_id
68
+ try:
69
+ api.create_repo(
70
+ repo_id=full_repo_id,
71
+ token=token,
72
+ repo_type="model",
73
+ private=True, # Default to private as in test_overfit config if not specified
74
+ exist_ok=True,
75
+ )
76
+ except Exception as exc:
77
+ print(f"⚠️ create_repo failed (may already exist): {exc}")
78
+ return full_repo_id
79
+
80
+ def repo_has_stage_latest(repo_id: str, stage: str, hf_auth_token: Optional[str] = None) -> bool:
81
+ """Checks if a stage/latest checkpoint exists in the HF repo."""
82
+ token = hf_auth_token or HF_TOKEN
83
+ if not HF_USE_HUB or not token:
84
+ return False
85
+ api = HfApi()
86
+ try:
87
+ files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token)
88
+ return any(path.startswith(f"{stage}/latest/") and path.endswith("config.json") for path in files)
89
+ except Exception as exc:
90
+ print(f"⚠️ Could not list files for {repo_id}: {exc}")
91
+ return False
92
+
93
+ def repo_list_epoch_numbers(repo_id: str, stage: str, hf_auth_token: Optional[str] = None) -> List[int]:
94
+ """
95
+ Returns sorted list of epoch numbers available under {stage}/epoch-XXX/ by scanning files.
96
+ """
97
+ token = hf_auth_token or HF_TOKEN
98
+ if not HF_USE_HUB or not token:
99
+ return []
100
+ api = HfApi()
101
+ try:
102
+ files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token)
103
+ except Exception as exc:
104
+ print(f"⚠️ Could not list files for {repo_id}: {exc}")
105
+ return []
106
+ epoch_numbers: List[int] = []
107
+ pattern = re.compile(rf"^{re.escape(stage)}/epoch-(\d+)/config\.json$")
108
+ for path in files:
109
+ m = pattern.match(path)
110
+ if m:
111
+ try:
112
+ epoch_numbers.append(int(m.group(1)))
113
+ except ValueError:
114
+ pass
115
+ return sorted(set(epoch_numbers))
116
+
117
+ def repo_get_latest_epoch_subfolder(repo_id: str, stage: str, hf_auth_token: Optional[str] = None) -> Optional[str]:
118
+ """
119
+ Returns subfolder path like '{stage}/epoch-XXX' for the highest available epoch, or None.
120
+ """
121
+ epochs = repo_list_epoch_numbers(repo_id, stage, hf_auth_token)
122
+ if not epochs:
123
+ return None
124
+ latest = max(epochs)
125
+ return f"{stage}/epoch-{latest:03d}"
126
+
127
+ def load_model_and_tokenizer_from_hf_subfolder(repo_id: str, subfolder: str, hf_auth_token: Optional[str] = None) -> Optional[Tuple[AutoModelForCausalLM, AutoTokenizer]]:
128
+ """
129
+ Loads model and tokenizer from HF under a specific subfolder.
130
+ """
131
+ if not HF_USE_HUB or (not hf_auth_token and not HF_TOKEN):
132
+ return None
133
+ print(f"\n---> Loading checkpoint from Hugging Face: {repo_id} (subfolder='{subfolder}')")
134
+ try:
135
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder, trust_remote_code=True)
136
+ model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=subfolder, trust_remote_code=True)
137
+ except Exception as exc:
138
+ print(f"⚠️ Failed to load model/tokenizer from subfolder '{subfolder}': {exc}")
139
+ return None
140
+ if tokenizer.pad_token is None:
141
+ tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
142
+ model.resize_token_embeddings(len(tokenizer))
143
+ model.config.pad_token_id = tokenizer.pad_token_id
144
+ return model, tokenizer
145
+
146
+ def download_training_state_from_subfolder(repo_id: str, subfolder: str, hf_auth_token: Optional[str] = None) -> Optional[Dict[str, Any]]:
147
+ """
148
+ Downloads training_state.json from a specific subfolder.
149
+ """
150
+ token = hf_auth_token or HF_TOKEN
151
+ if not HF_USE_HUB or not token:
152
+ return None
153
+ try:
154
+ state_path = hf_hub_download(
155
+ repo_id=repo_id,
156
+ filename=f"{subfolder}/training_state.json",
157
+ repo_type="model",
158
+ token=token,
159
+ )
160
+ with open(state_path, "r", encoding="utf-8") as f:
161
+ return json.load(f)
162
+ except Exception:
163
+ return None
164
+
165
+ def download_training_state(repo_id: str, stage: str, hf_auth_token: Optional[str] = None) -> Optional[Dict[str, Any]]:
166
+ """Downloads training_state.json from HF if present."""
167
+ token = hf_auth_token or HF_TOKEN
168
+ if not HF_USE_HUB or not token:
169
+ return None
170
+ try:
171
+ state_path = hf_hub_download(
172
+ repo_id=repo_id,
173
+ filename=f"{stage}/latest/training_state.json",
174
+ repo_type="model",
175
+ token=token,
176
+ )
177
+ with open(state_path, "r", encoding="utf-8") as f:
178
+ return json.load(f)
179
+ except Exception:
180
+ return None
181
+
182
+ def download_optimizer_state(repo_id: str, stage: str, hf_auth_token: Optional[str] = None) -> Optional[str]:
183
+ """Downloads optimizer.pt for resuming optimizer state."""
184
+ token = hf_auth_token or HF_TOKEN
185
+ if not HF_USE_HUB or not token:
186
+ return None
187
+ try:
188
+ opt_path = hf_hub_download(
189
+ repo_id=repo_id,
190
+ filename=f"{stage}/latest/optimizer.pt",
191
+ repo_type="model",
192
+ token=token,
193
+ )
194
+ return opt_path
195
+ except Exception:
196
+ return None
197
+
198
+ def load_model_and_tokenizer_from_hf(repo_id: str, stage: str, hf_auth_token: Optional[str] = None) -> Optional[Tuple[AutoModelForCausalLM, AutoTokenizer]]:
199
+ """
200
+ Loads model and tokenizer from HF under subfolder {stage}/latest if available.
201
+ """
202
+ if not repo_has_stage_latest(repo_id, stage, hf_auth_token):
203
+ return None
204
+ print(f"\n---> Loading checkpoint from Hugging Face: {repo_id} (subfolder='{stage}/latest')")
205
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=f"{stage}/latest", trust_remote_code=True)
206
+ model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=f"{stage}/latest", trust_remote_code=True)
207
+ if tokenizer.pad_token is None:
208
+ tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
209
+ model.resize_token_embeddings(len(tokenizer))
210
+ model.config.pad_token_id = tokenizer.pad_token_id
211
+ return model, tokenizer
212
+
213
+ def save_and_push_checkpoint(
214
+ stage: str,
215
+ epoch_index_zero_based: int,
216
+ model: AutoModelForCausalLM,
217
+ tokenizer: AutoTokenizer,
218
+ optimizer: AdamW,
219
+ avg_loss: float,
220
+ dataloader_len: int,
221
+ batch_size: int,
222
+ total_epochs: int,
223
+ repo_id: Optional[str],
224
+ hf_auth_token: Optional[str] = None
225
+ ) -> None:
226
+ """
227
+ Saves checkpoint locally and pushes to HF.
228
+ """
229
+ token = hf_auth_token or HF_TOKEN
230
+ epoch_number = epoch_index_zero_based + 1
231
+ stage_dir = os.path.join(CHECKPOINTS_DIR, stage)
232
+ epoch_dir_name = f"epoch-{epoch_number:03d}"
233
+ epoch_dir = os.path.join(stage_dir, epoch_dir_name)
234
+ latest_dir = os.path.join(stage_dir, "latest")
235
+ _ensure_dir(epoch_dir)
236
+ _ensure_dir(stage_dir)
237
+
238
+ # Save model + tokenizer
239
+ model.save_pretrained(epoch_dir)
240
+ tokenizer.save_pretrained(epoch_dir)
241
+
242
+ # Save optimizer state
243
+ torch.save(optimizer.state_dict(), os.path.join(epoch_dir, "optimizer.pt"))
244
+
245
+ # Save training state
246
+ training_state = {
247
+ "stage": stage,
248
+ "epoch_completed": epoch_number,
249
+ "total_epochs_for_stage": total_epochs,
250
+ "global_step": epoch_number * dataloader_len,
251
+ "avg_loss": float(avg_loss),
252
+ "batch_size": batch_size,
253
+ "saved_at": datetime.utcnow().isoformat() + "Z",
254
+ }
255
+ with open(os.path.join(epoch_dir, "training_state.json"), "w", encoding="utf-8") as f:
256
+ json.dump(training_state, f, ensure_ascii=False, indent=2)
257
+
258
+ # Update "latest"
259
+ if os.path.exists(latest_dir):
260
+ shutil.rmtree(latest_dir)
261
+ shutil.copytree(epoch_dir, latest_dir)
262
+
263
+ # Push to Hugging Face
264
+ if HF_USE_HUB and repo_id and token:
265
+ try:
266
+ upload_folder(
267
+ repo_id=repo_id,
268
+ folder_path=epoch_dir,
269
+ path_in_repo=f"{stage}/{epoch_dir_name}",
270
+ repo_type="model",
271
+ token=token,
272
+ commit_message=f"{stage}: save {epoch_dir_name}",
273
+ )
274
+ upload_folder(
275
+ repo_id=repo_id,
276
+ folder_path=latest_dir,
277
+ path_in_repo=f"{stage}/latest",
278
+ repo_type="model",
279
+ token=token,
280
+ commit_message=f"{stage}: update latest -> {epoch_dir_name}",
281
+ )
282
+ print(f"☁️ Pushed checkpoint to HF: {repo_id} ({stage}/{epoch_dir_name} and {stage}/latest)")
283
+ except Exception as exc:
284
+ print(f"⚠️ Failed to push checkpoint to HF: {exc}")
285
+ else:
286
+ print("ℹ️ Skipped HF push (Hub disabled or token/repo missing).")
287
+
288
+ def train_stage1_raw(
289
+ model,
290
+ tokenizer,
291
+ data: List[Dict[str, Any]],
292
+ device,
293
+ start_epoch: int = 0,
294
+ hf_repo_id: Optional[str] = None,
295
+ ):
296
+ """Trains the model on motion sequences only to learn the 'language of motion'."""
297
+ from data import MotionDataset # Import here to avoid circular imports
298
+
299
+ print("\n" + "="*80)
300
+ print(" STAGE 1: MOTION LANGUAGE MODELING (PRE-TRAINING)")
301
+ print(f" Training on {len(data)} samples.")
302
+ print("="*80)
303
+
304
+ dataset = MotionDataset(data, tokenizer)
305
+ dataloader = DataLoader(dataset, batch_size=S1_BATCH_SIZE, shuffle=True)
306
+
307
+ optimizer = AdamW(model.parameters(), lr=S1_LR)
308
+ model.to(device)
309
+ model.train()
310
+
311
+ # Try to resume optimizer if we resumed from HF
312
+ token = HF_TOKEN
313
+ if hf_repo_id and start_epoch > 0 and HF_USE_HUB and token:
314
+ opt_path = download_optimizer_state(hf_repo_id, "stage1", token)
315
+ if opt_path is not None:
316
+ try:
317
+ optimizer.load_state_dict(torch.load(opt_path, map_location=device))
318
+ print("↩️ Resumed optimizer state for Stage 1 from HF.")
319
+ except Exception as exc:
320
+ print(f"⚠️ Failed to load optimizer state for Stage 1: {exc}")
321
+
322
+ for epoch in range(start_epoch, S1_EPOCHS):
323
+ total_loss = 0
324
+ total_batches = len(dataloader)
325
+ epoch_start_time = time.time()
326
+ step_interval = max(1, total_batches // 50) # ~2% progress updates
327
+ for i, batch in enumerate(dataloader, 1):
328
+ optimizer.zero_grad()
329
+
330
+ input_ids = batch['input_ids'].squeeze(1).to(device)
331
+ attention_mask = batch['attention_mask'].squeeze(1).to(device)
332
+
333
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
334
+
335
+ loss = outputs.loss
336
+ loss.backward()
337
+ optimizer.step()
338
+ total_loss += loss.item()
339
+
340
+ # Progress with ETA
341
+ if i == 1 or (i % step_interval == 0) or (i == total_batches):
342
+ elapsed = time.time() - epoch_start_time
343
+ est_total = (elapsed / i) * total_batches
344
+ eta = est_total - elapsed
345
+ pct = (i / total_batches) * 100.0
346
+ print(
347
+ f"\r[Stage 1] Epoch {epoch+1}/{S1_EPOCHS} - "
348
+ f"{i}/{total_batches} ({pct:.1f}%) - ETA {_format_seconds(eta)}",
349
+ end="",
350
+ flush=True,
351
+ )
352
+
353
+ # Finish the progress line
354
+ print()
355
+ avg_loss = total_loss / len(dataloader)
356
+ print(f"--- End of Epoch {epoch+1}/{S1_EPOCHS}, Average Loss: {avg_loss:.4f} ---")
357
+ # Save checkpoint locally every epoch; push to HF only at interval or final epoch
358
+ push_this_epoch = ((epoch + 1) % CHECKPOINT_UPLOAD_INTERVAL_EPOCHS == 0) or ((epoch + 1) == S1_EPOCHS)
359
+ repo_for_epoch = hf_repo_id if push_this_epoch else None
360
+ save_and_push_checkpoint(
361
+ stage="stage1",
362
+ epoch_index_zero_based=epoch,
363
+ model=model,
364
+ tokenizer=tokenizer,
365
+ optimizer=optimizer,
366
+ avg_loss=avg_loss,
367
+ dataloader_len=len(dataloader),
368
+ batch_size=S1_BATCH_SIZE,
369
+ total_epochs=S1_EPOCHS,
370
+ repo_id=repo_for_epoch,
371
+ hf_auth_token=token
372
+ )
373
+
374
+ print("\n✅ Stage 1 Training Complete.")
375
+ return model
376
+
377
+ def train_stage2_raw(
378
+ model,
379
+ tokenizer,
380
+ data: List[Dict[str, Any]],
381
+ device,
382
+ start_epoch: int = 0,
383
+ hf_repo_id: Optional[str] = None,
384
+ hf_stage_subdir: str = "stage2",
385
+ ):
386
+ """Fine-tunes the motion-aware model to connect text prompts to motions."""
387
+ from data import TextMotionDataset # Import here to avoid circular imports
388
+
389
+ print("\n" + "="*80)
390
+ print(" STAGE 2: TEXT-TO-MOTION FINE-TUNING")
391
+ print(f" Training on {len(data)} samples.")
392
+ print("="*80)
393
+
394
+ dataset = TextMotionDataset(data, tokenizer)
395
+ dataloader = DataLoader(dataset, batch_size=S2_BATCH_SIZE, shuffle=True)
396
+
397
+ optimizer = AdamW(model.parameters(), lr=S2_LR)
398
+ model.to(device)
399
+ model.train()
400
+
401
+ # Try to resume optimizer if we resumed from HF
402
+ token = HF_TOKEN
403
+ if hf_repo_id and start_epoch > 0 and HF_USE_HUB and token:
404
+ opt_path = download_optimizer_state(hf_repo_id, hf_stage_subdir, token)
405
+ if opt_path is not None:
406
+ try:
407
+ optimizer.load_state_dict(torch.load(opt_path, map_location=device))
408
+ print("↩️ Resumed optimizer state for Stage 2 from HF.")
409
+ except Exception as exc:
410
+ print(f"⚠️ Failed to load optimizer state for Stage 2: {exc}")
411
+
412
+ for epoch in range(start_epoch, S2_EPOCHS):
413
+ total_loss = 0
414
+ total_batches = len(dataloader)
415
+ epoch_start_time = time.time()
416
+ step_interval = max(1, total_batches // 50) # ~2% progress updates
417
+ for i, batch in enumerate(dataloader, 1):
418
+ optimizer.zero_grad()
419
+
420
+ input_ids = batch['input_ids'].to(device)
421
+ attention_mask = batch['attention_mask'].to(device)
422
+ labels = batch['labels'].to(device)
423
+
424
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
425
+
426
+ loss = outputs.loss
427
+ loss.backward()
428
+ optimizer.step()
429
+ total_loss += loss.item()
430
+
431
+ # Progress with ETA
432
+ if i == 1 or (i % step_interval == 0) or (i == total_batches):
433
+ elapsed = time.time() - epoch_start_time
434
+ est_total = (elapsed / i) * total_batches
435
+ eta = est_total - elapsed
436
+ pct = (i / total_batches) * 100.0
437
+ print(
438
+ f"\r[Stage 2] Epoch {epoch+1}/{S2_EPOCHS} - "
439
+ f"{i}/{total_batches} ({pct:.1f}%) - ETA {_format_seconds(eta)}",
440
+ end="",
441
+ flush=True,
442
+ )
443
+
444
+ # Finish the progress line
445
+ print()
446
+ avg_loss = total_loss / len(dataloader)
447
+ print(f"--- End of Epoch {epoch+1}/{S2_EPOCHS}, Average Loss: {avg_loss:.4f} ---")
448
+ # Save checkpoint locally every epoch; push to HF only at interval or final epoch
449
+ push_this_epoch = ((epoch + 1) % CHECKPOINT_UPLOAD_INTERVAL_EPOCHS == 0) or ((epoch + 1) == S2_EPOCHS)
450
+ repo_for_epoch = hf_repo_id if push_this_epoch else None
451
+ save_and_push_checkpoint(
452
+ stage=hf_stage_subdir,
453
+ epoch_index_zero_based=epoch,
454
+ model=model,
455
+ tokenizer=tokenizer,
456
+ optimizer=optimizer,
457
+ avg_loss=avg_loss,
458
+ dataloader_len=len(dataloader),
459
+ batch_size=S2_BATCH_SIZE,
460
+ total_epochs=S2_EPOCHS,
461
+ repo_id=repo_for_epoch,
462
+ hf_auth_token=token
463
+ )
464
+
465
+ print("\n✅ Stage 2 Training Complete.")
466
+ return model
467
+
468
+ # ======================================================================================
469
+ # Existing Utilities
470
+ # ======================================================================================
471
+
472
+ def make_training_args(out_dir: str, epochs: int, two_point_hub: bool = False) -> TrainingArguments:
473
+ """
474
+ Create TrainingArguments for a training stage
475
+ """
476
+ return TrainingArguments(
477
+ output_dir=out_dir,
478
+ per_device_train_batch_size=BATCH_TRAIN,
479
+ per_device_eval_batch_size=BATCH_EVAL,
480
+ gradient_accumulation_steps=GRAD_ACCUM,
481
+ learning_rate=LR,
482
+ num_train_epochs=epochs,
483
+ logging_steps=LOG_STEPS,
484
+ eval_strategy="steps",
485
+ eval_steps=EVAL_STEPS,
486
+ # When using two-point hub checkpointing, disable periodic local saves and rely on forced saves
487
+ save_steps=(10**12 if two_point_hub else SAVE_STEPS),
488
+ save_total_limit=2,
489
+ warmup_ratio=WARMUP,
490
+ bf16=(DTYPE == torch.bfloat16),
491
+ fp16=(DTYPE == torch.float16),
492
+ lr_scheduler_type="cosine",
493
+ optim="adamw_torch",
494
+ report_to="none",
495
+ seed=SEED,
496
+ remove_unused_columns=False,
497
+ )
498
+
499
+
500
+ def latest_hub_checkpoint(repo_id: str) -> Optional[str]:
501
+ """
502
+ Download and return the local path to the latest checkpoint folder from the Hub.
503
+ Returns None if no checkpoint exists or on failure.
504
+ """
505
+ api = HfApi()
506
+ try:
507
+ files = api.list_repo_files(repo_id=repo_id, repo_type="model")
508
+ except Exception as e:
509
+ print(f"Hub list failed for {repo_id}: {e}")
510
+ return None
511
+
512
+ def _step_key(dirname: str) -> int:
513
+ nums = re.findall(r"\d+", dirname)
514
+ return int(nums[-1]) if nums else -1
515
+
516
+ ckpt_dirs = sorted(
517
+ {p.split('/')[0] for p in files if p.startswith("checkpoint-")},
518
+ key=_step_key,
519
+ )
520
+ if not ckpt_dirs:
521
+ return None
522
+ latest = ckpt_dirs[-1]
523
+ local_root = snapshot_download(
524
+ repo_id=repo_id,
525
+ repo_type="model",
526
+ allow_patterns=[f"{latest}/**", "trainer_state.json"],
527
+ local_dir_use_symlinks=False,
528
+ )
529
+ return os.path.join(local_root, latest)
530
+
531
+
532
+ class TwoPointHubCheckpointCallback(TrainerCallback):
533
+ """
534
+ Save to Hugging Face Hub exactly twice per training: halfway and at final step.
535
+ Keeps only the most recent N checkpoints on Hub.
536
+ """
537
+
538
+ def __init__(self, repo_id: str, keep_last: int = 2, token: Optional[str] = None):
539
+ self.repo_id = repo_id
540
+ self.keep_last = keep_last
541
+ self.api = HfApi()
542
+ self.token = token or os.environ.get("HUGGINGFACE_HUB_TOKEN")
543
+ self._half_step: Optional[int] = None
544
+ self._final_step: Optional[int] = None
545
+ self._saved_steps = set()
546
+ self._pending_push_for_step: Optional[int] = None
547
+ try:
548
+ self.api.create_repo(repo_id=self.repo_id, private=True, exist_ok=True, token=self.token)
549
+ except Exception as e:
550
+ print(f"Could not ensure repo exists: {e}")
551
+
552
+ def _enforce_keep_last(self) -> None:
553
+ try:
554
+ files = self.api.list_repo_files(repo_id=self.repo_id, repo_type="model", token=self.token)
555
+
556
+ def _step_key(dirname: str) -> int:
557
+ nums = re.findall(r"\d+", dirname)
558
+ return int(nums[-1]) if nums else -1
559
+
560
+ ckpt_dirs = sorted(
561
+ {p.split('/')[0] for p in files if p.startswith("checkpoint-")},
562
+ key=_step_key,
563
+ )
564
+ if len(ckpt_dirs) <= self.keep_last:
565
+ return
566
+ to_delete = ckpt_dirs[:-self.keep_last]
567
+ for d in to_delete:
568
+ for f in [p for p in files if p.startswith(f"{d}/")]:
569
+ try:
570
+ self.api.delete_file(path=f, repo_id=self.repo_id, repo_type="model", token=self.token)
571
+ except Exception as e:
572
+ print(f"Failed deleting {f}: {e}")
573
+ except Exception as e:
574
+ print(f"Keep-last enforcement failed: {e}")
575
+
576
+ def on_train_begin(self, args, state, control, **kwargs):
577
+ # Prefer Trainer-computed max_steps
578
+ if state.max_steps and state.max_steps > 0:
579
+ self._half_step = max(1, state.max_steps // 2)
580
+ self._final_step = state.max_steps
581
+ print(f"Two-point checkpointing: half={self._half_step}, final={self._final_step}")
582
+ else:
583
+ # Best-effort fallback using dataloader length and grad accumulation if available
584
+ td = kwargs.get("train_dataloader")
585
+ if td is not None and args.gradient_accumulation_steps > 0:
586
+ steps_per_epoch = math.ceil(len(td) / args.gradient_accumulation_steps)
587
+ self._final_step = steps_per_epoch * int(args.num_train_epochs)
588
+ self._half_step = max(1, self._final_step // 2)
589
+ print(f"Two-point checkpointing (approx): half={self._half_step}, final={self._final_step}")
590
+
591
+ def on_step_end(self, args, state, control, **kwargs):
592
+ if not self._final_step:
593
+ return control
594
+ gs = state.global_step
595
+ if gs == self._half_step and gs not in self._saved_steps:
596
+ control.should_save = True
597
+ self._pending_push_for_step = gs
598
+ if gs == self._final_step and gs not in self._saved_steps:
599
+ control.should_save = True
600
+ self._pending_push_for_step = gs
601
+ return control
602
+
603
+ def on_save(self, args, state, control, **kwargs):
604
+ # Push only when we triggered this save
605
+ if self._pending_push_for_step is None:
606
+ return control
607
+ step = self._pending_push_for_step
608
+ self._pending_push_for_step = None
609
+ self._saved_steps.add(step)
610
+
611
+ ckpt_dirname = f"checkpoint-{step}"
612
+ try:
613
+ upload_folder(
614
+ repo_id=self.repo_id,
615
+ folder_path=args.output_dir,
616
+ repo_type="model",
617
+ token=self.token,
618
+ commit_message=f"upload {ckpt_dirname}",
619
+ allow_patterns=[f"{ckpt_dirname}/**", "trainer_state.json"],
620
+ )
621
+ self._enforce_keep_last()
622
+ print(f"Pushed {ckpt_dirname} to {self.repo_id}")
623
+ except Exception as e:
624
+ print(f"Hub upload failed for {ckpt_dirname}: {e}")
625
+ return control
626
+
627
+
628
+ def train_stage(
629
+ stage_name: str,
630
+ model,
631
+ tokenizer,
632
+ train_dataset,
633
+ eval_dataset,
634
+ data_collator,
635
+ out_dir: str,
636
+ epochs: int,
637
+ hub_repo: Optional[str] = None,
638
+ ):
639
+ """
640
+ Train a single stage
641
+ """
642
+ print(f"\n{'='*60}")
643
+ print(f"Training {stage_name}")
644
+ print(f"{'='*60}")
645
+
646
+ # Auto-select Hub repo by stage if not provided
647
+ if hub_repo is None:
648
+ s = (stage_name or "").lower()
649
+ if s.startswith("stage1"):
650
+ hub_repo = HUB_REPO_S1
651
+ elif s.startswith("stage2"):
652
+ hub_repo = HUB_REPO_S2
653
+ elif s.startswith("stage3"):
654
+ hub_repo = HUB_REPO_S3
655
+
656
+ args = make_training_args(out_dir, epochs, two_point_hub=(hub_repo is not None))
657
+
658
+ trainer = Trainer(
659
+ model=model,
660
+ tokenizer=tokenizer,
661
+ train_dataset=train_dataset,
662
+ eval_dataset=eval_dataset,
663
+ args=args,
664
+ data_collator=data_collator,
665
+ )
666
+
667
+ # Train-loss early stop (match test_overfit behavior)
668
+ class TrainLossStopCallback(TrainerCallback):
669
+ def __init__(self, threshold: float = 1.0):
670
+ self.threshold = float(threshold)
671
+ self.triggered = False
672
+
673
+ def on_log(self, args, state, control, logs=None, **kwargs):
674
+ if logs is None:
675
+ return control
676
+ loss = logs.get("loss")
677
+ if loss is not None and loss < self.threshold and state.global_step > 0 and not self.triggered:
678
+ self.triggered = True
679
+ print(f"\nTrain-loss early stop: loss={loss:.4f} < {self.threshold}")
680
+ control.should_training_stop = True
681
+ return control
682
+
683
+ trainer.add_callback(TrainLossStopCallback(threshold=1.0))
684
+
685
+ # Add two-point Hub checkpoint uploader if configured
686
+ if hub_repo:
687
+ # Pass token if available to avoid auth prompts in Kaggle/Colab
688
+ token = HF_TOKEN if isinstance(HF_TOKEN, str) and len(HF_TOKEN) > 0 else None
689
+ trainer.add_callback(TwoPointHubCheckpointCallback(hub_repo, token=token))
690
+
691
+ # Train (with auto-resume from Hub if available)
692
+ resume_path = latest_hub_checkpoint(hub_repo) if hub_repo else None
693
+ if resume_path:
694
+ print(f"Resuming from Hub checkpoint: {resume_path}")
695
+ trainer.train(resume_from_checkpoint=resume_path)
696
+ else:
697
+ print(f"Starting training for {stage_name}...")
698
+ trainer.train()
699
+
700
+ # Evaluate
701
+ print(f"Evaluating {stage_name}...")
702
+ metrics = trainer.evaluate()
703
+
704
+ # Compute perplexity
705
+ eval_loss = metrics.get("eval_loss", float("nan"))
706
+ ppl = math.exp(eval_loss) if not math.isnan(eval_loss) else float("nan")
707
+
708
+ print(f"\n{stage_name} Results:")
709
+ print(f" eval_loss: {eval_loss:.4f}")
710
+ print(f" perplexity: {ppl:.3f}")
711
+
712
+ # Save model (optional - can be commented out to save space)
713
+ # trainer.save_model(out_dir)
714
+ # print(f"Model saved to {out_dir}")
715
+
716
+ return metrics
717
+
718
+
719
+ def save_model_to_hub(model, tokenizer, repo_id: str, stage_name: str):
720
+ """
721
+ Save model and tokenizer to HuggingFace Hub
722
+ """
723
+ print(f"\nSaving {stage_name} to HuggingFace Hub: {repo_id}")
724
+ model.push_to_hub(repo_id, commit_message=f"Upload {stage_name}")
725
+ tokenizer.push_to_hub(repo_id, commit_message=f"Upload {stage_name}")
726
+ print(f"Successfully saved {stage_name}")
727
+
728
+
729
+ def load_model_from_hub(repo_id: str):
730
+ """
731
+ Load model and tokenizer from HuggingFace Hub
732
+ """
733
+ from unsloth import FastLanguageModel
734
+ from config import MAX_SEQ_LEN, DTYPE
735
+
736
+ print(f"\nLoading model from HuggingFace Hub: {repo_id}")
737
+ model, tokenizer = FastLanguageModel.from_pretrained(
738
+ model_name=repo_id,
739
+ max_seq_length=MAX_SEQ_LEN,
740
+ dtype=DTYPE,
741
+ load_in_4bit=True,
742
+ )
743
+ print(f"Successfully loaded model from {repo_id}")
744
+ return model, tokenizer
train_mgpt_vqvae.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import zipfile
4
+ import torch
5
+ import torch.nn as nn
6
+ import pandas as pd
7
+ import numpy as np
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import glob
10
+ import warnings
11
+ import json
12
+ import time
13
+ from datetime import datetime
14
+ import random
15
+ import math
16
+ import matplotlib.pyplot as plt
17
+ import sys
18
+
19
+ # Add the mGPT directory to the path
20
+ sys.path.append('/kaggle/working')
21
+
22
+ from mGPT.archs.mgpt_vq import VQVae
23
+
24
+ warnings.filterwarnings("ignore")
25
+
26
+ # Configuration
27
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
28
+ DATA_ROOT = '/kaggle/working/extracted_files'
29
+ CHECKPOINT_DIR = '/kaggle/working/checkpoints_mgpt'
30
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
31
+ print("Device:", DEVICE)
32
+
33
+ # ──────────────────────────────────────────────────────────
34
+ # Enhanced Dataset with File Tracking and Batching (UNCHANGED)
35
+ # ──────────────────────────────────────────────────────────
36
+
37
+ def load_smplx_from_folder(folder_path):
38
+ all_frame_dicts = []
39
+ for pkl_file in sorted(glob.glob(os.path.join(folder_path, '*.pkl'))):
40
+ try:
41
+ with open(pkl_file, 'rb') as f:
42
+ data = pickle.load(f)
43
+ if isinstance(data, list):
44
+ all_frame_dicts.extend(data)
45
+ elif isinstance(data, dict):
46
+ all_frame_dicts.append(data)
47
+ except Exception:
48
+ continue
49
+ if not all_frame_dicts:
50
+ return None
51
+
52
+ param_keys = ['shape','body_pose','lhand_pose','rhand_pose','jaw_pose',
53
+ 'expression','root_pose','cam_trans']
54
+ param_dims = [10,63,45,45,3,10,3,3]
55
+ sequences = []
56
+ for frame in all_frame_dicts:
57
+ vec = []
58
+ for key, dim in zip(param_keys, param_dims):
59
+ arr = np.zeros(dim)
60
+ if key in frame and frame[key] is not None:
61
+ v = np.array(frame[key]).flatten()
62
+ arr[:min(len(v), dim)] = v[:dim]
63
+ vec.append(arr)
64
+ sequences.append(np.concatenate(vec))
65
+ return torch.tensor(np.stack(sequences), dtype=torch.float32)
66
+
67
+ class EnhancedMotionDataset(Dataset):
68
+ def __init__(self, root_dir, processed_files_path, batch_folders=1000):
69
+ self.root_dir = root_dir
70
+ self.processed_files_path = processed_files_path
71
+ self.batch_folders = batch_folders
72
+
73
+ print(f"\n[DEBUG] Initializing Dataset.")
74
+ print(f"[DEBUG] Root directory: '{self.root_dir}'")
75
+
76
+ if not os.path.exists(self.root_dir):
77
+ print(f"[DEBUG] ERROR: The root directory '{self.root_dir}' does not exist!")
78
+ self.all_folders = []
79
+ else:
80
+ print(f"[DEBUG] Root directory exists.")
81
+ glob_path = os.path.join(root_dir, '*')
82
+ print(f"[DEBUG] Using glob pattern: '{glob_path}'")
83
+ all_paths = glob.glob(glob_path)
84
+ print(f"[DEBUG] Glob found {len(all_paths)} total paths.")
85
+ self.all_folders = [d for d in all_paths if os.path.isdir(d)]
86
+ print(f"[DEBUG] Found {len(self.all_folders)} directories.")
87
+
88
+ self.processed = self._load_processed()
89
+ print(f"[DEBUG] Loaded {len(self.processed)} processed folder paths.")
90
+
91
+ self.unprocessed = [f for f in self.all_folders if f not in self.processed]
92
+ print(f"[DEBUG] Found {len(self.unprocessed)} unprocessed folders.")
93
+
94
+ self._prep_batch()
95
+
96
+ def _load_processed(self):
97
+ if os.path.exists(self.processed_files_path):
98
+ with open(self.processed_files_path, 'r') as f:
99
+ return json.load(f)
100
+ return []
101
+
102
+ def _save_processed(self):
103
+ with open(self.processed_files_path, 'w') as f:
104
+ json.dump(self.processed, f)
105
+
106
+ def _prep_batch(self):
107
+ self.current = self.unprocessed[:self.batch_folders]
108
+ self.samples = self.current.copy()
109
+ print(f"→ Loading {len(self.samples)} folders this batch")
110
+
111
+ def mark_batch_as_processed(self):
112
+ self.processed += self.current
113
+ self._save_processed()
114
+
115
+ def get_next_batch(self):
116
+ all_folders = [d for d in glob.glob(os.path.join(self.root_dir, '*')) if os.path.isdir(d)]
117
+ self.processed = self._load_processed()
118
+ self.unprocessed = [f for f in all_folders if f not in self.processed]
119
+
120
+ if not self.unprocessed:
121
+ print("✅ All data processed")
122
+ return False
123
+ self._prep_batch()
124
+ return True
125
+
126
+ def __len__(self):
127
+ return len(self.samples)
128
+
129
+ def __getitem__(self, idx):
130
+ seq = load_smplx_from_folder(self.samples[idx])
131
+ if seq is None or seq.shape[0] < 64:
132
+ return None
133
+ return seq
134
+
135
+ # ───────────────────────────────────────────��──────────────
136
+ # Checkpoint Management (UNCHANGED)
137
+ # ──────────────────────────────────────────────────────────
138
+
139
+ class CheckpointManager:
140
+ def __init__(self, checkpoint_dir, max_checkpoints=2):
141
+ self.checkpoint_dir = checkpoint_dir
142
+ self.max_checkpoints = max_checkpoints
143
+
144
+ def save_checkpoint(self, model, optimizer, epoch, batch_idx, loss, metadata=None):
145
+ checkpoint = {
146
+ 'epoch': epoch,
147
+ 'batch_idx': batch_idx,
148
+ 'model_state_dict': model.state_dict(),
149
+ 'optimizer_state_dict': optimizer.state_dict(),
150
+ 'loss': loss,
151
+ 'timestamp': datetime.now().isoformat(),
152
+ 'metadata': metadata or {}
153
+ }
154
+ checkpoint_path = os.path.join(
155
+ self.checkpoint_dir,
156
+ f'mgpt_vqvae_epoch_{epoch:03d}_batch_{batch_idx:04d}.pt'
157
+ )
158
+ torch.save(checkpoint, checkpoint_path)
159
+ print(f"Saved checkpoint: {checkpoint_path}")
160
+ self.cleanup_old_checkpoints()
161
+ return checkpoint_path
162
+
163
+ def cleanup_old_checkpoints(self):
164
+ checkpoints = glob.glob(os.path.join(self.checkpoint_dir, 'mgpt_vqvae_epoch_*.pt'))
165
+ checkpoints.sort(key=os.path.getmtime, reverse=True)
166
+ if len(checkpoints) > self.max_checkpoints:
167
+ for checkpoint in checkpoints[self.max_checkpoints:]:
168
+ os.remove(checkpoint)
169
+ print(f"Removed old checkpoint: {checkpoint}")
170
+
171
+ def load_latest_checkpoint(self):
172
+ checkpoints = glob.glob(os.path.join(self.checkpoint_dir, 'mgpt_vqvae_epoch_*.pt'))
173
+ if not checkpoints:
174
+ return None
175
+ latest_checkpoint = max(checkpoints, key=os.path.getmtime)
176
+ print(f"Loading checkpoint: {latest_checkpoint}")
177
+ return torch.load(latest_checkpoint, map_location=DEVICE)
178
+
179
+ def get_checkpoint_info(self):
180
+ checkpoints = glob.glob(os.path.join(self.checkpoint_dir, 'mgpt_vqvae_epoch_*.pt'))
181
+ return len(checkpoints), checkpoints
182
+
183
+ # ──────────────────────────────────────────────────────────
184
+ # Enhanced Training Function with MotionGPT VQ-VAE
185
+ # ──────────────────────────────────────────────────────────
186
+
187
+ def train_mgpt_vqvae(vq_model, dataset, epochs_per_batch=20, batch_size=16, lr=1e-4):
188
+ print("\n" + "="*70)
189
+ print(" STARTING MGPT VQ-VAE TRAINING WITH CHECKPOINTING ")
190
+ print("="*70)
191
+
192
+ optimizer = torch.optim.AdamW(vq_model.parameters(), lr=lr, weight_decay=1e-4)
193
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
194
+ loss_fn = nn.SmoothL1Loss(reduction='none')
195
+ checkpoint_manager = CheckpointManager(CHECKPOINT_DIR)
196
+
197
+ checkpoint = checkpoint_manager.load_latest_checkpoint()
198
+ global_epoch = 1
199
+ if checkpoint:
200
+ vq_model.load_state_dict(checkpoint['model_state_dict'])
201
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
202
+ global_epoch = checkpoint.get('metadata', {}).get('global_epoch', checkpoint['epoch'])
203
+ print(f"Resumed from GLOBAL epoch {global_epoch}")
204
+
205
+ vq_model.to(DEVICE).train()
206
+
207
+ # Define loss weights for SMPL parameters
208
+ param_dims = [10, 63, 45, 45, 3, 10, 3, 3]
209
+ param_starts = np.cumsum([0] + param_dims[:-1]).tolist()
210
+ smpl_dim = sum(param_dims)
211
+ loss_weights = torch.ones(smpl_dim, device=DEVICE)
212
+ loss_weights[param_starts[1]:param_starts[5]] = 10.0 # pose parameters
213
+ loss_weights[param_starts[0]:param_starts[1]] = 5.0 # shape parameters
214
+ loss_weights[param_starts[5]:param_starts[6]] = 8.0 # expression parameters
215
+
216
+ def log_codebook_analysis(x_recon, loss, perplexity, epoch, batch_idx):
217
+ # Extract encoded indices for analysis
218
+ with torch.no_grad():
219
+ x_in = vq_model.preprocess(x_recon[:1]) # Use reconstructed sample for analysis
220
+ x_encoder = vq_model.encoder(x_in)
221
+ x_flat = vq_model.quantizer.preprocess(x_encoder)
222
+ indices = vq_model.quantizer.quantize(x_flat)
223
+
224
+ unique_codes = torch.unique(indices)
225
+ usage_percentage = (len(unique_codes) / vq_model.quantizer.nb_code) * 100
226
+
227
+ print(f"[ANALYSIS] Epoch {epoch}, Batch {batch_idx}")
228
+ print(f"Unique codes used: {len(unique_codes)}/{vq_model.quantizer.nb_code} ({usage_percentage:.1f}%)")
229
+ print(f"Perplexity: {perplexity:.2f}")
230
+ return usage_percentage, indices
231
+
232
+ def save_reconstruction_sample(x, x_recon, lengths, epoch):
233
+ original_seq = x[0, :lengths[0]].cpu().numpy()
234
+ recon_seq = x_recon[0, :lengths[0]].cpu().numpy()
235
+ filename = os.path.join(CHECKPOINT_DIR, f'mgpt_recon_epoch_{epoch}.npz')
236
+ np.savez(filename, original=original_seq, reconstructed=recon_seq)
237
+ print(f"Saved reconstruction sample to {filename}")
238
+ mse = ((original_seq - recon_seq) ** 2).mean()
239
+ print(f"Reconstruction MSE: {mse:.6f}")
240
+ return mse
241
+
242
+ def collate_fn_enhanced(batch):
243
+ batch = [item for item in batch if item is not None]
244
+ if not batch:
245
+ return None
246
+ batch.sort(key=lambda x: x.shape[0], reverse=True)
247
+ max_len = batch[0].shape[0]
248
+ max_len = min(max_len, 256)
249
+ downsampling_factor = 8
250
+ padded_max_len = math.ceil(max_len / downsampling_factor) * downsampling_factor
251
+ padded_batch = torch.zeros(len(batch), padded_max_len, batch[0].shape[1])
252
+ lengths = []
253
+ for i, x in enumerate(batch):
254
+ length = min(x.shape[0], padded_max_len)
255
+ padded_batch[i, :length, :] = x[:length, :]
256
+ lengths.append(length)
257
+ return padded_batch, torch.tensor(lengths)
258
+
259
+ while True:
260
+ print(f"\n{'='*50}")
261
+ print(f"Processing file batch with {len(dataset)} files")
262
+ print(f"{'='*50}")
263
+
264
+ if len(dataset) == 0:
265
+ if not dataset.get_next_batch():
266
+ print("✅ All data processed! Training complete.")
267
+ break
268
+ continue
269
+
270
+ dataloader = DataLoader(
271
+ dataset, batch_size=batch_size, shuffle=True,
272
+ num_workers=0, collate_fn=collate_fn_enhanced, drop_last=True
273
+ )
274
+
275
+ for epoch in range(global_epoch, global_epoch + epochs_per_batch):
276
+ epoch_losses, epoch_vq_losses, epoch_rec_losses = [], [], []
277
+ codebook_usage_history = []
278
+ epoch_indices = []
279
+
280
+ for batch_idx, batch_data in enumerate(dataloader):
281
+ if batch_data is None:
282
+ continue
283
+
284
+ motion_batch, lengths = batch_data
285
+ x = motion_batch.to(DEVICE)
286
+
287
+ # Forward pass through MotionGPT VQ-VAE
288
+ x_recon, vq_loss, perplexity = vq_model(x)
289
+
290
+ if batch_idx % 50 == 0:
291
+ usage_pct, indices = log_codebook_analysis(x_recon, vq_loss, perplexity, epoch, batch_idx)
292
+ epoch_indices.append(indices.cpu().numpy().flatten())
293
+
294
+ # Calculate reconstruction loss with weighted parameters
295
+ rec_loss_unreduced = loss_fn(x_recon, x) * loss_weights.unsqueeze(0).unsqueeze(0)
296
+ mask = torch.zeros_like(x[:, :, 0])
297
+ for i, length in enumerate(lengths):
298
+ mask[i, :length] = 1.0
299
+ mask = mask.unsqueeze(-1).expand_as(rec_loss_unreduced)
300
+ rec_loss = (rec_loss_unreduced * mask).sum() / mask.sum()
301
+
302
+ vq_weight = 1.0
303
+ total_loss = rec_loss + vq_weight * vq_loss
304
+
305
+ optimizer.zero_grad()
306
+ total_loss.backward()
307
+ torch.nn.utils.clip_grad_norm_(vq_model.parameters(), max_norm=1.0)
308
+ optimizer.step()
309
+ scheduler.step()
310
+
311
+ epoch_losses.append(total_loss.item())
312
+ epoch_vq_losses.append(vq_loss.item())
313
+ epoch_rec_losses.append(rec_loss.item())
314
+
315
+ if batch_idx % 20 == 0:
316
+ current_lr = optimizer.param_groups[0]['lr']
317
+ print(f"[E:{epoch:03d}] B:{batch_idx:03d} | "
318
+ f"Loss: {total_loss.item():.4f} "
319
+ f"(Rec: {rec_loss.item():.4f}, VQ: {vq_loss.item():.4f}) | "
320
+ f"Perplexity: {perplexity:.2f} | "
321
+ f"LR: {current_lr:.2e}")
322
+
323
+ if epoch_losses:
324
+ avg_loss = np.mean(epoch_losses)
325
+ avg_vq_loss = np.mean(epoch_vq_losses)
326
+ avg_rec_loss = np.mean(epoch_rec_losses)
327
+
328
+ print(f"\n[EPOCH {epoch:03d} SUMMARY]")
329
+ print(f"Avg Loss: {avg_loss:.4f} (Rec: {avg_rec_loss:.4f}, VQ: {avg_vq_loss:.4f})")
330
+
331
+ # Create histogram if we collected indices
332
+ if epoch_indices:
333
+ all_epoch_indices = np.concatenate(epoch_indices)
334
+ plt.figure(figsize=(12, 6))
335
+ plt.hist(all_epoch_indices, bins=vq_model.quantizer.nb_code,
336
+ range=(0, vq_model.quantizer.nb_code-1))
337
+ plt.title(f'MotionGPT Codebook Usage Distribution - Epoch {epoch}')
338
+ plt.xlabel('Codebook Index')
339
+ plt.ylabel('Frequency')
340
+ hist_path = os.path.join(CHECKPOINT_DIR, f'mgpt_codebook_usage_epoch_{epoch:03d}.png')
341
+ plt.savefig(hist_path)
342
+ plt.close()
343
+ print(f"Saved codebook usage histogram to {hist_path}")
344
+
345
+ if epoch > 0 and epoch % 5 == 0:
346
+ vq_model.eval()
347
+ with torch.no_grad():
348
+ for val_data in dataloader:
349
+ if val_data is not None:
350
+ motion_batch, lengths = val_data
351
+ x = motion_batch.to(DEVICE)
352
+ x_recon, _, _ = vq_model(x)
353
+ save_reconstruction_sample(x, x_recon, lengths, epoch)
354
+ break
355
+ vq_model.train()
356
+
357
+ if epoch > 0 and epoch % 10 == 0:
358
+ checkpoint_manager.save_checkpoint(
359
+ vq_model, optimizer, epoch, -1, np.mean(epoch_losses),
360
+ metadata={'global_epoch': epoch}
361
+ )
362
+
363
+ global_epoch += epochs_per_batch
364
+
365
+ dataset.mark_batch_as_processed()
366
+
367
+ if not dataset.get_next_batch():
368
+ print("✅ All data processed! Training complete.")
369
+ break
370
+
371
+ return vq_model
372
+
373
+ # ──────────────────────────────────────────────────────────
374
+ # Main Training Script
375
+ # ──────────────────────────────────────────────────────────
376
+
377
+ def main():
378
+ print("Starting MotionGPT VQ-VAE Training System")
379
+ print(f"Checkpoint directory: {CHECKPOINT_DIR}")
380
+
381
+ smpl_dim = 182
382
+ codebook_size = 512
383
+ code_dim = 512
384
+
385
+ # Initialize MotionGPT VQ-VAE
386
+ vq_model = VQVae(
387
+ nfeats=smpl_dim,
388
+ quantizer="ema_reset", # Options: "ema_reset", "orig", "ema", "reset"
389
+ code_num=codebook_size,
390
+ code_dim=code_dim,
391
+ output_emb_width=code_dim,
392
+ down_t=3,
393
+ stride_t=2,
394
+ width=512,
395
+ depth=3,
396
+ dilation_growth_rate=3,
397
+ norm=None,
398
+ activation="relu"
399
+ ).to(DEVICE)
400
+
401
+ total_params = sum(p.numel() for p in vq_model.parameters())
402
+ trainable_params = sum(p.numel() for p in vq_model.parameters() if p.requires_grad)
403
+ print(f"Total parameters: {total_params:,}")
404
+ print(f"Trainable parameters: {trainable_params:,}")
405
+
406
+ motion_dataset = EnhancedMotionDataset(
407
+ root_dir=DATA_ROOT,
408
+ processed_files_path=os.path.join(CHECKPOINT_DIR, 'processed_folders_mgpt.json'),
409
+ batch_folders=800
410
+ )
411
+
412
+ vq_model = train_mgpt_vqvae(
413
+ vq_model,
414
+ motion_dataset,
415
+ epochs_per_batch=15,
416
+ batch_size=12,
417
+ lr=2e-4
418
+ )
419
+
420
+ print("\n" + "="*70)
421
+ print("MGPT VQ-VAE TRAINING COMPLETED SUCCESSFULLY!")
422
+ print("="*70)
423
+
424
+ final_model_path = os.path.join(CHECKPOINT_DIR, 'final_mgpt_vqvae_model.pt')
425
+ torch.save({
426
+ 'model_state_dict': vq_model.state_dict(),
427
+ 'model_config': {
428
+ 'nfeats': smpl_dim,
429
+ 'code_num': codebook_size,
430
+ 'code_dim': code_dim,
431
+ 'quantizer': "ema_reset"
432
+ },
433
+ 'training_completed': True
434
+ }, final_model_path)
435
+ print(f"Final model saved to: {final_model_path}")
436
+
437
+ if __name__ == "__main__":
438
+ main()
train_pipeline.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main training pipeline for Motion LLM (Matched to test_overfit.py logic)
3
+ Run this script to execute the full training process matching the reference implementation.
4
+ """
5
+ import os
6
+ import random
7
+ import torch
8
+ import json
9
+ import argparse
10
+ from types import SimpleNamespace
11
+ import warnings
12
+
13
+ # Import updated modules
14
+ from config import (
15
+ SEED, DATA_JSON_PATH, MODEL_NAME, PIPELINE_OUTPUT_DIR,
16
+ HF_STAGE1_REPO_ID, HF_STAGE2_REPO_ID, HF_STAGE2_SAVE_SUBDIR,
17
+ FORCE_STAGE2_FROM_STAGE1, HF_USE_HUB, HF_TOKEN,
18
+ EVALUATION_WORDS, EVAL_SAMPLE_LIMIT, RUN_EVALS_ONLY,
19
+ TEST_EVAL_OUTPUT_DIR, TEST_EVAL_DOWNLOAD_DIR, TEST_EVAL_EXTRACT_DIR,
20
+ TEST_EVAL_SAMPLE_LIMIT, TEST_EVAL_MAX_ZIPS, TEST_EVAL_HF_REPO, TEST_EVAL_HF_SUBFOLDER
21
+ )
22
+ from data import read_json_data, deduplicate_and_prepare_data, build_motion_vocab
23
+ from model import setup_model_and_tokenizer_raw, ensure_tokenizer_has_motion_tokens
24
+ from train import (
25
+ train_stage1_raw, train_stage2_raw, resolve_and_ensure_repo,
26
+ repo_has_stage_latest, load_model_and_tokenizer_from_hf,
27
+ download_training_state, repo_get_latest_epoch_subfolder,
28
+ load_model_and_tokenizer_from_hf_subfolder, download_training_state_from_subfolder
29
+ )
30
+ from metrics import (
31
+ evaluate_metrics_encoder_style, run_inference_on_all_samples,
32
+ evaluate_metrics_motiongpt_style, save_side_by_side_visualizations
33
+ )
34
+ import test_dataset_eval
35
+
36
+ # Suppress warnings
37
+ warnings.filterwarnings("ignore")
38
+
39
+ def main():
40
+ """Main function to run the entire pipeline matching test_overfit.py."""
41
+ print("="*80)
42
+ print(" Motion LLM Training Pipeline (Matches test_overfit.py)")
43
+ print("="*80)
44
+
45
+ # Set seeds
46
+ random.seed(SEED)
47
+ torch.manual_seed(SEED)
48
+ if torch.cuda.is_available():
49
+ torch.cuda.manual_seed_all(SEED)
50
+
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ print(f"Using device: {device}")
53
+
54
+ # 1. Load ALL data
55
+ print(f"\n[1/6] Loading dataset from {DATA_JSON_PATH}...")
56
+ all_entries = read_json_data(DATA_JSON_PATH)
57
+
58
+ # 2. Clean the ENTIRE dataset and get all tokens
59
+ print("\n[2/6] Cleaning dataset...")
60
+ cleaned_data, all_motion_tokens = deduplicate_and_prepare_data(all_entries)
61
+
62
+ # 3. Stage 1: Initialize or resume from HF, then train
63
+ print("\n[3/6] Stage 1 Setup & Training...")
64
+ resolved_stage1_repo = resolve_and_ensure_repo(HF_STAGE1_REPO_ID, HF_TOKEN) if HF_USE_HUB else None
65
+ start_epoch_s1 = 0
66
+ stage1_loaded = None
67
+ if resolved_stage1_repo:
68
+ if repo_has_stage_latest(resolved_stage1_repo, "stage1", HF_TOKEN):
69
+ stage1_loaded = load_model_and_tokenizer_from_hf(resolved_stage1_repo, "stage1", HF_TOKEN)
70
+ state_s1 = download_training_state(resolved_stage1_repo, "stage1", HF_TOKEN)
71
+ if state_s1 and isinstance(state_s1.get("epoch_completed"), int):
72
+ start_epoch_s1 = state_s1["epoch_completed"]
73
+ else:
74
+ # Fallback: no 'latest' folder; select highest epoch-XXX
75
+ latest_s1_sub = repo_get_latest_epoch_subfolder(resolved_stage1_repo, "stage1", HF_TOKEN)
76
+ if latest_s1_sub:
77
+ stage1_loaded = load_model_and_tokenizer_from_hf_subfolder(resolved_stage1_repo, latest_s1_sub, HF_TOKEN)
78
+ state_s1 = download_training_state_from_subfolder(resolved_stage1_repo, latest_s1_sub, HF_TOKEN)
79
+ if state_s1 and isinstance(state_s1.get("epoch_completed"), int):
80
+ start_epoch_s1 = state_s1["epoch_completed"]
81
+
82
+ if stage1_loaded:
83
+ base_model, tokenizer = stage1_loaded
84
+ # Ensure tokenizer contains all motion tokens (add missing if dataset expanded)
85
+ added = ensure_tokenizer_has_motion_tokens(tokenizer, all_motion_tokens)
86
+ if added > 0:
87
+ base_model.resize_token_embeddings(len(tokenizer))
88
+ else:
89
+ base_model, tokenizer = setup_model_and_tokenizer_raw(MODEL_NAME, all_motion_tokens)
90
+
91
+ print(f"\nStarting Stage 1 training on {len(cleaned_data)} samples (resume from epoch {start_epoch_s1}).")
92
+ motion_model = train_stage1_raw(
93
+ base_model,
94
+ tokenizer,
95
+ cleaned_data,
96
+ device,
97
+ start_epoch=start_epoch_s1,
98
+ hf_repo_id=resolved_stage1_repo,
99
+ )
100
+
101
+ # 4. Stage 2: Initialize or resume from HF, then train
102
+ print("\n[4/6] Stage 2 Setup & Training...")
103
+ resolved_stage2_repo = resolve_and_ensure_repo(HF_STAGE2_REPO_ID, HF_TOKEN) if HF_USE_HUB else None
104
+ start_epoch_s2 = 0
105
+ stage2_loaded = None
106
+ print(f"Stage 2 resume policy: FORCE_STAGE2_FROM_STAGE1={FORCE_STAGE2_FROM_STAGE1}, save_subdir='{HF_STAGE2_SAVE_SUBDIR}'")
107
+
108
+ if not FORCE_STAGE2_FROM_STAGE1 and resolved_stage2_repo:
109
+ # Prefer loading from the configured Stage 2 save subdir (e.g., 'stage2_v2')
110
+ if repo_has_stage_latest(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR, HF_TOKEN):
111
+ stage2_loaded = load_model_and_tokenizer_from_hf(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR, HF_TOKEN)
112
+ state_s2 = download_training_state(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR, HF_TOKEN)
113
+ if state_s2 and isinstance(state_s2.get("epoch_completed"), int):
114
+ start_epoch_s2 = state_s2["epoch_completed"]
115
+ print(f"Resuming Stage 2 from HF subfolder: {HF_STAGE2_SAVE_SUBDIR}/latest (epoch_completed={start_epoch_s2})")
116
+ else:
117
+ latest_s2_sub = repo_get_latest_epoch_subfolder(resolved_stage2_repo, HF_STAGE2_SAVE_SUBDIR, HF_TOKEN)
118
+ if latest_s2_sub:
119
+ stage2_loaded = load_model_and_tokenizer_from_hf_subfolder(resolved_stage2_repo, latest_s2_sub, HF_TOKEN)
120
+ state_s2 = download_training_state_from_subfolder(resolved_stage2_repo, latest_s2_sub, HF_TOKEN)
121
+ if state_s2 and isinstance(state_s2.get("epoch_completed"), int):
122
+ start_epoch_s2 = state_s2["epoch_completed"]
123
+ print(f"Resuming Stage 2 from HF subfolder: {latest_s2_sub} (epoch_completed={start_epoch_s2})")
124
+
125
+ if stage2_loaded:
126
+ stage2_model, tokenizer = stage2_loaded
127
+ added2 = ensure_tokenizer_has_motion_tokens(tokenizer, all_motion_tokens)
128
+ if added2 > 0:
129
+ stage2_model.resize_token_embeddings(len(tokenizer))
130
+ else:
131
+ stage2_model = motion_model # Start Stage 2 from Stage 1 model
132
+
133
+ print(f"\nStarting Stage 2 training on {len(cleaned_data)} samples (resume from epoch {start_epoch_s2}).")
134
+ final_model = train_stage2_raw(
135
+ stage2_model,
136
+ tokenizer,
137
+ cleaned_data,
138
+ device,
139
+ start_epoch=start_epoch_s2,
140
+ hf_repo_id=resolved_stage2_repo,
141
+ hf_stage_subdir=HF_STAGE2_SAVE_SUBDIR,
142
+ )
143
+
144
+ # Save final model locally
145
+ if not os.path.exists(PIPELINE_OUTPUT_DIR):
146
+ os.makedirs(PIPELINE_OUTPUT_DIR)
147
+ final_model.save_pretrained(PIPELINE_OUTPUT_DIR)
148
+ tokenizer.save_pretrained(PIPELINE_OUTPUT_DIR)
149
+ print(f"Model saved to {PIPELINE_OUTPUT_DIR}")
150
+
151
+ # 5. Evaluation on Specific Words
152
+ print("\n[5/6] Evaluation on Specific Words...")
153
+ print("--- Filtering data for evaluation on specific words ---")
154
+ evaluation_data = [item for item in cleaned_data if item['word'].lower() in EVALUATION_WORDS]
155
+ print(f"Found {len(evaluation_data)} samples for evaluation words: {EVALUATION_WORDS}")
156
+
157
+ metrics_json_path = os.path.join(PIPELINE_OUTPUT_DIR, "metrics.json")
158
+
159
+ # 6. Metrics-only mode or full flow
160
+ if RUN_EVALS_ONLY:
161
+ # Compute the 3 metrics using VQ-VAE encoder features and save to JSON
162
+ metrics_enc = evaluate_metrics_encoder_style(
163
+ final_model, tokenizer, evaluation_data, device, sample_limit=EVAL_SAMPLE_LIMIT
164
+ )
165
+ os.makedirs(PIPELINE_OUTPUT_DIR, exist_ok=True)
166
+ metrics_payload = {
167
+ "source": "vqvae_encoder",
168
+ "fid": metrics_enc.get("fid"),
169
+ "diversity": {
170
+ "ground_truth": metrics_enc.get("diversity_gt"),
171
+ "model": metrics_enc.get("diversity_gen"),
172
+ },
173
+ "multimodality": {
174
+ "ground_truth": metrics_enc.get("mim_gt"),
175
+ "model": metrics_enc.get("mim_gen"),
176
+ },
177
+ "num_pairs": len(metrics_enc.get("pairs", [])),
178
+ }
179
+ with open(metrics_json_path, "w", encoding="utf-8") as f:
180
+ json.dump(metrics_payload, f, ensure_ascii=False, indent=2)
181
+ print(f"\n✅ Saved metrics to {metrics_json_path}")
182
+ return
183
+
184
+ # Full flow: inference logs + MotionGPT-style metrics + encoder metrics + visualizations
185
+ run_inference_on_all_samples(final_model, tokenizer, evaluation_data, device)
186
+ metrics_token = evaluate_metrics_motiongpt_style(final_model, tokenizer, evaluation_data, all_motion_tokens, device)
187
+ # Also compute encoder-based 3 metrics
188
+ metrics_enc = evaluate_metrics_encoder_style(
189
+ final_model, tokenizer, evaluation_data, device, sample_limit=EVAL_SAMPLE_LIMIT
190
+ )
191
+ # Visualizations (skip if metrics-only)
192
+ viz_dir = os.path.join(PIPELINE_OUTPUT_DIR, "html_visualizations")
193
+ save_side_by_side_visualizations(metrics_token["pairs"], viz_dir, limit=4)
194
+
195
+ # 7. Run Test Dataset Evaluation (test_dataset_eval.py)
196
+ print("\n[6/6] Running Evaluation on Held-out Test Dataset...")
197
+ try:
198
+ # Construct args matching test_dataset_eval.parse_args
199
+ eval_args = SimpleNamespace(
200
+ drive_url=None,
201
+ drive_id=None,
202
+ local_extracted_dir=None, # Will assume user needs to configure this or it uses defaults if not provided
203
+ # Note: test_dataset_eval requires one of drive/local. We can try to rely on defaults or skip if not configured.
204
+ # We will set download_dir and extract_dir from config.
205
+ download_dir=TEST_EVAL_DOWNLOAD_DIR,
206
+ extract_dir=TEST_EVAL_EXTRACT_DIR,
207
+ max_zips=TEST_EVAL_MAX_ZIPS,
208
+ hf_repo_id=TEST_EVAL_HF_REPO,
209
+ hf_subfolder=TEST_EVAL_HF_SUBFOLDER,
210
+ vqvae_ckpt=None,
211
+ stats_path=None,
212
+ output_dir=TEST_EVAL_OUTPUT_DIR,
213
+ sample_limit=TEST_EVAL_SAMPLE_LIMIT,
214
+ seed=SEED
215
+ )
216
+
217
+ # For this pipeline, we might want to pass the *currently loaded* model instead of reloading from HF?
218
+ # test_dataset_eval.run_evaluation loads from HF.
219
+ # The prompt asked to "incorporate... code of test_dataset_eval.py".
220
+ # Ideally we pass the model object, but run_evaluation is written to load from HF.
221
+ # Given we just saved and pushed (if enabled), loading from HF is fine.
222
+ # If we haven't pushed (HF_USE_HUB=False), run_evaluation might fail if it tries to load from HF.
223
+ # However, the prompt implies using test_overfit.py training setup which pushes to HF.
224
+
225
+ # Critical fix: If we want to use the *local* model we just trained, we should modify test_dataset_eval or pass it.
226
+ # But test_dataset_eval.run_evaluation doesn't accept model arg.
227
+ # For now, we'll attempt to run it as designed (loading from HF).
228
+ # If HF_USE_HUB is False, this step might fail.
229
+
230
+ # Let's check if we can use local_extracted_dir if it exists, otherwise drive download.
231
+ # We will use a try-except block.
232
+
233
+ print("Calling test_dataset_eval.run_evaluation...")
234
+ # We need to provide either drive-url/id or local-extracted.
235
+ # We'll try to use the extracted dir if it has content, otherwise default to download if URL known?
236
+ # Actually, since we don't have a drive URL in config (it was an arg), we might skip this if not set up?
237
+ # But the user said "include the code".
238
+
239
+ # We'll default to using the extract dir if it exists, otherwise we might need to ask or skip.
240
+ # Let's assume the user has data or we use the default drive-id if known (it wasn't in the provided file).
241
+ # Wait, test_dataset_eval.py has mutually exclusive required group.
242
+ # I'll add a fallback: if TEST_EVAL_EXTRACT_DIR exists and has files, use it.
243
+
244
+ if os.path.exists(TEST_EVAL_EXTRACT_DIR) and os.listdir(TEST_EVAL_EXTRACT_DIR):
245
+ eval_args.local_extracted_dir = TEST_EVAL_EXTRACT_DIR
246
+ else:
247
+ # We don't have a drive URL hardcoded.
248
+ # We will mock the arg to fail gracefully or print a message.
249
+ print("⚠️ Skipping test_dataset_eval: No local data found and no Drive URL configured.")
250
+ eval_args = None
251
+
252
+ if eval_args:
253
+ test_dataset_eval.run_evaluation(eval_args)
254
+
255
+ except Exception as e:
256
+ print(f"⚠️ Test dataset evaluation failed: {e}")
257
+
258
+ print("\n" + "="*60)
259
+ print("Training pipeline complete!")
260
+ print("="*60)
261
+ print(f"Models saved to: {PIPELINE_OUTPUT_DIR}")
262
+
263
+ if __name__ == "__main__":
264
+ main()
train_vqvae.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import glob
8
+ import warnings
9
+ import json
10
+ from datetime import datetime
11
+ import math
12
+ import matplotlib.pyplot as plt
13
+ import torch.nn.functional as F
14
+ import sys
15
+ from tqdm import tqdm
16
+
17
+ # ==============================================================================
18
+ # 0) SETUP: Architecture files
19
+ # ==============================================================================
20
+ # Make sure your mGPT folder is in the Python path
21
+ # sys.path.append('/path/to/your/mGPT_folder')
22
+ from mGPT.archs.mgpt_vq import VQVae
23
+ from mGPT.archs.tools import quantize_cnn
24
+
25
+ warnings.filterwarnings("ignore")
26
+
27
+ # ==============================================================================
28
+ # 1) CONFIGURATION
29
+ # ==============================================================================
30
+ SANITY_CHECK_ENABLED = True
31
+ sanity_check_counter = 0
32
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
33
+ print("Device:", DEVICE)
34
+ print(f"Sanity checks are {'ENABLED' if SANITY_CHECK_ENABLED else 'DISABLED'}.")
35
+
36
+ # ==============================================================================
37
+ # 2) VQ-VAE MODEL (Your instrumented classes are fine)
38
+ # ==============================================================================
39
+ class QuantizeEMAReset_Sanity(quantize_cnn.QuantizeEMAReset):
40
+ def forward(self, x, current_batch_idx=0):
41
+ global sanity_check_counter
42
+ N, width, T = x.shape
43
+ x_proc = self.preprocess(x)
44
+ if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0:
45
+ print("[Quantizer.forward] Input shape `x`: ", x.shape)
46
+ print("[Quantizer.forward] Shape after preprocess `x_proc`: ", x_proc.shape)
47
+ print(f"[Quantizer.forward] Codebook shape: {self.codebook.shape}")
48
+ if self.training and not self.init: print("[Quantizer.forward] Codebook is UNINITIALIZED.")
49
+ else: print(f"[Quantizer.forward] Codebook stats: min={self.codebook.min():.3f}, max={self.codebook.max():.3f}, mean={self.codebook.mean():.3f}")
50
+ if self.training and not self.init: self.init_codebook(x_proc)
51
+ code_idx = self.quantize(x_proc)
52
+ x_d = self.dequantize(code_idx)
53
+ if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0:
54
+ print(f"[Quantizer.forward] Code index range: min={code_idx.min()}, max={code_idx.max()}")
55
+ assert code_idx.max() < self.nb_code, "A code index is out of bounds!"
56
+ if self.training: perplexity = self.update_codebook(x_proc, code_idx)
57
+ else: perplexity = self.compute_perplexity(code_idx)
58
+ commit_loss = F.mse_loss(x_proc, x_d.detach())
59
+ x_d = x_proc + (x_d - x_proc).detach()
60
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous()
61
+ return x_d, commit_loss, perplexity
62
+
63
+ class VQVae_Sanity(VQVae):
64
+ def __init__(self, *args, **kwargs):
65
+ super().__init__(*args, **kwargs)
66
+ if isinstance(self.quantizer, quantize_cnn.QuantizeEMAReset):
67
+ self.quantizer = QuantizeEMAReset_Sanity(
68
+ self.quantizer.nb_code, self.quantizer.code_dim, self.quantizer.mu
69
+ )
70
+ def forward(self, features, current_batch_idx=0):
71
+ global sanity_check_counter
72
+ x_in = self.preprocess(features)
73
+ if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0: print("[VQVae.forward] Shape after preprocess (permute): ", x_in.shape)
74
+ x_encoder = self.encoder(x_in)
75
+ if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0:
76
+ print("[VQVae.forward] Shape after encoder `x_encoder`: ", x_encoder.shape)
77
+ total_downsample_factor = 2**3
78
+ expected_len = math.ceil(features.shape[1] / total_downsample_factor)
79
+ print(f"[VQVae.forward] Calculated expected quantized length: ~{expected_len}")
80
+ assert abs(x_encoder.shape[2] - expected_len) <= 1, "Temporal downsampling seems incorrect."
81
+ x_quantized, loss, perplexity = self.quantizer(x_encoder, current_batch_idx)
82
+ if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0: print("[VQVae.forward] Shape after quantizer `x_quantized`: ", x_quantized.shape)
83
+ x_decoder = self.decoder(x_quantized)
84
+ if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0:
85
+ print("[VQVae.forward] Shape after decoder `x_decoder`: ", x_decoder.shape)
86
+ assert x_decoder.shape[2] == features.shape[1], "Decoder output temporal dim mismatch!"
87
+ x_out = self.postprocess(x_decoder)
88
+ return x_out, loss, perplexity
89
+
90
+ # Monkey-patching
91
+ sys.modules['mGPT.archs.mgpt_vq'].VQVae = VQVae_Sanity
92
+ sys.modules['mGPT.archs.mgpt_vq'].QuantizeEMAReset = QuantizeEMAReset_Sanity
93
+
94
+ class MotionGPT_VQVAE_Wrapper(nn.Module):
95
+ def __init__(self, smpl_dim, codebook_size=512, code_dim=512, **kwargs):
96
+ super().__init__()
97
+ self.smpl_dim = smpl_dim
98
+ self.vqvae = VQVae(
99
+ nfeats=smpl_dim, code_num=codebook_size, code_dim=code_dim,
100
+ output_emb_width=code_dim, **kwargs
101
+ )
102
+ param_dims = [10, 63, 45, 45, 3, 10, 3, 3]
103
+ param_starts = np.cumsum([0] + param_dims[:-1]).tolist()
104
+ loss_weights = torch.ones(smpl_dim)
105
+ loss_weights[param_starts[1]:param_starts[5]] = 10.0
106
+ loss_weights[param_starts[0]:param_starts[1]] = 5.0
107
+ loss_weights[param_starts[5]:param_starts[6]] = 8.0
108
+ self.register_buffer('loss_weights', loss_weights)
109
+ print(f"Initialized MotionGPT VQ-VAE with {codebook_size} codebook size")
110
+ def forward(self, x, current_batch_idx=0):
111
+ global sanity_check_counter
112
+ if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0:
113
+ print("\n" + "="*50)
114
+ print("--- VQ-VAE WRAPPER SANITY CHECK (Batch 0) ---")
115
+ print(f"[Input] Shape of input features `x`: {x.shape}")
116
+ print("-"*50)
117
+ x_recon, vq_loss, perplexity = self.vqvae(x, current_batch_idx)
118
+ if SANITY_CHECK_ENABLED and current_batch_idx == 0 and sanity_check_counter == 0:
119
+ print("[Output] Shape of reconstructed features `x_recon`: ", x_recon.shape)
120
+ assert x.shape == x_recon.shape, "Shape mismatch!"
121
+ print(f"[Output] vq_loss: {vq_loss.item():.6f}, perplexity: {perplexity.item():.2f}")
122
+ print("--- VQ-VAE WRAPPER SANITY CHECK COMPLETE ---")
123
+ print("="*50 + "\n")
124
+ indices, _ = self.vqvae.encode(x)
125
+ return x_recon, vq_loss, indices, perplexity
126
+
127
+ # ==============================================================================
128
+ # 3) DATA LOADING
129
+ # ==============================================================================
130
+ def load_motion_from_npz(file_path):
131
+ try:
132
+ with np.load(file_path) as data:
133
+ motion_data = data['motion']
134
+ return torch.tensor(motion_data, dtype=torch.float32)
135
+ except Exception as e:
136
+ print(f"Warning: Could not load {os.path.basename(file_path)}. Skipping. Error: {e}")
137
+ return None
138
+
139
+ class NpzMotionDataset(Dataset):
140
+ def __init__(self, root_dir, stats_path=None, min_seq_len=64):
141
+ self.min_seq_len = min_seq_len
142
+ print(f"\n[Dataset] Initializing from NPZ files in: '{root_dir}'")
143
+ glob_pattern = os.path.join(root_dir, '**', '*.npz')
144
+ self.files = glob.glob(glob_pattern, recursive=True)
145
+ if not self.files:
146
+ raise FileNotFoundError(f"FATAL: No .npz files found at '{glob_pattern}'.")
147
+ print(f"[Dataset] Found {len(self.files)} total .npz files.")
148
+
149
+ if stats_path and os.path.exists(stats_path):
150
+ stats = torch.load(stats_path, map_location='cpu')
151
+ self.mean = stats['mean']
152
+ self.std = stats['std']
153
+ print("[Dataset] Successfully loaded normalization stats to CPU.")
154
+ else:
155
+ print("❗ [Dataset] WARNING: Stats file not found. Proceeding without normalization. This will affect loss values and model performance.")
156
+ self.mean = 0
157
+ self.std = 1
158
+
159
+ def __len__(self):
160
+ return len(self.files)
161
+
162
+ def __getitem__(self, idx):
163
+ file_path = self.files[idx]
164
+ seq = load_motion_from_npz(file_path)
165
+ if seq is None or seq.shape[0] < self.min_seq_len:
166
+ return None
167
+ normalized_seq = (seq - self.mean) / self.std
168
+ return normalized_seq
169
+
170
+ # ==============================================================================
171
+ # 4) CHECKPOINT & CODEBOOK INITIALIZATION
172
+ # ==============================================================================
173
+ class CheckpointManager:
174
+ # (Your CheckpointManager code is fine, no changes needed here)
175
+ def __init__(self, checkpoint_dir, max_checkpoints=3):
176
+ self.checkpoint_dir = checkpoint_dir
177
+ self.max_checkpoints = max_checkpoints
178
+ def save_checkpoint(self, model, optimizer, epoch, loss, metadata=None):
179
+ checkpoint_path = os.path.join(self.checkpoint_dir, f'vqvae_epoch_{epoch:03d}.pt')
180
+ torch.save({
181
+ 'epoch': epoch,
182
+ 'model_state_dict': model.state_dict(),
183
+ 'optimizer_state_dict': optimizer.state_dict(),
184
+ 'loss': loss,
185
+ 'timestamp': datetime.now().isoformat(),
186
+ 'metadata': metadata or {}
187
+ }, checkpoint_path)
188
+ print(f"✅ Saved checkpoint: {checkpoint_path}")
189
+ self.cleanup_old_checkpoints()
190
+ def cleanup_old_checkpoints(self):
191
+ checkpoints = glob.glob(os.path.join(self.checkpoint_dir, 'vqvae_epoch_*.pt'))
192
+ if len(checkpoints) > self.max_checkpoints:
193
+ checkpoints.sort(key=os.path.getmtime)
194
+ for old_checkpoint in checkpoints[:-self.max_checkpoints]:
195
+ os.remove(old_checkpoint)
196
+ print(f"🗑️ Removed old checkpoint: {old_checkpoint}")
197
+ def load_latest_checkpoint(self):
198
+ checkpoints = glob.glob(os.path.join(self.checkpoint_dir, 'vqvae_epoch_*.pt'))
199
+ if not checkpoints: return None
200
+ latest_checkpoint_path = max(checkpoints, key=os.path.getmtime)
201
+ print(f"🔄 Loading latest checkpoint: {latest_checkpoint_path}")
202
+ return torch.load(latest_checkpoint_path, map_location=DEVICE, weights_only=False)
203
+
204
+ def initialize_codebook_from_dataset(model, dataloader, num_batches=100):
205
+ print(f"⚙️ Collecting data from {num_batches} batches for codebook initialization...")
206
+ all_latents = []
207
+ model.eval()
208
+ with torch.no_grad():
209
+ for i, batch_data in enumerate(dataloader):
210
+ if i >= num_batches: break
211
+ if batch_data and batch_data[0] is not None:
212
+ motion_batch, _ = batch_data
213
+ x = motion_batch.to(DEVICE)
214
+ z_e = model.vqvae.encoder(model.vqvae.preprocess(x))
215
+ z_e_flat = z_e.permute(0, 2, 1).reshape(-1, z_e.shape[1])
216
+ all_latents.append(z_e_flat.cpu())
217
+ if not all_latents: raise ValueError("Could not collect any latents for initialization.")
218
+ all_latents = torch.cat(all_latents, dim=0)
219
+ print(f"Collected {all_latents.shape[0]} latent vectors.")
220
+ codebook_size = model.vqvae.quantizer.nb_code
221
+ indices = torch.randperm(all_latents.shape[0])[:codebook_size]
222
+ initial_codebook = all_latents[indices].to(DEVICE)
223
+ model.vqvae.quantizer.init_codebook(initial_codebook)
224
+ print("✅ Codebook initialized successfully from a diverse data sample.")
225
+ model.train()
226
+
227
+ # ==============================================================================
228
+ # 5) CORRECTED & COMPLETE TRAINING FUNCTION (No Globals)
229
+ # ==============================================================================
230
+ def train_vqvae_colab(vq_model, dataset, checkpoint_dir, num_epochs=300, batch_size=32, lr=2e-4):
231
+ """
232
+ The complete, updated training function for Colab using .npz files.
233
+ This version avoids global variables by accepting checkpoint_dir as an argument.
234
+ """
235
+ global sanity_check_counter
236
+ print("\n" + "="*70 + "\n STARTING VQ-VAE TRAINING ON COLAB \n" + "="*70)
237
+
238
+ optimizer = torch.optim.AdamW(vq_model.parameters(), lr=lr, weight_decay=1e-4)
239
+ # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=15, T_mult=2)
240
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
241
+ loss_fn = nn.SmoothL1Loss(reduction='none')
242
+ # Use the passed-in checkpoint_dir
243
+ checkpoint_manager = CheckpointManager(checkpoint_dir)
244
+
245
+ start_epoch = 1
246
+ checkpoint = checkpoint_manager.load_latest_checkpoint()
247
+ if checkpoint:
248
+ vq_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
249
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
250
+ start_epoch = checkpoint.get('epoch', 1) + 1
251
+ print(f"✅ Resumed training from epoch {start_epoch}")
252
+ else: print("No CheckPoint Found")
253
+ vq_model.to(DEVICE).train()
254
+ codebook_size = vq_model.vqvae.quantizer.nb_code
255
+
256
+ def collate_fn_enhanced(batch):
257
+ batch = [item for item in batch if item is not None]
258
+ if not batch: return None, None
259
+ batch.sort(key=lambda x: x.shape[0], reverse=True)
260
+ max_len = min(batch[0].shape[0], 256)
261
+ padded_max_len = math.ceil(max_len / 8) * 8
262
+ padded_batch = torch.zeros(len(batch), padded_max_len, batch[0].shape[1])
263
+ lengths = [min(x.shape[0], padded_max_len) for x in batch]
264
+ for i, x_item in enumerate(batch):
265
+ padded_batch[i, :lengths[i], :] = x_item[:lengths[i], :]
266
+ return padded_batch, torch.tensor(lengths)
267
+
268
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2,
269
+ collate_fn=collate_fn_enhanced, drop_last=True, pin_memory=True)
270
+
271
+ if start_epoch == 1 and not getattr(vq_model.vqvae.quantizer, 'init', False):
272
+ initialize_codebook_from_dataset(vq_model, dataloader, num_batches=100)
273
+
274
+ for epoch in range(start_epoch, num_epochs + 1):
275
+ print(f"\n{'='*30} EPOCH {epoch}/{num_epochs} {'='*30}")
276
+ epoch_losses, epoch_vq_losses, epoch_rec_losses, epoch_perplexity = [], [], [], []
277
+ epoch_indices = []
278
+
279
+ for batch_idx, batch_data in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}")):
280
+ if not batch_data or batch_data[0] is None: continue
281
+
282
+ motion_batch, lengths = batch_data
283
+ x = motion_batch.to(DEVICE)
284
+ x_recon, vq_loss, indices, perplexity = vq_model(x, batch_idx)
285
+
286
+ rec_loss_unreduced = loss_fn(x_recon, x) * vq_model.loss_weights
287
+ mask = torch.zeros_like(x[:, :, 0], device=DEVICE)
288
+ for i, length in enumerate(lengths): mask[i, :length] = 1.0
289
+ mask = mask.unsqueeze(-1).expand_as(rec_loss_unreduced)
290
+ rec_loss = (rec_loss_unreduced * mask).sum() / mask.sum()
291
+
292
+ # vq_weight = max(150.0 * (0.97 ** max(0, epoch - 3)), 1.0)
293
+ beta = 0.25 # This is a standard and effective value.
294
+ total_loss = rec_loss + (beta * vq_loss)
295
+ # total_loss = rec_loss + (vq_weight * vq_loss)
296
+
297
+ optimizer.zero_grad(set_to_none=True)
298
+ total_loss.backward()
299
+ torch.nn.utils.clip_grad_norm_(vq_model.parameters(), max_norm=1.0)
300
+ optimizer.step()
301
+ scheduler.step()
302
+
303
+ epoch_losses.append(total_loss.item())
304
+ epoch_vq_losses.append(vq_loss.item())
305
+ epoch_rec_losses.append(rec_loss.item())
306
+ epoch_perplexity.append(perplexity.item())
307
+ epoch_indices.append(indices.cpu().numpy().flatten())
308
+
309
+ if batch_idx % 50 == 0 and batch_idx > 0:
310
+ print(f"\n[E:{epoch:03d}] B:{batch_idx:03d} | Loss: {total_loss.item():.4f} (Rec: {rec_loss.item():.4f}, VQ: {vq_loss.item():.6f}) | Perplexity: {perplexity.item():.2f}")
311
+
312
+ if SANITY_CHECK_ENABLED and batch_idx == 0 and sanity_check_counter == 0:
313
+ sanity_check_counter += 1
314
+
315
+ if not epoch_losses: continue
316
+
317
+ all_epoch_indices_flat = np.concatenate(epoch_indices)
318
+ counts = np.bincount(all_epoch_indices_flat, minlength=codebook_size)
319
+ avg_usage = (counts > 0).sum()
320
+ with torch.no_grad(): code_variance = vq_model.vqvae.quantizer.codebook.var(dim=0).mean().item()
321
+
322
+ print(f"\n[EPOCH {epoch:03d} SUMMARY]")
323
+ print(f" Avg Loss: {np.mean(epoch_losses):.4f} (Rec: {np.mean(epoch_rec_losses):.4f}, VQ: {np.mean(epoch_vq_losses):.6f})")
324
+ print(f" Avg Perplexity: {np.mean(epoch_perplexity):.2f}")
325
+ print(f" Codebook Usage: {avg_usage}/{codebook_size} ({(avg_usage/codebook_size)*100:.1f}%) | Variance: {code_variance:.6f}")
326
+
327
+ # Use the passed-in checkpoint_dir for saving plots
328
+ hist_path = os.path.join(checkpoint_dir, f'codebook_usage_epoch_{epoch:03d}.png')
329
+ plt.figure(figsize=(12, 6)); plt.hist(all_epoch_indices_flat, bins=codebook_size); plt.title(f'Codebook Usage - Epoch {epoch}'); plt.savefig(hist_path); plt.close()
330
+
331
+ if epoch > 0 and epoch % 5 == 0:
332
+ print("\n--- Performing End-of-Epoch Tasks ---")
333
+ vq_model.eval()
334
+ with torch.no_grad():
335
+ val_data = next(iter(dataloader))
336
+ if val_data and val_data[0] is not None:
337
+ motion_batch, lengths = val_data
338
+ x_val = motion_batch.to(DEVICE)
339
+ x_recon_val, _, _, _ = vq_model(x_val, -1)
340
+ orig = x_val[0, :lengths[0]].cpu().numpy()
341
+ recon = x_recon_val[0, :lengths[0]].cpu().numpy()
342
+ mse = ((orig - recon) ** 2).mean()
343
+ print(f"Reconstruction MSE on sample: {mse:.6f}")
344
+
345
+ with torch.no_grad():
346
+ usage_threshold = 10
347
+ underutilized_indices = torch.from_numpy(np.where(counts < usage_threshold)[0]).to(DEVICE)
348
+ num_to_reset = len(underutilized_indices)
349
+ if num_to_reset > 0:
350
+ print(f"[CODEBOOK MGMT] Resetting {num_to_reset} underutilized codes.")
351
+ reset_data = next(iter(dataloader))
352
+ if reset_data and reset_data[0] is not None:
353
+ motion_batch, _ = reset_data
354
+ x_reset = motion_batch.to(DEVICE)
355
+ z_e = vq_model.vqvae.encoder(vq_model.vqvae.preprocess(x_reset))
356
+ z_e_flat = z_e.permute(0, 2, 1).reshape(-1, z_e.shape[1])
357
+ if z_e_flat.shape[0] >= num_to_reset:
358
+ indices = torch.randperm(z_e_flat.size(0))[:num_to_reset]
359
+ vq_model.vqvae.quantizer.codebook.data[underutilized_indices] = z_e_flat[indices]
360
+ vq_model.train()
361
+
362
+ if epoch > 0 and epoch % 5 == 0:
363
+ checkpoint_manager.save_checkpoint(vq_model, optimizer, epoch, np.mean(epoch_losses))
364
+
365
+ print("\n✅ Training loop finished.")
366
+ return vq_model
367
+
368
+
369
+ # ==============================================================================
370
+ # 6) MAIN EXECUTION SCRIPT (No Globals)
371
+ # ==============================================================================
372
+ def main_colab():
373
+ from google.colab import drive
374
+ drive.mount('/content/drive')
375
+ print("✅ Google Drive mounted successfully.")
376
+
377
+ GDRIVE_ROOT = '/content/drive/MyDrive'
378
+
379
+ # Define all paths locally within the main function
380
+ STATS_PATH = f'/content/dataset_stats.pt'
381
+ DATA_ROOT = f'{GDRIVE_ROOT}/kaggle_upload/npz_data/batch_1'
382
+ CHECKPOINT_DIR = f'{GDRIVE_ROOT}/Colab_Checkpoints/MotionGPT_VQVAE_Final'
383
+
384
+ # The 'global' keyword is no longer needed
385
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
386
+ print(f"Data Root: {DATA_ROOT}")
387
+ print(f"Stats Path: {STATS_PATH}")
388
+ print(f"Checkpoint Dir: {CHECKPOINT_DIR}")
389
+
390
+ smpl_dim = 182
391
+ codebook_size = 512
392
+ code_dim = 512
393
+ vq_model = MotionGPT_VQVAE_Wrapper(
394
+ smpl_dim=smpl_dim, codebook_size=codebook_size, code_dim=code_dim,
395
+ quantizer="ema_reset", width=512, depth=3, down_t=3, stride_t=2,
396
+ dilation_growth_rate=3, activation='relu', norm=None
397
+ ).to(DEVICE)
398
+
399
+ motion_dataset = NpzMotionDataset(
400
+ root_dir=DATA_ROOT,
401
+ stats_path=STATS_PATH,
402
+ min_seq_len=64
403
+ )
404
+
405
+ # Pass CHECKPOINT_DIR as an argument to the training function
406
+ vq_model = train_vqvae_colab(
407
+ vq_model,
408
+ motion_dataset,
409
+ checkpoint_dir=CHECKPOINT_DIR, # Pass the path here
410
+ num_epochs=1000,
411
+ batch_size=32,
412
+ lr=2e-4
413
+ )
414
+
415
+ print("\n" + "="*70 + "\nVQ-VAE TRAINING COMPLETED SUCCESSFULLY!\n" + "="*70)
416
+ final_model_path = os.path.join(CHECKPOINT_DIR, 'final_vqvae_model.pt')
417
+ torch.save({'model_state_dict': vq_model.state_dict()}, final_model_path)
418
+ print(f"Final model saved to: {final_model_path}")
419
+
420
+ if __name__ == "__main__":
421
+ main_colab()
visualize.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization script to convert motion tokens to SMPL-X 3D animation.
3
+ Requires VQ-VAE checkpoint, dataset stats, and SMPL-X model files.
4
+
5
+ Usage:
6
+ # Visualize from LLM output string
7
+ python visualize.py --tokens "<MOT_BEGIN><motion_177><motion_135>...<MOT_END>"
8
+
9
+ # Visualize from saved file
10
+ python visualize.py --input motion_output.txt
11
+
12
+ # Generate and visualize in one go
13
+ python visualize.py --prompt "walking" --stage 3
14
+
15
+ # Custom paths
16
+ python visualize.py --tokens "..." --vqvae-ckpt /path/to/vqvae.pt --smplx-dir /path/to/smplx
17
+ """
18
+ import os
19
+ import sys
20
+ import re
21
+ import argparse
22
+ from pathlib import Path
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+ from config import WORK_DIR, DATA_DIR
29
+
30
+ # Try importing visualization dependencies
31
+ try:
32
+ import plotly.graph_objects as go
33
+ except ImportError:
34
+ print("Installing plotly...")
35
+ os.system("pip install -q plotly")
36
+ import plotly.graph_objects as go
37
+
38
+ try:
39
+ import smplx
40
+ except ImportError:
41
+ print("Installing smplx...")
42
+ os.system("pip install -q smplx==0.1.28")
43
+ import smplx
44
+
45
+ # =====================================================================
46
+ # Configuration - can be overridden via command-line or environment
47
+ # =====================================================================
48
+ # VQ-VAE checkpoint path (trained motion encoder/decoder)
49
+ VQVAE_CHECKPOINT = os.environ.get(
50
+ "VQVAE_CHECKPOINT",
51
+ os.path.join(DATA_DIR, "vqvae_model.pt")
52
+ )
53
+
54
+ # Dataset normalization stats (mean/std used during VQ-VAE training)
55
+ STATS_PATH = os.environ.get(
56
+ "VQVAE_STATS_PATH",
57
+ os.path.join(DATA_DIR, "vqvae_stats.pt")
58
+ )
59
+
60
+ # SMPL-X model directory (contains SMPLX_NEUTRAL.npz, etc.)
61
+ SMPLX_MODEL_DIR = os.environ.get(
62
+ "SMPLX_MODEL_DIR",
63
+ os.path.join(DATA_DIR, "smplx_models")
64
+ )
65
+
66
+ # Output directory for HTML animations
67
+ OUTPUT_DIR = os.environ.get("VIS_OUTPUT_DIR", WORK_DIR)
68
+
69
+ # Device
70
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+
72
+ # VQ-VAE architecture params (must match training config)
73
+ SMPL_DIM = 182
74
+ CODEBOOK_SIZE = 512
75
+ CODE_DIM = 512
76
+ VQ_ARGS = dict(
77
+ width=512,
78
+ depth=3,
79
+ down_t=2,
80
+ stride_t=2,
81
+ dilation_growth_rate=3,
82
+ activation='relu',
83
+ norm=None,
84
+ quantizer="ema_reset"
85
+ )
86
+
87
+ # SMPL-X parameter layout (must match VQ-VAE training)
88
+ PARAM_DIMS = [10, 63, 45, 45, 3, 10, 3, 3]
89
+ PARAM_NAMES = ["betas", "body_pose", "left_hand_pose", "right_hand_pose",
90
+ "trans", "expression", "jaw_pose", "eye_pose"]
91
+
92
+ # =====================================================================
93
+ # Import VQ-VAE architecture
94
+ # =====================================================================
95
+ try:
96
+ # Add SignMotionGPT to path if not already
97
+ sign_mgpt_dir = os.path.join(os.path.dirname(__file__))
98
+ if sign_mgpt_dir not in sys.path:
99
+ sys.path.insert(0, sign_mgpt_dir)
100
+
101
+ from mGPT.archs.mgpt_vq import VQVae
102
+ except ImportError as e:
103
+ print(f"❌ Could not import VQVae: {e}")
104
+ print("Make sure mGPT/archs/mgpt_vq.py exists in the project.")
105
+ sys.exit(1)
106
+
107
+
108
+ # =====================================================================
109
+ # VQ-VAE Wrapper
110
+ # =====================================================================
111
+ class MotionGPT_VQVAE_Wrapper(nn.Module):
112
+ """Wrapper matching the VQ-VAE training setup"""
113
+ def __init__(self, smpl_dim=SMPL_DIM, codebook_size=CODEBOOK_SIZE,
114
+ code_dim=CODE_DIM, **kwargs):
115
+ super().__init__()
116
+ self.vqvae = VQVae(
117
+ nfeats=smpl_dim,
118
+ code_num=codebook_size,
119
+ code_dim=code_dim,
120
+ output_emb_width=code_dim,
121
+ **kwargs
122
+ )
123
+
124
+
125
+ # =====================================================================
126
+ # Token Parsing
127
+ # =====================================================================
128
+ def parse_motion_tokens(token_str):
129
+ """
130
+ Parse motion tokens from LLM output string.
131
+ Accepts:
132
+ - "<MOT_BEGIN><motion_177><motion_135>...<MOT_END>"
133
+ - "177 135 152 200 46..."
134
+ - List/array of ints
135
+
136
+ Returns:
137
+ List of token integers
138
+ """
139
+ if isinstance(token_str, (list, tuple, np.ndarray)):
140
+ return [int(x) for x in token_str]
141
+
142
+ if not isinstance(token_str, str):
143
+ raise ValueError("Tokens must be string or list-like")
144
+
145
+ # Try extracting <motion_ID> tokens
146
+ matches = re.findall(r'<motion_(\d+)>', token_str)
147
+ if matches:
148
+ return [int(x) for x in matches]
149
+
150
+ # Try space-separated numbers
151
+ token_str = token_str.strip()
152
+ if token_str:
153
+ try:
154
+ return [int(x) for x in token_str.split()]
155
+ except ValueError:
156
+ pass
157
+
158
+ raise ValueError(f"Could not parse motion tokens from: {token_str[:100]}...")
159
+
160
+
161
+ # =====================================================================
162
+ # Model Loading
163
+ # =====================================================================
164
+ def load_vqvae(checkpoint_path, device=DEVICE, vq_args=VQ_ARGS):
165
+ """Load trained VQ-VAE model from checkpoint"""
166
+ if not os.path.exists(checkpoint_path):
167
+ raise FileNotFoundError(
168
+ f"VQ-VAE checkpoint not found: {checkpoint_path}\n"
169
+ f"Please download it and set VQVAE_CHECKPOINT environment variable "
170
+ f"or use --vqvae-ckpt argument."
171
+ )
172
+
173
+ print(f"Loading VQ-VAE from: {checkpoint_path}")
174
+ model = MotionGPT_VQVAE_Wrapper(
175
+ smpl_dim=SMPL_DIM,
176
+ codebook_size=CODEBOOK_SIZE,
177
+ code_dim=CODE_DIM,
178
+ **vq_args
179
+ ).to(device)
180
+
181
+ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
182
+ state_dict = ckpt.get('model_state_dict', ckpt)
183
+ model.load_state_dict(state_dict, strict=False)
184
+ model.eval()
185
+
186
+ print(f"✅ VQ-VAE loaded (codebook size: {CODEBOOK_SIZE})")
187
+ return model
188
+
189
+
190
+ def load_stats(stats_path):
191
+ """Load normalization statistics (mean/std) used during VQ-VAE training"""
192
+ if not stats_path or not os.path.exists(stats_path):
193
+ print(f"⚠️ Stats file not found: {stats_path}")
194
+ print(" Will skip denormalization (may affect quality)")
195
+ return None, None
196
+
197
+ print(f"Loading stats from: {stats_path}")
198
+ st = torch.load(stats_path, map_location='cpu', weights_only=False)
199
+ mean = st.get('mean', 0)
200
+ std = st.get('std', 1)
201
+
202
+ # Convert to numpy
203
+ if torch.is_tensor(mean):
204
+ mean = mean.cpu().numpy()
205
+ if torch.is_tensor(std):
206
+ std = std.cpu().numpy()
207
+
208
+ print(f"✅ Stats loaded (mean shape: {np.array(mean).shape})")
209
+ return mean, std
210
+
211
+
212
+ def load_smplx_model(model_dir, device=DEVICE):
213
+ """Load SMPL-X body model"""
214
+ if not os.path.exists(model_dir):
215
+ raise FileNotFoundError(
216
+ f"SMPL-X model directory not found: {model_dir}\n"
217
+ f"Please download SMPL-X models and set SMPLX_MODEL_DIR environment variable "
218
+ f"or use --smplx-dir argument."
219
+ )
220
+
221
+ print(f"Loading SMPL-X from: {model_dir}")
222
+ model = smplx.SMPLX(
223
+ model_path=model_dir,
224
+ model_type='smplx',
225
+ gender='neutral',
226
+ use_pca=False,
227
+ create_global_orient=True,
228
+ create_body_pose=True,
229
+ create_betas=True,
230
+ create_expression=True,
231
+ create_jaw_pose=True,
232
+ create_left_hand_pose=True,
233
+ create_right_hand_pose=True,
234
+ create_transl=True
235
+ ).to(device)
236
+
237
+ print(f"✅ SMPL-X loaded")
238
+ return model
239
+
240
+
241
+ # =====================================================================
242
+ # Token Decoding
243
+ # =====================================================================
244
+ def decode_tokens_to_params(tokens, vqvae_model, mean=None, std=None, device=DEVICE):
245
+ """
246
+ Decode motion tokens to SMPL-X parameters.
247
+
248
+ Args:
249
+ tokens: List of motion token IDs
250
+ vqvae_model: Trained VQ-VAE model
251
+ mean: Optional normalization mean
252
+ std: Optional normalization std
253
+ device: Device to run on
254
+
255
+ Returns:
256
+ numpy array of shape (T, SMPL_DIM) with SMPL-X parameters
257
+ """
258
+ if not tokens:
259
+ return np.zeros((0, SMPL_DIM), dtype=np.float32)
260
+
261
+ # Prepare token indices
262
+ idx = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) # (1, T_q)
263
+ T_q = idx.shape[1]
264
+
265
+ quantizer = vqvae_model.vqvae.quantizer
266
+
267
+ # Get code dimension
268
+ if hasattr(quantizer, "codebook"):
269
+ codebook = quantizer.codebook.to(device)
270
+ code_dim = codebook.shape[1]
271
+ else:
272
+ code_dim = CODE_DIM
273
+
274
+ # Dequantize tokens
275
+ x_quantized = None
276
+ if hasattr(quantizer, "dequantize"):
277
+ try:
278
+ with torch.no_grad():
279
+ dq = quantizer.dequantize(idx)
280
+ if dq is not None:
281
+ dq = dq.contiguous()
282
+ # Ensure shape is (N, code_dim, T_q)
283
+ if dq.ndim == 3 and dq.shape[1] == code_dim:
284
+ x_quantized = dq
285
+ elif dq.ndim == 3 and dq.shape[1] == T_q:
286
+ x_quantized = dq.permute(0, 2, 1).contiguous()
287
+ else:
288
+ x_quantized = None
289
+ except Exception:
290
+ x_quantized = None
291
+
292
+ # Fallback: manual codebook lookup
293
+ if x_quantized is None:
294
+ if not hasattr(quantizer, "codebook"):
295
+ raise RuntimeError("No dequantize method and no codebook available")
296
+ with torch.no_grad():
297
+ emb = codebook[idx] # (1, T_q, code_dim)
298
+ x_quantized = emb.permute(0, 2, 1).contiguous() # (1, code_dim, T_q)
299
+
300
+ # Decode through VQ-VAE decoder
301
+ with torch.no_grad():
302
+ x_dec = vqvae_model.vqvae.decoder(x_quantized)
303
+ smpl_out = vqvae_model.vqvae.postprocess(x_dec) # (1, T_out, SMPL_DIM)
304
+ params_np = smpl_out.squeeze(0).cpu().numpy() # (T_out, SMPL_DIM)
305
+
306
+ # Denormalize if stats provided
307
+ if (mean is not None) and (std is not None):
308
+ mean_arr = np.array(mean).reshape(1, -1)
309
+ std_arr = np.array(std).reshape(1, -1)
310
+ params_np = (params_np * std_arr) + mean_arr
311
+
312
+ return params_np
313
+
314
+
315
+ # =====================================================================
316
+ # SMPL-X Parameter to Vertices
317
+ # =====================================================================
318
+ def params_to_vertices(params_seq, smplx_model, batch_size=32):
319
+ """
320
+ Convert SMPL-X parameters to 3D vertices.
321
+
322
+ Args:
323
+ params_seq: numpy array (T, SMPL_DIM)
324
+ smplx_model: loaded SMPL-X model
325
+ batch_size: batch size for processing
326
+
327
+ Returns:
328
+ verts: numpy array (T, V, 3)
329
+ faces: numpy array (F, 3)
330
+ """
331
+ # Compute parameter slicing indices
332
+ starts = np.cumsum([0] + PARAM_DIMS[:-1])
333
+ ends = starts + np.array(PARAM_DIMS)
334
+
335
+ T = params_seq.shape[0]
336
+ all_verts = []
337
+
338
+ # Infer number of body joints
339
+ num_body_joints = getattr(smplx_model, "NUM_BODY_JOINTS", 21)
340
+
341
+ with torch.no_grad():
342
+ for s in range(0, T, batch_size):
343
+ batch = params_seq[s:s+batch_size] # (B, SMPL_DIM)
344
+ B = batch.shape[0]
345
+
346
+ # Extract parameters
347
+ np_parts = {}
348
+ for name, st, ed in zip(PARAM_NAMES, starts, ends):
349
+ np_parts[name] = batch[:, st:ed].astype(np.float32)
350
+
351
+ # Convert to tensors
352
+ tensor_parts = {
353
+ name: torch.from_numpy(arr).to(DEVICE)
354
+ for name, arr in np_parts.items()
355
+ }
356
+
357
+ # Handle body pose (may or may not include global orient)
358
+ body_t = tensor_parts['body_pose']
359
+ L_body = body_t.shape[1]
360
+ expected_no_go = num_body_joints * 3
361
+ expected_with_go = (num_body_joints + 1) * 3
362
+
363
+ if L_body == expected_with_go:
364
+ global_orient = body_t[:, :3].contiguous()
365
+ body_pose_only = body_t[:, 3:].contiguous()
366
+ elif L_body == expected_no_go:
367
+ global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
368
+ body_pose_only = body_t
369
+ else:
370
+ # Best-effort fallback
371
+ if L_body > expected_no_go:
372
+ global_orient = body_t[:, :3].contiguous()
373
+ body_pose_only = body_t[:, 3:].contiguous()
374
+ else:
375
+ pad_len = max(0, expected_no_go - L_body)
376
+ body_pose_only = F.pad(body_t, (0, pad_len))
377
+ global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
378
+
379
+ # Call SMPL-X
380
+ out = smplx_model(
381
+ betas=tensor_parts['betas'],
382
+ global_orient=global_orient,
383
+ body_pose=body_pose_only,
384
+ left_hand_pose=tensor_parts['left_hand_pose'],
385
+ right_hand_pose=tensor_parts['right_hand_pose'],
386
+ expression=tensor_parts['expression'],
387
+ jaw_pose=tensor_parts['jaw_pose'],
388
+ leye_pose=tensor_parts['eye_pose'],
389
+ reye_pose=tensor_parts['eye_pose'],
390
+ transl=tensor_parts['trans'],
391
+ return_verts=True
392
+ )
393
+
394
+ verts = out.vertices.detach().cpu().numpy() # (B, V, 3)
395
+ all_verts.append(verts)
396
+
397
+ verts_all = np.concatenate(all_verts, axis=0) # (T, V, 3)
398
+ faces = smplx_model.faces.astype(np.int32)
399
+
400
+ return verts_all, faces
401
+
402
+
403
+ # =====================================================================
404
+ # Visualization
405
+ # =====================================================================
406
+ def animate_motion(verts, faces, title="Generated Motion", output_path=None, fps=20):
407
+ """
408
+ Create interactive 3D animation using Plotly.
409
+
410
+ Args:
411
+ verts: numpy array (T, V, 3)
412
+ faces: numpy array (F, 3)
413
+ title: Plot title
414
+ output_path: Path to save HTML file
415
+ fps: Frames per second for animation
416
+
417
+ Returns:
418
+ Plotly figure object
419
+ """
420
+ T, V, _ = verts.shape
421
+ i, j, k = faces.T.tolist()
422
+
423
+ # Initial mesh
424
+ mesh = go.Mesh3d(
425
+ x=verts[0, :, 0],
426
+ y=verts[0, :, 1],
427
+ z=verts[0, :, 2],
428
+ i=i, j=j, k=k,
429
+ name=title,
430
+ flatshading=True,
431
+ opacity=0.7
432
+ )
433
+
434
+ # Create frames
435
+ frames = [
436
+ go.Frame(
437
+ data=[go.Mesh3d(
438
+ x=verts[t, :, 0],
439
+ y=verts[t, :, 1],
440
+ z=verts[t, :, 2],
441
+ i=i, j=j, k=k,
442
+ flatshading=True,
443
+ opacity=0.7
444
+ )],
445
+ name=str(t)
446
+ )
447
+ for t in range(T)
448
+ ]
449
+
450
+ # Create figure
451
+ fig = go.Figure(data=[mesh], frames=frames)
452
+
453
+ fig.update_layout(
454
+ title_text=title,
455
+ scene=dict(
456
+ aspectmode='data',
457
+ xaxis=dict(visible=False),
458
+ yaxis=dict(visible=False),
459
+ zaxis=dict(visible=False),
460
+ camera=dict(eye=dict(x=0, y=-2, z=0.7))
461
+ ),
462
+ updatemenus=[dict(
463
+ type="buttons",
464
+ buttons=[
465
+ dict(
466
+ label="Play",
467
+ method="animate",
468
+ args=[None, {
469
+ "frame": {"duration": 1000//fps, "redraw": True},
470
+ "fromcurrent": True
471
+ }]
472
+ ),
473
+ dict(
474
+ label="Pause",
475
+ method="animate",
476
+ args=[[None], {
477
+ "frame": {"duration": 0, "redraw": False}
478
+ }]
479
+ )
480
+ ]
481
+ )]
482
+ )
483
+
484
+ # Save HTML
485
+ if output_path:
486
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
487
+ fig.write_html(output_path)
488
+ print(f"✅ Animation saved to: {output_path}")
489
+
490
+ return fig
491
+
492
+
493
+ # =====================================================================
494
+ # Main Visualization Pipeline
495
+ # =====================================================================
496
+ def visualize(
497
+ tokens,
498
+ vqvae_ckpt=VQVAE_CHECKPOINT,
499
+ stats_path=STATS_PATH,
500
+ smplx_dir=SMPLX_MODEL_DIR,
501
+ output_html=None,
502
+ title="Generated Motion",
503
+ fps=20
504
+ ):
505
+ """
506
+ Complete visualization pipeline: tokens -> vertices -> animation.
507
+
508
+ Args:
509
+ tokens: Motion tokens (string or list of ints)
510
+ vqvae_ckpt: Path to VQ-VAE checkpoint
511
+ stats_path: Path to normalization stats
512
+ smplx_dir: Path to SMPL-X model directory
513
+ output_html: Path to save HTML animation
514
+ title: Animation title
515
+ fps: Frames per second
516
+
517
+ Returns:
518
+ Plotly figure object
519
+ """
520
+ print("="*60)
521
+ print("Motion Visualization Pipeline")
522
+ print("="*60)
523
+
524
+ # Parse tokens
525
+ print("\n[1/5] Parsing tokens...")
526
+ token_list = parse_motion_tokens(tokens)
527
+ print(f" Parsed {len(token_list)} tokens")
528
+ if not token_list:
529
+ print("❌ No tokens to visualize")
530
+ return None
531
+
532
+ # Load models
533
+ print("\n[2/5] Loading VQ-VAE...")
534
+ vq_model = load_vqvae(vqvae_ckpt, device=DEVICE)
535
+
536
+ print("\n[3/5] Loading normalization stats...")
537
+ mean, std = load_stats(stats_path)
538
+
539
+ print("\n[4/5] Loading SMPL-X model...")
540
+ smplx_model = load_smplx_model(smplx_dir, device=DEVICE)
541
+
542
+ # Decode tokens
543
+ print("\n[5/5] Decoding and rendering...")
544
+ print(" Decoding tokens to SMPL-X parameters...")
545
+ params = decode_tokens_to_params(token_list, vq_model, mean, std, device=DEVICE)
546
+ print(f" Decoded params shape: {params.shape}")
547
+
548
+ if params.shape[0] == 0:
549
+ print("❌ No frames produced from decoder")
550
+ return None
551
+
552
+ # Convert to vertices
553
+ print(" Converting parameters to vertices...")
554
+ verts, faces = params_to_vertices(params, smplx_model, batch_size=32)
555
+ print(f" Vertices shape: {verts.shape}, Faces: {faces.shape}")
556
+
557
+ # Create animation
558
+ print(" Creating animation...")
559
+ if output_html is None:
560
+ output_html = os.path.join(OUTPUT_DIR, "motion_animation.html")
561
+
562
+ fig = animate_motion(verts, faces, title=title, output_path=output_html, fps=fps)
563
+
564
+ print("\n" + "="*60)
565
+ print("✅ Visualization complete!")
566
+ print("="*60)
567
+
568
+ return fig
569
+
570
+
571
+ # =====================================================================
572
+ # CLI
573
+ # =====================================================================
574
+ def main():
575
+ parser = argparse.ArgumentParser(
576
+ description="Visualize motion tokens as 3D SMPL-X animation"
577
+ )
578
+
579
+ # Input options (mutually exclusive)
580
+ input_group = parser.add_mutually_exclusive_group(required=True)
581
+ input_group.add_argument(
582
+ "--tokens",
583
+ type=str,
584
+ help="Motion tokens string (e.g., '<MOT_BEGIN><motion_177>...<MOT_END>' or '177 135 152...')"
585
+ )
586
+ input_group.add_argument(
587
+ "--input",
588
+ type=str,
589
+ help="Path to file containing motion tokens"
590
+ )
591
+ input_group.add_argument(
592
+ "--prompt",
593
+ type=str,
594
+ help="Generate tokens from text prompt first (requires --stage)"
595
+ )
596
+
597
+ # Generation options (if using --prompt)
598
+ parser.add_argument(
599
+ "--stage",
600
+ type=int,
601
+ default=3,
602
+ choices=[1, 2, 3],
603
+ help="Stage model to use for generation (default: 3)"
604
+ )
605
+
606
+ # Model paths
607
+ parser.add_argument(
608
+ "--vqvae-ckpt",
609
+ type=str,
610
+ default=VQVAE_CHECKPOINT,
611
+ help=f"Path to VQ-VAE checkpoint (default: {VQVAE_CHECKPOINT})"
612
+ )
613
+ parser.add_argument(
614
+ "--stats",
615
+ type=str,
616
+ default=STATS_PATH,
617
+ help=f"Path to normalization stats (default: {STATS_PATH})"
618
+ )
619
+ parser.add_argument(
620
+ "--smplx-dir",
621
+ type=str,
622
+ default=SMPLX_MODEL_DIR,
623
+ help=f"Path to SMPL-X model directory (default: {SMPLX_MODEL_DIR})"
624
+ )
625
+
626
+ # Output options
627
+ parser.add_argument(
628
+ "--output",
629
+ type=str,
630
+ default=None,
631
+ help="Path to save HTML animation (default: motion_animation.html)"
632
+ )
633
+ parser.add_argument(
634
+ "--title",
635
+ type=str,
636
+ default="Generated Motion",
637
+ help="Animation title"
638
+ )
639
+ parser.add_argument(
640
+ "--fps",
641
+ type=int,
642
+ default=20,
643
+ help="Frames per second for animation (default: 20)"
644
+ )
645
+
646
+ args = parser.parse_args()
647
+
648
+ # Get tokens
649
+ if args.prompt:
650
+ # Generate tokens first using inference.py
651
+ print("Generating motion tokens from prompt...")
652
+ from inference import inference
653
+ tokens = inference(
654
+ prompt=args.prompt,
655
+ stage=args.stage,
656
+ output_file=None,
657
+ per_prompt_vocab=True
658
+ )
659
+ elif args.input:
660
+ # Read from file
661
+ with open(args.input, 'r') as f:
662
+ tokens = f.read().strip()
663
+ else:
664
+ # Direct token string
665
+ tokens = args.tokens
666
+
667
+ # Visualize
668
+ visualize(
669
+ tokens=tokens,
670
+ vqvae_ckpt=args.vqvae_ckpt,
671
+ stats_path=args.stats,
672
+ smplx_dir=args.smplx_dir,
673
+ output_html=args.output,
674
+ title=args.title,
675
+ fps=args.fps
676
+ )
677
+
678
+
679
+ if __name__ == "__main__":
680
+ main()
681
+