Spaces:
Running on Zero
Running on Zero
Clean rewrite: Python 3.10 venv + script-based inference
Browse files- EMAGE_Colab_Demo.ipynb +151 -330
EMAGE_Colab_Demo.ipynb
CHANGED
|
@@ -11,424 +11,244 @@
|
|
| 11 |
"- **DisCo**: Upper body gesture generation with diffusion\n",
|
| 12 |
"- **EMAGE**: Full body + face gesture generation\n",
|
| 13 |
"\n",
|
| 14 |
-
"[
|
| 15 |
-
"\n",
|
| 16 |
-
"[Project Page](https://pantomatrix.github.io/EMAGE/) | [GitHub](https://github.com/PantoMatrix/PantoMatrix) | [Paper](https://arxiv.org/abs/2401.00374)"
|
| 17 |
]
|
| 18 |
},
|
| 19 |
{
|
| 20 |
"cell_type": "markdown",
|
| 21 |
"metadata": {},
|
| 22 |
"source": [
|
| 23 |
-
"## 1. Setup Environment"
|
|
|
|
|
|
|
| 24 |
]
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
|
|
|
| 28 |
"metadata": {},
|
|
|
|
| 29 |
"source": [
|
| 30 |
-
"#
|
| 31 |
-
"
|
| 32 |
-
"
|
|
|
|
| 33 |
"\n",
|
| 34 |
-
"
|
| 35 |
-
"!
|
| 36 |
-
"!pip install -q numpy librosa soundfile transformers huggingface_hub\n",
|
| 37 |
-
"!pip install -q smplx trimesh scipy easydict omegaconf\n",
|
| 38 |
"\n",
|
| 39 |
-
"
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
},
|
| 47 |
{
|
| 48 |
"cell_type": "code",
|
|
|
|
| 49 |
"metadata": {},
|
|
|
|
| 50 |
"source": [
|
| 51 |
-
"#
|
| 52 |
"!apt-get install -y git-lfs > /dev/null 2>&1\n",
|
| 53 |
"!git lfs install\n",
|
| 54 |
"\n",
|
| 55 |
-
"
|
| 56 |
-
"!git clone https://
|
| 57 |
-
"\n",
|
| 58 |
-
"# Clone evaluation tools (contains SMPLX models)\n",
|
| 59 |
-
"!git clone https://huggingface.co/H-Liu1997/emage_evaltools PantoMatrix/emage_evaltools\n",
|
| 60 |
-
"%cd PantoMatrix/emage_evaltools\n",
|
| 61 |
"!git lfs pull\n",
|
| 62 |
-
"%cd /content
|
| 63 |
"\n",
|
| 64 |
-
"print(\"Code
|
| 65 |
-
]
|
| 66 |
-
"execution_count": null,
|
| 67 |
-
"outputs": []
|
| 68 |
},
|
| 69 |
{
|
| 70 |
"cell_type": "code",
|
|
|
|
| 71 |
"metadata": {},
|
|
|
|
| 72 |
"source": [
|
| 73 |
-
"
|
| 74 |
-
"import sys\n",
|
| 75 |
-
"
|
| 76 |
-
"
|
| 77 |
-
"import numpy as np\n",
|
| 78 |
-
"import librosa\n",
|
| 79 |
-
"import soundfile as sf\n",
|
| 80 |
-
"from IPython.display import Video, Audio, display\n",
|
| 81 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 82 |
"from models.camn_audio import CamnAudioModel\n",
|
| 83 |
"from models.disco_audio import DiscoAudioModel\n",
|
| 84 |
"from models.emage_audio import EmageAudioModel, EmageVQVAEConv, EmageVAEConv, EmageVQModel\n",
|
| 85 |
"from emage_utils.motion_io import beat_format_save\n",
|
| 86 |
"from emage_utils.npz2pose import render2d\n",
|
| 87 |
-
"from torchvision.io import write_video\n",
|
| 88 |
-
"import torch.nn.functional as F\n",
|
| 89 |
-
"\n",
|
| 90 |
-
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 91 |
-
"print(f\"Python: {sys.version}\")\n",
|
| 92 |
-
"print(f\"PyTorch: {torch.__version__}\")\n",
|
| 93 |
-
"print(f\"NumPy: {np.__version__}\")\n",
|
| 94 |
-
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
|
| 95 |
-
"print(f\"Using device: {device}\")\n",
|
| 96 |
"\n",
|
| 97 |
-
"
|
| 98 |
-
"\n",
|
| 99 |
-
"
|
| 100 |
-
"
|
| 101 |
-
"
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
},
|
| 106 |
{
|
| 107 |
"cell_type": "markdown",
|
| 108 |
"metadata": {},
|
| 109 |
"source": [
|
| 110 |
-
"## 2.
|
| 111 |
"\n",
|
| 112 |
-
"
|
| 113 |
]
|
| 114 |
},
|
| 115 |
{
|
| 116 |
"cell_type": "code",
|
|
|
|
| 117 |
"metadata": {},
|
|
|
|
| 118 |
"source": [
|
| 119 |
-
"#
|
| 120 |
-
"audio_path = \"
|
| 121 |
"\n",
|
| 122 |
-
"#
|
| 123 |
"# from google.colab import files\n",
|
| 124 |
"# uploaded = files.upload()\n",
|
| 125 |
-
"# audio_path = list(uploaded.keys())[0]\n",
|
| 126 |
"\n",
|
|
|
|
| 127 |
"display(Audio(audio_path))"
|
| 128 |
-
],
|
| 129 |
-
"execution_count": null,
|
| 130 |
-
"outputs": []
|
| 131 |
-
},
|
| 132 |
-
{
|
| 133 |
-
"cell_type": "markdown",
|
| 134 |
-
"metadata": {},
|
| 135 |
-
"source": [
|
| 136 |
-
"## 3. CaMN Model (Upper Body)\n",
|
| 137 |
-
"\n",
|
| 138 |
-
"CaMN generates upper body gestures from speech audio."
|
| 139 |
]
|
| 140 |
},
|
| 141 |
{
|
| 142 |
"cell_type": "code",
|
| 143 |
-
"metadata": {},
|
| 144 |
-
"source": [
|
| 145 |
-
"# Load CaMN model\n",
|
| 146 |
-
"model_camn = CamnAudioModel.from_pretrained(\"H-Liu1997/camn_audio\").to(device).eval()\n",
|
| 147 |
-
"print(\"CaMN model loaded!\")"
|
| 148 |
-
],
|
| 149 |
-
"execution_count": null,
|
| 150 |
-
"outputs": []
|
| 151 |
-
},
|
| 152 |
-
{
|
| 153 |
-
"cell_type": "code",
|
| 154 |
-
"metadata": {},
|
| 155 |
-
"source": [
|
| 156 |
-
"# CaMN Inference\n",
|
| 157 |
-
"sr_model = model_camn.cfg.audio_sr\n",
|
| 158 |
-
"pose_fps = model_camn.cfg.pose_fps\n",
|
| 159 |
-
"seed_frames = model_camn.cfg.seed_frames\n",
|
| 160 |
-
"\n",
|
| 161 |
-
"audio_loaded, _ = librosa.load(audio_path, sr=sr_model)\n",
|
| 162 |
-
"audio_t = torch.from_numpy(audio_loaded).float().unsqueeze(0).to(device)\n",
|
| 163 |
-
"sid = torch.zeros(1, 1).long().to(device)\n",
|
| 164 |
-
"\n",
|
| 165 |
-
"with torch.no_grad():\n",
|
| 166 |
-
" motion_pred = model_camn(audio_t, sid, seed_frames=seed_frames)[\"motion_axis_angle\"]\n",
|
| 167 |
-
"\n",
|
| 168 |
-
"t = motion_pred.shape[1]\n",
|
| 169 |
-
"motion_pred_np = motion_pred.cpu().numpy().reshape(t, -1)\n",
|
| 170 |
-
"\n",
|
| 171 |
-
"# Save motion\n",
|
| 172 |
-
"camn_npz_path = \"./outputs/camn_output.npz\"\n",
|
| 173 |
-
"beat_format_save(camn_npz_path, motion_pred_np, upsample=30 // pose_fps)\n",
|
| 174 |
-
"print(f\"CaMN motion saved to {camn_npz_path}\")"
|
| 175 |
-
],
|
| 176 |
-
"execution_count": null,
|
| 177 |
-
"outputs": []
|
| 178 |
-
},
|
| 179 |
-
{
|
| 180 |
-
"cell_type": "code",
|
| 181 |
-
"metadata": {},
|
| 182 |
-
"source": [
|
| 183 |
-
"# Visualize CaMN result\n",
|
| 184 |
-
"motion_dict = np.load(camn_npz_path, allow_pickle=True)\n",
|
| 185 |
-
"v2d = render2d(motion_dict, (720, 480), face_only=False, remove_global=True)\n",
|
| 186 |
-
"camn_video_path = \"./outputs/camn_output.mp4\"\n",
|
| 187 |
-
"write_video(camn_video_path, v2d.permute(0, 2, 3, 1), fps=30)\n",
|
| 188 |
-
"print(\"CaMN visualization:\")\n",
|
| 189 |
-
"display(Video(camn_video_path, embed=True, width=480))"
|
| 190 |
-
],
|
| 191 |
"execution_count": null,
|
| 192 |
-
"outputs": []
|
| 193 |
-
},
|
| 194 |
-
{
|
| 195 |
-
"cell_type": "markdown",
|
| 196 |
"metadata": {},
|
|
|
|
| 197 |
"source": [
|
| 198 |
-
"#
|
|
|
|
|
|
|
| 199 |
"\n",
|
| 200 |
-
"
|
|
|
|
| 201 |
]
|
| 202 |
},
|
| 203 |
{
|
| 204 |
"cell_type": "code",
|
| 205 |
-
"metadata": {},
|
| 206 |
-
"source": [
|
| 207 |
-
"# Load DisCo model\n",
|
| 208 |
-
"model_disco = DiscoAudioModel.from_pretrained(\"H-Liu1997/disco_audio\").to(device).eval()\n",
|
| 209 |
-
"print(\"DisCo model loaded!\")"
|
| 210 |
-
],
|
| 211 |
"execution_count": null,
|
| 212 |
-
"outputs": []
|
| 213 |
-
},
|
| 214 |
-
{
|
| 215 |
-
"cell_type": "code",
|
| 216 |
"metadata": {},
|
|
|
|
| 217 |
"source": [
|
| 218 |
-
"# DisCo
|
| 219 |
-
"
|
| 220 |
-
"
|
| 221 |
-
"seed_frames = model_disco.cfg.seed_frames\n",
|
| 222 |
-
"\n",
|
| 223 |
-
"audio_loaded, _ = librosa.load(audio_path, sr=sr_model)\n",
|
| 224 |
-
"audio_t = torch.from_numpy(audio_loaded).float().unsqueeze(0).to(device)\n",
|
| 225 |
-
"sid = torch.zeros(1, 1).long().to(device)\n",
|
| 226 |
-
"\n",
|
| 227 |
-
"with torch.no_grad():\n",
|
| 228 |
-
" motion_pred = model_disco(audio_t, sid, seed_frames=seed_frames, seed_motion=None)[\"motion_axis_angle\"]\n",
|
| 229 |
-
"\n",
|
| 230 |
-
"t = motion_pred.shape[1]\n",
|
| 231 |
-
"motion_pred_np = motion_pred.cpu().numpy().reshape(t, -1)\n",
|
| 232 |
"\n",
|
| 233 |
-
"
|
| 234 |
-
"
|
| 235 |
-
"beat_format_save(disco_npz_path, motion_pred_np, upsample=30 // pose_fps)\n",
|
| 236 |
-
"print(f\"DisCo motion saved to {disco_npz_path}\")"
|
| 237 |
-
],
|
| 238 |
-
"execution_count": null,
|
| 239 |
-
"outputs": []
|
| 240 |
-
},
|
| 241 |
-
{
|
| 242 |
-
"cell_type": "code",
|
| 243 |
-
"metadata": {},
|
| 244 |
-
"source": [
|
| 245 |
-
"# Visualize DisCo result\n",
|
| 246 |
-
"motion_dict = np.load(disco_npz_path, allow_pickle=True)\n",
|
| 247 |
-
"v2d = render2d(motion_dict, (720, 480), face_only=False, remove_global=True)\n",
|
| 248 |
-
"disco_video_path = \"./outputs/disco_output.mp4\"\n",
|
| 249 |
-
"write_video(disco_video_path, v2d.permute(0, 2, 3, 1), fps=30)\n",
|
| 250 |
-
"print(\"DisCo visualization:\")\n",
|
| 251 |
-
"display(Video(disco_video_path, embed=True, width=480))"
|
| 252 |
-
],
|
| 253 |
-
"execution_count": null,
|
| 254 |
-
"outputs": []
|
| 255 |
-
},
|
| 256 |
-
{
|
| 257 |
-
"cell_type": "markdown",
|
| 258 |
-
"metadata": {},
|
| 259 |
-
"source": [
|
| 260 |
-
"## 5. EMAGE Model (Full Body + Face)\n",
|
| 261 |
-
"\n",
|
| 262 |
-
"EMAGE generates full body gestures including face expressions."
|
| 263 |
]
|
| 264 |
},
|
| 265 |
{
|
| 266 |
"cell_type": "code",
|
| 267 |
-
"metadata": {},
|
| 268 |
-
"source": [
|
| 269 |
-
"# Load EMAGE model and VQ components\n",
|
| 270 |
-
"face_motion_vq = EmageVQVAEConv.from_pretrained(\"H-Liu1997/emage_audio\", subfolder=\"emage_vq/face\").to(device).eval()\n",
|
| 271 |
-
"upper_motion_vq = EmageVQVAEConv.from_pretrained(\"H-Liu1997/emage_audio\", subfolder=\"emage_vq/upper\").to(device).eval()\n",
|
| 272 |
-
"lower_motion_vq = EmageVQVAEConv.from_pretrained(\"H-Liu1997/emage_audio\", subfolder=\"emage_vq/lower\").to(device).eval()\n",
|
| 273 |
-
"hands_motion_vq = EmageVQVAEConv.from_pretrained(\"H-Liu1997/emage_audio\", subfolder=\"emage_vq/hands\").to(device).eval()\n",
|
| 274 |
-
"global_motion_ae = EmageVAEConv.from_pretrained(\"H-Liu1997/emage_audio\", subfolder=\"emage_vq/global\").to(device).eval()\n",
|
| 275 |
-
"\n",
|
| 276 |
-
"emage_vq_model = EmageVQModel(\n",
|
| 277 |
-
" face_model=face_motion_vq, \n",
|
| 278 |
-
" upper_model=upper_motion_vq,\n",
|
| 279 |
-
" lower_model=lower_motion_vq, \n",
|
| 280 |
-
" hands_model=hands_motion_vq,\n",
|
| 281 |
-
" global_model=global_motion_ae\n",
|
| 282 |
-
").to(device).eval()\n",
|
| 283 |
-
"\n",
|
| 284 |
-
"model_emage = EmageAudioModel.from_pretrained(\"H-Liu1997/emage_audio\").to(device).eval()\n",
|
| 285 |
-
"print(\"EMAGE model loaded!\")"
|
| 286 |
-
],
|
| 287 |
-
"execution_count": null,
|
| 288 |
-
"outputs": []
|
| 289 |
-
},
|
| 290 |
-
{
|
| 291 |
-
"cell_type": "code",
|
| 292 |
-
"metadata": {},
|
| 293 |
-
"source": [
|
| 294 |
-
"# EMAGE Inference\n",
|
| 295 |
-
"sr_model = model_emage.cfg.audio_sr\n",
|
| 296 |
-
"pose_fps = model_emage.cfg.pose_fps\n",
|
| 297 |
-
"\n",
|
| 298 |
-
"audio_loaded, _ = librosa.load(audio_path, sr=sr_model)\n",
|
| 299 |
-
"audio_t = torch.from_numpy(audio_loaded).float().unsqueeze(0).to(device)\n",
|
| 300 |
-
"sid = torch.zeros(1, 1).long().to(device)\n",
|
| 301 |
-
"\n",
|
| 302 |
-
"with torch.no_grad():\n",
|
| 303 |
-
" latent_dict = model_emage.inference(audio_t, sid, emage_vq_model, masked_motion=None, mask=None)\n",
|
| 304 |
-
" \n",
|
| 305 |
-
" face_latent = latent_dict[\"rec_face\"] if model_emage.cfg.lf > 0 and model_emage.cfg.cf == 0 else None\n",
|
| 306 |
-
" upper_latent = latent_dict[\"rec_upper\"] if model_emage.cfg.lu > 0 and model_emage.cfg.cu == 0 else None\n",
|
| 307 |
-
" hands_latent = latent_dict[\"rec_hands\"] if model_emage.cfg.lh > 0 and model_emage.cfg.ch == 0 else None\n",
|
| 308 |
-
" lower_latent = latent_dict[\"rec_lower\"] if model_emage.cfg.ll > 0 and model_emage.cfg.cl == 0 else None\n",
|
| 309 |
-
"\n",
|
| 310 |
-
" face_index = torch.max(F.log_softmax(latent_dict[\"cls_face\"], dim=2), dim=2)[1] if model_emage.cfg.cf > 0 else None\n",
|
| 311 |
-
" upper_index = torch.max(F.log_softmax(latent_dict[\"cls_upper\"], dim=2), dim=2)[1] if model_emage.cfg.cu > 0 else None\n",
|
| 312 |
-
" hands_index = torch.max(F.log_softmax(latent_dict[\"cls_hands\"], dim=2), dim=2)[1] if model_emage.cfg.ch > 0 else None\n",
|
| 313 |
-
" lower_index = torch.max(F.log_softmax(latent_dict[\"cls_lower\"], dim=2), dim=2)[1] if model_emage.cfg.cl > 0 else None\n",
|
| 314 |
-
"\n",
|
| 315 |
-
" ref_trans = torch.zeros(1, 1, 3).to(device)\n",
|
| 316 |
-
" all_pred = emage_vq_model.decode(\n",
|
| 317 |
-
" face_latent=face_latent, \n",
|
| 318 |
-
" upper_latent=upper_latent, \n",
|
| 319 |
-
" lower_latent=lower_latent, \n",
|
| 320 |
-
" hands_latent=hands_latent,\n",
|
| 321 |
-
" face_index=face_index, \n",
|
| 322 |
-
" upper_index=upper_index, \n",
|
| 323 |
-
" lower_index=lower_index, \n",
|
| 324 |
-
" hands_index=hands_index,\n",
|
| 325 |
-
" get_global_motion=True, \n",
|
| 326 |
-
" ref_trans=ref_trans[:, 0]\n",
|
| 327 |
-
" )\n",
|
| 328 |
-
"\n",
|
| 329 |
-
"motion_pred = all_pred[\"motion_axis_angle\"]\n",
|
| 330 |
-
"t = motion_pred.shape[1]\n",
|
| 331 |
-
"motion_pred_np = motion_pred.cpu().numpy().reshape(t, -1)\n",
|
| 332 |
-
"face_pred = all_pred[\"expression\"].cpu().numpy().reshape(t, -1)\n",
|
| 333 |
-
"trans_pred = all_pred[\"trans\"].cpu().numpy().reshape(t, -1)\n",
|
| 334 |
-
"\n",
|
| 335 |
-
"# Save motion\n",
|
| 336 |
-
"emage_npz_path = \"./outputs/emage_output.npz\"\n",
|
| 337 |
-
"beat_format_save(emage_npz_path, motion_pred_np, upsample=30 // pose_fps, expressions=face_pred, trans=trans_pred)\n",
|
| 338 |
-
"print(f\"EMAGE motion saved to {emage_npz_path}\")"
|
| 339 |
-
],
|
| 340 |
-
"execution_count": null,
|
| 341 |
-
"outputs": []
|
| 342 |
-
},
|
| 343 |
-
{
|
| 344 |
-
"cell_type": "code",
|
| 345 |
-
"metadata": {},
|
| 346 |
-
"source": [
|
| 347 |
-
"# Visualize EMAGE body result\n",
|
| 348 |
-
"motion_dict = np.load(emage_npz_path, allow_pickle=True)\n",
|
| 349 |
-
"v2d_body = render2d(motion_dict, (720, 480), face_only=False, remove_global=True)\n",
|
| 350 |
-
"emage_body_path = \"./outputs/emage_body.mp4\"\n",
|
| 351 |
-
"write_video(emage_body_path, v2d_body.permute(0, 2, 3, 1), fps=30)\n",
|
| 352 |
-
"print(\"EMAGE body visualization:\")\n",
|
| 353 |
-
"display(Video(emage_body_path, embed=True, width=480))"
|
| 354 |
-
],
|
| 355 |
"execution_count": null,
|
| 356 |
-
"outputs": []
|
| 357 |
-
},
|
| 358 |
-
{
|
| 359 |
-
"cell_type": "markdown",
|
| 360 |
"metadata": {},
|
|
|
|
| 361 |
"source": [
|
| 362 |
-
"#
|
|
|
|
|
|
|
| 363 |
"\n",
|
| 364 |
-
"
|
|
|
|
| 365 |
]
|
| 366 |
},
|
| 367 |
-
{
|
| 368 |
-
"cell_type": "code",
|
| 369 |
-
"metadata": {},
|
| 370 |
-
"source": [
|
| 371 |
-
"# Evaluation requires ground truth motion data\n",
|
| 372 |
-
"# This is a demo showing how to use the evaluation API\n",
|
| 373 |
-
"\n",
|
| 374 |
-
"if EVAL_AVAILABLE:\n",
|
| 375 |
-
" from emage_evaltools.metric import FGD, BC, L1Div\n",
|
| 376 |
-
" \n",
|
| 377 |
-
" # Initialize evaluators\n",
|
| 378 |
-
" fgd_evaluator = FGD(download_path=\"./emage_evaltools/\")\n",
|
| 379 |
-
" bc_evaluator = BC(download_path=\"./emage_evaltools/\", sigma=0.3, order=7)\n",
|
| 380 |
-
" l1div_evaluator = L1Div()\n",
|
| 381 |
-
" \n",
|
| 382 |
-
" print(\"Evaluation tools loaded!\")\n",
|
| 383 |
-
" print(\"Note: Full evaluation requires ground truth motion data from BEAT2 dataset\")\n",
|
| 384 |
-
" print(\"Download BEAT2: git clone https://huggingface.co/datasets/H-Liu1997/BEAT2\")\n",
|
| 385 |
-
"else:\n",
|
| 386 |
-
" print(\"Evaluation tools not available. Clone from GitHub to enable.\")"
|
| 387 |
-
],
|
| 388 |
-
"execution_count": null,
|
| 389 |
-
"outputs": []
|
| 390 |
-
},
|
| 391 |
-
{
|
| 392 |
-
"cell_type": "code",
|
| 393 |
-
"metadata": {},
|
| 394 |
-
"source": [
|
| 395 |
-
"# Visualize EMAGE face result\n",
|
| 396 |
-
"v2d_face = render2d(motion_dict, (720, 480), face_only=True, remove_global=True)\n",
|
| 397 |
-
"emage_face_path = \"./outputs/emage_face.mp4\"\n",
|
| 398 |
-
"write_video(emage_face_path, v2d_face.permute(0, 2, 3, 1), fps=30)\n",
|
| 399 |
-
"print(\"EMAGE face visualization:\")\n",
|
| 400 |
-
"display(Video(emage_face_path, embed=True, width=480))"
|
| 401 |
-
],
|
| 402 |
-
"execution_count": null,
|
| 403 |
-
"outputs": []
|
| 404 |
-
},
|
| 405 |
{
|
| 406 |
"cell_type": "markdown",
|
| 407 |
"metadata": {},
|
| 408 |
"source": [
|
| 409 |
-
"##
|
| 410 |
"\n",
|
| 411 |
-
"Download
|
| 412 |
]
|
| 413 |
},
|
| 414 |
{
|
| 415 |
"cell_type": "code",
|
|
|
|
| 416 |
"metadata": {},
|
|
|
|
| 417 |
"source": [
|
| 418 |
-
"
|
| 419 |
-
"\n",
|
| 420 |
"print(\"Generated files:\")\n",
|
| 421 |
-
"
|
| 422 |
-
"print(f\"
|
| 423 |
-
"print(f\" - EMAGE: {emage_npz_path}\")\n",
|
| 424 |
"\n",
|
| 425 |
-
"# Uncomment to download\n",
|
| 426 |
-
"#
|
| 427 |
-
"# files.download(
|
| 428 |
-
"# files.download(
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
"outputs": []
|
| 432 |
},
|
| 433 |
{
|
| 434 |
"cell_type": "markdown",
|
|
@@ -436,9 +256,10 @@
|
|
| 436 |
"source": [
|
| 437 |
"## Notes\n",
|
| 438 |
"\n",
|
| 439 |
-
"- **
|
| 440 |
-
"- **
|
| 441 |
-
"- **
|
|
|
|
| 442 |
]
|
| 443 |
}
|
| 444 |
],
|
|
@@ -455,4 +276,4 @@
|
|
| 455 |
},
|
| 456 |
"nbformat": 4,
|
| 457 |
"nbformat_minor": 4
|
| 458 |
-
}
|
|
|
|
| 11 |
"- **DisCo**: Upper body gesture generation with diffusion\n",
|
| 12 |
"- **EMAGE**: Full body + face gesture generation\n",
|
| 13 |
"\n",
|
| 14 |
+
"[Project Page](https://pantomatrix.github.io/EMAGE/) | [GitHub](https://github.com/PantoMatrix/PantoMatrix) | [Paper](https://arxiv.org/abs/2401.00374) | [HF Space](https://huggingface.co/spaces/H-Liu1997/EMAGE)"
|
|
|
|
|
|
|
| 15 |
]
|
| 16 |
},
|
| 17 |
{
|
| 18 |
"cell_type": "markdown",
|
| 19 |
"metadata": {},
|
| 20 |
"source": [
|
| 21 |
+
"## 1. Setup Environment\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"Install Python 3.10, create virtual environment, and install dependencies."
|
| 24 |
]
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
| 28 |
+
"execution_count": null,
|
| 29 |
"metadata": {},
|
| 30 |
+
"outputs": [],
|
| 31 |
"source": [
|
| 32 |
+
"# Install Python 3.10 and create virtual environment\n",
|
| 33 |
+
"!sudo add-apt-repository -y ppa:deadsnakes/ppa > /dev/null 2>&1\n",
|
| 34 |
+
"!sudo apt-get update -qq\n",
|
| 35 |
+
"!sudo apt-get install -y python3.10 python3.10-venv python3.10-dev > /dev/null 2>&1\n",
|
| 36 |
"\n",
|
| 37 |
+
"ENV_PATH = \"/content/py310_env\"\n",
|
| 38 |
+
"!python3.10 -m venv {ENV_PATH}\n",
|
|
|
|
|
|
|
| 39 |
"\n",
|
| 40 |
+
"PYTHON = f\"{ENV_PATH}/bin/python\"\n",
|
| 41 |
+
"PIP = f\"{ENV_PATH}/bin/pip\"\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"# Install dependencies\n",
|
| 44 |
+
"!{PIP} install -q --upgrade pip\n",
|
| 45 |
+
"!{PIP} install -q torch==2.1.2 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n",
|
| 46 |
+
"!{PIP} install -q numpy==1.23.0 librosa soundfile transformers huggingface_hub\n",
|
| 47 |
+
"!{PIP} install -q smplx trimesh scipy easydict omegaconf\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"# Verify\n",
|
| 50 |
+
"!{PYTHON} -c \"import torch; print(f'Python 3.10 + PyTorch {torch.__version__} + CUDA: {torch.cuda.is_available()}')\""
|
| 51 |
+
]
|
| 52 |
},
|
| 53 |
{
|
| 54 |
"cell_type": "code",
|
| 55 |
+
"execution_count": null,
|
| 56 |
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
"source": [
|
| 59 |
+
"# Clone code repositories\n",
|
| 60 |
"!apt-get install -y git-lfs > /dev/null 2>&1\n",
|
| 61 |
"!git lfs install\n",
|
| 62 |
"\n",
|
| 63 |
+
"!git clone https://github.com/PantoMatrix/PantoMatrix.git /content/PantoMatrix\n",
|
| 64 |
+
"!git clone https://huggingface.co/H-Liu1997/emage_evaltools /content/PantoMatrix/emage_evaltools\n",
|
| 65 |
+
"%cd /content/PantoMatrix/emage_evaltools\n",
|
|
|
|
|
|
|
|
|
|
| 66 |
"!git lfs pull\n",
|
| 67 |
+
"%cd /content\n",
|
| 68 |
"\n",
|
| 69 |
+
"print(\"Code ready!\")"
|
| 70 |
+
]
|
|
|
|
|
|
|
| 71 |
},
|
| 72 |
{
|
| 73 |
"cell_type": "code",
|
| 74 |
+
"execution_count": null,
|
| 75 |
"metadata": {},
|
| 76 |
+
"outputs": [],
|
| 77 |
"source": [
|
| 78 |
+
"%%writefile /content/run_inference.py\n",
|
| 79 |
+
"import sys, os\n",
|
| 80 |
+
"sys.path.insert(0, '/content/PantoMatrix')\n",
|
| 81 |
+
"os.chdir('/content/PantoMatrix')\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
"\n",
|
| 83 |
+
"import torch, numpy as np, librosa, argparse\n",
|
| 84 |
+
"import torch.nn.functional as F\n",
|
| 85 |
+
"from torchvision.io import write_video\n",
|
| 86 |
"from models.camn_audio import CamnAudioModel\n",
|
| 87 |
"from models.disco_audio import DiscoAudioModel\n",
|
| 88 |
"from models.emage_audio import EmageAudioModel, EmageVQVAEConv, EmageVAEConv, EmageVQModel\n",
|
| 89 |
"from emage_utils.motion_io import beat_format_save\n",
|
| 90 |
"from emage_utils.npz2pose import render2d\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
"\n",
|
| 92 |
+
"def main():\n",
|
| 93 |
+
" parser = argparse.ArgumentParser()\n",
|
| 94 |
+
" parser.add_argument('--audio', type=str, required=True)\n",
|
| 95 |
+
" parser.add_argument('--model', type=str, default='camn', choices=['camn', 'disco', 'emage'])\n",
|
| 96 |
+
" parser.add_argument('--output_dir', type=str, default='/content/outputs')\n",
|
| 97 |
+
" args = parser.parse_args()\n",
|
| 98 |
+
" \n",
|
| 99 |
+
" os.makedirs(args.output_dir, exist_ok=True)\n",
|
| 100 |
+
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 101 |
+
" print(f'Using device: {device}')\n",
|
| 102 |
+
" \n",
|
| 103 |
+
" if args.model == 'camn':\n",
|
| 104 |
+
" model = CamnAudioModel.from_pretrained('H-Liu1997/camn_audio').to(device).eval()\n",
|
| 105 |
+
" sr, fps, seed = model.cfg.audio_sr, model.cfg.pose_fps, model.cfg.seed_frames\n",
|
| 106 |
+
" audio, _ = librosa.load(args.audio, sr=sr)\n",
|
| 107 |
+
" audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(device)\n",
|
| 108 |
+
" with torch.no_grad():\n",
|
| 109 |
+
" motion = model(audio_t, torch.zeros(1,1).long().to(device), seed_frames=seed)['motion_axis_angle']\n",
|
| 110 |
+
" npz_path = os.path.join(args.output_dir, 'camn_output.npz')\n",
|
| 111 |
+
" beat_format_save(npz_path, motion.cpu().numpy().reshape(motion.shape[1], -1), upsample=30//fps)\n",
|
| 112 |
+
" \n",
|
| 113 |
+
" elif args.model == 'disco':\n",
|
| 114 |
+
" model = DiscoAudioModel.from_pretrained('H-Liu1997/disco_audio').to(device).eval()\n",
|
| 115 |
+
" sr, fps, seed = model.cfg.audio_sr, model.cfg.pose_fps, model.cfg.seed_frames\n",
|
| 116 |
+
" audio, _ = librosa.load(args.audio, sr=sr)\n",
|
| 117 |
+
" audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(device)\n",
|
| 118 |
+
" with torch.no_grad():\n",
|
| 119 |
+
" motion = model(audio_t, torch.zeros(1,1).long().to(device), seed_frames=seed, seed_motion=None)['motion_axis_angle']\n",
|
| 120 |
+
" npz_path = os.path.join(args.output_dir, 'disco_output.npz')\n",
|
| 121 |
+
" beat_format_save(npz_path, motion.cpu().numpy().reshape(motion.shape[1], -1), upsample=30//fps)\n",
|
| 122 |
+
" \n",
|
| 123 |
+
" else: # emage\n",
|
| 124 |
+
" vq_models = {k: EmageVQVAEConv.from_pretrained('H-Liu1997/emage_audio', subfolder=f'emage_vq/{k}').to(device).eval() \n",
|
| 125 |
+
" for k in ['face', 'upper', 'lower', 'hands']}\n",
|
| 126 |
+
" global_ae = EmageVAEConv.from_pretrained('H-Liu1997/emage_audio', subfolder='emage_vq/global').to(device).eval()\n",
|
| 127 |
+
" vq = EmageVQModel(face_model=vq_models['face'], upper_model=vq_models['upper'],\n",
|
| 128 |
+
" lower_model=vq_models['lower'], hands_model=vq_models['hands'], global_model=global_ae).to(device).eval()\n",
|
| 129 |
+
" model = EmageAudioModel.from_pretrained('H-Liu1997/emage_audio').to(device).eval()\n",
|
| 130 |
+
" sr, fps = model.cfg.audio_sr, model.cfg.pose_fps\n",
|
| 131 |
+
" audio, _ = librosa.load(args.audio, sr=sr)\n",
|
| 132 |
+
" audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(device)\n",
|
| 133 |
+
" with torch.no_grad():\n",
|
| 134 |
+
" lat = model.inference(audio_t, torch.zeros(1,1).long().to(device), vq, masked_motion=None, mask=None)\n",
|
| 135 |
+
" get = lambda k, c: lat[f'rec_{k}'] if getattr(model.cfg, f'l{k[0]}') > 0 and getattr(model.cfg, f'c{k[0]}') == 0 else None\n",
|
| 136 |
+
" idx = lambda k: torch.max(F.log_softmax(lat[f'cls_{k}'], dim=2), dim=2)[1] if getattr(model.cfg, f'c{k[0]}') > 0 else None\n",
|
| 137 |
+
" pred = vq.decode(face_latent=get('face','f'), upper_latent=get('upper','u'), lower_latent=get('lower','l'), hands_latent=get('hands','h'),\n",
|
| 138 |
+
" face_index=idx('face'), upper_index=idx('upper'), lower_index=idx('lower'), hands_index=idx('hands'),\n",
|
| 139 |
+
" get_global_motion=True, ref_trans=torch.zeros(1,3).to(device))\n",
|
| 140 |
+
" motion = pred['motion_axis_angle']\n",
|
| 141 |
+
" npz_path = os.path.join(args.output_dir, 'emage_output.npz')\n",
|
| 142 |
+
" beat_format_save(npz_path, motion.cpu().numpy().reshape(motion.shape[1], -1), upsample=30//fps,\n",
|
| 143 |
+
" expressions=pred['expression'].cpu().numpy().reshape(motion.shape[1], -1),\n",
|
| 144 |
+
" trans=pred['trans'].cpu().numpy().reshape(motion.shape[1], -1))\n",
|
| 145 |
+
" \n",
|
| 146 |
+
" # Render 2D visualization\n",
|
| 147 |
+
" motion_dict = np.load(npz_path, allow_pickle=True)\n",
|
| 148 |
+
" v2d = render2d(motion_dict, (720, 480), face_only=False, remove_global=True)\n",
|
| 149 |
+
" video_path = npz_path.replace('.npz', '_2d.mp4')\n",
|
| 150 |
+
" write_video(video_path, v2d.permute(0, 2, 3, 1), fps=30)\n",
|
| 151 |
+
" print(f'Saved: {npz_path}')\n",
|
| 152 |
+
" print(f'Video: {video_path}')\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"if __name__ == '__main__': main()"
|
| 155 |
+
]
|
| 156 |
},
|
| 157 |
{
|
| 158 |
"cell_type": "markdown",
|
| 159 |
"metadata": {},
|
| 160 |
"source": [
|
| 161 |
+
"## 2. Run Inference\n",
|
| 162 |
"\n",
|
| 163 |
+
"Choose your audio and model, then run inference."
|
| 164 |
]
|
| 165 |
},
|
| 166 |
{
|
| 167 |
"cell_type": "code",
|
| 168 |
+
"execution_count": null,
|
| 169 |
"metadata": {},
|
| 170 |
+
"outputs": [],
|
| 171 |
"source": [
|
| 172 |
+
"# Audio file (use example or upload your own)\n",
|
| 173 |
+
"audio_path = \"/content/PantoMatrix/examples/audio/2_scott_0_103_103_10s.wav\"\n",
|
| 174 |
"\n",
|
| 175 |
+
"# Uncomment to upload your own audio:\n",
|
| 176 |
"# from google.colab import files\n",
|
| 177 |
"# uploaded = files.upload()\n",
|
| 178 |
+
"# audio_path = \"/content/\" + list(uploaded.keys())[0]\n",
|
| 179 |
"\n",
|
| 180 |
+
"from IPython.display import Audio\n",
|
| 181 |
"display(Audio(audio_path))"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
]
|
| 183 |
},
|
| 184 |
{
|
| 185 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
"execution_count": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
"metadata": {},
|
| 188 |
+
"outputs": [],
|
| 189 |
"source": [
|
| 190 |
+
"# Run CaMN (Upper Body)\n",
|
| 191 |
+
"PYTHON = \"/content/py310_env/bin/python\"\n",
|
| 192 |
+
"!{PYTHON} /content/run_inference.py --audio {audio_path} --model camn\n",
|
| 193 |
"\n",
|
| 194 |
+
"from IPython.display import Video\n",
|
| 195 |
+
"display(Video(\"/content/outputs/camn_output_2d.mp4\", embed=True, width=600))"
|
| 196 |
]
|
| 197 |
},
|
| 198 |
{
|
| 199 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
"execution_count": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
"metadata": {},
|
| 202 |
+
"outputs": [],
|
| 203 |
"source": [
|
| 204 |
+
"# Run DisCo (Upper Body with Diffusion)\n",
|
| 205 |
+
"PYTHON = \"/content/py310_env/bin/python\"\n",
|
| 206 |
+
"!{PYTHON} /content/run_inference.py --audio {audio_path} --model disco\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
"\n",
|
| 208 |
+
"from IPython.display import Video\n",
|
| 209 |
+
"display(Video(\"/content/outputs/disco_output_2d.mp4\", embed=True, width=600))"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
]
|
| 211 |
},
|
| 212 |
{
|
| 213 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
"execution_count": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
"metadata": {},
|
| 216 |
+
"outputs": [],
|
| 217 |
"source": [
|
| 218 |
+
"# Run EMAGE (Full Body + Face)\n",
|
| 219 |
+
"PYTHON = \"/content/py310_env/bin/python\"\n",
|
| 220 |
+
"!{PYTHON} /content/run_inference.py --audio {audio_path} --model emage\n",
|
| 221 |
"\n",
|
| 222 |
+
"from IPython.display import Video\n",
|
| 223 |
+
"display(Video(\"/content/outputs/emage_output_2d.mp4\", embed=True, width=600))"
|
| 224 |
]
|
| 225 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
{
|
| 227 |
"cell_type": "markdown",
|
| 228 |
"metadata": {},
|
| 229 |
"source": [
|
| 230 |
+
"## 3. Download Results\n",
|
| 231 |
"\n",
|
| 232 |
+
"Download motion files (`.npz`) for use with Blender."
|
| 233 |
]
|
| 234 |
},
|
| 235 |
{
|
| 236 |
"cell_type": "code",
|
| 237 |
+
"execution_count": null,
|
| 238 |
"metadata": {},
|
| 239 |
+
"outputs": [],
|
| 240 |
"source": [
|
| 241 |
+
"import os\n",
|
|
|
|
| 242 |
"print(\"Generated files:\")\n",
|
| 243 |
+
"for f in os.listdir(\"/content/outputs\"):\n",
|
| 244 |
+
" print(f\" /content/outputs/{f}\")\n",
|
|
|
|
| 245 |
"\n",
|
| 246 |
+
"# Uncomment to download:\n",
|
| 247 |
+
"# from google.colab import files\n",
|
| 248 |
+
"# files.download(\"/content/outputs/camn_output.npz\")\n",
|
| 249 |
+
"# files.download(\"/content/outputs/disco_output.npz\")\n",
|
| 250 |
+
"# files.download(\"/content/outputs/emage_output.npz\")"
|
| 251 |
+
]
|
|
|
|
| 252 |
},
|
| 253 |
{
|
| 254 |
"cell_type": "markdown",
|
|
|
|
| 256 |
"source": [
|
| 257 |
"## Notes\n",
|
| 258 |
"\n",
|
| 259 |
+
"- **Environment**: Python 3.10.x + PyTorch 2.1.2 + CUDA 12.1\n",
|
| 260 |
+
"- **Motion Format**: `.npz` files contain SMPL-X format motion data\n",
|
| 261 |
+
"- **Visualization**: Use the [Blender Add-on](https://huggingface.co/datasets/H-Liu1997/BEAT2_Tools/blob/main/smplx_blender_addon_20230921.zip) for high-quality rendering\n",
|
| 262 |
+
"- **Interactive Demo**: [HuggingFace Space](https://huggingface.co/spaces/H-Liu1997/EMAGE)"
|
| 263 |
]
|
| 264 |
}
|
| 265 |
],
|
|
|
|
| 276 |
},
|
| 277 |
"nbformat": 4,
|
| 278 |
"nbformat_minor": 4
|
| 279 |
+
}
|