hmr-dataset / genmo /network /genmo_diffusion.py
zirobtc's picture
Upload folder using huggingface_hub
fbb20ff verified
from copy import deepcopy
import torch
import torch.nn as nn
from hydra.utils import instantiate
from genmo.diffusion_utils.model_util import create_gaussian_diffusion
from genmo.diffusion_utils.resample import create_named_schedule_sampler
from genmo.utils.net_utils import length_to_mask
from .genmo_cfg_sampler import ClassifierFreeSampleModel
class GENMODiffusion(nn.Module):
def __init__(
self,
model_cfg,
max_len=120,
# condition
cliffcam_dim=3,
cam_angvel_dim=6,
cam_t_vel_dim=3,
imgseq_dim=1024,
observed_motion_3d_dim=151,
encoded_music_dim=438,
encoded_audio_dim=128,
latent_dim=512,
dropout=0.1,
args=None,
cond_merge_strategy="add",
cond_exists_dim=512,
music_mask_prob=0.1,
img_process_modules=None,
img_process_modules_enable_grad={},
multi_text_module_cfg={},
**kwargs,
):
super().__init__()
self.model_cfg = model_cfg
self.args = args
self.max_len = max_len
self.regression_input_type = self.args.get("regression_input_type", "zero")
self.denoiser = instantiate(self.model_cfg.denoiser)
self.init_diffusion()
self.text_encoder, self.tokenizer = None, None
def init_diffusion(self):
self.train_diffusion = create_gaussian_diffusion(
self.model_cfg.diffusion, training=True
)
self.test_diffusion = create_gaussian_diffusion(
self.model_cfg.diffusion, training=False
)
gen_only_diffusion = deepcopy(self.model_cfg.diffusion)
gen_only_diffusion.test_timestep_respacing = self.model_cfg.diffusion.get(
"gen_only_test_timestep_respacing", "50"
)
print(
f"Gen only test timestep respacing: {gen_only_diffusion.test_timestep_respacing}"
)
self.test_gen_only_diffusion = create_gaussian_diffusion(
gen_only_diffusion, training=False
)
self.schedule_sampler = create_named_schedule_sampler(
self.model_cfg.diffusion.schedule_sampler_type, self.train_diffusion
)
return
def forward_train(self, inputs, mode):
assert self.training, "forward_train should only be called during training"
diffusion = self.train_diffusion if self.training else self.test_diffusion
length = inputs["length"]
# target_x = inputs["target_x"]
motion = inputs["motion"]
f_cond = inputs["f_cond"]
B, L, _ = motion.shape
vis_mask = length_to_mask(length, L) # (B, L)
valid_mask = inputs["mask"]["valid"]
assert (vis_mask == valid_mask).all()
denoiser_kwargs = {
"y": {
"text": inputs.get("caption", [""] * B),
"f_cond": f_cond,
"mask": vis_mask,
"length": length,
},
"inputs": inputs,
}
if "encoded_text" in inputs:
denoiser_kwargs["y"]["encoded_text"] = inputs["encoded_text"]
if "observed_motion_3d" in inputs:
denoiser_kwargs["observed_motion_3d"] = inputs["observed_motion_3d"]
denoiser_kwargs["motion_mask_3d"] = inputs["motion_mask_3d"]
denoiser_kwargs["rm_text_flag"] = inputs["rm_text_flag"]
if mode == "regression":
t = (
(torch.ones(B) * (diffusion.original_num_steps - 1))
.long()
.to(motion.device)
)
t_weights = torch.ones(B).to(motion.device)
x_start = motion
if self.regression_input_type == "zero":
x_t = torch.zeros_like(motion)
elif self.regression_input_type == "normal":
x_t = torch.randn_like(motion)
else:
raise ValueError(
f"Unsupported regression_input_type: {self.regression_input_type}"
)
elif mode == "diffusion":
t, t_weights = self.schedule_sampler.sample(motion.shape[0], motion.device)
if "regression_outputs" in inputs:
pred_x_start_regression = inputs["regression_outputs"]["model_output"][
"pred_x_start"
].detach()
else:
raise ValueError("No regression outputs found")
# pred_x_start_regression = torch.zeros_like(motion)
x_start_reg = pred_x_start_regression.clone()
x_start = motion.clone()
x_start[inputs["mask"]["2d_only"]] = x_start_reg[inputs["mask"]["2d_only"]]
# regression_mask = (
# torch.rand(B).to(motion.device) < self.args.use_regression_outputs_prob
# ).float()
# if "gen_only" in inputs and self.args.get("use_gt_for_gen_only", True):
# regression_mask[inputs["gen_only"]] = 0
# x_start = x_start_reg * regression_mask[:, None, None] + x_start_gt * (
# 1 - regression_mask[:, None, None]
# )
noise = torch.randn_like(x_start)
x_t = self.train_diffusion.q_sample(x_start.clone(), t, noise=noise)
denoise_out = self.denoiser(
x_t, diffusion._scale_timesteps(t), return_aux=False, **denoiser_kwargs
)
output = {
"target_x_start": x_start,
"t_weights": t_weights,
}
output.update(denoise_out)
for x in self.args.out_attr:
assert x in output, f"Output {x} not found in denoise_out"
return output
def forward_test(self, inputs, progress=False):
assert not self.training, "forward_test should only be called during inference"
diffusion = self.test_gen_only_diffusion
denoiser = self.denoiser
length = inputs["length"]
B, L = inputs["B"], inputs["L"]
motion = inputs["motion"]
f_cond, f_uncond = inputs["f_cond"], inputs["f_uncond"]
vis_mask = length_to_mask(length, L) # (B, L)
denoiser_kwargs = {
"y": {
"text": inputs.get("caption", [""] * B),
"f_cond": f_cond,
"f_uncond": f_uncond,
"mask": vis_mask,
"length": length,
},
"inputs": inputs,
}
if "encoded_text" in inputs:
denoiser_kwargs["y"]["encoded_text"] = inputs["encoded_text"]
if "meta" in inputs and "multi_text_data" in inputs["meta"][0]:
denoiser_kwargs["y"]["multi_text_data"] = inputs["meta"][0][
"multi_text_data"
]
if "observed_motion_3d" in inputs:
denoiser_kwargs["observed_motion_3d"] = inputs["observed_motion_3d"]
denoiser_kwargs["motion_mask_3d"] = inputs["motion_mask_3d"]
denoiser_kwargs["rm_text_flag"] = inputs.get("rm_text_flag", None)
if self.args.get("use_cfg_sampler_for_gen", False):
denoiser = ClassifierFreeSampleModel(denoiser)
denoiser_kwargs["y"]["scale"] = self.model_cfg.diffusion.guidance_param
diff_sampler = self.model_cfg.diffusion.get("sampler", "ddim")
if diff_sampler == "ddim":
sample_fn = diffusion.ddim_sample_loop_with_aux
kwargs = {"eta": self.model_cfg.diffusion.ddim_eta}
else:
raise NotImplementedError(f"Sampler {diff_sampler} not implemented")
if self.args.get("force_zero_noise", False):
noise = torch.zeros_like(motion)
elif self.args.get("force_rand_noise", False):
noise = torch.randn_like(motion)
else:
noise = torch.randn_like(motion)
if self.args.get("return_mid", False):
kwargs["return_mid"] = True
denoise_out = sample_fn(
denoiser,
motion.shape,
clip_denoised=False,
model_kwargs=denoiser_kwargs,
skip_timesteps=0, # 0 is the default value - i.e. don't skip any step
init_image=None,
progress=progress,
dump_steps=None,
noise=noise,
const_noise=False,
**kwargs,
)
output = denoise_out.copy()
for x in self.args.out_attr:
assert x in output, f"Output {x} not found in denoise_out"
return output
def forward(
self,
inputs,
train=False,
postproc=False,
static_cam=False,
mode=None,
test_mode=None,
normalizer_stats=None,
):
if train:
return self.forward_train(inputs, mode=mode)
else:
return self.forward_test(inputs)