Commit
·
eb339cb
0
Parent(s):
[Init]
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +12 -0
- LICENSE +25 -0
- README.md +11 -0
- app.py +236 -0
- configs/mld_t2m.yaml +104 -0
- configs/modules/denoiser.yaml +28 -0
- configs/modules/motion_vae.yaml +18 -0
- configs/modules/noise_optimizer.yaml +15 -0
- configs/modules/scheduler_ddim.yaml +14 -0
- configs/modules/scheduler_lcm.yaml +19 -0
- configs/modules/text_encoder.yaml +5 -0
- configs/modules/traj_encoder.yaml +17 -0
- configs/motionlcm_control_s.yaml +113 -0
- configs/motionlcm_control_t.yaml +111 -0
- configs/motionlcm_t2m.yaml +109 -0
- configs/motionlcm_t2m_clt.yaml +69 -0
- configs/vae.yaml +103 -0
- configs_v1/modules/denoiser.yaml +28 -0
- configs_v1/modules/motion_vae.yaml +18 -0
- configs_v1/modules/scheduler_lcm.yaml +11 -0
- configs_v1/modules/text_encoder.yaml +5 -0
- configs_v1/modules/traj_encoder.yaml +17 -0
- configs_v1/motionlcm_control_t.yaml +114 -0
- configs_v1/motionlcm_t2m.yaml +109 -0
- demo.py +196 -0
- fit.py +136 -0
- mld/__init__.py +0 -0
- mld/config.py +52 -0
- mld/data/__init__.py +0 -0
- mld/data/base.py +58 -0
- mld/data/data.py +73 -0
- mld/data/get_data.py +79 -0
- mld/data/humanml/__init__.py +0 -0
- mld/data/humanml/common/quaternion.py +29 -0
- mld/data/humanml/dataset.py +348 -0
- mld/data/humanml/scripts/motion_process.py +51 -0
- mld/data/humanml/utils/__init__.py +0 -0
- mld/data/humanml/utils/paramUtil.py +62 -0
- mld/data/humanml/utils/plot_script.py +98 -0
- mld/data/humanml/utils/word_vectorizer.py +82 -0
- mld/data/utils.py +52 -0
- mld/launch/__init__.py +0 -0
- mld/launch/blender.py +23 -0
- mld/models/__init__.py +0 -0
- mld/models/architectures/__init__.py +0 -0
- mld/models/architectures/dno.py +79 -0
- mld/models/architectures/mld_clip.py +72 -0
- mld/models/architectures/mld_denoiser.py +200 -0
- mld/models/architectures/mld_traj_encoder.py +64 -0
- mld/models/architectures/mld_vae.py +136 -0
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**/*.pyc
|
| 2 |
+
.idea/
|
| 3 |
+
__pycache__/
|
| 4 |
+
|
| 5 |
+
deps/
|
| 6 |
+
datasets/
|
| 7 |
+
experiments_t2m/
|
| 8 |
+
experiments_t2m_test/
|
| 9 |
+
experiments_control/
|
| 10 |
+
experiments_control_test/
|
| 11 |
+
experiments_recons/
|
| 12 |
+
experiments_recons_test/
|
LICENSE
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright Tsinghua University and Shanghai AI Laboratory. All Rights Reserved.
|
| 2 |
+
|
| 3 |
+
License for Non-commercial Scientific Research Purposes.
|
| 4 |
+
|
| 5 |
+
For more information see <https://github.com/Dai-Wenxun/MotionLCM>.
|
| 6 |
+
If you use this software, please cite the corresponding publications
|
| 7 |
+
listed on the above website.
|
| 8 |
+
|
| 9 |
+
Permission to use, copy, modify, and distribute this software and its
|
| 10 |
+
documentation for educational, research, and non-profit purposes only.
|
| 11 |
+
Any modification based on this work must be open-source and prohibited
|
| 12 |
+
for commercial, pornographic, military, or surveillance use.
|
| 13 |
+
|
| 14 |
+
The authors grant you a non-exclusive, worldwide, non-transferable,
|
| 15 |
+
non-sublicensable, revocable, royalty-free, and limited license under
|
| 16 |
+
our copyright interests to reproduce, distribute, and create derivative
|
| 17 |
+
works of the text, videos, and codes solely for your non-commercial
|
| 18 |
+
research purposes.
|
| 19 |
+
|
| 20 |
+
You must retain, in the source form of any derivative works that you
|
| 21 |
+
distribute, all copyright, patent, trademark, and attribution notices
|
| 22 |
+
from the source form of this work.
|
| 23 |
+
|
| 24 |
+
For commercial uses of this software, please send email to all people
|
| 25 |
+
in the author list.
|
README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: MotionLCM
|
| 3 |
+
emoji: 🏎️💨
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
python_version: 3.10.12
|
| 11 |
+
---
|
app.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import random
|
| 4 |
+
import datetime
|
| 5 |
+
import os.path as osp
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
import tqdm
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import gradio as gr
|
| 13 |
+
|
| 14 |
+
from mld.config import get_module_config
|
| 15 |
+
from mld.data.get_data import get_dataset
|
| 16 |
+
from mld.models.modeltype.mld import MLD
|
| 17 |
+
from mld.utils.utils import set_seed
|
| 18 |
+
from mld.data.humanml.utils.plot_script import plot_3d_motion
|
| 19 |
+
|
| 20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 21 |
+
|
| 22 |
+
WEBSITE = """
|
| 23 |
+
<div class="embed_hidden">
|
| 24 |
+
<h1 style='text-align: center'> MotionLCM: Real-time Controllable Motion Generation via Latent Consistency Model </h1>
|
| 25 |
+
<h2 style='text-align: center'>
|
| 26 |
+
<a href="https://github.com/Dai-Wenxun/" target="_blank"><nobr>Wenxun Dai</nobr><sup>1</sup></a>  
|
| 27 |
+
<a href="https://lhchen.top/" target="_blank"><nobr>Ling-Hao Chen</nobr></a><sup>1</sup>  
|
| 28 |
+
<a href="https://wangjingbo1219.github.io/" target="_blank"><nobr>Jingbo Wang</nobr></a><sup>2</sup>  
|
| 29 |
+
<a href="https://moonsliu.github.io/" target="_blank"><nobr>Jinpeng Liu</nobr></a><sup>1</sup>  
|
| 30 |
+
<a href="https://daibo.info/" target="_blank"><nobr>Bo Dai</nobr></a><sup>2</sup>  
|
| 31 |
+
<a href="https://andytang15.github.io/" target="_blank"><nobr>Yansong Tang</nobr></a><sup>1</sup>
|
| 32 |
+
</h2>
|
| 33 |
+
<h2 style='text-align: center'>
|
| 34 |
+
<nobr><sup>1</sup>Tsinghua University</nobr>  
|
| 35 |
+
<nobr><sup>2</sup>Shanghai AI Laboratory</nobr>
|
| 36 |
+
</h2>
|
| 37 |
+
</div>
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
WEBSITE_bottom = """
|
| 41 |
+
<div class="embed_hidden">
|
| 42 |
+
<p>
|
| 43 |
+
Space adapted from <a href="https://huggingface.co/spaces/Mathux/TMR" target="_blank">TMR</a>
|
| 44 |
+
and <a href="https://huggingface.co/spaces/MeYourHint/MoMask" target="_blank">MoMask</a>.
|
| 45 |
+
</p>
|
| 46 |
+
</div>
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
EXAMPLES = [
|
| 50 |
+
"a person does a jump",
|
| 51 |
+
"a person waves both arms in the air.",
|
| 52 |
+
"The person takes 4 steps backwards.",
|
| 53 |
+
"this person bends forward as if to bow.",
|
| 54 |
+
"The person was pushed but did not fall.",
|
| 55 |
+
"a man walks forward in a snake like pattern.",
|
| 56 |
+
"a man paces back and forth along the same line.",
|
| 57 |
+
"with arms out to the sides a person walks forward",
|
| 58 |
+
"A man bends down and picks something up with his right hand.",
|
| 59 |
+
"The man walked forward, spun right on one foot and walked back to his original position.",
|
| 60 |
+
"a person slightly bent over with right hand pressing against the air walks forward slowly"
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
if not os.path.exists("./experiments_t2m/"):
|
| 64 |
+
os.system("bash prepare/download_pretrained_models.sh")
|
| 65 |
+
if not os.path.exists('./deps/glove/'):
|
| 66 |
+
os.system("bash prepare/download_glove.sh")
|
| 67 |
+
if not os.path.exists('./deps/sentence-t5-large/'):
|
| 68 |
+
os.system("bash prepare/prepare_t5.sh")
|
| 69 |
+
if not os.path.exists('./deps/t2m/'):
|
| 70 |
+
os.system("bash prepare/download_t2m_evaluators.sh")
|
| 71 |
+
if not os.path.exists('./datasets/humanml3d/'):
|
| 72 |
+
os.system("bash prepare/prepare_tiny_humanml3d.sh")
|
| 73 |
+
|
| 74 |
+
DEFAULT_TEXT = "cheerfully walking forward with each step."
|
| 75 |
+
MAX_VIDEOS = 8
|
| 76 |
+
NUM_ROWS = 2
|
| 77 |
+
NUM_COLS = MAX_VIDEOS // NUM_ROWS
|
| 78 |
+
EXAMPLES_PER_PAGE = 12
|
| 79 |
+
T2M_CFG = "./configs/mld_t2m.yaml"
|
| 80 |
+
step_map = {1: 10, 2: 25, 4: 50}
|
| 81 |
+
|
| 82 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 83 |
+
print("device: ", device)
|
| 84 |
+
|
| 85 |
+
cfg = OmegaConf.load(T2M_CFG)
|
| 86 |
+
cfg_root = os.path.dirname(T2M_CFG)
|
| 87 |
+
cfg_model = get_module_config(cfg.model, cfg.model.target, cfg_root)
|
| 88 |
+
cfg = OmegaConf.merge(cfg, cfg_model)
|
| 89 |
+
set_seed(cfg.SEED_VALUE)
|
| 90 |
+
|
| 91 |
+
name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
| 92 |
+
cfg.output_dir = osp.join(cfg.TEST_FOLDER, name_time_str)
|
| 93 |
+
vis_dir = osp.join(cfg.output_dir, 'samples')
|
| 94 |
+
os.makedirs(cfg.output_dir, exist_ok=False)
|
| 95 |
+
os.makedirs(vis_dir, exist_ok=False)
|
| 96 |
+
|
| 97 |
+
state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"]
|
| 98 |
+
print("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))
|
| 99 |
+
|
| 100 |
+
is_lcm = False
|
| 101 |
+
lcm_key = 'denoiser.time_embedding.cond_proj.weight' # unique key for CFG
|
| 102 |
+
if lcm_key in state_dict:
|
| 103 |
+
is_lcm = True
|
| 104 |
+
time_cond_proj_dim = state_dict[lcm_key].shape[1]
|
| 105 |
+
cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim
|
| 106 |
+
print(f'Is LCM: {is_lcm}')
|
| 107 |
+
|
| 108 |
+
dataset = get_dataset(cfg)
|
| 109 |
+
model = MLD(cfg, dataset)
|
| 110 |
+
model.to(device)
|
| 111 |
+
model.eval()
|
| 112 |
+
model.requires_grad_(False)
|
| 113 |
+
model.load_state_dict(state_dict)
|
| 114 |
+
|
| 115 |
+
FPS = eval(f"cfg.DATASET.{cfg.DATASET.NAME.upper()}.FRAME_RATE")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@torch.no_grad()
|
| 119 |
+
def generate(text_, motion_len_):
|
| 120 |
+
batch = {"text": [text_] * MAX_VIDEOS, "length": [motion_len_] * MAX_VIDEOS}
|
| 121 |
+
|
| 122 |
+
s = time.time()
|
| 123 |
+
joints = model(batch)[0]
|
| 124 |
+
runtime_infer = round(time.time() - s, 3)
|
| 125 |
+
|
| 126 |
+
s = time.time()
|
| 127 |
+
path = []
|
| 128 |
+
for i in tqdm.tqdm(range(len(joints))):
|
| 129 |
+
uid = random.randrange(999999999)
|
| 130 |
+
video_path = osp.join(vis_dir, f"sample_{uid}.mp4")
|
| 131 |
+
plot_3d_motion(video_path, joints[i].detach().cpu().numpy(), '', fps=FPS)
|
| 132 |
+
path.append(video_path)
|
| 133 |
+
runtime_draw = round(time.time() - s, 3)
|
| 134 |
+
|
| 135 |
+
runtime_info = f'Inference {len(joints)} motions, Runtime (Inference): {runtime_infer}s, ' \
|
| 136 |
+
f'Runtime (Draw Skeleton): {runtime_draw}s, device: {device} '
|
| 137 |
+
|
| 138 |
+
return path, runtime_info
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def generate_component(generate_function, text_, motion_len_, num_inference_steps_, guidance_scale_):
|
| 142 |
+
if text_ == "" or text_ is None:
|
| 143 |
+
return [None] * MAX_VIDEOS + ["Please modify the text prompt."]
|
| 144 |
+
|
| 145 |
+
model.cfg.model.scheduler.num_inference_steps = step_map[num_inference_steps_]
|
| 146 |
+
model.guidance_scale = guidance_scale_
|
| 147 |
+
motion_len_ = max(36, min(int(float(motion_len_) * FPS), 196))
|
| 148 |
+
paths, info = generate_function(text_, motion_len_)
|
| 149 |
+
paths = paths + [None] * (MAX_VIDEOS - len(paths))
|
| 150 |
+
return paths + [info]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
theme = gr.themes.Default(primary_hue="purple", secondary_hue="gray")
|
| 154 |
+
generate_and_show = partial(generate_component, generate)
|
| 155 |
+
|
| 156 |
+
with gr.Blocks(theme=theme) as demo:
|
| 157 |
+
gr.HTML(WEBSITE)
|
| 158 |
+
videos = []
|
| 159 |
+
|
| 160 |
+
with gr.Row():
|
| 161 |
+
with gr.Column(scale=3):
|
| 162 |
+
text = gr.Textbox(
|
| 163 |
+
show_label=True,
|
| 164 |
+
label="Text prompt",
|
| 165 |
+
value=DEFAULT_TEXT,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
with gr.Row():
|
| 169 |
+
with gr.Column(scale=2):
|
| 170 |
+
motion_len = gr.Slider(
|
| 171 |
+
minimum=1.8,
|
| 172 |
+
maximum=9.8,
|
| 173 |
+
step=0.2,
|
| 174 |
+
value=5.0,
|
| 175 |
+
label="Motion length",
|
| 176 |
+
info="Motion duration in seconds: [1.8s, 9.8s] (FPS = 20)."
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
with gr.Column(scale=1):
|
| 180 |
+
num_inference_steps = gr.Radio(
|
| 181 |
+
[1, 2, 4],
|
| 182 |
+
label="Inference steps",
|
| 183 |
+
value=4,
|
| 184 |
+
info="Number of inference steps.",
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
cfg = gr.Slider(
|
| 188 |
+
minimum=1,
|
| 189 |
+
maximum=15,
|
| 190 |
+
step=0.5,
|
| 191 |
+
value=7.5,
|
| 192 |
+
label="CFG",
|
| 193 |
+
info="Classifier-free diffusion guidance.",
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
gen_btn = gr.Button("Generate", variant="primary")
|
| 197 |
+
clear = gr.Button("Clear", variant="secondary")
|
| 198 |
+
|
| 199 |
+
results = gr.Textbox(show_label=True,
|
| 200 |
+
label='Inference info (runtime and device)',
|
| 201 |
+
info='Real-time inference cannot be achieved using the free CPU. Local GPU deployment is recommended.',
|
| 202 |
+
interactive=False)
|
| 203 |
+
|
| 204 |
+
with gr.Column(scale=2):
|
| 205 |
+
examples = gr.Examples(
|
| 206 |
+
examples=EXAMPLES,
|
| 207 |
+
inputs=[text],
|
| 208 |
+
examples_per_page=EXAMPLES_PER_PAGE)
|
| 209 |
+
|
| 210 |
+
for i in range(NUM_ROWS):
|
| 211 |
+
with gr.Row():
|
| 212 |
+
for j in range(NUM_COLS):
|
| 213 |
+
video = gr.Video(autoplay=True, loop=True)
|
| 214 |
+
videos.append(video)
|
| 215 |
+
|
| 216 |
+
# gr.HTML(WEBSITE_bottom)
|
| 217 |
+
|
| 218 |
+
gen_btn.click(
|
| 219 |
+
fn=generate_and_show,
|
| 220 |
+
inputs=[text, motion_len, num_inference_steps, cfg],
|
| 221 |
+
outputs=videos + [results],
|
| 222 |
+
)
|
| 223 |
+
text.submit(
|
| 224 |
+
fn=generate_and_show,
|
| 225 |
+
inputs=[text, motion_len, num_inference_steps, cfg],
|
| 226 |
+
outputs=videos + [results],
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def clear_videos():
|
| 231 |
+
return [None] * MAX_VIDEOS + [DEFAULT_TEXT] + [None]
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
clear.click(fn=clear_videos, outputs=videos + [text] + [results])
|
| 235 |
+
|
| 236 |
+
demo.launch()
|
configs/mld_t2m.yaml
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FOLDER: './experiments_t2m'
|
| 2 |
+
TEST_FOLDER: './experiments_t2m_test'
|
| 3 |
+
|
| 4 |
+
NAME: 'mld_humanml'
|
| 5 |
+
|
| 6 |
+
SEED_VALUE: 1234
|
| 7 |
+
|
| 8 |
+
TRAIN:
|
| 9 |
+
BATCH_SIZE: 64
|
| 10 |
+
SPLIT: 'train'
|
| 11 |
+
NUM_WORKERS: 8
|
| 12 |
+
PERSISTENT_WORKERS: true
|
| 13 |
+
|
| 14 |
+
PRETRAINED: 'experiments_recons/vae_humanml/vae_humanml.ckpt'
|
| 15 |
+
|
| 16 |
+
validation_steps: -1
|
| 17 |
+
validation_epochs: 50
|
| 18 |
+
checkpointing_steps: -1
|
| 19 |
+
checkpointing_epochs: 50
|
| 20 |
+
max_train_steps: -1
|
| 21 |
+
max_train_epochs: 3000
|
| 22 |
+
learning_rate: 1e-4
|
| 23 |
+
lr_scheduler: "cosine"
|
| 24 |
+
lr_warmup_steps: 1000
|
| 25 |
+
adam_beta1: 0.9
|
| 26 |
+
adam_beta2: 0.999
|
| 27 |
+
adam_weight_decay: 0.0
|
| 28 |
+
adam_epsilon: 1e-08
|
| 29 |
+
max_grad_norm: 1.0
|
| 30 |
+
model_ema: false
|
| 31 |
+
model_ema_steps: 32
|
| 32 |
+
model_ema_decay: 0.999
|
| 33 |
+
|
| 34 |
+
VAL:
|
| 35 |
+
BATCH_SIZE: 32
|
| 36 |
+
SPLIT: 'test'
|
| 37 |
+
NUM_WORKERS: 12
|
| 38 |
+
PERSISTENT_WORKERS: true
|
| 39 |
+
|
| 40 |
+
TEST:
|
| 41 |
+
BATCH_SIZE: 32
|
| 42 |
+
SPLIT: 'test'
|
| 43 |
+
NUM_WORKERS: 12
|
| 44 |
+
PERSISTENT_WORKERS: true
|
| 45 |
+
|
| 46 |
+
CHECKPOINTS: 'experiments_t2m/mld_humanml/mld_humanml.ckpt'
|
| 47 |
+
|
| 48 |
+
# Testing Args
|
| 49 |
+
REPLICATION_TIMES: 20
|
| 50 |
+
MM_NUM_SAMPLES: 100
|
| 51 |
+
MM_NUM_REPEATS: 30
|
| 52 |
+
MM_NUM_TIMES: 10
|
| 53 |
+
DIVERSITY_TIMES: 300
|
| 54 |
+
DO_MM_TEST: true
|
| 55 |
+
|
| 56 |
+
DATASET:
|
| 57 |
+
NAME: 'humanml3d'
|
| 58 |
+
SMPL_PATH: './deps/smpl'
|
| 59 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
| 60 |
+
HUMANML3D:
|
| 61 |
+
FRAME_RATE: 20.0
|
| 62 |
+
UNIT_LEN: 4
|
| 63 |
+
ROOT: './datasets/humanml3d'
|
| 64 |
+
CONTROL_ARGS:
|
| 65 |
+
CONTROL: false
|
| 66 |
+
TEMPORAL: false
|
| 67 |
+
TRAIN_JOINTS: [0]
|
| 68 |
+
TEST_JOINTS: [0]
|
| 69 |
+
TRAIN_DENSITY: 'random'
|
| 70 |
+
TEST_DENSITY: 100
|
| 71 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
| 72 |
+
SAMPLER:
|
| 73 |
+
MAX_LEN: 200
|
| 74 |
+
MIN_LEN: 40
|
| 75 |
+
MAX_TEXT_LEN: 20
|
| 76 |
+
PADDING_TO_MAX: false
|
| 77 |
+
WINDOW_SIZE: null
|
| 78 |
+
|
| 79 |
+
METRIC:
|
| 80 |
+
DIST_SYNC_ON_STEP: true
|
| 81 |
+
TYPE: ['TM2TMetrics']
|
| 82 |
+
|
| 83 |
+
model:
|
| 84 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_ddim', 'noise_optimizer']
|
| 85 |
+
latent_dim: [16, 32]
|
| 86 |
+
guidance_scale: 7.5
|
| 87 |
+
guidance_uncondp: 0.1
|
| 88 |
+
|
| 89 |
+
t2m_textencoder:
|
| 90 |
+
dim_word: 300
|
| 91 |
+
dim_pos_ohot: 15
|
| 92 |
+
dim_text_hidden: 512
|
| 93 |
+
dim_coemb_hidden: 512
|
| 94 |
+
|
| 95 |
+
t2m_motionencoder:
|
| 96 |
+
dim_move_hidden: 512
|
| 97 |
+
dim_move_latent: 512
|
| 98 |
+
dim_motion_hidden: 1024
|
| 99 |
+
dim_motion_latent: 512
|
| 100 |
+
|
| 101 |
+
bert_path: './deps/distilbert-base-uncased'
|
| 102 |
+
clip_path: './deps/clip-vit-large-patch14'
|
| 103 |
+
t5_path: './deps/sentence-t5-large'
|
| 104 |
+
t2m_path: './deps/t2m/'
|
configs/modules/denoiser.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
denoiser:
|
| 2 |
+
target: mld.models.architectures.mld_denoiser.MldDenoiser
|
| 3 |
+
params:
|
| 4 |
+
latent_dim: ${model.latent_dim}
|
| 5 |
+
hidden_dim: 256
|
| 6 |
+
text_dim: 768
|
| 7 |
+
time_dim: 768
|
| 8 |
+
ff_size: 1024
|
| 9 |
+
num_layers: 9
|
| 10 |
+
num_heads: 4
|
| 11 |
+
dropout: 0.1
|
| 12 |
+
normalize_before: false
|
| 13 |
+
norm_eps: 1e-5
|
| 14 |
+
activation: 'gelu'
|
| 15 |
+
norm_post: true
|
| 16 |
+
activation_post: null
|
| 17 |
+
flip_sin_to_cos: true
|
| 18 |
+
freq_shift: 0
|
| 19 |
+
time_act_fn: 'silu'
|
| 20 |
+
time_post_act_fn: null
|
| 21 |
+
position_embedding: 'learned'
|
| 22 |
+
arch: 'trans_enc'
|
| 23 |
+
add_mem_pos: true
|
| 24 |
+
force_pre_post_proj: true
|
| 25 |
+
text_act_fn: null
|
| 26 |
+
zero_init_cond: true
|
| 27 |
+
controlnet_embed_dim: 256
|
| 28 |
+
controlnet_act_fn: 'silu'
|
configs/modules/motion_vae.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
motion_vae:
|
| 2 |
+
target: mld.models.architectures.mld_vae.MldVae
|
| 3 |
+
params:
|
| 4 |
+
nfeats: ${DATASET.NFEATS}
|
| 5 |
+
latent_dim: ${model.latent_dim}
|
| 6 |
+
hidden_dim: 256
|
| 7 |
+
force_pre_post_proj: true
|
| 8 |
+
ff_size: 1024
|
| 9 |
+
num_layers: 9
|
| 10 |
+
num_heads: 4
|
| 11 |
+
dropout: 0.1
|
| 12 |
+
arch: 'encoder_decoder'
|
| 13 |
+
normalize_before: false
|
| 14 |
+
norm_eps: 1e-5
|
| 15 |
+
activation: 'gelu'
|
| 16 |
+
norm_post: true
|
| 17 |
+
activation_post: null
|
| 18 |
+
position_embedding: 'learned'
|
configs/modules/noise_optimizer.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
noise_optimizer:
|
| 2 |
+
target: mld.models.architectures.dno.DNO
|
| 3 |
+
params:
|
| 4 |
+
optimize: false
|
| 5 |
+
max_train_steps: 400
|
| 6 |
+
learning_rate: 0.1
|
| 7 |
+
lr_scheduler: 'cosine'
|
| 8 |
+
lr_warmup_steps: 50
|
| 9 |
+
clip_grad: true
|
| 10 |
+
loss_hint_type: 'l2'
|
| 11 |
+
loss_diff_penalty: 0.000
|
| 12 |
+
loss_correlate_penalty: 100
|
| 13 |
+
visualize_samples: 0
|
| 14 |
+
visualize_ske_steps: []
|
| 15 |
+
output_dir: ${output_dir}
|
configs/modules/scheduler_ddim.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
scheduler:
|
| 2 |
+
target: diffusers.DDIMScheduler
|
| 3 |
+
num_inference_steps: 50
|
| 4 |
+
eta: 0.0
|
| 5 |
+
params:
|
| 6 |
+
num_train_timesteps: 1000
|
| 7 |
+
beta_start: 0.00085
|
| 8 |
+
beta_end: 0.012
|
| 9 |
+
beta_schedule: 'scaled_linear'
|
| 10 |
+
prediction_type: 'epsilon'
|
| 11 |
+
clip_sample: false
|
| 12 |
+
# below are for ddim
|
| 13 |
+
set_alpha_to_one: false
|
| 14 |
+
steps_offset: 1
|
configs/modules/scheduler_lcm.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
scheduler:
|
| 2 |
+
target: mld.models.schedulers.scheduling_lcm.LCMScheduler
|
| 3 |
+
num_inference_steps: 1
|
| 4 |
+
cfg_step_map:
|
| 5 |
+
1: 8.0
|
| 6 |
+
2: 12.5
|
| 7 |
+
4: 13.5
|
| 8 |
+
params:
|
| 9 |
+
num_train_timesteps: 1000
|
| 10 |
+
beta_start: 0.00085
|
| 11 |
+
beta_end: 0.012
|
| 12 |
+
beta_schedule: 'scaled_linear'
|
| 13 |
+
clip_sample: false
|
| 14 |
+
set_alpha_to_one: false
|
| 15 |
+
original_inference_steps: 10
|
| 16 |
+
timesteps_step_map:
|
| 17 |
+
1: [799]
|
| 18 |
+
2: [699, 299]
|
| 19 |
+
4: [699, 399, 299, 299]
|
configs/modules/text_encoder.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
text_encoder:
|
| 2 |
+
target: mld.models.architectures.mld_clip.MldTextEncoder
|
| 3 |
+
params:
|
| 4 |
+
last_hidden_state: false
|
| 5 |
+
modelpath: ${model.t5_path}
|
configs/modules/traj_encoder.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
traj_encoder:
|
| 2 |
+
target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder
|
| 3 |
+
params:
|
| 4 |
+
nfeats: ${DATASET.NJOINTS}
|
| 5 |
+
latent_dim: ${model.latent_dim}
|
| 6 |
+
hidden_dim: 256
|
| 7 |
+
force_post_proj: true
|
| 8 |
+
ff_size: 1024
|
| 9 |
+
num_layers: 9
|
| 10 |
+
num_heads: 4
|
| 11 |
+
dropout: 0.1
|
| 12 |
+
normalize_before: false
|
| 13 |
+
norm_eps: 1e-5
|
| 14 |
+
activation: 'gelu'
|
| 15 |
+
norm_post: true
|
| 16 |
+
activation_post: null
|
| 17 |
+
position_embedding: 'learned'
|
configs/motionlcm_control_s.yaml
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FOLDER: './experiments_control/spatial'
|
| 2 |
+
TEST_FOLDER: './experiments_control_test/spatial'
|
| 3 |
+
|
| 4 |
+
NAME: 'motionlcm_humanml'
|
| 5 |
+
|
| 6 |
+
SEED_VALUE: 1234
|
| 7 |
+
|
| 8 |
+
TRAIN:
|
| 9 |
+
DATASET: 'humanml3d'
|
| 10 |
+
BATCH_SIZE: 128
|
| 11 |
+
SPLIT: 'train'
|
| 12 |
+
NUM_WORKERS: 8
|
| 13 |
+
PERSISTENT_WORKERS: true
|
| 14 |
+
|
| 15 |
+
PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
|
| 16 |
+
|
| 17 |
+
validation_steps: -1
|
| 18 |
+
validation_epochs: 50
|
| 19 |
+
checkpointing_steps: -1
|
| 20 |
+
checkpointing_epochs: 50
|
| 21 |
+
max_train_steps: -1
|
| 22 |
+
max_train_epochs: 1000
|
| 23 |
+
learning_rate: 1e-4
|
| 24 |
+
learning_rate_spatial: 1e-4
|
| 25 |
+
lr_scheduler: "cosine"
|
| 26 |
+
lr_warmup_steps: 1000
|
| 27 |
+
adam_beta1: 0.9
|
| 28 |
+
adam_beta2: 0.999
|
| 29 |
+
adam_weight_decay: 0.0
|
| 30 |
+
adam_epsilon: 1e-08
|
| 31 |
+
max_grad_norm: 1.0
|
| 32 |
+
|
| 33 |
+
VAL:
|
| 34 |
+
DATASET: 'humanml3d'
|
| 35 |
+
BATCH_SIZE: 32
|
| 36 |
+
SPLIT: 'test'
|
| 37 |
+
NUM_WORKERS: 12
|
| 38 |
+
PERSISTENT_WORKERS: true
|
| 39 |
+
|
| 40 |
+
TEST:
|
| 41 |
+
DATASET: 'humanml3d'
|
| 42 |
+
BATCH_SIZE: 32
|
| 43 |
+
SPLIT: 'test'
|
| 44 |
+
NUM_WORKERS: 12
|
| 45 |
+
PERSISTENT_WORKERS: true
|
| 46 |
+
|
| 47 |
+
CHECKPOINTS: 'experiments_control/spatial/motionlcm_humanml/motionlcm_humanml_s_pelvis.ckpt'
|
| 48 |
+
# CHECKPOINTS: 'experiments_control/spatial/motionlcm_humanml/motionlcm_humanml_s_all.ckpt'
|
| 49 |
+
|
| 50 |
+
# Testing Args
|
| 51 |
+
REPLICATION_TIMES: 1
|
| 52 |
+
DIVERSITY_TIMES: 300
|
| 53 |
+
DO_MM_TEST: false
|
| 54 |
+
MAX_NUM_SAMPLES: 1024
|
| 55 |
+
|
| 56 |
+
DATASET:
|
| 57 |
+
NAME: 'humanml3d'
|
| 58 |
+
SMPL_PATH: './deps/smpl'
|
| 59 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
| 60 |
+
HUMANML3D:
|
| 61 |
+
FRAME_RATE: 20.0
|
| 62 |
+
UNIT_LEN: 4
|
| 63 |
+
ROOT: './datasets/humanml3d'
|
| 64 |
+
CONTROL_ARGS:
|
| 65 |
+
CONTROL: true
|
| 66 |
+
TEMPORAL: false
|
| 67 |
+
TRAIN_JOINTS: [0]
|
| 68 |
+
TEST_JOINTS: [0]
|
| 69 |
+
TRAIN_DENSITY: 'random'
|
| 70 |
+
TEST_DENSITY: 100
|
| 71 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
| 72 |
+
SAMPLER:
|
| 73 |
+
MAX_LEN: 200
|
| 74 |
+
MIN_LEN: 40
|
| 75 |
+
MAX_TEXT_LEN: 20
|
| 76 |
+
PADDING_TO_MAX: false
|
| 77 |
+
WINDOW_SIZE: null
|
| 78 |
+
|
| 79 |
+
METRIC:
|
| 80 |
+
DIST_SYNC_ON_STEP: true
|
| 81 |
+
TYPE: ['TM2TMetrics', 'ControlMetrics']
|
| 82 |
+
|
| 83 |
+
model:
|
| 84 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'traj_encoder', 'noise_optimizer']
|
| 85 |
+
latent_dim: [16, 32]
|
| 86 |
+
guidance_scale: 'dynamic'
|
| 87 |
+
|
| 88 |
+
# ControlNet Args
|
| 89 |
+
is_controlnet: true
|
| 90 |
+
vaeloss: true
|
| 91 |
+
vaeloss_type: 'mask'
|
| 92 |
+
cond_ratio: 1.0
|
| 93 |
+
control_loss_func: 'l1_smooth'
|
| 94 |
+
use_3d: true
|
| 95 |
+
lcm_w_min_nax: [5, 15]
|
| 96 |
+
lcm_num_ddim_timesteps: 10
|
| 97 |
+
|
| 98 |
+
t2m_textencoder:
|
| 99 |
+
dim_word: 300
|
| 100 |
+
dim_pos_ohot: 15
|
| 101 |
+
dim_text_hidden: 512
|
| 102 |
+
dim_coemb_hidden: 512
|
| 103 |
+
|
| 104 |
+
t2m_motionencoder:
|
| 105 |
+
dim_move_hidden: 512
|
| 106 |
+
dim_move_latent: 512
|
| 107 |
+
dim_motion_hidden: 1024
|
| 108 |
+
dim_motion_latent: 512
|
| 109 |
+
|
| 110 |
+
bert_path: './deps/distilbert-base-uncased'
|
| 111 |
+
clip_path: './deps/clip-vit-large-patch14'
|
| 112 |
+
t5_path: './deps/sentence-t5-large'
|
| 113 |
+
t2m_path: './deps/t2m/'
|
configs/motionlcm_control_t.yaml
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FOLDER: './experiments_control/temporal'
|
| 2 |
+
TEST_FOLDER: './experiments_control_test/temporal'
|
| 3 |
+
|
| 4 |
+
NAME: 'motionlcm_humanml'
|
| 5 |
+
|
| 6 |
+
SEED_VALUE: 1234
|
| 7 |
+
|
| 8 |
+
TRAIN:
|
| 9 |
+
DATASET: 'humanml3d'
|
| 10 |
+
BATCH_SIZE: 128
|
| 11 |
+
SPLIT: 'train'
|
| 12 |
+
NUM_WORKERS: 8
|
| 13 |
+
PERSISTENT_WORKERS: true
|
| 14 |
+
|
| 15 |
+
PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
|
| 16 |
+
|
| 17 |
+
validation_steps: -1
|
| 18 |
+
validation_epochs: 50
|
| 19 |
+
checkpointing_steps: -1
|
| 20 |
+
checkpointing_epochs: 50
|
| 21 |
+
max_train_steps: -1
|
| 22 |
+
max_train_epochs: 1000
|
| 23 |
+
learning_rate: 1e-4
|
| 24 |
+
learning_rate_spatial: 1e-4
|
| 25 |
+
lr_scheduler: "cosine"
|
| 26 |
+
lr_warmup_steps: 1000
|
| 27 |
+
adam_beta1: 0.9
|
| 28 |
+
adam_beta2: 0.999
|
| 29 |
+
adam_weight_decay: 0.0
|
| 30 |
+
adam_epsilon: 1e-08
|
| 31 |
+
max_grad_norm: 1.0
|
| 32 |
+
|
| 33 |
+
VAL:
|
| 34 |
+
DATASET: 'humanml3d'
|
| 35 |
+
BATCH_SIZE: 32
|
| 36 |
+
SPLIT: 'test'
|
| 37 |
+
NUM_WORKERS: 12
|
| 38 |
+
PERSISTENT_WORKERS: true
|
| 39 |
+
|
| 40 |
+
TEST:
|
| 41 |
+
DATASET: 'humanml3d'
|
| 42 |
+
BATCH_SIZE: 32
|
| 43 |
+
SPLIT: 'test'
|
| 44 |
+
NUM_WORKERS: 12
|
| 45 |
+
PERSISTENT_WORKERS: true
|
| 46 |
+
|
| 47 |
+
CHECKPOINTS: 'experiments_control/temporal/motionlcm_humanml/motionlcm_humanml_t.ckpt'
|
| 48 |
+
|
| 49 |
+
# Testing Args
|
| 50 |
+
REPLICATION_TIMES: 20
|
| 51 |
+
DIVERSITY_TIMES: 300
|
| 52 |
+
DO_MM_TEST: false
|
| 53 |
+
|
| 54 |
+
DATASET:
|
| 55 |
+
NAME: 'humanml3d'
|
| 56 |
+
SMPL_PATH: './deps/smpl'
|
| 57 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
| 58 |
+
HUMANML3D:
|
| 59 |
+
FRAME_RATE: 20.0
|
| 60 |
+
UNIT_LEN: 4
|
| 61 |
+
ROOT: './datasets/humanml3d'
|
| 62 |
+
CONTROL_ARGS:
|
| 63 |
+
CONTROL: true
|
| 64 |
+
TEMPORAL: true
|
| 65 |
+
TRAIN_JOINTS: [0, 10, 11, 15, 20, 21]
|
| 66 |
+
TEST_JOINTS: [0, 10, 11, 15, 20, 21]
|
| 67 |
+
TRAIN_DENSITY: [25, 25]
|
| 68 |
+
TEST_DENSITY: 25
|
| 69 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
| 70 |
+
SAMPLER:
|
| 71 |
+
MAX_LEN: 200
|
| 72 |
+
MIN_LEN: 40
|
| 73 |
+
MAX_TEXT_LEN: 20
|
| 74 |
+
PADDING_TO_MAX: false
|
| 75 |
+
WINDOW_SIZE: null
|
| 76 |
+
|
| 77 |
+
METRIC:
|
| 78 |
+
DIST_SYNC_ON_STEP: true
|
| 79 |
+
TYPE: ['TM2TMetrics', 'ControlMetrics']
|
| 80 |
+
|
| 81 |
+
model:
|
| 82 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'traj_encoder', 'noise_optimizer']
|
| 83 |
+
latent_dim: [16, 32]
|
| 84 |
+
guidance_scale: 'dynamic'
|
| 85 |
+
|
| 86 |
+
# ControlNet Args
|
| 87 |
+
is_controlnet: true
|
| 88 |
+
vaeloss: true
|
| 89 |
+
vaeloss_type: 'sum'
|
| 90 |
+
cond_ratio: 1.0
|
| 91 |
+
control_loss_func: 'l2'
|
| 92 |
+
use_3d: false
|
| 93 |
+
lcm_w_min_nax: [5, 15]
|
| 94 |
+
lcm_num_ddim_timesteps: 10
|
| 95 |
+
|
| 96 |
+
t2m_textencoder:
|
| 97 |
+
dim_word: 300
|
| 98 |
+
dim_pos_ohot: 15
|
| 99 |
+
dim_text_hidden: 512
|
| 100 |
+
dim_coemb_hidden: 512
|
| 101 |
+
|
| 102 |
+
t2m_motionencoder:
|
| 103 |
+
dim_move_hidden: 512
|
| 104 |
+
dim_move_latent: 512
|
| 105 |
+
dim_motion_hidden: 1024
|
| 106 |
+
dim_motion_latent: 512
|
| 107 |
+
|
| 108 |
+
bert_path: './deps/distilbert-base-uncased'
|
| 109 |
+
clip_path: './deps/clip-vit-large-patch14'
|
| 110 |
+
t5_path: './deps/sentence-t5-large'
|
| 111 |
+
t2m_path: './deps/t2m/'
|
configs/motionlcm_t2m.yaml
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FOLDER: './experiments_t2m'
|
| 2 |
+
TEST_FOLDER: './experiments_t2m_test'
|
| 3 |
+
|
| 4 |
+
NAME: 'motionlcm_humanml'
|
| 5 |
+
|
| 6 |
+
SEED_VALUE: 1234
|
| 7 |
+
|
| 8 |
+
TRAIN:
|
| 9 |
+
BATCH_SIZE: 128
|
| 10 |
+
SPLIT: 'train'
|
| 11 |
+
NUM_WORKERS: 8
|
| 12 |
+
PERSISTENT_WORKERS: true
|
| 13 |
+
|
| 14 |
+
PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml.ckpt'
|
| 15 |
+
|
| 16 |
+
validation_steps: -1
|
| 17 |
+
validation_epochs: 50
|
| 18 |
+
checkpointing_steps: -1
|
| 19 |
+
checkpointing_epochs: 50
|
| 20 |
+
max_train_steps: -1
|
| 21 |
+
max_train_epochs: 1000
|
| 22 |
+
learning_rate: 2e-4
|
| 23 |
+
lr_scheduler: "cosine"
|
| 24 |
+
lr_warmup_steps: 1000
|
| 25 |
+
adam_beta1: 0.9
|
| 26 |
+
adam_beta2: 0.999
|
| 27 |
+
adam_weight_decay: 0.0
|
| 28 |
+
adam_epsilon: 1e-08
|
| 29 |
+
max_grad_norm: 1.0
|
| 30 |
+
|
| 31 |
+
# Latent Consistency Distillation Specific Arguments
|
| 32 |
+
w_min: 5.0
|
| 33 |
+
w_max: 15.0
|
| 34 |
+
num_ddim_timesteps: 10
|
| 35 |
+
loss_type: 'huber'
|
| 36 |
+
huber_c: 0.5
|
| 37 |
+
unet_time_cond_proj_dim: 256
|
| 38 |
+
ema_decay: 0.95
|
| 39 |
+
|
| 40 |
+
VAL:
|
| 41 |
+
BATCH_SIZE: 32
|
| 42 |
+
SPLIT: 'test'
|
| 43 |
+
NUM_WORKERS: 12
|
| 44 |
+
PERSISTENT_WORKERS: true
|
| 45 |
+
|
| 46 |
+
TEST:
|
| 47 |
+
BATCH_SIZE: 32
|
| 48 |
+
SPLIT: 'test'
|
| 49 |
+
NUM_WORKERS: 12
|
| 50 |
+
PERSISTENT_WORKERS: true
|
| 51 |
+
|
| 52 |
+
CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
|
| 53 |
+
|
| 54 |
+
# Testing Args
|
| 55 |
+
REPLICATION_TIMES: 20
|
| 56 |
+
MM_NUM_SAMPLES: 100
|
| 57 |
+
MM_NUM_REPEATS: 30
|
| 58 |
+
MM_NUM_TIMES: 10
|
| 59 |
+
DIVERSITY_TIMES: 300
|
| 60 |
+
DO_MM_TEST: true
|
| 61 |
+
|
| 62 |
+
DATASET:
|
| 63 |
+
NAME: 'humanml3d'
|
| 64 |
+
SMPL_PATH: './deps/smpl'
|
| 65 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
| 66 |
+
HUMANML3D:
|
| 67 |
+
FRAME_RATE: 20.0
|
| 68 |
+
UNIT_LEN: 4
|
| 69 |
+
ROOT: './datasets/humanml3d'
|
| 70 |
+
CONTROL_ARGS:
|
| 71 |
+
CONTROL: false
|
| 72 |
+
TEMPORAL: false
|
| 73 |
+
TRAIN_JOINTS: [0]
|
| 74 |
+
TEST_JOINTS: [0]
|
| 75 |
+
TRAIN_DENSITY: 'random'
|
| 76 |
+
TEST_DENSITY: 100
|
| 77 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
| 78 |
+
SAMPLER:
|
| 79 |
+
MAX_LEN: 200
|
| 80 |
+
MIN_LEN: 40
|
| 81 |
+
MAX_TEXT_LEN: 20
|
| 82 |
+
PADDING_TO_MAX: false
|
| 83 |
+
WINDOW_SIZE: null
|
| 84 |
+
|
| 85 |
+
METRIC:
|
| 86 |
+
DIST_SYNC_ON_STEP: true
|
| 87 |
+
TYPE: ['TM2TMetrics']
|
| 88 |
+
|
| 89 |
+
model:
|
| 90 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'noise_optimizer']
|
| 91 |
+
latent_dim: [16, 32]
|
| 92 |
+
guidance_scale: 'dynamic'
|
| 93 |
+
|
| 94 |
+
t2m_textencoder:
|
| 95 |
+
dim_word: 300
|
| 96 |
+
dim_pos_ohot: 15
|
| 97 |
+
dim_text_hidden: 512
|
| 98 |
+
dim_coemb_hidden: 512
|
| 99 |
+
|
| 100 |
+
t2m_motionencoder:
|
| 101 |
+
dim_move_hidden: 512
|
| 102 |
+
dim_move_latent: 512
|
| 103 |
+
dim_motion_hidden: 1024
|
| 104 |
+
dim_motion_latent: 512
|
| 105 |
+
|
| 106 |
+
bert_path: './deps/distilbert-base-uncased'
|
| 107 |
+
clip_path: './deps/clip-vit-large-patch14'
|
| 108 |
+
t5_path: './deps/sentence-t5-large'
|
| 109 |
+
t2m_path: './deps/t2m/'
|
configs/motionlcm_t2m_clt.yaml
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FOLDER: './experiments_t2m'
|
| 2 |
+
TEST_FOLDER: './experiments_t2m_test'
|
| 3 |
+
|
| 4 |
+
NAME: 'motionlcm_humanml'
|
| 5 |
+
|
| 6 |
+
SEED_VALUE: 1234
|
| 7 |
+
|
| 8 |
+
TEST:
|
| 9 |
+
BATCH_SIZE: 1
|
| 10 |
+
SPLIT: 'test'
|
| 11 |
+
NUM_WORKERS: 12
|
| 12 |
+
PERSISTENT_WORKERS: true
|
| 13 |
+
|
| 14 |
+
CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
|
| 15 |
+
|
| 16 |
+
# Testing Args
|
| 17 |
+
REPLICATION_TIMES: 1
|
| 18 |
+
DIVERSITY_TIMES: 300
|
| 19 |
+
DO_MM_TEST: false
|
| 20 |
+
MAX_NUM_SAMPLES: 1024
|
| 21 |
+
|
| 22 |
+
DATASET:
|
| 23 |
+
NAME: 'humanml3d'
|
| 24 |
+
SMPL_PATH: './deps/smpl'
|
| 25 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
| 26 |
+
HUMANML3D:
|
| 27 |
+
FRAME_RATE: 20.0
|
| 28 |
+
UNIT_LEN: 4
|
| 29 |
+
ROOT: './datasets/humanml3d'
|
| 30 |
+
CONTROL_ARGS:
|
| 31 |
+
CONTROL: true
|
| 32 |
+
TEMPORAL: false
|
| 33 |
+
TRAIN_JOINTS: [0]
|
| 34 |
+
TEST_JOINTS: [0]
|
| 35 |
+
TRAIN_DENSITY: 'random'
|
| 36 |
+
TEST_DENSITY: 100
|
| 37 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
| 38 |
+
SAMPLER:
|
| 39 |
+
MAX_LEN: 200
|
| 40 |
+
MIN_LEN: 40
|
| 41 |
+
MAX_TEXT_LEN: 20
|
| 42 |
+
PADDING_TO_MAX: false
|
| 43 |
+
WINDOW_SIZE: null
|
| 44 |
+
|
| 45 |
+
METRIC:
|
| 46 |
+
DIST_SYNC_ON_STEP: true
|
| 47 |
+
TYPE: ['TM2TMetrics', 'ControlMetrics']
|
| 48 |
+
|
| 49 |
+
model:
|
| 50 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'noise_optimizer']
|
| 51 |
+
latent_dim: [16, 32]
|
| 52 |
+
guidance_scale: 'dynamic'
|
| 53 |
+
|
| 54 |
+
t2m_textencoder:
|
| 55 |
+
dim_word: 300
|
| 56 |
+
dim_pos_ohot: 15
|
| 57 |
+
dim_text_hidden: 512
|
| 58 |
+
dim_coemb_hidden: 512
|
| 59 |
+
|
| 60 |
+
t2m_motionencoder:
|
| 61 |
+
dim_move_hidden: 512
|
| 62 |
+
dim_move_latent: 512
|
| 63 |
+
dim_motion_hidden: 1024
|
| 64 |
+
dim_motion_latent: 512
|
| 65 |
+
|
| 66 |
+
bert_path: './deps/distilbert-base-uncased'
|
| 67 |
+
clip_path: './deps/clip-vit-large-patch14'
|
| 68 |
+
t5_path: './deps/sentence-t5-large'
|
| 69 |
+
t2m_path: './deps/t2m/'
|
configs/vae.yaml
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FOLDER: './experiments_recons'
|
| 2 |
+
TEST_FOLDER: './experiments_recons_test'
|
| 3 |
+
|
| 4 |
+
NAME: 'vae_humanml'
|
| 5 |
+
|
| 6 |
+
SEED_VALUE: 1234
|
| 7 |
+
|
| 8 |
+
TRAIN:
|
| 9 |
+
BATCH_SIZE: 128
|
| 10 |
+
SPLIT: 'train'
|
| 11 |
+
NUM_WORKERS: 8
|
| 12 |
+
PERSISTENT_WORKERS: true
|
| 13 |
+
PRETRAINED: ''
|
| 14 |
+
|
| 15 |
+
validation_steps: -1
|
| 16 |
+
validation_epochs: 100
|
| 17 |
+
checkpointing_steps: -1
|
| 18 |
+
checkpointing_epochs: 100
|
| 19 |
+
max_train_steps: -1
|
| 20 |
+
max_train_epochs: 6000
|
| 21 |
+
learning_rate: 2e-4
|
| 22 |
+
lr_scheduler: "cosine"
|
| 23 |
+
lr_warmup_steps: 1000
|
| 24 |
+
adam_beta1: 0.9
|
| 25 |
+
adam_beta2: 0.999
|
| 26 |
+
adam_weight_decay: 0.0
|
| 27 |
+
adam_epsilon: 1e-08
|
| 28 |
+
max_grad_norm: 1.0
|
| 29 |
+
|
| 30 |
+
VAL:
|
| 31 |
+
BATCH_SIZE: 32
|
| 32 |
+
SPLIT: 'test'
|
| 33 |
+
NUM_WORKERS: 12
|
| 34 |
+
PERSISTENT_WORKERS: true
|
| 35 |
+
|
| 36 |
+
TEST:
|
| 37 |
+
BATCH_SIZE: 32
|
| 38 |
+
SPLIT: 'test'
|
| 39 |
+
NUM_WORKERS: 12
|
| 40 |
+
PERSISTENT_WORKERS: true
|
| 41 |
+
|
| 42 |
+
CHECKPOINTS: 'experiments_recons/vae_humanml/vae_humanml.ckpt'
|
| 43 |
+
|
| 44 |
+
# Testing Args
|
| 45 |
+
REPLICATION_TIMES: 20
|
| 46 |
+
DIVERSITY_TIMES: 300
|
| 47 |
+
DO_MM_TEST: false
|
| 48 |
+
|
| 49 |
+
DATASET:
|
| 50 |
+
NAME: 'humanml3d'
|
| 51 |
+
SMPL_PATH: './deps/smpl'
|
| 52 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
| 53 |
+
HUMANML3D:
|
| 54 |
+
FRAME_RATE: 20.0
|
| 55 |
+
UNIT_LEN: 4
|
| 56 |
+
ROOT: './datasets/humanml3d'
|
| 57 |
+
CONTROL_ARGS:
|
| 58 |
+
CONTROL: false
|
| 59 |
+
TEMPORAL: false
|
| 60 |
+
TRAIN_JOINTS: [0]
|
| 61 |
+
TEST_JOINTS: [0]
|
| 62 |
+
TRAIN_DENSITY: 'random'
|
| 63 |
+
TEST_DESITY: 100
|
| 64 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
| 65 |
+
SAMPLER:
|
| 66 |
+
MAX_LEN: 200
|
| 67 |
+
MIN_LEN: 40
|
| 68 |
+
MAX_TEXT_LEN: 20
|
| 69 |
+
PADDING_TO_MAX: true
|
| 70 |
+
WINDOW_SIZE: 64
|
| 71 |
+
|
| 72 |
+
METRIC:
|
| 73 |
+
DIST_SYNC_ON_STEP: true
|
| 74 |
+
TYPE: ['TM2TMetrics', "PosMetrics"]
|
| 75 |
+
|
| 76 |
+
model:
|
| 77 |
+
target: ['motion_vae']
|
| 78 |
+
latent_dim: [16, 32]
|
| 79 |
+
|
| 80 |
+
# VAE Args
|
| 81 |
+
rec_feats_ratio: 1.0
|
| 82 |
+
rec_joints_ratio: 1.0
|
| 83 |
+
rec_velocity_ratio: 0.0
|
| 84 |
+
kl_ratio: 1e-4
|
| 85 |
+
|
| 86 |
+
rec_feats_loss: 'l1_smooth'
|
| 87 |
+
rec_joints_loss: 'l1_smooth'
|
| 88 |
+
rec_velocity_loss: 'l1_smooth'
|
| 89 |
+
mask_loss: true
|
| 90 |
+
|
| 91 |
+
t2m_textencoder:
|
| 92 |
+
dim_word: 300
|
| 93 |
+
dim_pos_ohot: 15
|
| 94 |
+
dim_text_hidden: 512
|
| 95 |
+
dim_coemb_hidden: 512
|
| 96 |
+
|
| 97 |
+
t2m_motionencoder:
|
| 98 |
+
dim_move_hidden: 512
|
| 99 |
+
dim_move_latent: 512
|
| 100 |
+
dim_motion_hidden: 1024
|
| 101 |
+
dim_motion_latent: 512
|
| 102 |
+
|
| 103 |
+
t2m_path: './deps/t2m/'
|
configs_v1/modules/denoiser.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
denoiser:
|
| 2 |
+
target: mld.models.architectures.mld_denoiser.MldDenoiser
|
| 3 |
+
params:
|
| 4 |
+
latent_dim: ${model.latent_dim}
|
| 5 |
+
hidden_dim: null
|
| 6 |
+
text_dim: 768
|
| 7 |
+
time_dim: 768
|
| 8 |
+
ff_size: 1024
|
| 9 |
+
num_layers: 9
|
| 10 |
+
num_heads: 4
|
| 11 |
+
dropout: 0.1
|
| 12 |
+
normalize_before: false
|
| 13 |
+
norm_eps: 1e-5
|
| 14 |
+
activation: 'gelu'
|
| 15 |
+
norm_post: true
|
| 16 |
+
activation_post: null
|
| 17 |
+
flip_sin_to_cos: true
|
| 18 |
+
freq_shift: 0
|
| 19 |
+
time_act_fn: 'silu'
|
| 20 |
+
time_post_act_fn: null
|
| 21 |
+
position_embedding: 'learned'
|
| 22 |
+
arch: 'trans_enc'
|
| 23 |
+
add_mem_pos: true
|
| 24 |
+
force_pre_post_proj: false
|
| 25 |
+
text_act_fn: 'relu'
|
| 26 |
+
zero_init_cond: true
|
| 27 |
+
controlnet_embed_dim: 256
|
| 28 |
+
controlnet_act_fn: null
|
configs_v1/modules/motion_vae.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
motion_vae:
|
| 2 |
+
target: mld.models.architectures.mld_vae.MldVae
|
| 3 |
+
params:
|
| 4 |
+
nfeats: ${DATASET.NFEATS}
|
| 5 |
+
latent_dim: ${model.latent_dim}
|
| 6 |
+
hidden_dim: null
|
| 7 |
+
force_pre_post_proj: false
|
| 8 |
+
ff_size: 1024
|
| 9 |
+
num_layers: 9
|
| 10 |
+
num_heads: 4
|
| 11 |
+
dropout: 0.1
|
| 12 |
+
arch: 'encoder_decoder'
|
| 13 |
+
normalize_before: false
|
| 14 |
+
norm_eps: 1e-5
|
| 15 |
+
activation: 'gelu'
|
| 16 |
+
norm_post: true
|
| 17 |
+
activation_post: null
|
| 18 |
+
position_embedding: 'learned'
|
configs_v1/modules/scheduler_lcm.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
scheduler:
|
| 2 |
+
target: diffusers.LCMScheduler
|
| 3 |
+
num_inference_steps: 1
|
| 4 |
+
params:
|
| 5 |
+
num_train_timesteps: 1000
|
| 6 |
+
beta_start: 0.00085
|
| 7 |
+
beta_end: 0.012
|
| 8 |
+
beta_schedule: 'scaled_linear'
|
| 9 |
+
clip_sample: false
|
| 10 |
+
set_alpha_to_one: false
|
| 11 |
+
original_inference_steps: 50
|
configs_v1/modules/text_encoder.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
text_encoder:
|
| 2 |
+
target: mld.models.architectures.mld_clip.MldTextEncoder
|
| 3 |
+
params:
|
| 4 |
+
last_hidden_state: false
|
| 5 |
+
modelpath: ${model.t5_path}
|
configs_v1/modules/traj_encoder.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
traj_encoder:
|
| 2 |
+
target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder
|
| 3 |
+
params:
|
| 4 |
+
nfeats: ${DATASET.NJOINTS}
|
| 5 |
+
latent_dim: ${model.latent_dim}
|
| 6 |
+
hidden_dim: null
|
| 7 |
+
force_post_proj: false
|
| 8 |
+
ff_size: 1024
|
| 9 |
+
num_layers: 9
|
| 10 |
+
num_heads: 4
|
| 11 |
+
dropout: 0.1
|
| 12 |
+
normalize_before: false
|
| 13 |
+
norm_eps: 1e-5
|
| 14 |
+
activation: 'gelu'
|
| 15 |
+
norm_post: true
|
| 16 |
+
activation_post: null
|
| 17 |
+
position_embedding: 'learned'
|
configs_v1/motionlcm_control_t.yaml
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FOLDER: './experiments_control/temporal'
|
| 2 |
+
TEST_FOLDER: './experiments_control_test/temporal'
|
| 3 |
+
|
| 4 |
+
NAME: 'motionlcm_humanml'
|
| 5 |
+
|
| 6 |
+
SEED_VALUE: 1234
|
| 7 |
+
|
| 8 |
+
TRAIN:
|
| 9 |
+
DATASET: 'humanml3d'
|
| 10 |
+
BATCH_SIZE: 128
|
| 11 |
+
SPLIT: 'train'
|
| 12 |
+
NUM_WORKERS: 8
|
| 13 |
+
PERSISTENT_WORKERS: true
|
| 14 |
+
|
| 15 |
+
PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml_v1.ckpt'
|
| 16 |
+
|
| 17 |
+
validation_steps: -1
|
| 18 |
+
validation_epochs: 50
|
| 19 |
+
checkpointing_steps: -1
|
| 20 |
+
checkpointing_epochs: 50
|
| 21 |
+
max_train_steps: -1
|
| 22 |
+
max_train_epochs: 1000
|
| 23 |
+
learning_rate: 1e-4
|
| 24 |
+
learning_rate_spatial: 1e-4
|
| 25 |
+
lr_scheduler: "cosine"
|
| 26 |
+
lr_warmup_steps: 1000
|
| 27 |
+
adam_beta1: 0.9
|
| 28 |
+
adam_beta2: 0.999
|
| 29 |
+
adam_weight_decay: 0.0
|
| 30 |
+
adam_epsilon: 1e-08
|
| 31 |
+
max_grad_norm: 1.0
|
| 32 |
+
|
| 33 |
+
VAL:
|
| 34 |
+
DATASET: 'humanml3d'
|
| 35 |
+
BATCH_SIZE: 32
|
| 36 |
+
SPLIT: 'test'
|
| 37 |
+
NUM_WORKERS: 12
|
| 38 |
+
PERSISTENT_WORKERS: true
|
| 39 |
+
|
| 40 |
+
TEST:
|
| 41 |
+
DATASET: 'humanml3d'
|
| 42 |
+
BATCH_SIZE: 32
|
| 43 |
+
SPLIT: 'test'
|
| 44 |
+
NUM_WORKERS: 12
|
| 45 |
+
PERSISTENT_WORKERS: true
|
| 46 |
+
|
| 47 |
+
CHECKPOINTS: 'experiments_control/temporal/motionlcm_humanml/motionlcm_humanml_t_v1.ckpt'
|
| 48 |
+
|
| 49 |
+
# Testing Args
|
| 50 |
+
REPLICATION_TIMES: 20
|
| 51 |
+
MM_NUM_SAMPLES: 100
|
| 52 |
+
MM_NUM_REPEATS: 30
|
| 53 |
+
MM_NUM_TIMES: 10
|
| 54 |
+
DIVERSITY_TIMES: 300
|
| 55 |
+
DO_MM_TEST: false
|
| 56 |
+
|
| 57 |
+
DATASET:
|
| 58 |
+
NAME: 'humanml3d'
|
| 59 |
+
SMPL_PATH: './deps/smpl'
|
| 60 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
| 61 |
+
HUMANML3D:
|
| 62 |
+
FRAME_RATE: 20.0
|
| 63 |
+
UNIT_LEN: 4
|
| 64 |
+
ROOT: './datasets/humanml3d'
|
| 65 |
+
CONTROL_ARGS:
|
| 66 |
+
CONTROL: true
|
| 67 |
+
TEMPORAL: true
|
| 68 |
+
TRAIN_JOINTS: [0, 10, 11, 15, 20, 21]
|
| 69 |
+
TEST_JOINTS: [0, 10, 11, 15, 20, 21]
|
| 70 |
+
TRAIN_DENSITY: [25, 25]
|
| 71 |
+
TEST_DENSITY: 25
|
| 72 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
| 73 |
+
SAMPLER:
|
| 74 |
+
MAX_LEN: 200
|
| 75 |
+
MIN_LEN: 40
|
| 76 |
+
MAX_TEXT_LEN: 20
|
| 77 |
+
PADDING_TO_MAX: false
|
| 78 |
+
WINDOW_SIZE: null
|
| 79 |
+
|
| 80 |
+
METRIC:
|
| 81 |
+
DIST_SYNC_ON_STEP: true
|
| 82 |
+
TYPE: ['TM2TMetrics', 'ControlMetrics']
|
| 83 |
+
|
| 84 |
+
model:
|
| 85 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'traj_encoder']
|
| 86 |
+
latent_dim: [1, 256]
|
| 87 |
+
guidance_scale: 7.5
|
| 88 |
+
|
| 89 |
+
# ControlNet Args
|
| 90 |
+
is_controlnet: true
|
| 91 |
+
vaeloss: true
|
| 92 |
+
vaeloss_type: 'sum'
|
| 93 |
+
cond_ratio: 1.0
|
| 94 |
+
control_loss_func: 'l2'
|
| 95 |
+
use_3d: false
|
| 96 |
+
lcm_w_min_nax: null
|
| 97 |
+
lcm_num_ddim_timesteps: null
|
| 98 |
+
|
| 99 |
+
t2m_textencoder:
|
| 100 |
+
dim_word: 300
|
| 101 |
+
dim_pos_ohot: 15
|
| 102 |
+
dim_text_hidden: 512
|
| 103 |
+
dim_coemb_hidden: 512
|
| 104 |
+
|
| 105 |
+
t2m_motionencoder:
|
| 106 |
+
dim_move_hidden: 512
|
| 107 |
+
dim_move_latent: 512
|
| 108 |
+
dim_motion_hidden: 1024
|
| 109 |
+
dim_motion_latent: 512
|
| 110 |
+
|
| 111 |
+
bert_path: './deps/distilbert-base-uncased'
|
| 112 |
+
clip_path: './deps/clip-vit-large-patch14'
|
| 113 |
+
t5_path: './deps/sentence-t5-large'
|
| 114 |
+
t2m_path: './deps/t2m/'
|
configs_v1/motionlcm_t2m.yaml
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FOLDER: './experiments_t2m'
|
| 2 |
+
TEST_FOLDER: './experiments_t2m_test'
|
| 3 |
+
|
| 4 |
+
NAME: 'motionlcm_humanml'
|
| 5 |
+
|
| 6 |
+
SEED_VALUE: 1234
|
| 7 |
+
|
| 8 |
+
TRAIN:
|
| 9 |
+
BATCH_SIZE: 256
|
| 10 |
+
SPLIT: 'train'
|
| 11 |
+
NUM_WORKERS: 8
|
| 12 |
+
PERSISTENT_WORKERS: true
|
| 13 |
+
|
| 14 |
+
PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml_v1.ckpt'
|
| 15 |
+
|
| 16 |
+
validation_steps: -1
|
| 17 |
+
validation_epochs: 50
|
| 18 |
+
checkpointing_steps: -1
|
| 19 |
+
checkpointing_epochs: 50
|
| 20 |
+
max_train_steps: -1
|
| 21 |
+
max_train_epochs: 1000
|
| 22 |
+
learning_rate: 2e-4
|
| 23 |
+
lr_scheduler: "cosine"
|
| 24 |
+
lr_warmup_steps: 1000
|
| 25 |
+
adam_beta1: 0.9
|
| 26 |
+
adam_beta2: 0.999
|
| 27 |
+
adam_weight_decay: 0.0
|
| 28 |
+
adam_epsilon: 1e-08
|
| 29 |
+
max_grad_norm: 1.0
|
| 30 |
+
|
| 31 |
+
# Latent Consistency Distillation Specific Arguments
|
| 32 |
+
w_min: 5.0
|
| 33 |
+
w_max: 15.0
|
| 34 |
+
num_ddim_timesteps: 50
|
| 35 |
+
loss_type: 'huber'
|
| 36 |
+
huber_c: 0.001
|
| 37 |
+
unet_time_cond_proj_dim: 256
|
| 38 |
+
ema_decay: 0.95
|
| 39 |
+
|
| 40 |
+
VAL:
|
| 41 |
+
BATCH_SIZE: 32
|
| 42 |
+
SPLIT: 'test'
|
| 43 |
+
NUM_WORKERS: 12
|
| 44 |
+
PERSISTENT_WORKERS: true
|
| 45 |
+
|
| 46 |
+
TEST:
|
| 47 |
+
BATCH_SIZE: 32
|
| 48 |
+
SPLIT: 'test'
|
| 49 |
+
NUM_WORKERS: 12
|
| 50 |
+
PERSISTENT_WORKERS: true
|
| 51 |
+
|
| 52 |
+
CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml_v1.ckpt'
|
| 53 |
+
|
| 54 |
+
# Testing Args
|
| 55 |
+
REPLICATION_TIMES: 20
|
| 56 |
+
MM_NUM_SAMPLES: 100
|
| 57 |
+
MM_NUM_REPEATS: 30
|
| 58 |
+
MM_NUM_TIMES: 10
|
| 59 |
+
DIVERSITY_TIMES: 300
|
| 60 |
+
DO_MM_TEST: true
|
| 61 |
+
|
| 62 |
+
DATASET:
|
| 63 |
+
NAME: 'humanml3d'
|
| 64 |
+
SMPL_PATH: './deps/smpl'
|
| 65 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
| 66 |
+
HUMANML3D:
|
| 67 |
+
FRAME_RATE: 20.0
|
| 68 |
+
UNIT_LEN: 4
|
| 69 |
+
ROOT: './datasets/humanml3d'
|
| 70 |
+
CONTROL_ARGS:
|
| 71 |
+
CONTROL: false
|
| 72 |
+
TEMPORAL: false
|
| 73 |
+
TRAIN_JOINTS: [0]
|
| 74 |
+
TEST_JOINTS: [0]
|
| 75 |
+
TRAIN_DENSITY: 'random'
|
| 76 |
+
TEST_DENSITY: 100
|
| 77 |
+
MEAN_STD_PATH: './datasets/humanml_spatial_norm'
|
| 78 |
+
SAMPLER:
|
| 79 |
+
MAX_LEN: 200
|
| 80 |
+
MIN_LEN: 40
|
| 81 |
+
MAX_TEXT_LEN: 20
|
| 82 |
+
PADDING_TO_MAX: false
|
| 83 |
+
WINDOW_SIZE: null
|
| 84 |
+
|
| 85 |
+
METRIC:
|
| 86 |
+
DIST_SYNC_ON_STEP: true
|
| 87 |
+
TYPE: ['TM2TMetrics']
|
| 88 |
+
|
| 89 |
+
model:
|
| 90 |
+
target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm']
|
| 91 |
+
latent_dim: [1, 256]
|
| 92 |
+
guidance_scale: 7.5
|
| 93 |
+
|
| 94 |
+
t2m_textencoder:
|
| 95 |
+
dim_word: 300
|
| 96 |
+
dim_pos_ohot: 15
|
| 97 |
+
dim_text_hidden: 512
|
| 98 |
+
dim_coemb_hidden: 512
|
| 99 |
+
|
| 100 |
+
t2m_motionencoder:
|
| 101 |
+
dim_move_hidden: 512
|
| 102 |
+
dim_move_latent: 512
|
| 103 |
+
dim_motion_hidden: 1024
|
| 104 |
+
dim_motion_latent: 512
|
| 105 |
+
|
| 106 |
+
bert_path: './deps/distilbert-base-uncased'
|
| 107 |
+
clip_path: './deps/clip-vit-large-patch14'
|
| 108 |
+
t5_path: './deps/sentence-t5-large'
|
| 109 |
+
t2m_path: './deps/t2m/'
|
demo.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import sys
|
| 4 |
+
import datetime
|
| 5 |
+
import logging
|
| 6 |
+
import os.path as osp
|
| 7 |
+
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from mld.config import parse_args
|
| 13 |
+
from mld.data.get_data import get_dataset
|
| 14 |
+
from mld.models.modeltype.mld import MLD
|
| 15 |
+
from mld.models.modeltype.vae import VAE
|
| 16 |
+
from mld.utils.utils import set_seed, move_batch_to_device
|
| 17 |
+
from mld.data.humanml.utils.plot_script import plot_3d_motion
|
| 18 |
+
from mld.utils.temos_utils import remove_padding
|
| 19 |
+
|
| 20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_example_hint_input(text_path: str) -> tuple:
|
| 24 |
+
with open(text_path, "r") as f:
|
| 25 |
+
lines = f.readlines()
|
| 26 |
+
|
| 27 |
+
n_frames, control_type_ids, control_hint_ids = [], [], []
|
| 28 |
+
for line in lines:
|
| 29 |
+
s = line.strip()
|
| 30 |
+
n_frame, control_type_id, control_hint_id = s.split(' ')
|
| 31 |
+
n_frames.append(int(n_frame))
|
| 32 |
+
control_type_ids.append(int(control_type_id))
|
| 33 |
+
control_hint_ids.append(int(control_hint_id))
|
| 34 |
+
|
| 35 |
+
return n_frames, control_type_ids, control_hint_ids
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_example_input(text_path: str) -> tuple:
|
| 39 |
+
with open(text_path, "r") as f:
|
| 40 |
+
lines = f.readlines()
|
| 41 |
+
|
| 42 |
+
texts, lens = [], []
|
| 43 |
+
for line in lines:
|
| 44 |
+
s = line.strip()
|
| 45 |
+
s_l = s.split(" ")[0]
|
| 46 |
+
s_t = s[(len(s_l) + 1):]
|
| 47 |
+
lens.append(int(s_l))
|
| 48 |
+
texts.append(s_t)
|
| 49 |
+
return texts, lens
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def main():
|
| 53 |
+
cfg = parse_args()
|
| 54 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 55 |
+
set_seed(cfg.SEED_VALUE)
|
| 56 |
+
|
| 57 |
+
name_time_str = osp.join(cfg.NAME, "demo_" + datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
| 58 |
+
cfg.output_dir = osp.join(cfg.TEST_FOLDER, name_time_str)
|
| 59 |
+
vis_dir = osp.join(cfg.output_dir, 'samples')
|
| 60 |
+
os.makedirs(cfg.output_dir, exist_ok=False)
|
| 61 |
+
os.makedirs(vis_dir, exist_ok=False)
|
| 62 |
+
|
| 63 |
+
steam_handler = logging.StreamHandler(sys.stdout)
|
| 64 |
+
file_handler = logging.FileHandler(osp.join(cfg.output_dir, 'output.log'))
|
| 65 |
+
logging.basicConfig(level=logging.INFO,
|
| 66 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 67 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 68 |
+
handlers=[steam_handler, file_handler])
|
| 69 |
+
logger = logging.getLogger(__name__)
|
| 70 |
+
|
| 71 |
+
OmegaConf.save(cfg, osp.join(cfg.output_dir, 'config.yaml'))
|
| 72 |
+
|
| 73 |
+
state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"]
|
| 74 |
+
logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))
|
| 75 |
+
|
| 76 |
+
# Step 1: Check if the checkpoint is VAE-based.
|
| 77 |
+
is_vae = False
|
| 78 |
+
vae_key = 'vae.skel_embedding.weight'
|
| 79 |
+
if vae_key in state_dict:
|
| 80 |
+
is_vae = True
|
| 81 |
+
logger.info(f'Is VAE: {is_vae}')
|
| 82 |
+
|
| 83 |
+
# Step 2: Check if the checkpoint is MLD-based.
|
| 84 |
+
is_mld = False
|
| 85 |
+
mld_key = 'denoiser.time_embedding.linear_1.weight'
|
| 86 |
+
if mld_key in state_dict:
|
| 87 |
+
is_mld = True
|
| 88 |
+
logger.info(f'Is MLD: {is_mld}')
|
| 89 |
+
|
| 90 |
+
# Step 3: Check if the checkpoint is LCM-based.
|
| 91 |
+
is_lcm = False
|
| 92 |
+
lcm_key = 'denoiser.time_embedding.cond_proj.weight' # unique key for CFG
|
| 93 |
+
if lcm_key in state_dict:
|
| 94 |
+
is_lcm = True
|
| 95 |
+
time_cond_proj_dim = state_dict[lcm_key].shape[1]
|
| 96 |
+
cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim
|
| 97 |
+
logger.info(f'Is LCM: {is_lcm}')
|
| 98 |
+
|
| 99 |
+
# Step 4: Check if the checkpoint is Controlnet-based.
|
| 100 |
+
cn_key = "controlnet.controlnet_cond_embedding.0.weight"
|
| 101 |
+
is_controlnet = True if cn_key in state_dict else False
|
| 102 |
+
cfg.model.is_controlnet = is_controlnet
|
| 103 |
+
logger.info(f'Is Controlnet: {is_controlnet}')
|
| 104 |
+
|
| 105 |
+
if is_mld or is_lcm or is_controlnet:
|
| 106 |
+
target_model_class = MLD
|
| 107 |
+
else:
|
| 108 |
+
target_model_class = VAE
|
| 109 |
+
|
| 110 |
+
if cfg.optimize:
|
| 111 |
+
assert cfg.model.get('noise_optimizer') is not None
|
| 112 |
+
cfg.model.noise_optimizer.params.optimize = True
|
| 113 |
+
logger.info('Optimization enabled. Set the batch size to 1.')
|
| 114 |
+
logger.info(f'Original batch size: {cfg.TEST.BATCH_SIZE}')
|
| 115 |
+
cfg.TEST.BATCH_SIZE = 1
|
| 116 |
+
|
| 117 |
+
dataset = get_dataset(cfg)
|
| 118 |
+
model = target_model_class(cfg, dataset)
|
| 119 |
+
model.to(device)
|
| 120 |
+
model.eval()
|
| 121 |
+
model.requires_grad_(False)
|
| 122 |
+
logger.info(model.load_state_dict(state_dict))
|
| 123 |
+
|
| 124 |
+
FPS = eval(f"cfg.DATASET.{cfg.DATASET.NAME.upper()}.FRAME_RATE")
|
| 125 |
+
|
| 126 |
+
if cfg.example is not None and not is_controlnet:
|
| 127 |
+
text, length = load_example_input(cfg.example)
|
| 128 |
+
for t, l in zip(text, length):
|
| 129 |
+
logger.info(f"{l}: {t}")
|
| 130 |
+
|
| 131 |
+
batch = {"length": length, "text": text}
|
| 132 |
+
|
| 133 |
+
for rep_i in range(cfg.replication):
|
| 134 |
+
with torch.no_grad():
|
| 135 |
+
joints = model(batch)[0]
|
| 136 |
+
|
| 137 |
+
num_samples = len(joints)
|
| 138 |
+
for i in range(num_samples):
|
| 139 |
+
res = dict()
|
| 140 |
+
pkl_path = osp.join(vis_dir, f"sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl")
|
| 141 |
+
res['joints'] = joints[i].detach().cpu().numpy()
|
| 142 |
+
res['text'] = text[i]
|
| 143 |
+
res['length'] = length[i]
|
| 144 |
+
res['hint'] = None
|
| 145 |
+
with open(pkl_path, 'wb') as f:
|
| 146 |
+
pickle.dump(res, f)
|
| 147 |
+
logger.info(f"Motions are generated here:\n{pkl_path}")
|
| 148 |
+
|
| 149 |
+
if not cfg.no_plot:
|
| 150 |
+
plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(), text[i], fps=FPS)
|
| 151 |
+
|
| 152 |
+
else:
|
| 153 |
+
test_dataloader = dataset.test_dataloader()
|
| 154 |
+
for rep_i in range(cfg.replication):
|
| 155 |
+
for batch_id, batch in enumerate(test_dataloader):
|
| 156 |
+
batch = move_batch_to_device(batch, device)
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
joints, joints_ref = model(batch)
|
| 159 |
+
|
| 160 |
+
num_samples = len(joints)
|
| 161 |
+
text = batch['text']
|
| 162 |
+
length = batch['length']
|
| 163 |
+
if 'hint' in batch:
|
| 164 |
+
hint, hint_mask = batch['hint'], batch['hint_mask']
|
| 165 |
+
hint = dataset.denorm_spatial(hint) * hint_mask
|
| 166 |
+
hint = remove_padding(hint, lengths=length)
|
| 167 |
+
else:
|
| 168 |
+
hint = None
|
| 169 |
+
|
| 170 |
+
for i in range(num_samples):
|
| 171 |
+
res = dict()
|
| 172 |
+
pkl_path = osp.join(vis_dir, f"batch_id_{batch_id}_sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl")
|
| 173 |
+
res['joints'] = joints[i].detach().cpu().numpy()
|
| 174 |
+
res['text'] = text[i]
|
| 175 |
+
res['length'] = length[i]
|
| 176 |
+
res['hint'] = hint[i].detach().cpu().numpy() if hint is not None else None
|
| 177 |
+
with open(pkl_path, 'wb') as f:
|
| 178 |
+
pickle.dump(res, f)
|
| 179 |
+
logger.info(f"Motions are generated here:\n{pkl_path}")
|
| 180 |
+
|
| 181 |
+
if not cfg.no_plot:
|
| 182 |
+
plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(),
|
| 183 |
+
text[i], fps=FPS, hint=hint[i].detach().cpu().numpy() if hint is not None else None)
|
| 184 |
+
|
| 185 |
+
if rep_i == 0:
|
| 186 |
+
res['joints'] = joints_ref[i].detach().cpu().numpy()
|
| 187 |
+
with open(pkl_path.replace('.pkl', '_ref.pkl'), 'wb') as f:
|
| 188 |
+
pickle.dump(res, f)
|
| 189 |
+
logger.info(f"Motions are generated here:\n{pkl_path.replace('.pkl', '_ref.pkl')}")
|
| 190 |
+
if not cfg.no_plot:
|
| 191 |
+
plot_3d_motion(pkl_path.replace('.pkl', '_ref.mp4'), joints_ref[i].detach().cpu().numpy(),
|
| 192 |
+
text[i], fps=FPS, hint=hint[i].detach().cpu().numpy() if hint is not None else None)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
main()
|
fit.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# borrow from optimization https://github.com/wangsen1312/joints2smpl
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import pickle
|
| 5 |
+
|
| 6 |
+
import h5py
|
| 7 |
+
import natsort
|
| 8 |
+
import smplx
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from mld.transforms.joints2rots import config
|
| 13 |
+
from mld.transforms.joints2rots.smplify import SMPLify3D
|
| 14 |
+
|
| 15 |
+
parser = argparse.ArgumentParser()
|
| 16 |
+
parser.add_argument("--pkl", type=str, default=None, help="pkl motion file")
|
| 17 |
+
parser.add_argument("--dir", type=str, default=None, help="pkl motion folder")
|
| 18 |
+
parser.add_argument("--num_smplify_iters", type=int, default=150, help="num of smplify iters")
|
| 19 |
+
parser.add_argument("--cuda", type=bool, default=True, help="enables cuda")
|
| 20 |
+
parser.add_argument("--gpu_ids", type=int, default=0, help="choose gpu ids")
|
| 21 |
+
parser.add_argument("--num_joints", type=int, default=22, help="joint number")
|
| 22 |
+
parser.add_argument("--joint_category", type=str, default="AMASS", help="use correspondence")
|
| 23 |
+
parser.add_argument("--fix_foot", type=str, default="False", help="fix foot or not")
|
| 24 |
+
opt = parser.parse_args()
|
| 25 |
+
print(opt)
|
| 26 |
+
|
| 27 |
+
if opt.pkl:
|
| 28 |
+
paths = [opt.pkl]
|
| 29 |
+
elif opt.dir:
|
| 30 |
+
paths = []
|
| 31 |
+
file_list = natsort.natsorted(os.listdir(opt.dir))
|
| 32 |
+
for item in file_list:
|
| 33 |
+
if item.endswith('.pkl') and not item.endswith("_mesh.pkl"):
|
| 34 |
+
paths.append(os.path.join(opt.dir, item))
|
| 35 |
+
else:
|
| 36 |
+
raise ValueError(f'{opt.pkl} and {opt.dir} are both None!')
|
| 37 |
+
|
| 38 |
+
for path in paths:
|
| 39 |
+
# load joints
|
| 40 |
+
if os.path.exists(path.replace('.pkl', '_mesh.pkl')):
|
| 41 |
+
print(f"{path} is rendered! skip!")
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
with open(path, 'rb') as f:
|
| 45 |
+
data = pickle.load(f)
|
| 46 |
+
|
| 47 |
+
joints = data['joints']
|
| 48 |
+
# load predefined something
|
| 49 |
+
device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu")
|
| 50 |
+
print(config.SMPL_MODEL_DIR)
|
| 51 |
+
smplxmodel = smplx.create(
|
| 52 |
+
config.SMPL_MODEL_DIR,
|
| 53 |
+
model_type="smpl",
|
| 54 |
+
gender="neutral",
|
| 55 |
+
ext="pkl",
|
| 56 |
+
batch_size=joints.shape[0],
|
| 57 |
+
).to(device)
|
| 58 |
+
|
| 59 |
+
# load the mean pose as original
|
| 60 |
+
smpl_mean_file = config.SMPL_MEAN_FILE
|
| 61 |
+
|
| 62 |
+
file = h5py.File(smpl_mean_file, "r")
|
| 63 |
+
init_mean_pose = (
|
| 64 |
+
torch.from_numpy(file["pose"][:])
|
| 65 |
+
.unsqueeze(0).repeat(joints.shape[0], 1)
|
| 66 |
+
.float()
|
| 67 |
+
.to(device)
|
| 68 |
+
)
|
| 69 |
+
init_mean_shape = (
|
| 70 |
+
torch.from_numpy(file["shape"][:])
|
| 71 |
+
.unsqueeze(0).repeat(joints.shape[0], 1)
|
| 72 |
+
.float()
|
| 73 |
+
.to(device)
|
| 74 |
+
)
|
| 75 |
+
cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device)
|
| 76 |
+
|
| 77 |
+
# initialize SMPLify
|
| 78 |
+
smplify = SMPLify3D(
|
| 79 |
+
smplxmodel=smplxmodel,
|
| 80 |
+
batch_size=joints.shape[0],
|
| 81 |
+
joints_category=opt.joint_category,
|
| 82 |
+
num_iters=opt.num_smplify_iters,
|
| 83 |
+
device=device,
|
| 84 |
+
)
|
| 85 |
+
print("initialize SMPLify3D done!")
|
| 86 |
+
|
| 87 |
+
print("Start SMPLify!")
|
| 88 |
+
keypoints_3d = torch.Tensor(joints).to(device).float()
|
| 89 |
+
|
| 90 |
+
if opt.joint_category == "AMASS":
|
| 91 |
+
confidence_input = torch.ones(opt.num_joints)
|
| 92 |
+
# make sure the foot and ankle
|
| 93 |
+
if opt.fix_foot:
|
| 94 |
+
confidence_input[7] = 1.5
|
| 95 |
+
confidence_input[8] = 1.5
|
| 96 |
+
confidence_input[10] = 1.5
|
| 97 |
+
confidence_input[11] = 1.5
|
| 98 |
+
else:
|
| 99 |
+
print("Such category not settle down!")
|
| 100 |
+
|
| 101 |
+
# ----- from initial to fitting -------
|
| 102 |
+
(
|
| 103 |
+
new_opt_vertices,
|
| 104 |
+
new_opt_joints,
|
| 105 |
+
new_opt_pose,
|
| 106 |
+
new_opt_betas,
|
| 107 |
+
new_opt_cam_t,
|
| 108 |
+
new_opt_joint_loss,
|
| 109 |
+
) = smplify(
|
| 110 |
+
init_mean_pose.detach(),
|
| 111 |
+
init_mean_shape.detach(),
|
| 112 |
+
cam_trans_zero.detach(),
|
| 113 |
+
keypoints_3d,
|
| 114 |
+
conf_3d=confidence_input.to(device)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# fix shape
|
| 118 |
+
betas = torch.zeros_like(new_opt_betas)
|
| 119 |
+
root = keypoints_3d[:, 0, :]
|
| 120 |
+
|
| 121 |
+
output = smplxmodel(
|
| 122 |
+
betas=betas,
|
| 123 |
+
global_orient=new_opt_pose[:, :3],
|
| 124 |
+
body_pose=new_opt_pose[:, 3:],
|
| 125 |
+
transl=root,
|
| 126 |
+
return_verts=True
|
| 127 |
+
)
|
| 128 |
+
vertices = output.vertices.detach().cpu().numpy()
|
| 129 |
+
floor_height = vertices[..., 1].min()
|
| 130 |
+
vertices[..., 1] -= floor_height
|
| 131 |
+
data['vertices'] = vertices
|
| 132 |
+
|
| 133 |
+
save_file = path.replace('.pkl', '_mesh.pkl')
|
| 134 |
+
with open(save_file, 'wb') as f:
|
| 135 |
+
pickle.dump(data, f)
|
| 136 |
+
print(f'vertices saved in {save_file}')
|
mld/__init__.py
ADDED
|
File without changes
|
mld/config.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import importlib
|
| 3 |
+
from typing import Type, TypeVar
|
| 4 |
+
from argparse import ArgumentParser
|
| 5 |
+
|
| 6 |
+
from omegaconf import OmegaConf, DictConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_module_config(cfg_model: DictConfig, paths: list[str], cfg_root: str) -> DictConfig:
|
| 10 |
+
files = [os.path.join(cfg_root, 'modules', p+'.yaml') for p in paths]
|
| 11 |
+
for file in files:
|
| 12 |
+
assert os.path.exists(file), f'{file} is not exists.'
|
| 13 |
+
with open(file, 'r') as f:
|
| 14 |
+
cfg_model.merge_with(OmegaConf.load(f))
|
| 15 |
+
return cfg_model
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_obj_from_str(string: str, reload: bool = False) -> Type:
|
| 19 |
+
module, cls = string.rsplit(".", 1)
|
| 20 |
+
if reload:
|
| 21 |
+
module_imp = importlib.import_module(module)
|
| 22 |
+
importlib.reload(module_imp)
|
| 23 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def instantiate_from_config(config: DictConfig) -> TypeVar:
|
| 27 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def parse_args() -> DictConfig:
|
| 31 |
+
parser = ArgumentParser()
|
| 32 |
+
parser.add_argument("--cfg", type=str, required=True, help="The main config file")
|
| 33 |
+
parser.add_argument('--example', type=str, required=False, help="The input texts and lengths with txt format")
|
| 34 |
+
parser.add_argument('--example_hint', type=str, required=False, help="The input hint ids and lengths with txt format")
|
| 35 |
+
parser.add_argument('--no-plot', action="store_true", required=False, help="Whether to plot the skeleton-based motion")
|
| 36 |
+
parser.add_argument('--replication', type=int, default=1, help="The number of replications of sampling")
|
| 37 |
+
parser.add_argument('--vis', type=str, default="tb", choices=['tb', 'swanlab'], help="The visualization backends: tensorboard or swanlab")
|
| 38 |
+
parser.add_argument('--optimize', action='store_true', help="Enable optimization for motion control")
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
|
| 41 |
+
cfg = OmegaConf.load(args.cfg)
|
| 42 |
+
cfg_root = os.path.dirname(args.cfg)
|
| 43 |
+
cfg_model = get_module_config(cfg.model, cfg.model.target, cfg_root)
|
| 44 |
+
cfg = OmegaConf.merge(cfg, cfg_model)
|
| 45 |
+
|
| 46 |
+
cfg.example = args.example
|
| 47 |
+
cfg.example_hint = args.example_hint
|
| 48 |
+
cfg.no_plot = args.no_plot
|
| 49 |
+
cfg.replication = args.replication
|
| 50 |
+
cfg.vis = args.vis
|
| 51 |
+
cfg.optimize = args.optimize
|
| 52 |
+
return cfg
|
mld/data/__init__.py
ADDED
|
File without changes
|
mld/data/base.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from os.path import join as pjoin
|
| 3 |
+
from typing import Any, Callable
|
| 4 |
+
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseDataModule:
|
| 9 |
+
def __init__(self, collate_fn: Callable) -> None:
|
| 10 |
+
super(BaseDataModule, self).__init__()
|
| 11 |
+
self.collate_fn = collate_fn
|
| 12 |
+
self.is_mm = False
|
| 13 |
+
|
| 14 |
+
def get_sample_set(self, overrides: dict) -> Any:
|
| 15 |
+
sample_params = copy.deepcopy(self.hparams)
|
| 16 |
+
sample_params.update(overrides)
|
| 17 |
+
split_file = pjoin(
|
| 18 |
+
eval(f"self.cfg.DATASET.{self.name.upper()}.ROOT"),
|
| 19 |
+
self.cfg.TEST.SPLIT + ".txt"
|
| 20 |
+
)
|
| 21 |
+
return self.Dataset(split_file=split_file, **sample_params)
|
| 22 |
+
|
| 23 |
+
def __getattr__(self, item: str) -> Any:
|
| 24 |
+
if item.endswith("_dataset") and not item.startswith("_"):
|
| 25 |
+
subset = item[:-len("_dataset")].upper()
|
| 26 |
+
item_c = "_" + item
|
| 27 |
+
if item_c not in self.__dict__:
|
| 28 |
+
split_file = pjoin(
|
| 29 |
+
eval(f"self.cfg.DATASET.{self.name.upper()}.ROOT"),
|
| 30 |
+
eval(f"self.cfg.{subset}.SPLIT") + ".txt"
|
| 31 |
+
)
|
| 32 |
+
self.__dict__[item_c] = self.Dataset(split_file=split_file, **self.hparams)
|
| 33 |
+
return getattr(self, item_c)
|
| 34 |
+
classname = self.__class__.__name__
|
| 35 |
+
raise AttributeError(f"'{classname}' object has no attribute '{item}'")
|
| 36 |
+
|
| 37 |
+
def get_dataloader_options(self, stage: str) -> dict:
|
| 38 |
+
stage_args = eval(f"self.cfg.{stage.upper()}")
|
| 39 |
+
dataloader_options = {
|
| 40 |
+
"batch_size": stage_args.BATCH_SIZE,
|
| 41 |
+
"num_workers": stage_args.NUM_WORKERS,
|
| 42 |
+
"collate_fn": self.collate_fn,
|
| 43 |
+
"persistent_workers": stage_args.PERSISTENT_WORKERS,
|
| 44 |
+
}
|
| 45 |
+
return dataloader_options
|
| 46 |
+
|
| 47 |
+
def train_dataloader(self) -> DataLoader:
|
| 48 |
+
dataloader_options = self.get_dataloader_options('TRAIN')
|
| 49 |
+
return DataLoader(self.train_dataset, shuffle=True, **dataloader_options)
|
| 50 |
+
|
| 51 |
+
def val_dataloader(self) -> DataLoader:
|
| 52 |
+
dataloader_options = self.get_dataloader_options('VAL')
|
| 53 |
+
return DataLoader(self.val_dataset, shuffle=False, **dataloader_options)
|
| 54 |
+
|
| 55 |
+
def test_dataloader(self) -> DataLoader:
|
| 56 |
+
dataloader_options = self.get_dataloader_options('TEST')
|
| 57 |
+
dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
|
| 58 |
+
return DataLoader(self.test_dataset, shuffle=False, **dataloader_options)
|
mld/data/data.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Callable, Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from omegaconf import DictConfig
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .base import BaseDataModule
|
| 10 |
+
from .humanml.dataset import Text2MotionDataset, MotionDataset
|
| 11 |
+
from .humanml.scripts.motion_process import recover_from_ric
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# (nfeats, njoints)
|
| 15 |
+
dataset_map = {'humanml3d': (263, 22), 'kit': (251, 21)}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DataModule(BaseDataModule):
|
| 19 |
+
|
| 20 |
+
def __init__(self,
|
| 21 |
+
name: str,
|
| 22 |
+
cfg: DictConfig,
|
| 23 |
+
motion_only: bool,
|
| 24 |
+
collate_fn: Optional[Callable] = None,
|
| 25 |
+
**kwargs) -> None:
|
| 26 |
+
super().__init__(collate_fn=collate_fn)
|
| 27 |
+
self.cfg = cfg
|
| 28 |
+
self.name = name
|
| 29 |
+
self.nfeats, self.njoints = dataset_map[name]
|
| 30 |
+
self.hparams = copy.deepcopy({**kwargs, 'njoints': self.njoints})
|
| 31 |
+
self.Dataset = MotionDataset if motion_only else Text2MotionDataset
|
| 32 |
+
sample_overrides = {"tiny": True, "progress_bar": False}
|
| 33 |
+
self._sample_set = self.get_sample_set(overrides=sample_overrides)
|
| 34 |
+
|
| 35 |
+
def denorm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
|
| 37 |
+
raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
|
| 38 |
+
hint = hint * raw_std + raw_mean
|
| 39 |
+
return hint
|
| 40 |
+
|
| 41 |
+
def norm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
|
| 43 |
+
raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
|
| 44 |
+
hint = (hint - raw_mean) / raw_std
|
| 45 |
+
return hint
|
| 46 |
+
|
| 47 |
+
def feats2joints(self, features: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
mean = torch.tensor(self.hparams['mean']).to(features)
|
| 49 |
+
std = torch.tensor(self.hparams['std']).to(features)
|
| 50 |
+
features = features * std + mean
|
| 51 |
+
return recover_from_ric(features, self.njoints)
|
| 52 |
+
|
| 53 |
+
def renorm4t2m(self, features: torch.Tensor) -> torch.Tensor:
|
| 54 |
+
# renorm to t2m norms for using t2m evaluators
|
| 55 |
+
ori_mean = torch.tensor(self.hparams['mean']).to(features)
|
| 56 |
+
ori_std = torch.tensor(self.hparams['std']).to(features)
|
| 57 |
+
eval_mean = torch.tensor(self.hparams['mean_eval']).to(features)
|
| 58 |
+
eval_std = torch.tensor(self.hparams['std_eval']).to(features)
|
| 59 |
+
features = features * ori_std + ori_mean
|
| 60 |
+
features = (features - eval_mean) / eval_std
|
| 61 |
+
return features
|
| 62 |
+
|
| 63 |
+
def mm_mode(self, mm_on: bool = True) -> None:
|
| 64 |
+
if mm_on:
|
| 65 |
+
self.is_mm = True
|
| 66 |
+
self.name_list = self.test_dataset.name_list
|
| 67 |
+
self.mm_list = np.random.choice(self.name_list,
|
| 68 |
+
self.cfg.TEST.MM_NUM_SAMPLES,
|
| 69 |
+
replace=False)
|
| 70 |
+
self.test_dataset.name_list = self.mm_list
|
| 71 |
+
else:
|
| 72 |
+
self.is_mm = False
|
| 73 |
+
self.test_dataset.name_list = self.name_list
|
mld/data/get_data.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from os.path import join as pjoin
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from omegaconf import DictConfig
|
| 7 |
+
|
| 8 |
+
from .data import DataModule
|
| 9 |
+
from .base import BaseDataModule
|
| 10 |
+
from .utils import mld_collate, mld_collate_motion_only
|
| 11 |
+
from .humanml.utils.word_vectorizer import WordVectorizer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]:
|
| 15 |
+
name = "t2m" if dataset_name == "humanml3d" else dataset_name
|
| 16 |
+
assert name in ["t2m", "kit"]
|
| 17 |
+
if phase in ["val"]:
|
| 18 |
+
if name == 't2m':
|
| 19 |
+
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta")
|
| 20 |
+
elif name == 'kit':
|
| 21 |
+
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta")
|
| 22 |
+
else:
|
| 23 |
+
raise ValueError("Only support t2m and kit")
|
| 24 |
+
mean = np.load(pjoin(data_root, "mean.npy"))
|
| 25 |
+
std = np.load(pjoin(data_root, "std.npy"))
|
| 26 |
+
else:
|
| 27 |
+
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
|
| 28 |
+
mean = np.load(pjoin(data_root, "Mean.npy"))
|
| 29 |
+
std = np.load(pjoin(data_root, "Std.npy"))
|
| 30 |
+
|
| 31 |
+
return mean, std
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_WordVectorizer(cfg: DictConfig, dataset_name: str) -> Optional[WordVectorizer]:
|
| 35 |
+
if dataset_name.lower() in ["humanml3d", "kit"]:
|
| 36 |
+
return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab")
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError("Only support WordVectorizer for HumanML3D and KIT")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
dataset_module_map = {"humanml3d": DataModule, "kit": DataModule}
|
| 42 |
+
motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_dataset(cfg: DictConfig, motion_only: bool = False) -> BaseDataModule:
|
| 46 |
+
dataset_name = cfg.DATASET.NAME
|
| 47 |
+
if dataset_name.lower() in ["humanml3d", "kit"]:
|
| 48 |
+
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
|
| 49 |
+
mean, std = get_mean_std('train', cfg, dataset_name)
|
| 50 |
+
mean_eval, std_eval = get_mean_std("val", cfg, dataset_name)
|
| 51 |
+
wordVectorizer = None if motion_only else get_WordVectorizer(cfg, dataset_name)
|
| 52 |
+
collate_fn = mld_collate_motion_only if motion_only else mld_collate
|
| 53 |
+
dataset = dataset_module_map[dataset_name.lower()](
|
| 54 |
+
name=dataset_name.lower(),
|
| 55 |
+
cfg=cfg,
|
| 56 |
+
motion_only=motion_only,
|
| 57 |
+
collate_fn=collate_fn,
|
| 58 |
+
mean=mean,
|
| 59 |
+
std=std,
|
| 60 |
+
mean_eval=mean_eval,
|
| 61 |
+
std_eval=std_eval,
|
| 62 |
+
w_vectorizer=wordVectorizer,
|
| 63 |
+
text_dir=pjoin(data_root, "texts"),
|
| 64 |
+
motion_dir=pjoin(data_root, motion_subdir[dataset_name]),
|
| 65 |
+
max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN,
|
| 66 |
+
min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN,
|
| 67 |
+
max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN,
|
| 68 |
+
unit_length=eval(f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"),
|
| 69 |
+
fps=eval(f"cfg.DATASET.{dataset_name.upper()}.FRAME_RATE"),
|
| 70 |
+
padding_to_max=cfg.DATASET.PADDING_TO_MAX,
|
| 71 |
+
window_size=cfg.DATASET.WINDOW_SIZE,
|
| 72 |
+
control_args=eval(f"cfg.DATASET.{dataset_name.upper()}.CONTROL_ARGS"))
|
| 73 |
+
|
| 74 |
+
cfg.DATASET.NFEATS = dataset.nfeats
|
| 75 |
+
cfg.DATASET.NJOINTS = dataset.njoints
|
| 76 |
+
return dataset
|
| 77 |
+
|
| 78 |
+
elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]:
|
| 79 |
+
raise NotImplementedError
|
mld/data/humanml/__init__.py
ADDED
|
File without changes
|
mld/data/humanml/common/quaternion.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def qinv(q: torch.Tensor) -> torch.Tensor:
|
| 5 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
| 6 |
+
mask = torch.ones_like(q)
|
| 7 |
+
mask[..., 1:] = -mask[..., 1:]
|
| 8 |
+
return q * mask
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def qrot(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 12 |
+
"""
|
| 13 |
+
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
| 14 |
+
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
| 15 |
+
where * denotes any number of dimensions.
|
| 16 |
+
Returns a tensor of shape (*, 3).
|
| 17 |
+
"""
|
| 18 |
+
assert q.shape[-1] == 4
|
| 19 |
+
assert v.shape[-1] == 3
|
| 20 |
+
assert q.shape[:-1] == v.shape[:-1]
|
| 21 |
+
|
| 22 |
+
original_shape = list(v.shape)
|
| 23 |
+
q = q.contiguous().view(-1, 4)
|
| 24 |
+
v = v.contiguous().view(-1, 3)
|
| 25 |
+
|
| 26 |
+
qvec = q[:, 1:]
|
| 27 |
+
uv = torch.cross(qvec, v, dim=1)
|
| 28 |
+
uuv = torch.cross(qvec, uv, dim=1)
|
| 29 |
+
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
|
mld/data/humanml/dataset.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import logging
|
| 4 |
+
import codecs as cs
|
| 5 |
+
from os.path import join as pjoin
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from rich.progress import track
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
|
| 13 |
+
from .scripts.motion_process import recover_from_ric
|
| 14 |
+
from .utils.word_vectorizer import WordVectorizer
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MotionDataset(Dataset):
|
| 20 |
+
def __init__(self, mean: np.ndarray, std: np.ndarray,
|
| 21 |
+
split_file: str, motion_dir: str, window_size: int,
|
| 22 |
+
tiny: bool = False, progress_bar: bool = True, **kwargs) -> None:
|
| 23 |
+
self.data = []
|
| 24 |
+
self.lengths = []
|
| 25 |
+
id_list = []
|
| 26 |
+
with cs.open(split_file, "r") as f:
|
| 27 |
+
for line in f.readlines():
|
| 28 |
+
id_list.append(line.strip())
|
| 29 |
+
|
| 30 |
+
maxdata = 10 if tiny else 1e10
|
| 31 |
+
if progress_bar:
|
| 32 |
+
enumerator = enumerate(
|
| 33 |
+
track(
|
| 34 |
+
id_list,
|
| 35 |
+
f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}",
|
| 36 |
+
))
|
| 37 |
+
else:
|
| 38 |
+
enumerator = enumerate(id_list)
|
| 39 |
+
|
| 40 |
+
count = 0
|
| 41 |
+
for i, name in enumerator:
|
| 42 |
+
if count > maxdata:
|
| 43 |
+
break
|
| 44 |
+
try:
|
| 45 |
+
motion = np.load(pjoin(motion_dir, name + '.npy'))
|
| 46 |
+
if motion.shape[0] < window_size:
|
| 47 |
+
continue
|
| 48 |
+
self.lengths.append(motion.shape[0] - window_size)
|
| 49 |
+
self.data.append(motion)
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(e)
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
self.cumsum = np.cumsum([0] + self.lengths)
|
| 55 |
+
if not tiny:
|
| 56 |
+
logger.info("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
|
| 57 |
+
|
| 58 |
+
self.mean = mean
|
| 59 |
+
self.std = std
|
| 60 |
+
self.window_size = window_size
|
| 61 |
+
|
| 62 |
+
def __len__(self) -> int:
|
| 63 |
+
return self.cumsum[-1]
|
| 64 |
+
|
| 65 |
+
def __getitem__(self, item: int) -> tuple:
|
| 66 |
+
if item != 0:
|
| 67 |
+
motion_id = np.searchsorted(self.cumsum, item) - 1
|
| 68 |
+
idx = item - self.cumsum[motion_id] - 1
|
| 69 |
+
else:
|
| 70 |
+
motion_id = 0
|
| 71 |
+
idx = 0
|
| 72 |
+
motion = self.data[motion_id][idx:idx + self.window_size]
|
| 73 |
+
"Z Normalization"
|
| 74 |
+
motion = (motion - self.mean) / self.std
|
| 75 |
+
return motion, self.window_size
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Text2MotionDataset(Dataset):
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
mean: np.ndarray,
|
| 83 |
+
std: np.ndarray,
|
| 84 |
+
split_file: str,
|
| 85 |
+
w_vectorizer: WordVectorizer,
|
| 86 |
+
max_motion_length: int,
|
| 87 |
+
min_motion_length: int,
|
| 88 |
+
max_text_len: int,
|
| 89 |
+
unit_length: int,
|
| 90 |
+
motion_dir: str,
|
| 91 |
+
text_dir: str,
|
| 92 |
+
fps: int,
|
| 93 |
+
padding_to_max: bool,
|
| 94 |
+
njoints: int,
|
| 95 |
+
tiny: bool = False,
|
| 96 |
+
progress_bar: bool = True,
|
| 97 |
+
**kwargs,
|
| 98 |
+
) -> None:
|
| 99 |
+
self.w_vectorizer = w_vectorizer
|
| 100 |
+
self.max_motion_length = max_motion_length
|
| 101 |
+
self.min_motion_length = min_motion_length
|
| 102 |
+
self.max_text_len = max_text_len
|
| 103 |
+
self.unit_length = unit_length
|
| 104 |
+
self.padding_to_max = padding_to_max
|
| 105 |
+
self.njoints = njoints
|
| 106 |
+
|
| 107 |
+
data_dict = {}
|
| 108 |
+
id_list = []
|
| 109 |
+
with cs.open(split_file, "r") as f:
|
| 110 |
+
for line in f.readlines():
|
| 111 |
+
id_list.append(line.strip())
|
| 112 |
+
self.id_list = id_list
|
| 113 |
+
|
| 114 |
+
maxdata = 10 if tiny else 1e10
|
| 115 |
+
if progress_bar:
|
| 116 |
+
enumerator = enumerate(
|
| 117 |
+
track(
|
| 118 |
+
id_list,
|
| 119 |
+
f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}",
|
| 120 |
+
))
|
| 121 |
+
else:
|
| 122 |
+
enumerator = enumerate(id_list)
|
| 123 |
+
count = 0
|
| 124 |
+
bad_count = 0
|
| 125 |
+
new_name_list = []
|
| 126 |
+
length_list = []
|
| 127 |
+
for i, name in enumerator:
|
| 128 |
+
if count > maxdata:
|
| 129 |
+
break
|
| 130 |
+
try:
|
| 131 |
+
motion = np.load(pjoin(motion_dir, name + ".npy"))
|
| 132 |
+
if len(motion) < self.min_motion_length or len(motion) >= self.max_motion_length:
|
| 133 |
+
bad_count += 1
|
| 134 |
+
continue
|
| 135 |
+
text_data = []
|
| 136 |
+
flag = False
|
| 137 |
+
with cs.open(pjoin(text_dir, name + ".txt")) as f:
|
| 138 |
+
for line in f.readlines():
|
| 139 |
+
text_dict = {}
|
| 140 |
+
line_split = line.strip().split("#")
|
| 141 |
+
caption = line_split[0]
|
| 142 |
+
tokens = line_split[1].split(" ")
|
| 143 |
+
f_tag = float(line_split[2])
|
| 144 |
+
to_tag = float(line_split[3])
|
| 145 |
+
f_tag = 0.0 if np.isnan(f_tag) else f_tag
|
| 146 |
+
to_tag = 0.0 if np.isnan(to_tag) else to_tag
|
| 147 |
+
|
| 148 |
+
text_dict["caption"] = caption
|
| 149 |
+
text_dict["tokens"] = tokens
|
| 150 |
+
if f_tag == 0.0 and to_tag == 0.0:
|
| 151 |
+
flag = True
|
| 152 |
+
text_data.append(text_dict)
|
| 153 |
+
else:
|
| 154 |
+
try:
|
| 155 |
+
n_motion = motion[int(f_tag * fps): int(to_tag * fps)]
|
| 156 |
+
if (len(n_motion)) < self.min_motion_length or \
|
| 157 |
+
len(n_motion) >= self.max_motion_length:
|
| 158 |
+
continue
|
| 159 |
+
new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name
|
| 160 |
+
while new_name in data_dict:
|
| 161 |
+
new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name
|
| 162 |
+
data_dict[new_name] = {
|
| 163 |
+
"motion": n_motion,
|
| 164 |
+
"length": len(n_motion),
|
| 165 |
+
"text": [text_dict],
|
| 166 |
+
}
|
| 167 |
+
new_name_list.append(new_name)
|
| 168 |
+
length_list.append(len(n_motion))
|
| 169 |
+
except ValueError:
|
| 170 |
+
print(line_split)
|
| 171 |
+
print(line_split[2], line_split[3], f_tag, to_tag, name)
|
| 172 |
+
|
| 173 |
+
if flag:
|
| 174 |
+
data_dict[name] = {
|
| 175 |
+
"motion": motion,
|
| 176 |
+
"length": len(motion),
|
| 177 |
+
"text": text_data,
|
| 178 |
+
}
|
| 179 |
+
new_name_list.append(name)
|
| 180 |
+
length_list.append(len(motion))
|
| 181 |
+
count += 1
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(e)
|
| 184 |
+
pass
|
| 185 |
+
|
| 186 |
+
name_list, length_list = zip(
|
| 187 |
+
*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
|
| 188 |
+
|
| 189 |
+
if not tiny:
|
| 190 |
+
logger.info(f"Reading {len(self.id_list)} motions from {split_file}.")
|
| 191 |
+
logger.info(f"Total {len(name_list)} motions are used.")
|
| 192 |
+
logger.info(f"{bad_count} motion sequences not within the length range of "
|
| 193 |
+
f"[{self.min_motion_length}, {self.max_motion_length}) are filtered out.")
|
| 194 |
+
|
| 195 |
+
self.mean = mean
|
| 196 |
+
self.std = std
|
| 197 |
+
|
| 198 |
+
control_args = kwargs['control_args']
|
| 199 |
+
self.control_mode = None
|
| 200 |
+
if os.path.exists(control_args.MEAN_STD_PATH):
|
| 201 |
+
self.raw_mean = np.load(pjoin(control_args.MEAN_STD_PATH, 'Mean_raw.npy'))
|
| 202 |
+
self.raw_std = np.load(pjoin(control_args.MEAN_STD_PATH, 'Std_raw.npy'))
|
| 203 |
+
else:
|
| 204 |
+
self.raw_mean = self.raw_std = None
|
| 205 |
+
if not tiny and control_args.CONTROL:
|
| 206 |
+
self.t_ctrl = control_args.TEMPORAL
|
| 207 |
+
self.training_control_joints = np.array(control_args.TRAIN_JOINTS)
|
| 208 |
+
self.testing_control_joints = np.array(control_args.TEST_JOINTS)
|
| 209 |
+
self.training_density = control_args.TRAIN_DENSITY
|
| 210 |
+
self.testing_density = control_args.TEST_DENSITY
|
| 211 |
+
|
| 212 |
+
self.control_mode = 'val' if ('test' in split_file or 'val' in split_file) else 'train'
|
| 213 |
+
if self.control_mode == 'train':
|
| 214 |
+
logger.info(f'Training Control Joints: {self.training_control_joints}')
|
| 215 |
+
logger.info(f'Training Control Density: {self.training_density}')
|
| 216 |
+
else:
|
| 217 |
+
logger.info(f'Testing Control Joints: {self.testing_control_joints}')
|
| 218 |
+
logger.info(f'Testing Control Density: {self.testing_density}')
|
| 219 |
+
logger.info(f"Temporal Control: {self.t_ctrl}")
|
| 220 |
+
|
| 221 |
+
self.data_dict = data_dict
|
| 222 |
+
self.name_list = name_list
|
| 223 |
+
|
| 224 |
+
def __len__(self) -> int:
|
| 225 |
+
return len(self.name_list)
|
| 226 |
+
|
| 227 |
+
def random_mask(self, joints: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 228 |
+
choose_joint = self.testing_control_joints
|
| 229 |
+
|
| 230 |
+
length = joints.shape[0]
|
| 231 |
+
density = self.testing_density
|
| 232 |
+
if density in [1, 2, 5]:
|
| 233 |
+
choose_seq_num = density
|
| 234 |
+
else:
|
| 235 |
+
choose_seq_num = int(length * density / 100)
|
| 236 |
+
|
| 237 |
+
if self.t_ctrl:
|
| 238 |
+
choose_seq = np.arange(0, choose_seq_num)
|
| 239 |
+
else:
|
| 240 |
+
choose_seq = np.random.choice(length, choose_seq_num, replace=False)
|
| 241 |
+
choose_seq.sort()
|
| 242 |
+
|
| 243 |
+
mask_seq = np.zeros((length, self.njoints, 3))
|
| 244 |
+
for cj in choose_joint:
|
| 245 |
+
mask_seq[choose_seq, cj] = 1.0
|
| 246 |
+
|
| 247 |
+
joints = (joints - self.raw_mean) / self.raw_std
|
| 248 |
+
joints = joints * mask_seq
|
| 249 |
+
return joints, mask_seq
|
| 250 |
+
|
| 251 |
+
def random_mask_train(self, joints: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 252 |
+
if self.t_ctrl:
|
| 253 |
+
choose_joint = self.training_control_joints
|
| 254 |
+
else:
|
| 255 |
+
num_joints = len(self.training_control_joints)
|
| 256 |
+
num_joints_control = 1
|
| 257 |
+
choose_joint = np.random.choice(num_joints, num_joints_control, replace=False)
|
| 258 |
+
choose_joint = self.training_control_joints[choose_joint]
|
| 259 |
+
|
| 260 |
+
length = joints.shape[0]
|
| 261 |
+
|
| 262 |
+
if self.training_density == 'random':
|
| 263 |
+
choose_seq_num = np.random.choice(length - 1, 1) + 1
|
| 264 |
+
else:
|
| 265 |
+
choose_seq_num = int(length * random.uniform(self.training_density[0], self.training_density[1]) / 100)
|
| 266 |
+
|
| 267 |
+
if self.t_ctrl:
|
| 268 |
+
choose_seq = np.arange(0, choose_seq_num)
|
| 269 |
+
else:
|
| 270 |
+
choose_seq = np.random.choice(length, choose_seq_num, replace=False)
|
| 271 |
+
choose_seq.sort()
|
| 272 |
+
|
| 273 |
+
mask_seq = np.zeros((length, self.njoints, 3))
|
| 274 |
+
for cj in choose_joint:
|
| 275 |
+
mask_seq[choose_seq, cj] = 1
|
| 276 |
+
|
| 277 |
+
joints = (joints - self.raw_mean) / self.raw_std
|
| 278 |
+
joints = joints * mask_seq
|
| 279 |
+
return joints, mask_seq
|
| 280 |
+
|
| 281 |
+
def __getitem__(self, idx: int) -> tuple:
|
| 282 |
+
data = self.data_dict[self.name_list[idx]]
|
| 283 |
+
motion, m_length, text_list = data["motion"], data["length"], data["text"]
|
| 284 |
+
# Randomly select a caption
|
| 285 |
+
text_data = random.choice(text_list)
|
| 286 |
+
caption, tokens = text_data["caption"], text_data["tokens"]
|
| 287 |
+
|
| 288 |
+
if len(tokens) < self.max_text_len:
|
| 289 |
+
# pad with "unk"
|
| 290 |
+
tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
|
| 291 |
+
sent_len = len(tokens)
|
| 292 |
+
tokens = tokens + ["unk/OTHER"] * (self.max_text_len + 2 - sent_len)
|
| 293 |
+
else:
|
| 294 |
+
# crop
|
| 295 |
+
tokens = tokens[:self.max_text_len]
|
| 296 |
+
tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
|
| 297 |
+
sent_len = len(tokens)
|
| 298 |
+
pos_one_hots = []
|
| 299 |
+
word_embeddings = []
|
| 300 |
+
for token in tokens:
|
| 301 |
+
word_emb, pos_oh = self.w_vectorizer[token]
|
| 302 |
+
pos_one_hots.append(pos_oh[None, :])
|
| 303 |
+
word_embeddings.append(word_emb[None, :])
|
| 304 |
+
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
|
| 305 |
+
word_embeddings = np.concatenate(word_embeddings, axis=0)
|
| 306 |
+
|
| 307 |
+
# Crop the motions in to times of 4, and introduce small variations
|
| 308 |
+
if self.unit_length < 10:
|
| 309 |
+
coin2 = np.random.choice(["single", "single", "double"])
|
| 310 |
+
else:
|
| 311 |
+
coin2 = "single"
|
| 312 |
+
|
| 313 |
+
if coin2 == "double":
|
| 314 |
+
m_length = (m_length // self.unit_length - 1) * self.unit_length
|
| 315 |
+
elif coin2 == "single":
|
| 316 |
+
m_length = (m_length // self.unit_length) * self.unit_length
|
| 317 |
+
idx = random.randint(0, len(motion) - m_length)
|
| 318 |
+
motion = motion[idx:idx + m_length]
|
| 319 |
+
|
| 320 |
+
hint, hint_mask = None, None
|
| 321 |
+
if self.control_mode is not None:
|
| 322 |
+
joints = recover_from_ric(torch.from_numpy(motion).float(), self.njoints)
|
| 323 |
+
joints = joints.numpy()
|
| 324 |
+
if self.control_mode == 'train':
|
| 325 |
+
hint, hint_mask = self.random_mask_train(joints)
|
| 326 |
+
else:
|
| 327 |
+
hint, hint_mask = self.random_mask(joints)
|
| 328 |
+
|
| 329 |
+
if self.padding_to_max:
|
| 330 |
+
padding = np.zeros((self.max_motion_length - m_length, *hint.shape[1:]))
|
| 331 |
+
hint = np.concatenate([hint, padding], axis=0)
|
| 332 |
+
hint_mask = np.concatenate([hint_mask, padding], axis=0)
|
| 333 |
+
|
| 334 |
+
"Z Normalization"
|
| 335 |
+
motion = (motion - self.mean) / self.std
|
| 336 |
+
|
| 337 |
+
if self.padding_to_max:
|
| 338 |
+
padding = np.zeros((self.max_motion_length - m_length, motion.shape[1]))
|
| 339 |
+
motion = np.concatenate([motion, padding], axis=0)
|
| 340 |
+
|
| 341 |
+
return (word_embeddings,
|
| 342 |
+
pos_one_hots,
|
| 343 |
+
caption,
|
| 344 |
+
sent_len,
|
| 345 |
+
motion,
|
| 346 |
+
m_length,
|
| 347 |
+
"_".join(tokens),
|
| 348 |
+
(hint, hint_mask))
|
mld/data/humanml/scripts/motion_process.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ..common.quaternion import qinv, qrot
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Recover global angle and positions for rotation dataset
|
| 7 |
+
# root_rot_velocity (B, seq_len, 1)
|
| 8 |
+
# root_linear_velocity (B, seq_len, 2)
|
| 9 |
+
# root_y (B, seq_len, 1)
|
| 10 |
+
# ric_data (B, seq_len, (joint_num - 1)*3)
|
| 11 |
+
# rot_data (B, seq_len, (joint_num - 1)*6)
|
| 12 |
+
# local_velocity (B, seq_len, joint_num*3)
|
| 13 |
+
# foot contact (B, seq_len, 4)
|
| 14 |
+
def recover_root_rot_pos(data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 15 |
+
rot_vel = data[..., 0]
|
| 16 |
+
r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
|
| 17 |
+
'''Get Y-axis rotation from rotation velocity'''
|
| 18 |
+
r_rot_ang[..., 1:] = rot_vel[..., :-1]
|
| 19 |
+
r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
|
| 20 |
+
|
| 21 |
+
r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
|
| 22 |
+
r_rot_quat[..., 0] = torch.cos(r_rot_ang)
|
| 23 |
+
r_rot_quat[..., 2] = torch.sin(r_rot_ang)
|
| 24 |
+
|
| 25 |
+
r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
|
| 26 |
+
r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
|
| 27 |
+
'''Add Y-axis rotation to root position'''
|
| 28 |
+
r_pos = qrot(qinv(r_rot_quat), r_pos)
|
| 29 |
+
|
| 30 |
+
r_pos = torch.cumsum(r_pos, dim=-2)
|
| 31 |
+
|
| 32 |
+
r_pos[..., 1] = data[..., 3]
|
| 33 |
+
return r_rot_quat, r_pos
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def recover_from_ric(data: torch.Tensor, joints_num: int) -> torch.Tensor:
|
| 37 |
+
r_rot_quat, r_pos = recover_root_rot_pos(data)
|
| 38 |
+
positions = data[..., 4:(joints_num - 1) * 3 + 4]
|
| 39 |
+
positions = positions.view(positions.shape[:-1] + (-1, 3))
|
| 40 |
+
|
| 41 |
+
'''Add Y-axis rotation to local joints'''
|
| 42 |
+
positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
|
| 43 |
+
|
| 44 |
+
'''Add root XZ to joints'''
|
| 45 |
+
positions[..., 0] += r_pos[..., 0:1]
|
| 46 |
+
positions[..., 2] += r_pos[..., 2:3]
|
| 47 |
+
|
| 48 |
+
'''Concat root and joints'''
|
| 49 |
+
positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
|
| 50 |
+
|
| 51 |
+
return positions
|
mld/data/humanml/utils/__init__.py
ADDED
|
File without changes
|
mld/data/humanml/utils/paramUtil.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
# Define a kinematic tree for the skeletal structure
|
| 4 |
+
kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
|
| 5 |
+
|
| 6 |
+
kit_raw_offsets = np.array(
|
| 7 |
+
[
|
| 8 |
+
[0, 0, 0],
|
| 9 |
+
[0, 1, 0],
|
| 10 |
+
[0, 1, 0],
|
| 11 |
+
[0, 1, 0],
|
| 12 |
+
[0, 1, 0],
|
| 13 |
+
[1, 0, 0],
|
| 14 |
+
[0, -1, 0],
|
| 15 |
+
[0, -1, 0],
|
| 16 |
+
[-1, 0, 0],
|
| 17 |
+
[0, -1, 0],
|
| 18 |
+
[0, -1, 0],
|
| 19 |
+
[1, 0, 0],
|
| 20 |
+
[0, -1, 0],
|
| 21 |
+
[0, -1, 0],
|
| 22 |
+
[0, 0, 1],
|
| 23 |
+
[0, 0, 1],
|
| 24 |
+
[-1, 0, 0],
|
| 25 |
+
[0, -1, 0],
|
| 26 |
+
[0, -1, 0],
|
| 27 |
+
[0, 0, 1],
|
| 28 |
+
[0, 0, 1]
|
| 29 |
+
]
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
t2m_raw_offsets = np.array([[0, 0, 0],
|
| 33 |
+
[1, 0, 0],
|
| 34 |
+
[-1, 0, 0],
|
| 35 |
+
[0, 1, 0],
|
| 36 |
+
[0, -1, 0],
|
| 37 |
+
[0, -1, 0],
|
| 38 |
+
[0, 1, 0],
|
| 39 |
+
[0, -1, 0],
|
| 40 |
+
[0, -1, 0],
|
| 41 |
+
[0, 1, 0],
|
| 42 |
+
[0, 0, 1],
|
| 43 |
+
[0, 0, 1],
|
| 44 |
+
[0, 1, 0],
|
| 45 |
+
[1, 0, 0],
|
| 46 |
+
[-1, 0, 0],
|
| 47 |
+
[0, 0, 1],
|
| 48 |
+
[0, -1, 0],
|
| 49 |
+
[0, -1, 0],
|
| 50 |
+
[0, -1, 0],
|
| 51 |
+
[0, -1, 0],
|
| 52 |
+
[0, -1, 0],
|
| 53 |
+
[0, -1, 0]])
|
| 54 |
+
|
| 55 |
+
t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21],
|
| 56 |
+
[9, 13, 16, 18, 20]]
|
| 57 |
+
t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
|
| 58 |
+
t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
|
| 59 |
+
|
| 60 |
+
kit_tgt_skel_id = '03950'
|
| 61 |
+
|
| 62 |
+
t2m_tgt_skel_id = '000021'
|
mld/data/humanml/utils/plot_script.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from textwrap import wrap
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import mpl_toolkits.mplot3d.axes3d as p3
|
| 8 |
+
from matplotlib.animation import FuncAnimation
|
| 9 |
+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
| 10 |
+
|
| 11 |
+
import mld.data.humanml.utils.paramUtil as paramUtil
|
| 12 |
+
|
| 13 |
+
skeleton = paramUtil.t2m_kinematic_chain
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def plot_3d_motion(save_path: str, joints: np.ndarray, title: str,
|
| 17 |
+
figsize: tuple[int, int] = (3, 3),
|
| 18 |
+
fps: int = 120, radius: int = 3, kinematic_tree: list = skeleton,
|
| 19 |
+
hint: Optional[np.ndarray] = None) -> None:
|
| 20 |
+
|
| 21 |
+
title = '\n'.join(wrap(title, 20))
|
| 22 |
+
|
| 23 |
+
def init():
|
| 24 |
+
ax.set_xlim3d([-radius / 2, radius / 2])
|
| 25 |
+
ax.set_ylim3d([0, radius])
|
| 26 |
+
ax.set_zlim3d([-radius / 3., radius * 2 / 3.])
|
| 27 |
+
fig.suptitle(title, fontsize=10)
|
| 28 |
+
ax.grid(b=False)
|
| 29 |
+
|
| 30 |
+
def plot_xzPlane(minx, maxx, miny, minz, maxz):
|
| 31 |
+
# Plot a plane XZ
|
| 32 |
+
verts = [
|
| 33 |
+
[minx, miny, minz],
|
| 34 |
+
[minx, miny, maxz],
|
| 35 |
+
[maxx, miny, maxz],
|
| 36 |
+
[maxx, miny, minz]
|
| 37 |
+
]
|
| 38 |
+
xz_plane = Poly3DCollection([verts])
|
| 39 |
+
xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
|
| 40 |
+
ax.add_collection3d(xz_plane)
|
| 41 |
+
|
| 42 |
+
# (seq_len, joints_num, 3)
|
| 43 |
+
data = joints.copy().reshape(len(joints), -1, 3)
|
| 44 |
+
|
| 45 |
+
data *= 1.3 # scale for visualization
|
| 46 |
+
if hint is not None:
|
| 47 |
+
mask = hint.sum(-1) != 0
|
| 48 |
+
hint = hint[mask]
|
| 49 |
+
hint *= 1.3
|
| 50 |
+
|
| 51 |
+
fig = plt.figure(figsize=figsize)
|
| 52 |
+
plt.tight_layout()
|
| 53 |
+
ax = p3.Axes3D(fig)
|
| 54 |
+
init()
|
| 55 |
+
MINS = data.min(axis=0).min(axis=0)
|
| 56 |
+
MAXS = data.max(axis=0).max(axis=0)
|
| 57 |
+
colors = ["#DD5A37", "#D69E00", "#B75A39", "#DD5A37", "#D69E00",
|
| 58 |
+
"#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00",
|
| 59 |
+
"#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", ]
|
| 60 |
+
|
| 61 |
+
frame_number = data.shape[0]
|
| 62 |
+
|
| 63 |
+
height_offset = MINS[1]
|
| 64 |
+
data[:, :, 1] -= height_offset
|
| 65 |
+
if hint is not None:
|
| 66 |
+
hint[..., 1] -= height_offset
|
| 67 |
+
trajec = data[:, 0, [0, 2]]
|
| 68 |
+
|
| 69 |
+
data[..., 0] -= data[:, 0:1, 0]
|
| 70 |
+
data[..., 2] -= data[:, 0:1, 2]
|
| 71 |
+
|
| 72 |
+
def update(index):
|
| 73 |
+
ax.lines = []
|
| 74 |
+
ax.collections = []
|
| 75 |
+
ax.view_init(elev=120, azim=-90)
|
| 76 |
+
ax.dist = 7.5
|
| 77 |
+
plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
|
| 78 |
+
MAXS[2] - trajec[index, 1])
|
| 79 |
+
|
| 80 |
+
if hint is not None:
|
| 81 |
+
ax.scatter(hint[..., 0] - trajec[index, 0], hint[..., 1], hint[..., 2] - trajec[index, 1], color="#80B79A")
|
| 82 |
+
|
| 83 |
+
for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
|
| 84 |
+
if i < 5:
|
| 85 |
+
linewidth = 4.0
|
| 86 |
+
else:
|
| 87 |
+
linewidth = 2.0
|
| 88 |
+
ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
|
| 89 |
+
color=color)
|
| 90 |
+
|
| 91 |
+
plt.axis('off')
|
| 92 |
+
ax.set_xticklabels([])
|
| 93 |
+
ax.set_yticklabels([])
|
| 94 |
+
ax.set_zticklabels([])
|
| 95 |
+
|
| 96 |
+
ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False)
|
| 97 |
+
ani.save(save_path, fps=fps)
|
| 98 |
+
plt.close()
|
mld/data/humanml/utils/word_vectorizer.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
from os.path import join as pjoin
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
POS_enumerator = {
|
| 8 |
+
'VERB': 0,
|
| 9 |
+
'NOUN': 1,
|
| 10 |
+
'DET': 2,
|
| 11 |
+
'ADP': 3,
|
| 12 |
+
'NUM': 4,
|
| 13 |
+
'AUX': 5,
|
| 14 |
+
'PRON': 6,
|
| 15 |
+
'ADJ': 7,
|
| 16 |
+
'ADV': 8,
|
| 17 |
+
'Loc_VIP': 9,
|
| 18 |
+
'Body_VIP': 10,
|
| 19 |
+
'Obj_VIP': 11,
|
| 20 |
+
'Act_VIP': 12,
|
| 21 |
+
'Desc_VIP': 13,
|
| 22 |
+
'OTHER': 14
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
|
| 26 |
+
'up', 'down', 'straight', 'curve')
|
| 27 |
+
|
| 28 |
+
Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
|
| 29 |
+
|
| 30 |
+
Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
|
| 31 |
+
|
| 32 |
+
Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
|
| 33 |
+
'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
|
| 34 |
+
'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
|
| 35 |
+
|
| 36 |
+
Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
|
| 37 |
+
'angrily', 'sadly')
|
| 38 |
+
|
| 39 |
+
VIP_dict = {
|
| 40 |
+
'Loc_VIP': Loc_list,
|
| 41 |
+
'Body_VIP': Body_list,
|
| 42 |
+
'Obj_VIP': Obj_List,
|
| 43 |
+
'Act_VIP': Act_list,
|
| 44 |
+
'Desc_VIP': Desc_list,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class WordVectorizer(object):
|
| 49 |
+
def __init__(self, meta_root: str, prefix: str) -> None:
|
| 50 |
+
vectors = np.load(pjoin(meta_root, '%s_data.npy' % prefix))
|
| 51 |
+
words = pickle.load(open(pjoin(meta_root, '%s_words.pkl' % prefix), 'rb'))
|
| 52 |
+
word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl' % prefix), 'rb'))
|
| 53 |
+
self.word2vec = {w: vectors[word2idx[w]] for w in words}
|
| 54 |
+
|
| 55 |
+
def _get_pos_ohot(self, pos: str) -> np.ndarray:
|
| 56 |
+
pos_vec = np.zeros(len(POS_enumerator))
|
| 57 |
+
if pos in POS_enumerator:
|
| 58 |
+
pos_vec[POS_enumerator[pos]] = 1
|
| 59 |
+
else:
|
| 60 |
+
pos_vec[POS_enumerator['OTHER']] = 1
|
| 61 |
+
return pos_vec
|
| 62 |
+
|
| 63 |
+
def __len__(self) -> int:
|
| 64 |
+
return len(self.word2vec)
|
| 65 |
+
|
| 66 |
+
def __getitem__(self, item: str) -> tuple:
|
| 67 |
+
word, pos = item.split('/')
|
| 68 |
+
if word in self.word2vec:
|
| 69 |
+
word_vec = self.word2vec[word]
|
| 70 |
+
vip_pos = None
|
| 71 |
+
for key, values in VIP_dict.items():
|
| 72 |
+
if word in values:
|
| 73 |
+
vip_pos = key
|
| 74 |
+
break
|
| 75 |
+
if vip_pos is not None:
|
| 76 |
+
pos_vec = self._get_pos_ohot(vip_pos)
|
| 77 |
+
else:
|
| 78 |
+
pos_vec = self._get_pos_ohot(pos)
|
| 79 |
+
else:
|
| 80 |
+
word_vec = self.word2vec['unk']
|
| 81 |
+
pos_vec = self._get_pos_ohot('OTHER')
|
| 82 |
+
return word_vec, pos_vec
|
mld/data/utils.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from mld.utils.temos_utils import lengths_to_mask
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def collate_tensors(batch: list) -> torch.Tensor:
|
| 7 |
+
dims = batch[0].dim()
|
| 8 |
+
max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
|
| 9 |
+
size = (len(batch), ) + tuple(max_size)
|
| 10 |
+
canvas = batch[0].new_zeros(size=size)
|
| 11 |
+
for i, b in enumerate(batch):
|
| 12 |
+
sub_tensor = canvas[i]
|
| 13 |
+
for d in range(dims):
|
| 14 |
+
sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
|
| 15 |
+
sub_tensor.add_(b)
|
| 16 |
+
return canvas
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def mld_collate(batch: list) -> dict:
|
| 20 |
+
notnone_batches = [b for b in batch if b is not None]
|
| 21 |
+
notnone_batches.sort(key=lambda x: x[3], reverse=True)
|
| 22 |
+
adapted_batch = {
|
| 23 |
+
"motion":
|
| 24 |
+
collate_tensors([torch.tensor(b[4]).float() for b in notnone_batches]),
|
| 25 |
+
"text": [b[2] for b in notnone_batches],
|
| 26 |
+
"length": [b[5] for b in notnone_batches],
|
| 27 |
+
"word_embs":
|
| 28 |
+
collate_tensors([torch.tensor(b[0]).float() for b in notnone_batches]),
|
| 29 |
+
"pos_ohot":
|
| 30 |
+
collate_tensors([torch.tensor(b[1]).float() for b in notnone_batches]),
|
| 31 |
+
"text_len":
|
| 32 |
+
collate_tensors([torch.tensor(b[3]) for b in notnone_batches]),
|
| 33 |
+
"tokens": [b[6] for b in notnone_batches]
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
mask = lengths_to_mask(adapted_batch['length'], adapted_batch['motion'].device, adapted_batch['motion'].shape[1])
|
| 37 |
+
adapted_batch['mask'] = mask
|
| 38 |
+
|
| 39 |
+
# collate trajectory
|
| 40 |
+
if notnone_batches[0][-1][0] is not None:
|
| 41 |
+
adapted_batch['hint'] = collate_tensors([torch.tensor(b[-1][0]).float() for b in notnone_batches])
|
| 42 |
+
adapted_batch['hint_mask'] = collate_tensors([torch.tensor(b[-1][1]).float() for b in notnone_batches])
|
| 43 |
+
|
| 44 |
+
return adapted_batch
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def mld_collate_motion_only(batch: list) -> dict:
|
| 48 |
+
batch = {
|
| 49 |
+
"motion": collate_tensors([torch.tensor(b[0]).float() for b in batch]),
|
| 50 |
+
"length": [b[1] for b in batch]
|
| 51 |
+
}
|
| 52 |
+
return batch
|
mld/launch/__init__.py
ADDED
|
File without changes
|
mld/launch/blender.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fix blender path
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from argparse import ArgumentParser
|
| 5 |
+
|
| 6 |
+
sys.path.insert(0, os.path.expanduser("~/.local/lib/python3.9/site-packages"))
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Monkey patch argparse such that
|
| 10 |
+
# blender / python parsing works
|
| 11 |
+
def parse_args(self, args=None, namespace=None):
|
| 12 |
+
if args is not None:
|
| 13 |
+
return self.parse_args_bak(args=args, namespace=namespace)
|
| 14 |
+
try:
|
| 15 |
+
idx = sys.argv.index("--")
|
| 16 |
+
args = sys.argv[idx + 1:] # the list after '--'
|
| 17 |
+
except ValueError as e: # '--' not in the list:
|
| 18 |
+
args = []
|
| 19 |
+
return self.parse_args_bak(args=args, namespace=namespace)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
setattr(ArgumentParser, 'parse_args_bak', ArgumentParser.parse_args)
|
| 23 |
+
setattr(ArgumentParser, 'parse_args', parse_args)
|
mld/models/__init__.py
ADDED
|
File without changes
|
mld/models/architectures/__init__.py
ADDED
|
File without changes
|
mld/models/architectures/dno.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DNO(object):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
optimize: bool,
|
| 12 |
+
max_train_steps: int,
|
| 13 |
+
learning_rate: float,
|
| 14 |
+
lr_scheduler: str,
|
| 15 |
+
lr_warmup_steps: int,
|
| 16 |
+
clip_grad: bool,
|
| 17 |
+
loss_hint_type: str,
|
| 18 |
+
loss_diff_penalty: float,
|
| 19 |
+
loss_correlate_penalty: float,
|
| 20 |
+
visualize_samples: int,
|
| 21 |
+
visualize_ske_steps: list[int],
|
| 22 |
+
output_dir: str
|
| 23 |
+
) -> None:
|
| 24 |
+
|
| 25 |
+
self.optimize = optimize
|
| 26 |
+
self.max_train_steps = max_train_steps
|
| 27 |
+
self.learning_rate = learning_rate
|
| 28 |
+
self.lr_scheduler = lr_scheduler
|
| 29 |
+
self.lr_warmup_steps = lr_warmup_steps
|
| 30 |
+
self.clip_grad = clip_grad
|
| 31 |
+
self.loss_hint_type = loss_hint_type
|
| 32 |
+
self.loss_diff_penalty = loss_diff_penalty
|
| 33 |
+
self.loss_correlate_penalty = loss_correlate_penalty
|
| 34 |
+
|
| 35 |
+
if loss_hint_type == 'l1':
|
| 36 |
+
self.loss_hint_func = F.l1_loss
|
| 37 |
+
elif loss_hint_type == 'l1_smooth':
|
| 38 |
+
self.loss_hint_func = F.smooth_l1_loss
|
| 39 |
+
elif loss_hint_type == 'l2':
|
| 40 |
+
self.loss_hint_func = F.mse_loss
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError(f'Invalid loss type: {loss_hint_type}')
|
| 43 |
+
|
| 44 |
+
self.visualize_samples = float('inf') if visualize_samples == 'inf' else visualize_samples
|
| 45 |
+
assert self.visualize_samples >= 0
|
| 46 |
+
self.visualize_samples_done = 0
|
| 47 |
+
self.visualize_ske_steps = visualize_ske_steps
|
| 48 |
+
if len(visualize_ske_steps) > 0:
|
| 49 |
+
self.vis_dir = os.path.join(output_dir, 'vis_optimize')
|
| 50 |
+
os.makedirs(self.vis_dir)
|
| 51 |
+
|
| 52 |
+
self.writer = None
|
| 53 |
+
self.output_dir = output_dir
|
| 54 |
+
if self.visualize_samples > 0:
|
| 55 |
+
self.writer = SummaryWriter(output_dir)
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def do_visualize(self):
|
| 59 |
+
return self.visualize_samples_done < self.visualize_samples
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def noise_regularize_1d(noise: torch.Tensor, stop_at: int = 2, dim: int = 1) -> torch.Tensor:
|
| 63 |
+
size = noise.shape[dim]
|
| 64 |
+
if size & (size - 1) != 0:
|
| 65 |
+
new_size = 2 ** (size - 1).bit_length()
|
| 66 |
+
pad = new_size - size
|
| 67 |
+
pad_shape = list(noise.shape)
|
| 68 |
+
pad_shape[dim] = pad
|
| 69 |
+
pad_noise = torch.randn(*pad_shape, device=noise.device)
|
| 70 |
+
noise = torch.cat([noise, pad_noise], dim=dim)
|
| 71 |
+
size = noise.shape[dim]
|
| 72 |
+
|
| 73 |
+
loss = torch.zeros(noise.shape[0], device=noise.device)
|
| 74 |
+
while size > stop_at:
|
| 75 |
+
rolled_noise = torch.roll(noise, shifts=1, dims=dim)
|
| 76 |
+
loss += (noise * rolled_noise).mean(dim=tuple(range(1, noise.ndim))).pow(2)
|
| 77 |
+
noise = noise.view(*noise.shape[:dim], size // 2, 2, *noise.shape[dim + 1:]).mean(dim=dim + 1)
|
| 78 |
+
size //= 2
|
| 79 |
+
return loss
|
mld/models/architectures/mld_clip.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from transformers import AutoModel, AutoTokenizer
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MldTextEncoder(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, modelpath: str, last_hidden_state: bool = False) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
if 't5' in modelpath:
|
| 14 |
+
self.text_model = SentenceTransformer(modelpath)
|
| 15 |
+
self.tokenizer = self.text_model.tokenizer
|
| 16 |
+
else:
|
| 17 |
+
self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
|
| 18 |
+
self.text_model = AutoModel.from_pretrained(modelpath)
|
| 19 |
+
|
| 20 |
+
self.max_length = self.tokenizer.model_max_length
|
| 21 |
+
if "clip" in modelpath:
|
| 22 |
+
self.text_encoded_dim = self.text_model.config.text_config.hidden_size
|
| 23 |
+
if last_hidden_state:
|
| 24 |
+
self.name = "clip_hidden"
|
| 25 |
+
else:
|
| 26 |
+
self.name = "clip"
|
| 27 |
+
elif "bert" in modelpath:
|
| 28 |
+
self.name = "bert"
|
| 29 |
+
self.text_encoded_dim = self.text_model.config.hidden_size
|
| 30 |
+
elif 't5' in modelpath:
|
| 31 |
+
self.name = 't5'
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"Model {modelpath} not supported")
|
| 34 |
+
|
| 35 |
+
def forward(self, texts: list[str]) -> torch.Tensor:
|
| 36 |
+
# get prompt text embeddings
|
| 37 |
+
if self.name in ["clip", "clip_hidden"]:
|
| 38 |
+
text_inputs = self.tokenizer(
|
| 39 |
+
texts,
|
| 40 |
+
padding="max_length",
|
| 41 |
+
truncation=True,
|
| 42 |
+
max_length=self.max_length,
|
| 43 |
+
return_tensors="pt",
|
| 44 |
+
)
|
| 45 |
+
text_input_ids = text_inputs.input_ids
|
| 46 |
+
# split into max length Clip can handle
|
| 47 |
+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
| 48 |
+
text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length]
|
| 49 |
+
elif self.name == "bert":
|
| 50 |
+
text_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
| 51 |
+
|
| 52 |
+
if self.name == "clip":
|
| 53 |
+
# (batch_Size, text_encoded_dim)
|
| 54 |
+
text_embeddings = self.text_model.get_text_features(
|
| 55 |
+
text_input_ids.to(self.text_model.device))
|
| 56 |
+
# (batch_Size, 1, text_encoded_dim)
|
| 57 |
+
text_embeddings = text_embeddings.unsqueeze(1)
|
| 58 |
+
elif self.name == "clip_hidden":
|
| 59 |
+
# (batch_Size, seq_length , text_encoded_dim)
|
| 60 |
+
text_embeddings = self.text_model.text_model(
|
| 61 |
+
text_input_ids.to(self.text_model.device)).last_hidden_state
|
| 62 |
+
elif self.name == "bert":
|
| 63 |
+
# (batch_Size, seq_length , text_encoded_dim)
|
| 64 |
+
text_embeddings = self.text_model(
|
| 65 |
+
**text_inputs.to(self.text_model.device)).last_hidden_state
|
| 66 |
+
elif self.name == 't5':
|
| 67 |
+
text_embeddings = self.text_model.encode(texts, show_progress_bar=False, convert_to_tensor=True, batch_size=len(texts))
|
| 68 |
+
text_embeddings = text_embeddings.unsqueeze(1)
|
| 69 |
+
else:
|
| 70 |
+
raise NotImplementedError(f"Model {self.name} not implemented")
|
| 71 |
+
|
| 72 |
+
return text_embeddings
|
mld/models/architectures/mld_denoiser.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from mld.models.operator.embeddings import TimestepEmbedding, Timesteps
|
| 7 |
+
from mld.models.operator.attention import (SkipTransformerEncoder,
|
| 8 |
+
SkipTransformerDecoder,
|
| 9 |
+
TransformerDecoder,
|
| 10 |
+
TransformerDecoderLayer,
|
| 11 |
+
TransformerEncoder,
|
| 12 |
+
TransformerEncoderLayer)
|
| 13 |
+
from mld.models.operator.moe import MoeTransformerEncoderLayer, MoeTransformerDecoderLayer
|
| 14 |
+
from mld.models.operator.utils import get_clones, get_activation_fn, zero_module
|
| 15 |
+
from mld.models.operator.position_encoding import build_position_encoding
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_balancing_loss_func(router_logits: tuple, num_experts: int = 4, topk: int = 2):
|
| 19 |
+
router_logits = torch.cat(router_logits, dim=0)
|
| 20 |
+
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1)
|
| 21 |
+
_, selected_experts = torch.topk(routing_weights, topk, dim=-1)
|
| 22 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
| 23 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 24 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 25 |
+
overall_loss = num_experts * torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 26 |
+
return overall_loss
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class MldDenoiser(nn.Module):
|
| 30 |
+
|
| 31 |
+
def __init__(self,
|
| 32 |
+
latent_dim: list = [1, 256],
|
| 33 |
+
hidden_dim: Optional[int] = None,
|
| 34 |
+
text_dim: int = 768,
|
| 35 |
+
time_dim: int = 768,
|
| 36 |
+
ff_size: int = 1024,
|
| 37 |
+
num_layers: int = 9,
|
| 38 |
+
num_heads: int = 4,
|
| 39 |
+
dropout: float = 0.1,
|
| 40 |
+
normalize_before: bool = False,
|
| 41 |
+
norm_eps: float = 1e-5,
|
| 42 |
+
activation: str = "gelu",
|
| 43 |
+
norm_post: bool = True,
|
| 44 |
+
activation_post: Optional[str] = None,
|
| 45 |
+
flip_sin_to_cos: bool = True,
|
| 46 |
+
freq_shift: float = 0,
|
| 47 |
+
time_act_fn: str = 'silu',
|
| 48 |
+
time_post_act_fn: Optional[str] = None,
|
| 49 |
+
position_embedding: str = "learned",
|
| 50 |
+
arch: str = "trans_enc",
|
| 51 |
+
add_mem_pos: bool = True,
|
| 52 |
+
force_pre_post_proj: bool = False,
|
| 53 |
+
text_act_fn: str = 'relu',
|
| 54 |
+
time_cond_proj_dim: Optional[int] = None,
|
| 55 |
+
zero_init_cond: bool = True,
|
| 56 |
+
is_controlnet: bool = False,
|
| 57 |
+
controlnet_embed_dim: Optional[int] = None,
|
| 58 |
+
controlnet_act_fn: str = 'silu',
|
| 59 |
+
moe: bool = False,
|
| 60 |
+
moe_num_experts: int = 4,
|
| 61 |
+
moe_topk: int = 2,
|
| 62 |
+
moe_loss_weight: float = 1e-2,
|
| 63 |
+
moe_jitter_noise: Optional[float] = None
|
| 64 |
+
) -> None:
|
| 65 |
+
super(MldDenoiser, self).__init__()
|
| 66 |
+
|
| 67 |
+
self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim
|
| 68 |
+
add_pre_post_proj = force_pre_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1])
|
| 69 |
+
self.latent_pre = nn.Linear(latent_dim[-1], self.latent_dim) if add_pre_post_proj else nn.Identity()
|
| 70 |
+
self.latent_post = nn.Linear(self.latent_dim, latent_dim[-1]) if add_pre_post_proj else nn.Identity()
|
| 71 |
+
|
| 72 |
+
self.arch = arch
|
| 73 |
+
self.time_cond_proj_dim = time_cond_proj_dim
|
| 74 |
+
|
| 75 |
+
self.moe_num_experts = moe_num_experts
|
| 76 |
+
self.moe_topk = moe_topk
|
| 77 |
+
self.moe_loss_weight = moe_loss_weight
|
| 78 |
+
|
| 79 |
+
self.time_proj = Timesteps(time_dim, flip_sin_to_cos, freq_shift)
|
| 80 |
+
self.time_embedding = TimestepEmbedding(time_dim, self.latent_dim, time_act_fn, post_act_fn=time_post_act_fn,
|
| 81 |
+
cond_proj_dim=time_cond_proj_dim, zero_init_cond=zero_init_cond)
|
| 82 |
+
self.emb_proj = nn.Sequential(get_activation_fn(text_act_fn), nn.Linear(text_dim, self.latent_dim))
|
| 83 |
+
|
| 84 |
+
self.query_pos = build_position_encoding(self.latent_dim, position_embedding=position_embedding)
|
| 85 |
+
if self.arch == "trans_enc":
|
| 86 |
+
if moe:
|
| 87 |
+
encoder_layer = MoeTransformerEncoderLayer(
|
| 88 |
+
self.latent_dim, num_heads, moe_num_experts, moe_topk, ff_size,
|
| 89 |
+
dropout, activation, normalize_before, norm_eps, moe_jitter_noise)
|
| 90 |
+
else:
|
| 91 |
+
encoder_layer = TransformerEncoderLayer(
|
| 92 |
+
self.latent_dim, num_heads, ff_size, dropout,
|
| 93 |
+
activation, normalize_before, norm_eps)
|
| 94 |
+
|
| 95 |
+
encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post and not is_controlnet else None
|
| 96 |
+
self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post,
|
| 97 |
+
is_controlnet=is_controlnet, is_moe=moe)
|
| 98 |
+
|
| 99 |
+
elif self.arch == 'trans_dec':
|
| 100 |
+
if add_mem_pos:
|
| 101 |
+
self.mem_pos = build_position_encoding(self.latent_dim, position_embedding=position_embedding)
|
| 102 |
+
else:
|
| 103 |
+
self.mem_pos = None
|
| 104 |
+
if moe:
|
| 105 |
+
decoder_layer = MoeTransformerDecoderLayer(
|
| 106 |
+
self.latent_dim, num_heads, moe_num_experts, moe_topk, ff_size,
|
| 107 |
+
dropout, activation, normalize_before, norm_eps, moe_jitter_noise)
|
| 108 |
+
else:
|
| 109 |
+
decoder_layer = TransformerDecoderLayer(
|
| 110 |
+
self.latent_dim, num_heads, ff_size, dropout,
|
| 111 |
+
activation, normalize_before, norm_eps)
|
| 112 |
+
|
| 113 |
+
decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post and not is_controlnet else None
|
| 114 |
+
self.decoder = SkipTransformerDecoder(decoder_layer, num_layers, decoder_norm, activation_post,
|
| 115 |
+
is_controlnet=is_controlnet, is_moe=moe)
|
| 116 |
+
else:
|
| 117 |
+
raise ValueError(f"Not supported architecture: {self.arch}!")
|
| 118 |
+
|
| 119 |
+
self.is_controlnet = is_controlnet
|
| 120 |
+
if self.is_controlnet:
|
| 121 |
+
embed_dim = controlnet_embed_dim if controlnet_embed_dim is not None else self.latent_dim
|
| 122 |
+
modules = [
|
| 123 |
+
nn.Linear(latent_dim[-1], embed_dim),
|
| 124 |
+
get_activation_fn(controlnet_act_fn) if controlnet_act_fn else None,
|
| 125 |
+
nn.Linear(embed_dim, embed_dim),
|
| 126 |
+
get_activation_fn(controlnet_act_fn) if controlnet_act_fn else None,
|
| 127 |
+
zero_module(nn.Linear(embed_dim, latent_dim[-1]))
|
| 128 |
+
]
|
| 129 |
+
self.controlnet_cond_embedding = nn.Sequential(*[m for m in modules if m is not None])
|
| 130 |
+
|
| 131 |
+
self.controlnet_down_mid_blocks = nn.ModuleList([
|
| 132 |
+
zero_module(nn.Linear(self.latent_dim, self.latent_dim)) for _ in range(num_layers)])
|
| 133 |
+
|
| 134 |
+
def forward(self,
|
| 135 |
+
sample: torch.Tensor,
|
| 136 |
+
timestep: torch.Tensor,
|
| 137 |
+
encoder_hidden_states: torch.Tensor,
|
| 138 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 139 |
+
controlnet_cond: Optional[torch.Tensor] = None,
|
| 140 |
+
controlnet_residuals: Optional[list[torch.Tensor]] = None
|
| 141 |
+
) -> tuple:
|
| 142 |
+
|
| 143 |
+
# 0. check if controlnet
|
| 144 |
+
if self.is_controlnet:
|
| 145 |
+
sample = sample + self.controlnet_cond_embedding(controlnet_cond)
|
| 146 |
+
|
| 147 |
+
# 1. dimension matching (pre)
|
| 148 |
+
sample = sample.permute(1, 0, 2)
|
| 149 |
+
sample = self.latent_pre(sample)
|
| 150 |
+
|
| 151 |
+
# 2. time_embedding
|
| 152 |
+
timesteps = timestep.expand(sample.shape[1]).clone()
|
| 153 |
+
time_emb = self.time_proj(timesteps)
|
| 154 |
+
time_emb = time_emb.to(dtype=sample.dtype)
|
| 155 |
+
# [1, bs, latent_dim] <= [bs, latent_dim]
|
| 156 |
+
time_emb = self.time_embedding(time_emb, timestep_cond).unsqueeze(0)
|
| 157 |
+
|
| 158 |
+
# 3. condition + time embedding
|
| 159 |
+
# text_emb [seq_len, batch_size, text_dim] <= [batch_size, seq_len, text_dim]
|
| 160 |
+
encoder_hidden_states = encoder_hidden_states.permute(1, 0, 2)
|
| 161 |
+
# text embedding projection
|
| 162 |
+
text_emb_latent = self.emb_proj(encoder_hidden_states)
|
| 163 |
+
emb_latent = torch.cat((time_emb, text_emb_latent), 0)
|
| 164 |
+
|
| 165 |
+
# 4. transformer
|
| 166 |
+
if self.arch == "trans_enc":
|
| 167 |
+
xseq = torch.cat((sample, emb_latent), axis=0)
|
| 168 |
+
xseq = self.query_pos(xseq)
|
| 169 |
+
tokens, intermediates, router_logits = self.encoder(xseq, controlnet_residuals=controlnet_residuals)
|
| 170 |
+
elif self.arch == 'trans_dec':
|
| 171 |
+
sample = self.query_pos(sample)
|
| 172 |
+
if self.mem_pos:
|
| 173 |
+
emb_latent = self.mem_pos(emb_latent)
|
| 174 |
+
tokens, intermediates, router_logits = self.decoder(sample, emb_latent,
|
| 175 |
+
controlnet_residuals=controlnet_residuals)
|
| 176 |
+
else:
|
| 177 |
+
raise TypeError(f"{self.arch} is not supported")
|
| 178 |
+
|
| 179 |
+
router_loss = None
|
| 180 |
+
if router_logits is not None:
|
| 181 |
+
router_loss = load_balancing_loss_func(router_logits, self.moe_num_experts, self.moe_topk)
|
| 182 |
+
router_loss = self.moe_loss_weight * router_loss
|
| 183 |
+
|
| 184 |
+
if self.is_controlnet:
|
| 185 |
+
control_res_samples = []
|
| 186 |
+
for res, block in zip(intermediates, self.controlnet_down_mid_blocks):
|
| 187 |
+
r = block(res)
|
| 188 |
+
control_res_samples.append(r)
|
| 189 |
+
return control_res_samples, router_loss
|
| 190 |
+
elif self.arch == "trans_enc":
|
| 191 |
+
sample = tokens[:sample.shape[0]]
|
| 192 |
+
elif self.arch == 'trans_dec':
|
| 193 |
+
sample = tokens
|
| 194 |
+
else:
|
| 195 |
+
raise TypeError(f"{self.arch} is not supported")
|
| 196 |
+
|
| 197 |
+
# 5. dimension matching (post)
|
| 198 |
+
sample = self.latent_post(sample)
|
| 199 |
+
sample = sample.permute(1, 0, 2)
|
| 200 |
+
return sample, router_loss
|
mld/models/architectures/mld_traj_encoder.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from mld.models.operator.attention import SkipTransformerEncoder, TransformerEncoderLayer
|
| 7 |
+
from mld.models.operator.position_encoding import build_position_encoding
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MldTrajEncoder(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self,
|
| 13 |
+
nfeats: int,
|
| 14 |
+
latent_dim: list = [1, 256],
|
| 15 |
+
hidden_dim: Optional[int] = None,
|
| 16 |
+
force_post_proj: bool = False,
|
| 17 |
+
ff_size: int = 1024,
|
| 18 |
+
num_layers: int = 9,
|
| 19 |
+
num_heads: int = 4,
|
| 20 |
+
dropout: float = 0.1,
|
| 21 |
+
normalize_before: bool = False,
|
| 22 |
+
norm_eps: float = 1e-5,
|
| 23 |
+
activation: str = "gelu",
|
| 24 |
+
norm_post: bool = True,
|
| 25 |
+
activation_post: Optional[str] = None,
|
| 26 |
+
position_embedding: str = "learned") -> None:
|
| 27 |
+
super(MldTrajEncoder, self).__init__()
|
| 28 |
+
|
| 29 |
+
self.latent_size = latent_dim[0]
|
| 30 |
+
self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim
|
| 31 |
+
add_post_proj = force_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1])
|
| 32 |
+
self.latent_proj = nn.Linear(self.latent_dim, latent_dim[-1]) if add_post_proj else nn.Identity()
|
| 33 |
+
|
| 34 |
+
self.skel_embedding = nn.Linear(nfeats * 3, self.latent_dim)
|
| 35 |
+
|
| 36 |
+
self.query_pos_encoder = build_position_encoding(
|
| 37 |
+
self.latent_dim, position_embedding=position_embedding)
|
| 38 |
+
|
| 39 |
+
encoder_layer = TransformerEncoderLayer(
|
| 40 |
+
self.latent_dim,
|
| 41 |
+
num_heads,
|
| 42 |
+
ff_size,
|
| 43 |
+
dropout,
|
| 44 |
+
activation,
|
| 45 |
+
normalize_before,
|
| 46 |
+
norm_eps
|
| 47 |
+
)
|
| 48 |
+
encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None
|
| 49 |
+
self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post)
|
| 50 |
+
self.global_motion_token = nn.Parameter(torch.randn(self.latent_size, self.latent_dim))
|
| 51 |
+
|
| 52 |
+
def forward(self, features: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
| 53 |
+
bs, nframes, nfeats = features.shape
|
| 54 |
+
x = self.skel_embedding(features)
|
| 55 |
+
x = x.permute(1, 0, 2)
|
| 56 |
+
dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1))
|
| 57 |
+
dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device)
|
| 58 |
+
aug_mask = torch.cat((dist_masks, mask), 1)
|
| 59 |
+
xseq = torch.cat((dist, x), 0)
|
| 60 |
+
xseq = self.query_pos_encoder(xseq)
|
| 61 |
+
global_token = self.encoder(xseq, src_key_padding_mask=~aug_mask)[0][:dist.shape[0]]
|
| 62 |
+
global_token = self.latent_proj(global_token)
|
| 63 |
+
global_token = global_token.permute(1, 0, 2)
|
| 64 |
+
return global_token
|
mld/models/architectures/mld_vae.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.distributions.distribution import Distribution
|
| 6 |
+
|
| 7 |
+
from mld.models.operator.attention import (
|
| 8 |
+
SkipTransformerEncoder,
|
| 9 |
+
SkipTransformerDecoder,
|
| 10 |
+
TransformerDecoder,
|
| 11 |
+
TransformerDecoderLayer,
|
| 12 |
+
TransformerEncoder,
|
| 13 |
+
TransformerEncoderLayer
|
| 14 |
+
)
|
| 15 |
+
from mld.models.operator.position_encoding import build_position_encoding
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MldVae(nn.Module):
|
| 19 |
+
|
| 20 |
+
def __init__(self,
|
| 21 |
+
nfeats: int,
|
| 22 |
+
latent_dim: list = [1, 256],
|
| 23 |
+
hidden_dim: Optional[int] = None,
|
| 24 |
+
force_pre_post_proj: bool = False,
|
| 25 |
+
ff_size: int = 1024,
|
| 26 |
+
num_layers: int = 9,
|
| 27 |
+
num_heads: int = 4,
|
| 28 |
+
dropout: float = 0.1,
|
| 29 |
+
arch: str = "encoder_decoder",
|
| 30 |
+
normalize_before: bool = False,
|
| 31 |
+
norm_eps: float = 1e-5,
|
| 32 |
+
activation: str = "gelu",
|
| 33 |
+
norm_post: bool = True,
|
| 34 |
+
activation_post: Optional[str] = None,
|
| 35 |
+
position_embedding: str = "learned") -> None:
|
| 36 |
+
super(MldVae, self).__init__()
|
| 37 |
+
|
| 38 |
+
self.latent_size = latent_dim[0]
|
| 39 |
+
self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim
|
| 40 |
+
add_pre_post_proj = force_pre_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1])
|
| 41 |
+
self.latent_pre = nn.Linear(self.latent_dim, latent_dim[-1]) if add_pre_post_proj else nn.Identity()
|
| 42 |
+
self.latent_post = nn.Linear(latent_dim[-1], self.latent_dim) if add_pre_post_proj else nn.Identity()
|
| 43 |
+
|
| 44 |
+
self.arch = arch
|
| 45 |
+
|
| 46 |
+
self.query_pos_encoder = build_position_encoding(
|
| 47 |
+
self.latent_dim, position_embedding=position_embedding)
|
| 48 |
+
|
| 49 |
+
encoder_layer = TransformerEncoderLayer(
|
| 50 |
+
self.latent_dim,
|
| 51 |
+
num_heads,
|
| 52 |
+
ff_size,
|
| 53 |
+
dropout,
|
| 54 |
+
activation,
|
| 55 |
+
normalize_before,
|
| 56 |
+
norm_eps
|
| 57 |
+
)
|
| 58 |
+
encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None
|
| 59 |
+
self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post)
|
| 60 |
+
|
| 61 |
+
if self.arch == "all_encoder":
|
| 62 |
+
decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None
|
| 63 |
+
self.decoder = SkipTransformerEncoder(encoder_layer, num_layers, decoder_norm, activation_post)
|
| 64 |
+
elif self.arch == 'encoder_decoder':
|
| 65 |
+
self.query_pos_decoder = build_position_encoding(
|
| 66 |
+
self.latent_dim, position_embedding=position_embedding)
|
| 67 |
+
|
| 68 |
+
decoder_layer = TransformerDecoderLayer(
|
| 69 |
+
self.latent_dim,
|
| 70 |
+
num_heads,
|
| 71 |
+
ff_size,
|
| 72 |
+
dropout,
|
| 73 |
+
activation,
|
| 74 |
+
normalize_before,
|
| 75 |
+
norm_eps
|
| 76 |
+
)
|
| 77 |
+
decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None
|
| 78 |
+
self.decoder = SkipTransformerDecoder(decoder_layer, num_layers, decoder_norm, activation_post)
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"Not support architecture: {self.arch}!")
|
| 81 |
+
|
| 82 |
+
self.global_motion_token = nn.Parameter(torch.randn(self.latent_size * 2, self.latent_dim))
|
| 83 |
+
self.skel_embedding = nn.Linear(nfeats, self.latent_dim)
|
| 84 |
+
self.final_layer = nn.Linear(self.latent_dim, nfeats)
|
| 85 |
+
|
| 86 |
+
def forward(self, features: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, Distribution]:
|
| 87 |
+
z, dist = self.encode(features, mask)
|
| 88 |
+
feats_rst = self.decode(z, mask)
|
| 89 |
+
return feats_rst, z, dist
|
| 90 |
+
|
| 91 |
+
def encode(self, features: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, Distribution]:
|
| 92 |
+
bs, nframes, nfeats = features.shape
|
| 93 |
+
x = self.skel_embedding(features)
|
| 94 |
+
x = x.permute(1, 0, 2)
|
| 95 |
+
dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1))
|
| 96 |
+
dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device)
|
| 97 |
+
aug_mask = torch.cat((dist_masks, mask), 1)
|
| 98 |
+
xseq = torch.cat((dist, x), 0)
|
| 99 |
+
|
| 100 |
+
xseq = self.query_pos_encoder(xseq)
|
| 101 |
+
dist = self.encoder(xseq, src_key_padding_mask=~aug_mask)[0][:dist.shape[0]]
|
| 102 |
+
dist = self.latent_pre(dist)
|
| 103 |
+
|
| 104 |
+
mu = dist[0:self.latent_size, ...]
|
| 105 |
+
logvar = dist[self.latent_size:, ...]
|
| 106 |
+
|
| 107 |
+
std = logvar.exp().pow(0.5)
|
| 108 |
+
dist = torch.distributions.Normal(mu, std)
|
| 109 |
+
latent = dist.rsample()
|
| 110 |
+
# [latent_dim[0], batch_size, latent_dim] -> [batch_size, latent_dim[0], latent_dim[1]]
|
| 111 |
+
latent = latent.permute(1, 0, 2)
|
| 112 |
+
return latent, dist
|
| 113 |
+
|
| 114 |
+
def decode(self, z: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
# [batch_size, latent_dim[0], latent_dim[1]] -> [latent_dim[0], batch_size, latent_dim[1]]
|
| 116 |
+
z = self.latent_post(z)
|
| 117 |
+
z = z.permute(1, 0, 2)
|
| 118 |
+
bs, nframes = mask.shape
|
| 119 |
+
queries = torch.zeros(nframes, bs, self.latent_dim, device=z.device)
|
| 120 |
+
|
| 121 |
+
if self.arch == "all_encoder":
|
| 122 |
+
xseq = torch.cat((z, queries), axis=0)
|
| 123 |
+
z_mask = torch.ones((bs, self.latent_size), dtype=torch.bool, device=z.device)
|
| 124 |
+
aug_mask = torch.cat((z_mask, mask), axis=1)
|
| 125 |
+
xseq = self.query_pos_decoder(xseq)
|
| 126 |
+
output = self.decoder(xseq, src_key_padding_mask=~aug_mask)[0][z.shape[0]:]
|
| 127 |
+
elif self.arch == "encoder_decoder":
|
| 128 |
+
queries = self.query_pos_decoder(queries)
|
| 129 |
+
output = self.decoder(tgt=queries, memory=z, tgt_key_padding_mask=~mask)[0]
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError(f"Not support architecture: {self.arch}!")
|
| 132 |
+
|
| 133 |
+
output = self.final_layer(output)
|
| 134 |
+
output[~mask.T] = 0
|
| 135 |
+
feats = output.permute(1, 0, 2)
|
| 136 |
+
return feats
|