|
|
from copy import deepcopy |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from hydra.utils import instantiate |
|
|
|
|
|
from genmo.diffusion_utils.model_util import create_gaussian_diffusion |
|
|
from genmo.diffusion_utils.resample import create_named_schedule_sampler |
|
|
from genmo.utils.net_utils import length_to_mask |
|
|
|
|
|
from .genmo_cfg_sampler import ClassifierFreeSampleModel |
|
|
|
|
|
|
|
|
class GENMODiffusion(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
model_cfg, |
|
|
max_len=120, |
|
|
|
|
|
cliffcam_dim=3, |
|
|
cam_angvel_dim=6, |
|
|
cam_t_vel_dim=3, |
|
|
imgseq_dim=1024, |
|
|
observed_motion_3d_dim=151, |
|
|
encoded_music_dim=438, |
|
|
encoded_audio_dim=128, |
|
|
latent_dim=512, |
|
|
dropout=0.1, |
|
|
args=None, |
|
|
cond_merge_strategy="add", |
|
|
cond_exists_dim=512, |
|
|
music_mask_prob=0.1, |
|
|
img_process_modules=None, |
|
|
img_process_modules_enable_grad={}, |
|
|
multi_text_module_cfg={}, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
self.model_cfg = model_cfg |
|
|
self.args = args |
|
|
self.max_len = max_len |
|
|
|
|
|
self.regression_input_type = self.args.get("regression_input_type", "zero") |
|
|
|
|
|
self.denoiser = instantiate(self.model_cfg.denoiser) |
|
|
self.init_diffusion() |
|
|
self.text_encoder, self.tokenizer = None, None |
|
|
|
|
|
def init_diffusion(self): |
|
|
self.train_diffusion = create_gaussian_diffusion( |
|
|
self.model_cfg.diffusion, training=True |
|
|
) |
|
|
self.test_diffusion = create_gaussian_diffusion( |
|
|
self.model_cfg.diffusion, training=False |
|
|
) |
|
|
gen_only_diffusion = deepcopy(self.model_cfg.diffusion) |
|
|
gen_only_diffusion.test_timestep_respacing = self.model_cfg.diffusion.get( |
|
|
"gen_only_test_timestep_respacing", "50" |
|
|
) |
|
|
print( |
|
|
f"Gen only test timestep respacing: {gen_only_diffusion.test_timestep_respacing}" |
|
|
) |
|
|
self.test_gen_only_diffusion = create_gaussian_diffusion( |
|
|
gen_only_diffusion, training=False |
|
|
) |
|
|
self.schedule_sampler = create_named_schedule_sampler( |
|
|
self.model_cfg.diffusion.schedule_sampler_type, self.train_diffusion |
|
|
) |
|
|
return |
|
|
|
|
|
def forward_train(self, inputs, mode): |
|
|
assert self.training, "forward_train should only be called during training" |
|
|
diffusion = self.train_diffusion if self.training else self.test_diffusion |
|
|
length = inputs["length"] |
|
|
|
|
|
motion = inputs["motion"] |
|
|
f_cond = inputs["f_cond"] |
|
|
B, L, _ = motion.shape |
|
|
|
|
|
vis_mask = length_to_mask(length, L) |
|
|
valid_mask = inputs["mask"]["valid"] |
|
|
assert (vis_mask == valid_mask).all() |
|
|
|
|
|
denoiser_kwargs = { |
|
|
"y": { |
|
|
"text": inputs.get("caption", [""] * B), |
|
|
"f_cond": f_cond, |
|
|
"mask": vis_mask, |
|
|
"length": length, |
|
|
}, |
|
|
"inputs": inputs, |
|
|
} |
|
|
if "encoded_text" in inputs: |
|
|
denoiser_kwargs["y"]["encoded_text"] = inputs["encoded_text"] |
|
|
if "observed_motion_3d" in inputs: |
|
|
denoiser_kwargs["observed_motion_3d"] = inputs["observed_motion_3d"] |
|
|
denoiser_kwargs["motion_mask_3d"] = inputs["motion_mask_3d"] |
|
|
denoiser_kwargs["rm_text_flag"] = inputs["rm_text_flag"] |
|
|
|
|
|
if mode == "regression": |
|
|
t = ( |
|
|
(torch.ones(B) * (diffusion.original_num_steps - 1)) |
|
|
.long() |
|
|
.to(motion.device) |
|
|
) |
|
|
t_weights = torch.ones(B).to(motion.device) |
|
|
x_start = motion |
|
|
if self.regression_input_type == "zero": |
|
|
x_t = torch.zeros_like(motion) |
|
|
elif self.regression_input_type == "normal": |
|
|
x_t = torch.randn_like(motion) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unsupported regression_input_type: {self.regression_input_type}" |
|
|
) |
|
|
elif mode == "diffusion": |
|
|
t, t_weights = self.schedule_sampler.sample(motion.shape[0], motion.device) |
|
|
if "regression_outputs" in inputs: |
|
|
pred_x_start_regression = inputs["regression_outputs"]["model_output"][ |
|
|
"pred_x_start" |
|
|
].detach() |
|
|
else: |
|
|
raise ValueError("No regression outputs found") |
|
|
|
|
|
x_start_reg = pred_x_start_regression.clone() |
|
|
x_start = motion.clone() |
|
|
x_start[inputs["mask"]["2d_only"]] = x_start_reg[inputs["mask"]["2d_only"]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noise = torch.randn_like(x_start) |
|
|
x_t = self.train_diffusion.q_sample(x_start.clone(), t, noise=noise) |
|
|
|
|
|
denoise_out = self.denoiser( |
|
|
x_t, diffusion._scale_timesteps(t), return_aux=False, **denoiser_kwargs |
|
|
) |
|
|
|
|
|
output = { |
|
|
"target_x_start": x_start, |
|
|
"t_weights": t_weights, |
|
|
} |
|
|
output.update(denoise_out) |
|
|
for x in self.args.out_attr: |
|
|
assert x in output, f"Output {x} not found in denoise_out" |
|
|
|
|
|
return output |
|
|
|
|
|
def forward_test(self, inputs, progress=False): |
|
|
assert not self.training, "forward_test should only be called during inference" |
|
|
diffusion = self.test_gen_only_diffusion |
|
|
|
|
|
denoiser = self.denoiser |
|
|
length = inputs["length"] |
|
|
B, L = inputs["B"], inputs["L"] |
|
|
|
|
|
motion = inputs["motion"] |
|
|
f_cond, f_uncond = inputs["f_cond"], inputs["f_uncond"] |
|
|
|
|
|
vis_mask = length_to_mask(length, L) |
|
|
|
|
|
denoiser_kwargs = { |
|
|
"y": { |
|
|
"text": inputs.get("caption", [""] * B), |
|
|
"f_cond": f_cond, |
|
|
"f_uncond": f_uncond, |
|
|
"mask": vis_mask, |
|
|
"length": length, |
|
|
}, |
|
|
"inputs": inputs, |
|
|
} |
|
|
if "encoded_text" in inputs: |
|
|
denoiser_kwargs["y"]["encoded_text"] = inputs["encoded_text"] |
|
|
if "meta" in inputs and "multi_text_data" in inputs["meta"][0]: |
|
|
denoiser_kwargs["y"]["multi_text_data"] = inputs["meta"][0][ |
|
|
"multi_text_data" |
|
|
] |
|
|
if "observed_motion_3d" in inputs: |
|
|
denoiser_kwargs["observed_motion_3d"] = inputs["observed_motion_3d"] |
|
|
denoiser_kwargs["motion_mask_3d"] = inputs["motion_mask_3d"] |
|
|
denoiser_kwargs["rm_text_flag"] = inputs.get("rm_text_flag", None) |
|
|
|
|
|
if self.args.get("use_cfg_sampler_for_gen", False): |
|
|
denoiser = ClassifierFreeSampleModel(denoiser) |
|
|
denoiser_kwargs["y"]["scale"] = self.model_cfg.diffusion.guidance_param |
|
|
diff_sampler = self.model_cfg.diffusion.get("sampler", "ddim") |
|
|
if diff_sampler == "ddim": |
|
|
sample_fn = diffusion.ddim_sample_loop_with_aux |
|
|
kwargs = {"eta": self.model_cfg.diffusion.ddim_eta} |
|
|
else: |
|
|
raise NotImplementedError(f"Sampler {diff_sampler} not implemented") |
|
|
|
|
|
if self.args.get("force_zero_noise", False): |
|
|
noise = torch.zeros_like(motion) |
|
|
elif self.args.get("force_rand_noise", False): |
|
|
noise = torch.randn_like(motion) |
|
|
else: |
|
|
noise = torch.randn_like(motion) |
|
|
|
|
|
if self.args.get("return_mid", False): |
|
|
kwargs["return_mid"] = True |
|
|
|
|
|
denoise_out = sample_fn( |
|
|
denoiser, |
|
|
motion.shape, |
|
|
clip_denoised=False, |
|
|
model_kwargs=denoiser_kwargs, |
|
|
skip_timesteps=0, |
|
|
init_image=None, |
|
|
progress=progress, |
|
|
dump_steps=None, |
|
|
noise=noise, |
|
|
const_noise=False, |
|
|
**kwargs, |
|
|
) |
|
|
output = denoise_out.copy() |
|
|
|
|
|
for x in self.args.out_attr: |
|
|
assert x in output, f"Output {x} not found in denoise_out" |
|
|
return output |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
inputs, |
|
|
train=False, |
|
|
postproc=False, |
|
|
static_cam=False, |
|
|
mode=None, |
|
|
test_mode=None, |
|
|
normalizer_stats=None, |
|
|
): |
|
|
if train: |
|
|
return self.forward_train(inputs, mode=mode) |
|
|
else: |
|
|
return self.forward_test(inputs) |
|
|
|