Spaces:
Runtime error
Runtime error
| import time | |
| import inspect | |
| import logging | |
| from typing import Optional | |
| import numpy as np | |
| from omegaconf import DictConfig | |
| import torch | |
| import torch.nn.functional as F | |
| from models.config import instantiate_from_config | |
| from models.utils.utils import count_parameters, extract_into_tensor, sum_flat | |
| logger = logging.getLogger(__name__) | |
| class GestureDiffusion(torch.nn.Module): | |
| def __init__(self, cfg) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| self.modality_encoder = instantiate_from_config(cfg.model.modality_encoder) | |
| self.denoiser = instantiate_from_config(cfg.model.denoiser) | |
| self.scheduler = instantiate_from_config(cfg.model.scheduler) | |
| self.alphas = torch.sqrt(self.scheduler.alphas_cumprod) | |
| self.sigmas = torch.sqrt(1 - self.scheduler.alphas_cumprod) | |
| self.do_classifier_free_guidance = cfg.model.do_classifier_free_guidance | |
| self.guidance_scale = cfg.model.guidance_scale | |
| self.smooth_l1_loss = torch.nn.SmoothL1Loss(reduction='none') | |
| self.seq_len = self.denoiser.seq_len | |
| self.input_dim = self.denoiser.input_dim | |
| self.num_joints = self.denoiser.joint_num | |
| def summarize_parameters(self) -> None: | |
| logger.info(f'Denoiser: {count_parameters(self.denoiser)}M') | |
| logger.info(f'Scheduler: {count_parameters(self.modality_encoder)}M') | |
| def apply_classifier_free_guidance(self, x, timesteps, seed, at_feat, guidance_scale=1.0): | |
| """ | |
| Apply classifier-free guidance by running both conditional and unconditional predictions. | |
| Args: | |
| x: Input tensor | |
| timesteps: Timestep tensor | |
| seed: Seed vectors | |
| at_feat: Audio features | |
| guidance_scale: Guidance scale (1.0 means no guidance) | |
| Returns: | |
| Guided output tensor | |
| """ | |
| if guidance_scale <= 1.0: | |
| # No guidance needed, run normal forward pass | |
| return self.denoiser( | |
| x=x, | |
| timesteps=timesteps, | |
| seed=seed, | |
| at_feat=at_feat, | |
| cond_drop_prob=0.0, | |
| null_cond=False | |
| ) | |
| # Double the batch for classifier free guidance | |
| x_doubled = torch.cat([x] * 2, dim=0) | |
| seed_doubled = torch.cat([seed] * 2, dim=0) | |
| at_feat_doubled = torch.cat([at_feat] * 2, dim=0) | |
| # Properly expand timesteps to match doubled batch size | |
| batch_size = x.shape[0] | |
| timesteps_doubled = timesteps.expand(batch_size * 2) | |
| # Create conditional and unconditional audio features | |
| batch_size = at_feat.shape[0] | |
| null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) | |
| at_feat_uncond = null_cond_embed.unsqueeze(0).expand(batch_size, -1, -1) | |
| at_feat_combined = torch.cat([at_feat, at_feat_uncond], dim=0) | |
| # Run both conditional and unconditional predictions | |
| output = self.denoiser( | |
| x=x_doubled, | |
| timesteps=timesteps_doubled, | |
| seed=seed_doubled, | |
| at_feat=at_feat_combined, | |
| ) | |
| # Split predictions and apply guidance | |
| pred_cond, pred_uncond = output.chunk(2, dim=0) | |
| guided_output = pred_uncond + guidance_scale * (pred_cond - pred_uncond) | |
| return guided_output | |
| def apply_conditional_dropout(self, at_feat, cond_drop_prob=0.1): | |
| """ | |
| Apply conditional dropout during training to simulate classifier-free guidance. | |
| Args: | |
| at_feat: Audio features tensor | |
| cond_drop_prob: Probability of dropping conditions (default 0.1) | |
| Returns: | |
| Modified audio features with some conditions replaced by null embeddings | |
| """ | |
| batch_size = at_feat.shape[0] | |
| # Create dropout mask | |
| keep_mask = torch.rand(batch_size, device=at_feat.device) > cond_drop_prob | |
| # Create null condition embeddings | |
| null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) | |
| # Apply dropout: replace dropped conditions with null embeddings | |
| at_feat_dropped = at_feat.clone() | |
| at_feat_dropped[~keep_mask] = null_cond_embed.unsqueeze(0).expand((~keep_mask).sum(), -1, -1) | |
| return at_feat_dropped | |
| def predicted_origin(self, model_output: torch.Tensor, timesteps: torch.Tensor, sample: torch.Tensor) -> tuple: | |
| self.alphas = self.alphas.to(model_output.device) | |
| self.sigmas = self.sigmas.to(model_output.device) | |
| alphas = extract_into_tensor(self.alphas, timesteps, sample.shape) | |
| sigmas = extract_into_tensor(self.sigmas, timesteps, sample.shape) | |
| # i will do this | |
| if self.scheduler.config.prediction_type == "epsilon": | |
| pred_original_sample = (sample - sigmas * model_output) / alphas | |
| pred_epsilon = model_output | |
| elif self.scheduler.config.prediction_type == "sample": | |
| pred_original_sample = model_output | |
| pred_epsilon = (sample - alphas * model_output) / sigmas | |
| elif self.scheduler.config.prediction_type == "v_prediction": | |
| pred_original_sample = alphas * sample - sigmas * model_output | |
| pred_epsilon = alphas * model_output + sigmas * sample | |
| else: | |
| raise ValueError(f"Invalid prediction_type {self.scheduler.config.prediction_type}.") | |
| return pred_original_sample, pred_epsilon | |
| def forward(self, cond_: dict) -> dict: | |
| audio = cond_['y']['audio_onset'] | |
| word = cond_['y']['word'] | |
| id = cond_['y']['id'] | |
| seed = cond_['y']['seed'] | |
| style_feature = cond_['y']['style_feature'] | |
| audio_feat = self.modality_encoder(audio, word) | |
| bs = audio_feat.shape[0] | |
| shape_ = (bs, self.input_dim * self.num_joints, 1, self.seq_len) | |
| latents = torch.randn(shape_, device=audio_feat.device) | |
| latents = self._diffusion_reverse(latents, seed, audio_feat, guidance_scale=self.guidance_scale) | |
| return latents | |
| def _diffusion_reverse( | |
| self, | |
| latents: torch.Tensor, | |
| seed: torch.Tensor, | |
| at_feat: torch.Tensor, | |
| guidance_scale: float = 1, | |
| ) -> torch.Tensor: | |
| return_dict = {} | |
| # scale the initial noise by the standard deviation required by the scheduler, like in Stable Diffusion | |
| # this is the initial noise need to be returned for rectified training | |
| latents = latents * self.scheduler.init_noise_sigma | |
| noise = latents | |
| return_dict["init_noise"] = latents | |
| return_dict['at_feat'] = at_feat | |
| return_dict['seed'] = seed | |
| # set timesteps | |
| self.scheduler.set_timesteps(self.cfg.model.scheduler.num_inference_steps) | |
| timesteps = self.scheduler.timesteps.to(at_feat.device) | |
| latents = torch.zeros_like(latents) | |
| latents = self.scheduler.add_noise(latents, noise, timesteps[0]) | |
| # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
| # eta (Ξ·) is only used with the DDIMScheduler, and between [0, 1] | |
| extra_step_kwargs = {} | |
| if "eta" in set( | |
| inspect.signature(self.scheduler.step).parameters.keys()): | |
| extra_step_kwargs["eta"] = self.cfg.model.scheduler.eta | |
| for i, t in enumerate(timesteps): | |
| latent_model_input = latents | |
| # actually it does nothing here according to ddim scheduler | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| # predict the noise residual | |
| model_output = self.apply_classifier_free_guidance( | |
| x=latent_model_input, | |
| timesteps=t, | |
| seed=seed, | |
| at_feat=at_feat, | |
| guidance_scale=guidance_scale) | |
| latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample | |
| return_dict['latents'] = latents | |
| return return_dict | |
| def _diffusion_process(self, | |
| latents: torch.Tensor, | |
| audio_feat: torch.Tensor, | |
| id: torch.Tensor, | |
| seed: torch.Tensor, | |
| style_feature: torch.Tensor | |
| ) -> dict: | |
| # [batch_size, n_frame, latent_dim] | |
| noise = torch.randn_like(latents) | |
| bsz = latents.shape[0] | |
| timesteps = torch.randint( | |
| 0, | |
| self.scheduler.config.num_train_timesteps, | |
| (bsz,), | |
| device=latents.device | |
| ) | |
| timesteps = timesteps.long() | |
| noisy_latents = self.scheduler.add_noise(latents.clone(), noise, timesteps) | |
| model_output = self.denoiser( | |
| x=noisy_latents, | |
| timesteps=timesteps, | |
| seed=seed, | |
| at_feat=audio_feat, | |
| ) | |
| latents_pred, noise_pred = self.predicted_origin(model_output, timesteps, noisy_latents) | |
| n_set = { | |
| "noise": noise, | |
| "noise_pred": noise_pred, | |
| "sample_pred": latents_pred, | |
| "sample_gt": latents, | |
| "timesteps": timesteps, | |
| "model_output": model_output, | |
| } | |
| return n_set | |
| def train_forward(self, cond_: dict, x0: torch.Tensor) -> dict: | |
| audio = cond_['y']['audio_onset'] | |
| word = cond_['y']['word'] | |
| id = cond_['y']['id'] | |
| seed = cond_['y']['seed'] | |
| style_feature = cond_['y']['style_feature'] | |
| audio_feat = self.modality_encoder(audio, word) | |
| # Apply conditional dropout during training | |
| audio_feat = self.apply_conditional_dropout(audio_feat, cond_drop_prob=0.1) | |
| n_set = self._diffusion_process(x0, audio_feat, id, seed, style_feature) | |
| loss_dict = dict() | |
| # Diffusion loss | |
| if self.scheduler.config.prediction_type == "epsilon": | |
| model_pred, target = n_set['noise_pred'], n_set['noise'] | |
| elif self.scheduler.config.prediction_type == "sample": | |
| model_pred, target = n_set['sample_pred'], n_set['sample_gt'] | |
| elif self.scheduler.config.prediction_type == "v_prediction": | |
| # For v_prediction, we need to compute the v target | |
| # v = alpha * noise - sigma * x0 | |
| timesteps = n_set['timesteps'] | |
| self.alphas = self.alphas.to(x0.device) | |
| self.sigmas = self.sigmas.to(x0.device) | |
| alphas = extract_into_tensor(self.alphas, timesteps, x0.shape) | |
| sigmas = extract_into_tensor(self.sigmas, timesteps, x0.shape) | |
| v_target = alphas * n_set['noise'] - sigmas * n_set['sample_gt'] | |
| model_pred, target = n_set['model_output'], v_target # The model output is the v prediction | |
| else: | |
| raise ValueError(f"Invalid prediction_type {self.scheduler.config.prediction_type}.") | |
| # mse loss | |
| diff_loss = F.mse_loss(target, model_pred, reduction="mean") | |
| loss_dict['diff_loss'] = diff_loss | |
| total_loss = sum(loss_dict.values()) | |
| loss_dict['loss'] = total_loss | |
| return loss_dict | |