H-Liu1997 commited on
Commit
35ee6b1
·
1 Parent(s): 93735c1

Clean rewrite: Python 3.10 venv + script-based inference

Browse files
Files changed (1) hide show
  1. EMAGE_Colab_Demo.ipynb +151 -330
EMAGE_Colab_Demo.ipynb CHANGED
@@ -11,424 +11,244 @@
11
  "- **DisCo**: Upper body gesture generation with diffusion\n",
12
  "- **EMAGE**: Full body + face gesture generation\n",
13
  "\n",
14
- "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/#fileId=https%3A//huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE_Colab_Demo.ipynb)\n",
15
- "\n",
16
- "[Project Page](https://pantomatrix.github.io/EMAGE/) | [GitHub](https://github.com/PantoMatrix/PantoMatrix) | [Paper](https://arxiv.org/abs/2401.00374)"
17
  ]
18
  },
19
  {
20
  "cell_type": "markdown",
21
  "metadata": {},
22
  "source": [
23
- "## 1. Setup Environment"
 
 
24
  ]
25
  },
26
  {
27
  "cell_type": "code",
 
28
  "metadata": {},
 
29
  "source": [
30
- "# Step 1: Install dependencies (using Colab's Python 3.11)\n",
31
- "import sys\n",
32
- "print(f\"Python version: {sys.version}\")\n",
 
33
  "\n",
34
- "# Install PyTorch with CUDA 12.1\n",
35
- "!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n",
36
- "!pip install -q numpy librosa soundfile transformers huggingface_hub\n",
37
- "!pip install -q smplx trimesh scipy easydict omegaconf\n",
38
  "\n",
39
- "# Verify installation\n",
40
- "import torch\n",
41
- "print(f\"PyTorch: {torch.__version__}\")\n",
42
- "print(f\"CUDA available: {torch.cuda.is_available()}\")"
43
- ],
44
- "execution_count": null,
45
- "outputs": []
 
 
 
 
 
46
  },
47
  {
48
  "cell_type": "code",
 
49
  "metadata": {},
 
50
  "source": [
51
- "# Step 2: Clone code repositories\n",
52
  "!apt-get install -y git-lfs > /dev/null 2>&1\n",
53
  "!git lfs install\n",
54
  "\n",
55
- "# Clone PantoMatrix from GitHub\n",
56
- "!git clone https://github.com/PantoMatrix/PantoMatrix.git\n",
57
- "\n",
58
- "# Clone evaluation tools (contains SMPLX models)\n",
59
- "!git clone https://huggingface.co/H-Liu1997/emage_evaltools PantoMatrix/emage_evaltools\n",
60
- "%cd PantoMatrix/emage_evaltools\n",
61
  "!git lfs pull\n",
62
- "%cd /content/PantoMatrix\n",
63
  "\n",
64
- "print(\"Code cloned successfully!\")"
65
- ],
66
- "execution_count": null,
67
- "outputs": []
68
  },
69
  {
70
  "cell_type": "code",
 
71
  "metadata": {},
 
72
  "source": [
73
- "# Step 3: Import libraries\n",
74
- "import sys\n",
75
- "import os\n",
76
- "import torch\n",
77
- "import numpy as np\n",
78
- "import librosa\n",
79
- "import soundfile as sf\n",
80
- "from IPython.display import Video, Audio, display\n",
81
  "\n",
 
 
 
82
  "from models.camn_audio import CamnAudioModel\n",
83
  "from models.disco_audio import DiscoAudioModel\n",
84
  "from models.emage_audio import EmageAudioModel, EmageVQVAEConv, EmageVAEConv, EmageVQModel\n",
85
  "from emage_utils.motion_io import beat_format_save\n",
86
  "from emage_utils.npz2pose import render2d\n",
87
- "from torchvision.io import write_video\n",
88
- "import torch.nn.functional as F\n",
89
- "\n",
90
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
91
- "print(f\"Python: {sys.version}\")\n",
92
- "print(f\"PyTorch: {torch.__version__}\")\n",
93
- "print(f\"NumPy: {np.__version__}\")\n",
94
- "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
95
- "print(f\"Using device: {device}\")\n",
96
  "\n",
97
- "os.makedirs(\"./outputs\", exist_ok=True)\n",
98
- "\n",
99
- "# Check if evaluation tools are available\n",
100
- "EVAL_AVAILABLE = os.path.exists(\"./emage_evaltools/metric.py\")\n",
101
- "print(f\"Evaluation tools available: {EVAL_AVAILABLE}\")"
102
- ],
103
- "execution_count": null,
104
- "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  },
106
  {
107
  "cell_type": "markdown",
108
  "metadata": {},
109
  "source": [
110
- "## 2. Upload Your Audio\n",
111
  "\n",
112
- "Upload a `.wav` file or use the example audio."
113
  ]
114
  },
115
  {
116
  "cell_type": "code",
 
117
  "metadata": {},
 
118
  "source": [
119
- "# Option 1: Use example audio\n",
120
- "audio_path = \"./examples/audio/2_scott_0_103_103_10s.wav\"\n",
121
  "\n",
122
- "# Option 2: Upload your own audio\n",
123
  "# from google.colab import files\n",
124
  "# uploaded = files.upload()\n",
125
- "# audio_path = list(uploaded.keys())[0]\n",
126
  "\n",
 
127
  "display(Audio(audio_path))"
128
- ],
129
- "execution_count": null,
130
- "outputs": []
131
- },
132
- {
133
- "cell_type": "markdown",
134
- "metadata": {},
135
- "source": [
136
- "## 3. CaMN Model (Upper Body)\n",
137
- "\n",
138
- "CaMN generates upper body gestures from speech audio."
139
  ]
140
  },
141
  {
142
  "cell_type": "code",
143
- "metadata": {},
144
- "source": [
145
- "# Load CaMN model\n",
146
- "model_camn = CamnAudioModel.from_pretrained(\"H-Liu1997/camn_audio\").to(device).eval()\n",
147
- "print(\"CaMN model loaded!\")"
148
- ],
149
- "execution_count": null,
150
- "outputs": []
151
- },
152
- {
153
- "cell_type": "code",
154
- "metadata": {},
155
- "source": [
156
- "# CaMN Inference\n",
157
- "sr_model = model_camn.cfg.audio_sr\n",
158
- "pose_fps = model_camn.cfg.pose_fps\n",
159
- "seed_frames = model_camn.cfg.seed_frames\n",
160
- "\n",
161
- "audio_loaded, _ = librosa.load(audio_path, sr=sr_model)\n",
162
- "audio_t = torch.from_numpy(audio_loaded).float().unsqueeze(0).to(device)\n",
163
- "sid = torch.zeros(1, 1).long().to(device)\n",
164
- "\n",
165
- "with torch.no_grad():\n",
166
- " motion_pred = model_camn(audio_t, sid, seed_frames=seed_frames)[\"motion_axis_angle\"]\n",
167
- "\n",
168
- "t = motion_pred.shape[1]\n",
169
- "motion_pred_np = motion_pred.cpu().numpy().reshape(t, -1)\n",
170
- "\n",
171
- "# Save motion\n",
172
- "camn_npz_path = \"./outputs/camn_output.npz\"\n",
173
- "beat_format_save(camn_npz_path, motion_pred_np, upsample=30 // pose_fps)\n",
174
- "print(f\"CaMN motion saved to {camn_npz_path}\")"
175
- ],
176
- "execution_count": null,
177
- "outputs": []
178
- },
179
- {
180
- "cell_type": "code",
181
- "metadata": {},
182
- "source": [
183
- "# Visualize CaMN result\n",
184
- "motion_dict = np.load(camn_npz_path, allow_pickle=True)\n",
185
- "v2d = render2d(motion_dict, (720, 480), face_only=False, remove_global=True)\n",
186
- "camn_video_path = \"./outputs/camn_output.mp4\"\n",
187
- "write_video(camn_video_path, v2d.permute(0, 2, 3, 1), fps=30)\n",
188
- "print(\"CaMN visualization:\")\n",
189
- "display(Video(camn_video_path, embed=True, width=480))"
190
- ],
191
  "execution_count": null,
192
- "outputs": []
193
- },
194
- {
195
- "cell_type": "markdown",
196
  "metadata": {},
 
197
  "source": [
198
- "## 4. DisCo Model (Upper Body with Diffusion)\n",
 
 
199
  "\n",
200
- "DisCo uses diffusion for more diverse upper body gesture generation."
 
201
  ]
202
  },
203
  {
204
  "cell_type": "code",
205
- "metadata": {},
206
- "source": [
207
- "# Load DisCo model\n",
208
- "model_disco = DiscoAudioModel.from_pretrained(\"H-Liu1997/disco_audio\").to(device).eval()\n",
209
- "print(\"DisCo model loaded!\")"
210
- ],
211
  "execution_count": null,
212
- "outputs": []
213
- },
214
- {
215
- "cell_type": "code",
216
  "metadata": {},
 
217
  "source": [
218
- "# DisCo Inference\n",
219
- "sr_model = model_disco.cfg.audio_sr\n",
220
- "pose_fps = model_disco.cfg.pose_fps\n",
221
- "seed_frames = model_disco.cfg.seed_frames\n",
222
- "\n",
223
- "audio_loaded, _ = librosa.load(audio_path, sr=sr_model)\n",
224
- "audio_t = torch.from_numpy(audio_loaded).float().unsqueeze(0).to(device)\n",
225
- "sid = torch.zeros(1, 1).long().to(device)\n",
226
- "\n",
227
- "with torch.no_grad():\n",
228
- " motion_pred = model_disco(audio_t, sid, seed_frames=seed_frames, seed_motion=None)[\"motion_axis_angle\"]\n",
229
- "\n",
230
- "t = motion_pred.shape[1]\n",
231
- "motion_pred_np = motion_pred.cpu().numpy().reshape(t, -1)\n",
232
  "\n",
233
- "# Save motion\n",
234
- "disco_npz_path = \"./outputs/disco_output.npz\"\n",
235
- "beat_format_save(disco_npz_path, motion_pred_np, upsample=30 // pose_fps)\n",
236
- "print(f\"DisCo motion saved to {disco_npz_path}\")"
237
- ],
238
- "execution_count": null,
239
- "outputs": []
240
- },
241
- {
242
- "cell_type": "code",
243
- "metadata": {},
244
- "source": [
245
- "# Visualize DisCo result\n",
246
- "motion_dict = np.load(disco_npz_path, allow_pickle=True)\n",
247
- "v2d = render2d(motion_dict, (720, 480), face_only=False, remove_global=True)\n",
248
- "disco_video_path = \"./outputs/disco_output.mp4\"\n",
249
- "write_video(disco_video_path, v2d.permute(0, 2, 3, 1), fps=30)\n",
250
- "print(\"DisCo visualization:\")\n",
251
- "display(Video(disco_video_path, embed=True, width=480))"
252
- ],
253
- "execution_count": null,
254
- "outputs": []
255
- },
256
- {
257
- "cell_type": "markdown",
258
- "metadata": {},
259
- "source": [
260
- "## 5. EMAGE Model (Full Body + Face)\n",
261
- "\n",
262
- "EMAGE generates full body gestures including face expressions."
263
  ]
264
  },
265
  {
266
  "cell_type": "code",
267
- "metadata": {},
268
- "source": [
269
- "# Load EMAGE model and VQ components\n",
270
- "face_motion_vq = EmageVQVAEConv.from_pretrained(\"H-Liu1997/emage_audio\", subfolder=\"emage_vq/face\").to(device).eval()\n",
271
- "upper_motion_vq = EmageVQVAEConv.from_pretrained(\"H-Liu1997/emage_audio\", subfolder=\"emage_vq/upper\").to(device).eval()\n",
272
- "lower_motion_vq = EmageVQVAEConv.from_pretrained(\"H-Liu1997/emage_audio\", subfolder=\"emage_vq/lower\").to(device).eval()\n",
273
- "hands_motion_vq = EmageVQVAEConv.from_pretrained(\"H-Liu1997/emage_audio\", subfolder=\"emage_vq/hands\").to(device).eval()\n",
274
- "global_motion_ae = EmageVAEConv.from_pretrained(\"H-Liu1997/emage_audio\", subfolder=\"emage_vq/global\").to(device).eval()\n",
275
- "\n",
276
- "emage_vq_model = EmageVQModel(\n",
277
- " face_model=face_motion_vq, \n",
278
- " upper_model=upper_motion_vq,\n",
279
- " lower_model=lower_motion_vq, \n",
280
- " hands_model=hands_motion_vq,\n",
281
- " global_model=global_motion_ae\n",
282
- ").to(device).eval()\n",
283
- "\n",
284
- "model_emage = EmageAudioModel.from_pretrained(\"H-Liu1997/emage_audio\").to(device).eval()\n",
285
- "print(\"EMAGE model loaded!\")"
286
- ],
287
- "execution_count": null,
288
- "outputs": []
289
- },
290
- {
291
- "cell_type": "code",
292
- "metadata": {},
293
- "source": [
294
- "# EMAGE Inference\n",
295
- "sr_model = model_emage.cfg.audio_sr\n",
296
- "pose_fps = model_emage.cfg.pose_fps\n",
297
- "\n",
298
- "audio_loaded, _ = librosa.load(audio_path, sr=sr_model)\n",
299
- "audio_t = torch.from_numpy(audio_loaded).float().unsqueeze(0).to(device)\n",
300
- "sid = torch.zeros(1, 1).long().to(device)\n",
301
- "\n",
302
- "with torch.no_grad():\n",
303
- " latent_dict = model_emage.inference(audio_t, sid, emage_vq_model, masked_motion=None, mask=None)\n",
304
- " \n",
305
- " face_latent = latent_dict[\"rec_face\"] if model_emage.cfg.lf > 0 and model_emage.cfg.cf == 0 else None\n",
306
- " upper_latent = latent_dict[\"rec_upper\"] if model_emage.cfg.lu > 0 and model_emage.cfg.cu == 0 else None\n",
307
- " hands_latent = latent_dict[\"rec_hands\"] if model_emage.cfg.lh > 0 and model_emage.cfg.ch == 0 else None\n",
308
- " lower_latent = latent_dict[\"rec_lower\"] if model_emage.cfg.ll > 0 and model_emage.cfg.cl == 0 else None\n",
309
- "\n",
310
- " face_index = torch.max(F.log_softmax(latent_dict[\"cls_face\"], dim=2), dim=2)[1] if model_emage.cfg.cf > 0 else None\n",
311
- " upper_index = torch.max(F.log_softmax(latent_dict[\"cls_upper\"], dim=2), dim=2)[1] if model_emage.cfg.cu > 0 else None\n",
312
- " hands_index = torch.max(F.log_softmax(latent_dict[\"cls_hands\"], dim=2), dim=2)[1] if model_emage.cfg.ch > 0 else None\n",
313
- " lower_index = torch.max(F.log_softmax(latent_dict[\"cls_lower\"], dim=2), dim=2)[1] if model_emage.cfg.cl > 0 else None\n",
314
- "\n",
315
- " ref_trans = torch.zeros(1, 1, 3).to(device)\n",
316
- " all_pred = emage_vq_model.decode(\n",
317
- " face_latent=face_latent, \n",
318
- " upper_latent=upper_latent, \n",
319
- " lower_latent=lower_latent, \n",
320
- " hands_latent=hands_latent,\n",
321
- " face_index=face_index, \n",
322
- " upper_index=upper_index, \n",
323
- " lower_index=lower_index, \n",
324
- " hands_index=hands_index,\n",
325
- " get_global_motion=True, \n",
326
- " ref_trans=ref_trans[:, 0]\n",
327
- " )\n",
328
- "\n",
329
- "motion_pred = all_pred[\"motion_axis_angle\"]\n",
330
- "t = motion_pred.shape[1]\n",
331
- "motion_pred_np = motion_pred.cpu().numpy().reshape(t, -1)\n",
332
- "face_pred = all_pred[\"expression\"].cpu().numpy().reshape(t, -1)\n",
333
- "trans_pred = all_pred[\"trans\"].cpu().numpy().reshape(t, -1)\n",
334
- "\n",
335
- "# Save motion\n",
336
- "emage_npz_path = \"./outputs/emage_output.npz\"\n",
337
- "beat_format_save(emage_npz_path, motion_pred_np, upsample=30 // pose_fps, expressions=face_pred, trans=trans_pred)\n",
338
- "print(f\"EMAGE motion saved to {emage_npz_path}\")"
339
- ],
340
- "execution_count": null,
341
- "outputs": []
342
- },
343
- {
344
- "cell_type": "code",
345
- "metadata": {},
346
- "source": [
347
- "# Visualize EMAGE body result\n",
348
- "motion_dict = np.load(emage_npz_path, allow_pickle=True)\n",
349
- "v2d_body = render2d(motion_dict, (720, 480), face_only=False, remove_global=True)\n",
350
- "emage_body_path = \"./outputs/emage_body.mp4\"\n",
351
- "write_video(emage_body_path, v2d_body.permute(0, 2, 3, 1), fps=30)\n",
352
- "print(\"EMAGE body visualization:\")\n",
353
- "display(Video(emage_body_path, embed=True, width=480))"
354
- ],
355
  "execution_count": null,
356
- "outputs": []
357
- },
358
- {
359
- "cell_type": "markdown",
360
  "metadata": {},
 
361
  "source": [
362
- "## 7. Evaluation (Optional)\n",
 
 
363
  "\n",
364
- "Compute metrics like FGD (Frechet Gesture Distance), BC (Beat Consistency), L1 Diversity."
 
365
  ]
366
  },
367
- {
368
- "cell_type": "code",
369
- "metadata": {},
370
- "source": [
371
- "# Evaluation requires ground truth motion data\n",
372
- "# This is a demo showing how to use the evaluation API\n",
373
- "\n",
374
- "if EVAL_AVAILABLE:\n",
375
- " from emage_evaltools.metric import FGD, BC, L1Div\n",
376
- " \n",
377
- " # Initialize evaluators\n",
378
- " fgd_evaluator = FGD(download_path=\"./emage_evaltools/\")\n",
379
- " bc_evaluator = BC(download_path=\"./emage_evaltools/\", sigma=0.3, order=7)\n",
380
- " l1div_evaluator = L1Div()\n",
381
- " \n",
382
- " print(\"Evaluation tools loaded!\")\n",
383
- " print(\"Note: Full evaluation requires ground truth motion data from BEAT2 dataset\")\n",
384
- " print(\"Download BEAT2: git clone https://huggingface.co/datasets/H-Liu1997/BEAT2\")\n",
385
- "else:\n",
386
- " print(\"Evaluation tools not available. Clone from GitHub to enable.\")"
387
- ],
388
- "execution_count": null,
389
- "outputs": []
390
- },
391
- {
392
- "cell_type": "code",
393
- "metadata": {},
394
- "source": [
395
- "# Visualize EMAGE face result\n",
396
- "v2d_face = render2d(motion_dict, (720, 480), face_only=True, remove_global=True)\n",
397
- "emage_face_path = \"./outputs/emage_face.mp4\"\n",
398
- "write_video(emage_face_path, v2d_face.permute(0, 2, 3, 1), fps=30)\n",
399
- "print(\"EMAGE face visualization:\")\n",
400
- "display(Video(emage_face_path, embed=True, width=480))"
401
- ],
402
- "execution_count": null,
403
- "outputs": []
404
- },
405
  {
406
  "cell_type": "markdown",
407
  "metadata": {},
408
  "source": [
409
- "## 6. Download Results\n",
410
  "\n",
411
- "Download the generated motion files (`.npz`) for use with Blender."
412
  ]
413
  },
414
  {
415
  "cell_type": "code",
 
416
  "metadata": {},
 
417
  "source": [
418
- "from google.colab import files\n",
419
- "\n",
420
  "print(\"Generated files:\")\n",
421
- "print(f\" - CaMN: {camn_npz_path}\")\n",
422
- "print(f\" - DisCo: {disco_npz_path}\")\n",
423
- "print(f\" - EMAGE: {emage_npz_path}\")\n",
424
  "\n",
425
- "# Uncomment to download\n",
426
- "# files.download(camn_npz_path)\n",
427
- "# files.download(disco_npz_path)\n",
428
- "# files.download(emage_npz_path)"
429
- ],
430
- "execution_count": null,
431
- "outputs": []
432
  },
433
  {
434
  "cell_type": "markdown",
@@ -436,9 +256,10 @@
436
  "source": [
437
  "## Notes\n",
438
  "\n",
439
- "- **Motion Format**: The `.npz` files contain SMPL-X format motion data\n",
440
- "- **Blender Visualization**: Use the [Blender Add-on](https://huggingface.co/datasets/H-Liu1997/BEAT2_Tools/blob/main/smplx_blender_addon_20230921.zip) for high-quality rendering\n",
441
- "- **HuggingFace Space**: Try the [interactive demo](https://huggingface.co/spaces/H-Liu1997/EMAGE) for quick testing"
 
442
  ]
443
  }
444
  ],
@@ -455,4 +276,4 @@
455
  },
456
  "nbformat": 4,
457
  "nbformat_minor": 4
458
- }
 
11
  "- **DisCo**: Upper body gesture generation with diffusion\n",
12
  "- **EMAGE**: Full body + face gesture generation\n",
13
  "\n",
14
+ "[Project Page](https://pantomatrix.github.io/EMAGE/) | [GitHub](https://github.com/PantoMatrix/PantoMatrix) | [Paper](https://arxiv.org/abs/2401.00374) | [HF Space](https://huggingface.co/spaces/H-Liu1997/EMAGE)"
 
 
15
  ]
16
  },
17
  {
18
  "cell_type": "markdown",
19
  "metadata": {},
20
  "source": [
21
+ "## 1. Setup Environment\n",
22
+ "\n",
23
+ "Install Python 3.10, create virtual environment, and install dependencies."
24
  ]
25
  },
26
  {
27
  "cell_type": "code",
28
+ "execution_count": null,
29
  "metadata": {},
30
+ "outputs": [],
31
  "source": [
32
+ "# Install Python 3.10 and create virtual environment\n",
33
+ "!sudo add-apt-repository -y ppa:deadsnakes/ppa > /dev/null 2>&1\n",
34
+ "!sudo apt-get update -qq\n",
35
+ "!sudo apt-get install -y python3.10 python3.10-venv python3.10-dev > /dev/null 2>&1\n",
36
  "\n",
37
+ "ENV_PATH = \"/content/py310_env\"\n",
38
+ "!python3.10 -m venv {ENV_PATH}\n",
 
 
39
  "\n",
40
+ "PYTHON = f\"{ENV_PATH}/bin/python\"\n",
41
+ "PIP = f\"{ENV_PATH}/bin/pip\"\n",
42
+ "\n",
43
+ "# Install dependencies\n",
44
+ "!{PIP} install -q --upgrade pip\n",
45
+ "!{PIP} install -q torch==2.1.2 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n",
46
+ "!{PIP} install -q numpy==1.23.0 librosa soundfile transformers huggingface_hub\n",
47
+ "!{PIP} install -q smplx trimesh scipy easydict omegaconf\n",
48
+ "\n",
49
+ "# Verify\n",
50
+ "!{PYTHON} -c \"import torch; print(f'Python 3.10 + PyTorch {torch.__version__} + CUDA: {torch.cuda.is_available()}')\""
51
+ ]
52
  },
53
  {
54
  "cell_type": "code",
55
+ "execution_count": null,
56
  "metadata": {},
57
+ "outputs": [],
58
  "source": [
59
+ "# Clone code repositories\n",
60
  "!apt-get install -y git-lfs > /dev/null 2>&1\n",
61
  "!git lfs install\n",
62
  "\n",
63
+ "!git clone https://github.com/PantoMatrix/PantoMatrix.git /content/PantoMatrix\n",
64
+ "!git clone https://huggingface.co/H-Liu1997/emage_evaltools /content/PantoMatrix/emage_evaltools\n",
65
+ "%cd /content/PantoMatrix/emage_evaltools\n",
 
 
 
66
  "!git lfs pull\n",
67
+ "%cd /content\n",
68
  "\n",
69
+ "print(\"Code ready!\")"
70
+ ]
 
 
71
  },
72
  {
73
  "cell_type": "code",
74
+ "execution_count": null,
75
  "metadata": {},
76
+ "outputs": [],
77
  "source": [
78
+ "%%writefile /content/run_inference.py\n",
79
+ "import sys, os\n",
80
+ "sys.path.insert(0, '/content/PantoMatrix')\n",
81
+ "os.chdir('/content/PantoMatrix')\n",
 
 
 
 
82
  "\n",
83
+ "import torch, numpy as np, librosa, argparse\n",
84
+ "import torch.nn.functional as F\n",
85
+ "from torchvision.io import write_video\n",
86
  "from models.camn_audio import CamnAudioModel\n",
87
  "from models.disco_audio import DiscoAudioModel\n",
88
  "from models.emage_audio import EmageAudioModel, EmageVQVAEConv, EmageVAEConv, EmageVQModel\n",
89
  "from emage_utils.motion_io import beat_format_save\n",
90
  "from emage_utils.npz2pose import render2d\n",
 
 
 
 
 
 
 
 
 
91
  "\n",
92
+ "def main():\n",
93
+ " parser = argparse.ArgumentParser()\n",
94
+ " parser.add_argument('--audio', type=str, required=True)\n",
95
+ " parser.add_argument('--model', type=str, default='camn', choices=['camn', 'disco', 'emage'])\n",
96
+ " parser.add_argument('--output_dir', type=str, default='/content/outputs')\n",
97
+ " args = parser.parse_args()\n",
98
+ " \n",
99
+ " os.makedirs(args.output_dir, exist_ok=True)\n",
100
+ " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
101
+ " print(f'Using device: {device}')\n",
102
+ " \n",
103
+ " if args.model == 'camn':\n",
104
+ " model = CamnAudioModel.from_pretrained('H-Liu1997/camn_audio').to(device).eval()\n",
105
+ " sr, fps, seed = model.cfg.audio_sr, model.cfg.pose_fps, model.cfg.seed_frames\n",
106
+ " audio, _ = librosa.load(args.audio, sr=sr)\n",
107
+ " audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(device)\n",
108
+ " with torch.no_grad():\n",
109
+ " motion = model(audio_t, torch.zeros(1,1).long().to(device), seed_frames=seed)['motion_axis_angle']\n",
110
+ " npz_path = os.path.join(args.output_dir, 'camn_output.npz')\n",
111
+ " beat_format_save(npz_path, motion.cpu().numpy().reshape(motion.shape[1], -1), upsample=30//fps)\n",
112
+ " \n",
113
+ " elif args.model == 'disco':\n",
114
+ " model = DiscoAudioModel.from_pretrained('H-Liu1997/disco_audio').to(device).eval()\n",
115
+ " sr, fps, seed = model.cfg.audio_sr, model.cfg.pose_fps, model.cfg.seed_frames\n",
116
+ " audio, _ = librosa.load(args.audio, sr=sr)\n",
117
+ " audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(device)\n",
118
+ " with torch.no_grad():\n",
119
+ " motion = model(audio_t, torch.zeros(1,1).long().to(device), seed_frames=seed, seed_motion=None)['motion_axis_angle']\n",
120
+ " npz_path = os.path.join(args.output_dir, 'disco_output.npz')\n",
121
+ " beat_format_save(npz_path, motion.cpu().numpy().reshape(motion.shape[1], -1), upsample=30//fps)\n",
122
+ " \n",
123
+ " else: # emage\n",
124
+ " vq_models = {k: EmageVQVAEConv.from_pretrained('H-Liu1997/emage_audio', subfolder=f'emage_vq/{k}').to(device).eval() \n",
125
+ " for k in ['face', 'upper', 'lower', 'hands']}\n",
126
+ " global_ae = EmageVAEConv.from_pretrained('H-Liu1997/emage_audio', subfolder='emage_vq/global').to(device).eval()\n",
127
+ " vq = EmageVQModel(face_model=vq_models['face'], upper_model=vq_models['upper'],\n",
128
+ " lower_model=vq_models['lower'], hands_model=vq_models['hands'], global_model=global_ae).to(device).eval()\n",
129
+ " model = EmageAudioModel.from_pretrained('H-Liu1997/emage_audio').to(device).eval()\n",
130
+ " sr, fps = model.cfg.audio_sr, model.cfg.pose_fps\n",
131
+ " audio, _ = librosa.load(args.audio, sr=sr)\n",
132
+ " audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(device)\n",
133
+ " with torch.no_grad():\n",
134
+ " lat = model.inference(audio_t, torch.zeros(1,1).long().to(device), vq, masked_motion=None, mask=None)\n",
135
+ " get = lambda k, c: lat[f'rec_{k}'] if getattr(model.cfg, f'l{k[0]}') > 0 and getattr(model.cfg, f'c{k[0]}') == 0 else None\n",
136
+ " idx = lambda k: torch.max(F.log_softmax(lat[f'cls_{k}'], dim=2), dim=2)[1] if getattr(model.cfg, f'c{k[0]}') > 0 else None\n",
137
+ " pred = vq.decode(face_latent=get('face','f'), upper_latent=get('upper','u'), lower_latent=get('lower','l'), hands_latent=get('hands','h'),\n",
138
+ " face_index=idx('face'), upper_index=idx('upper'), lower_index=idx('lower'), hands_index=idx('hands'),\n",
139
+ " get_global_motion=True, ref_trans=torch.zeros(1,3).to(device))\n",
140
+ " motion = pred['motion_axis_angle']\n",
141
+ " npz_path = os.path.join(args.output_dir, 'emage_output.npz')\n",
142
+ " beat_format_save(npz_path, motion.cpu().numpy().reshape(motion.shape[1], -1), upsample=30//fps,\n",
143
+ " expressions=pred['expression'].cpu().numpy().reshape(motion.shape[1], -1),\n",
144
+ " trans=pred['trans'].cpu().numpy().reshape(motion.shape[1], -1))\n",
145
+ " \n",
146
+ " # Render 2D visualization\n",
147
+ " motion_dict = np.load(npz_path, allow_pickle=True)\n",
148
+ " v2d = render2d(motion_dict, (720, 480), face_only=False, remove_global=True)\n",
149
+ " video_path = npz_path.replace('.npz', '_2d.mp4')\n",
150
+ " write_video(video_path, v2d.permute(0, 2, 3, 1), fps=30)\n",
151
+ " print(f'Saved: {npz_path}')\n",
152
+ " print(f'Video: {video_path}')\n",
153
+ "\n",
154
+ "if __name__ == '__main__': main()"
155
+ ]
156
  },
157
  {
158
  "cell_type": "markdown",
159
  "metadata": {},
160
  "source": [
161
+ "## 2. Run Inference\n",
162
  "\n",
163
+ "Choose your audio and model, then run inference."
164
  ]
165
  },
166
  {
167
  "cell_type": "code",
168
+ "execution_count": null,
169
  "metadata": {},
170
+ "outputs": [],
171
  "source": [
172
+ "# Audio file (use example or upload your own)\n",
173
+ "audio_path = \"/content/PantoMatrix/examples/audio/2_scott_0_103_103_10s.wav\"\n",
174
  "\n",
175
+ "# Uncomment to upload your own audio:\n",
176
  "# from google.colab import files\n",
177
  "# uploaded = files.upload()\n",
178
+ "# audio_path = \"/content/\" + list(uploaded.keys())[0]\n",
179
  "\n",
180
+ "from IPython.display import Audio\n",
181
  "display(Audio(audio_path))"
 
 
 
 
 
 
 
 
 
 
 
182
  ]
183
  },
184
  {
185
  "cell_type": "code",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  "execution_count": null,
 
 
 
 
187
  "metadata": {},
188
+ "outputs": [],
189
  "source": [
190
+ "# Run CaMN (Upper Body)\n",
191
+ "PYTHON = \"/content/py310_env/bin/python\"\n",
192
+ "!{PYTHON} /content/run_inference.py --audio {audio_path} --model camn\n",
193
  "\n",
194
+ "from IPython.display import Video\n",
195
+ "display(Video(\"/content/outputs/camn_output_2d.mp4\", embed=True, width=600))"
196
  ]
197
  },
198
  {
199
  "cell_type": "code",
 
 
 
 
 
 
200
  "execution_count": null,
 
 
 
 
201
  "metadata": {},
202
+ "outputs": [],
203
  "source": [
204
+ "# Run DisCo (Upper Body with Diffusion)\n",
205
+ "PYTHON = \"/content/py310_env/bin/python\"\n",
206
+ "!{PYTHON} /content/run_inference.py --audio {audio_path} --model disco\n",
 
 
 
 
 
 
 
 
 
 
 
207
  "\n",
208
+ "from IPython.display import Video\n",
209
+ "display(Video(\"/content/outputs/disco_output_2d.mp4\", embed=True, width=600))"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  ]
211
  },
212
  {
213
  "cell_type": "code",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  "execution_count": null,
 
 
 
 
215
  "metadata": {},
216
+ "outputs": [],
217
  "source": [
218
+ "# Run EMAGE (Full Body + Face)\n",
219
+ "PYTHON = \"/content/py310_env/bin/python\"\n",
220
+ "!{PYTHON} /content/run_inference.py --audio {audio_path} --model emage\n",
221
  "\n",
222
+ "from IPython.display import Video\n",
223
+ "display(Video(\"/content/outputs/emage_output_2d.mp4\", embed=True, width=600))"
224
  ]
225
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  {
227
  "cell_type": "markdown",
228
  "metadata": {},
229
  "source": [
230
+ "## 3. Download Results\n",
231
  "\n",
232
+ "Download motion files (`.npz`) for use with Blender."
233
  ]
234
  },
235
  {
236
  "cell_type": "code",
237
+ "execution_count": null,
238
  "metadata": {},
239
+ "outputs": [],
240
  "source": [
241
+ "import os\n",
 
242
  "print(\"Generated files:\")\n",
243
+ "for f in os.listdir(\"/content/outputs\"):\n",
244
+ " print(f\" /content/outputs/{f}\")\n",
 
245
  "\n",
246
+ "# Uncomment to download:\n",
247
+ "# from google.colab import files\n",
248
+ "# files.download(\"/content/outputs/camn_output.npz\")\n",
249
+ "# files.download(\"/content/outputs/disco_output.npz\")\n",
250
+ "# files.download(\"/content/outputs/emage_output.npz\")"
251
+ ]
 
252
  },
253
  {
254
  "cell_type": "markdown",
 
256
  "source": [
257
  "## Notes\n",
258
  "\n",
259
+ "- **Environment**: Python 3.10.x + PyTorch 2.1.2 + CUDA 12.1\n",
260
+ "- **Motion Format**: `.npz` files contain SMPL-X format motion data\n",
261
+ "- **Visualization**: Use the [Blender Add-on](https://huggingface.co/datasets/H-Liu1997/BEAT2_Tools/blob/main/smplx_blender_addon_20230921.zip) for high-quality rendering\n",
262
+ "- **Interactive Demo**: [HuggingFace Space](https://huggingface.co/spaces/H-Liu1997/EMAGE)"
263
  ]
264
  }
265
  ],
 
276
  },
277
  "nbformat": 4,
278
  "nbformat_minor": 4
279
+ }