megalado
commited on
Commit
·
434c5bc
1
Parent(s):
ac3046f
Add local checkpoint path and checkpoints folder
Browse files- .gitattributes +1 -0
- app.py +3 -16
.gitattributes
CHANGED
|
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 36 |
motion-diffusion-model/**/*.pt filter=lfs diff=lfs merge=lfs -text
|
| 37 |
motion-diffusion-model/**/*.npy filter=lfs diff=lfs merge=lfs -text
|
| 38 |
motion-diffusion-model/**/*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 36 |
motion-diffusion-model/**/*.pt filter=lfs diff=lfs merge=lfs -text
|
| 37 |
motion-diffusion-model/**/*.npy filter=lfs diff=lfs merge=lfs -text
|
| 38 |
motion-diffusion-model/**/*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
checkpoints/*.pt filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
| 1 |
-
cat > app.py << 'EOL'
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
import os
|
| 5 |
import sys
|
| 6 |
import numpy as np
|
| 7 |
import subprocess
|
| 8 |
-
import gdown
|
| 9 |
from pathlib import Path
|
| 10 |
|
| 11 |
# Setup progress tracking
|
|
@@ -25,17 +23,6 @@ def setup_environment():
|
|
| 25 |
progress_status = "Installing Spacy language model..."
|
| 26 |
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
|
| 27 |
|
| 28 |
-
# Create directories for checkpoints
|
| 29 |
-
os.makedirs("motion-diffusion-model/save/humanml_trans_enc_512", exist_ok=True)
|
| 30 |
-
|
| 31 |
-
# Download the model checkpoint if not already present
|
| 32 |
-
model_path = "motion-diffusion-model/save/humanml_trans_enc_512/model000200000.pt"
|
| 33 |
-
if not Path(model_path).exists():
|
| 34 |
-
progress_status = "Downloading model checkpoint (may take a few minutes)..."
|
| 35 |
-
# Model checkpoint from Google Drive
|
| 36 |
-
url = "https://drive.google.com/uc?id=1dbBIlDsYwvAcMctb3Zc_DLA-L-yrBNtN"
|
| 37 |
-
gdown.download(url, model_path, quiet=False)
|
| 38 |
-
|
| 39 |
# Download other necessary files if they don't exist
|
| 40 |
if not Path("motion-diffusion-model/data/smpl").exists():
|
| 41 |
progress_status = "Downloading SMPL files..."
|
|
@@ -82,8 +69,8 @@ def text_to_motion(text_prompt, motion_length=3.0, seed=0):
|
|
| 82 |
if seed is not None:
|
| 83 |
torch.manual_seed(seed)
|
| 84 |
|
| 85 |
-
#
|
| 86 |
-
model_path = "
|
| 87 |
|
| 88 |
# Generate the motion
|
| 89 |
progress_status = "Running MDM generation..."
|
|
@@ -149,4 +136,4 @@ with gr.Blocks() as demo:
|
|
| 149 |
|
| 150 |
# Launch the app
|
| 151 |
if __name__ == "__main__":
|
| 152 |
-
demo.launch()
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
import numpy as np
|
| 6 |
import subprocess
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
# Setup progress tracking
|
|
|
|
| 23 |
progress_status = "Installing Spacy language model..."
|
| 24 |
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# Download other necessary files if they don't exist
|
| 27 |
if not Path("motion-diffusion-model/data/smpl").exists():
|
| 28 |
progress_status = "Downloading SMPL files..."
|
|
|
|
| 69 |
if seed is not None:
|
| 70 |
torch.manual_seed(seed)
|
| 71 |
|
| 72 |
+
# Use the local checkpoint path
|
| 73 |
+
model_path = "checkpoints/mld_humanml.pt"
|
| 74 |
|
| 75 |
# Generate the motion
|
| 76 |
progress_status = "Running MDM generation..."
|
|
|
|
| 136 |
|
| 137 |
# Launch the app
|
| 138 |
if __name__ == "__main__":
|
| 139 |
+
demo.launch()
|