| from copy import deepcopy |
|
|
| import torch |
| import torch.nn as nn |
| from hydra.utils import instantiate |
|
|
| from genmo.diffusion_utils.model_util import create_gaussian_diffusion |
| from genmo.diffusion_utils.resample import create_named_schedule_sampler |
| from genmo.utils.net_utils import length_to_mask |
|
|
| from .genmo_cfg_sampler import ClassifierFreeSampleModel |
|
|
|
|
| class GENMODiffusion(nn.Module): |
| def __init__( |
| self, |
| model_cfg, |
| max_len=120, |
| |
| cliffcam_dim=3, |
| cam_angvel_dim=6, |
| cam_t_vel_dim=3, |
| imgseq_dim=1024, |
| observed_motion_3d_dim=151, |
| encoded_music_dim=438, |
| encoded_audio_dim=128, |
| latent_dim=512, |
| dropout=0.1, |
| args=None, |
| cond_merge_strategy="add", |
| cond_exists_dim=512, |
| music_mask_prob=0.1, |
| img_process_modules=None, |
| img_process_modules_enable_grad={}, |
| multi_text_module_cfg={}, |
| **kwargs, |
| ): |
| super().__init__() |
| self.model_cfg = model_cfg |
| self.args = args |
| self.max_len = max_len |
|
|
| self.regression_input_type = self.args.get("regression_input_type", "zero") |
|
|
| self.denoiser = instantiate(self.model_cfg.denoiser) |
| self.init_diffusion() |
| self.text_encoder, self.tokenizer = None, None |
|
|
| def init_diffusion(self): |
| self.train_diffusion = create_gaussian_diffusion( |
| self.model_cfg.diffusion, training=True |
| ) |
| self.test_diffusion = create_gaussian_diffusion( |
| self.model_cfg.diffusion, training=False |
| ) |
| gen_only_diffusion = deepcopy(self.model_cfg.diffusion) |
| gen_only_diffusion.test_timestep_respacing = self.model_cfg.diffusion.get( |
| "gen_only_test_timestep_respacing", "50" |
| ) |
| print( |
| f"Gen only test timestep respacing: {gen_only_diffusion.test_timestep_respacing}" |
| ) |
| self.test_gen_only_diffusion = create_gaussian_diffusion( |
| gen_only_diffusion, training=False |
| ) |
| self.schedule_sampler = create_named_schedule_sampler( |
| self.model_cfg.diffusion.schedule_sampler_type, self.train_diffusion |
| ) |
| return |
|
|
| def forward_train(self, inputs, mode): |
| assert self.training, "forward_train should only be called during training" |
| diffusion = self.train_diffusion if self.training else self.test_diffusion |
| length = inputs["length"] |
| |
| motion = inputs["motion"] |
| f_cond = inputs["f_cond"] |
| B, L, _ = motion.shape |
|
|
| vis_mask = length_to_mask(length, L) |
| valid_mask = inputs["mask"]["valid"] |
| assert (vis_mask == valid_mask).all() |
|
|
| denoiser_kwargs = { |
| "y": { |
| "text": inputs.get("caption", [""] * B), |
| "f_cond": f_cond, |
| "mask": vis_mask, |
| "length": length, |
| }, |
| "inputs": inputs, |
| } |
| if "encoded_text" in inputs: |
| denoiser_kwargs["y"]["encoded_text"] = inputs["encoded_text"] |
| if "observed_motion_3d" in inputs: |
| denoiser_kwargs["observed_motion_3d"] = inputs["observed_motion_3d"] |
| denoiser_kwargs["motion_mask_3d"] = inputs["motion_mask_3d"] |
| denoiser_kwargs["rm_text_flag"] = inputs["rm_text_flag"] |
|
|
| if mode == "regression": |
| t = ( |
| (torch.ones(B) * (diffusion.original_num_steps - 1)) |
| .long() |
| .to(motion.device) |
| ) |
| t_weights = torch.ones(B).to(motion.device) |
| x_start = motion |
| if self.regression_input_type == "zero": |
| x_t = torch.zeros_like(motion) |
| elif self.regression_input_type == "normal": |
| x_t = torch.randn_like(motion) |
| else: |
| raise ValueError( |
| f"Unsupported regression_input_type: {self.regression_input_type}" |
| ) |
| elif mode == "diffusion": |
| t, t_weights = self.schedule_sampler.sample(motion.shape[0], motion.device) |
| if "regression_outputs" in inputs: |
| pred_x_start_regression = inputs["regression_outputs"]["model_output"][ |
| "pred_x_start" |
| ].detach() |
| else: |
| raise ValueError("No regression outputs found") |
| |
| x_start_reg = pred_x_start_regression.clone() |
| x_start = motion.clone() |
| x_start[inputs["mask"]["2d_only"]] = x_start_reg[inputs["mask"]["2d_only"]] |
| |
| |
| |
| |
| |
| |
| |
| |
| noise = torch.randn_like(x_start) |
| x_t = self.train_diffusion.q_sample(x_start.clone(), t, noise=noise) |
|
|
| denoise_out = self.denoiser( |
| x_t, diffusion._scale_timesteps(t), return_aux=False, **denoiser_kwargs |
| ) |
|
|
| output = { |
| "target_x_start": x_start, |
| "t_weights": t_weights, |
| } |
| output.update(denoise_out) |
| for x in self.args.out_attr: |
| assert x in output, f"Output {x} not found in denoise_out" |
|
|
| return output |
|
|
| def forward_test(self, inputs, progress=False): |
| assert not self.training, "forward_test should only be called during inference" |
| diffusion = self.test_gen_only_diffusion |
|
|
| denoiser = self.denoiser |
| length = inputs["length"] |
| B, L = inputs["B"], inputs["L"] |
|
|
| motion = inputs["motion"] |
| f_cond, f_uncond = inputs["f_cond"], inputs["f_uncond"] |
|
|
| vis_mask = length_to_mask(length, L) |
|
|
| denoiser_kwargs = { |
| "y": { |
| "text": inputs.get("caption", [""] * B), |
| "f_cond": f_cond, |
| "f_uncond": f_uncond, |
| "mask": vis_mask, |
| "length": length, |
| }, |
| "inputs": inputs, |
| } |
| if "encoded_text" in inputs: |
| denoiser_kwargs["y"]["encoded_text"] = inputs["encoded_text"] |
| if "meta" in inputs and "multi_text_data" in inputs["meta"][0]: |
| denoiser_kwargs["y"]["multi_text_data"] = inputs["meta"][0][ |
| "multi_text_data" |
| ] |
| if "observed_motion_3d" in inputs: |
| denoiser_kwargs["observed_motion_3d"] = inputs["observed_motion_3d"] |
| denoiser_kwargs["motion_mask_3d"] = inputs["motion_mask_3d"] |
| denoiser_kwargs["rm_text_flag"] = inputs.get("rm_text_flag", None) |
|
|
| if self.args.get("use_cfg_sampler_for_gen", False): |
| denoiser = ClassifierFreeSampleModel(denoiser) |
| denoiser_kwargs["y"]["scale"] = self.model_cfg.diffusion.guidance_param |
| diff_sampler = self.model_cfg.diffusion.get("sampler", "ddim") |
| if diff_sampler == "ddim": |
| sample_fn = diffusion.ddim_sample_loop_with_aux |
| kwargs = {"eta": self.model_cfg.diffusion.ddim_eta} |
| else: |
| raise NotImplementedError(f"Sampler {diff_sampler} not implemented") |
|
|
| if self.args.get("force_zero_noise", False): |
| noise = torch.zeros_like(motion) |
| elif self.args.get("force_rand_noise", False): |
| noise = torch.randn_like(motion) |
| else: |
| noise = torch.randn_like(motion) |
|
|
| if self.args.get("return_mid", False): |
| kwargs["return_mid"] = True |
|
|
| denoise_out = sample_fn( |
| denoiser, |
| motion.shape, |
| clip_denoised=False, |
| model_kwargs=denoiser_kwargs, |
| skip_timesteps=0, |
| init_image=None, |
| progress=progress, |
| dump_steps=None, |
| noise=noise, |
| const_noise=False, |
| **kwargs, |
| ) |
| output = denoise_out.copy() |
|
|
| for x in self.args.out_attr: |
| assert x in output, f"Output {x} not found in denoise_out" |
| return output |
|
|
| def forward( |
| self, |
| inputs, |
| train=False, |
| postproc=False, |
| static_cam=False, |
| mode=None, |
| test_mode=None, |
| normalizer_stats=None, |
| ): |
| if train: |
| return self.forward_train(inputs, mode=mode) |
| else: |
| return self.forward_test(inputs) |
|
|