GestureLSM / models /Diffusion.py
Tharun156's picture
Upload 149 files
f7400bf verified
import time
import inspect
import logging
from typing import Optional
import numpy as np
from omegaconf import DictConfig
import torch
import torch.nn.functional as F
from models.config import instantiate_from_config
from models.utils.utils import count_parameters, extract_into_tensor, sum_flat
logger = logging.getLogger(__name__)
class GestureDiffusion(torch.nn.Module):
def __init__(self, cfg) -> None:
super().__init__()
self.cfg = cfg
self.modality_encoder = instantiate_from_config(cfg.model.modality_encoder)
self.denoiser = instantiate_from_config(cfg.model.denoiser)
self.scheduler = instantiate_from_config(cfg.model.scheduler)
self.alphas = torch.sqrt(self.scheduler.alphas_cumprod)
self.sigmas = torch.sqrt(1 - self.scheduler.alphas_cumprod)
self.do_classifier_free_guidance = cfg.model.do_classifier_free_guidance
self.guidance_scale = cfg.model.guidance_scale
self.smooth_l1_loss = torch.nn.SmoothL1Loss(reduction='none')
self.seq_len = self.denoiser.seq_len
self.input_dim = self.denoiser.input_dim
self.num_joints = self.denoiser.joint_num
def summarize_parameters(self) -> None:
logger.info(f'Denoiser: {count_parameters(self.denoiser)}M')
logger.info(f'Scheduler: {count_parameters(self.modality_encoder)}M')
def apply_classifier_free_guidance(self, x, timesteps, seed, at_feat, guidance_scale=1.0):
"""
Apply classifier-free guidance by running both conditional and unconditional predictions.
Args:
x: Input tensor
timesteps: Timestep tensor
seed: Seed vectors
at_feat: Audio features
guidance_scale: Guidance scale (1.0 means no guidance)
Returns:
Guided output tensor
"""
if guidance_scale <= 1.0:
# No guidance needed, run normal forward pass
return self.denoiser(
x=x,
timesteps=timesteps,
seed=seed,
at_feat=at_feat,
cond_drop_prob=0.0,
null_cond=False
)
# Double the batch for classifier free guidance
x_doubled = torch.cat([x] * 2, dim=0)
seed_doubled = torch.cat([seed] * 2, dim=0)
at_feat_doubled = torch.cat([at_feat] * 2, dim=0)
# Properly expand timesteps to match doubled batch size
batch_size = x.shape[0]
timesteps_doubled = timesteps.expand(batch_size * 2)
# Create conditional and unconditional audio features
batch_size = at_feat.shape[0]
null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype)
at_feat_uncond = null_cond_embed.unsqueeze(0).expand(batch_size, -1, -1)
at_feat_combined = torch.cat([at_feat, at_feat_uncond], dim=0)
# Run both conditional and unconditional predictions
output = self.denoiser(
x=x_doubled,
timesteps=timesteps_doubled,
seed=seed_doubled,
at_feat=at_feat_combined,
)
# Split predictions and apply guidance
pred_cond, pred_uncond = output.chunk(2, dim=0)
guided_output = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
return guided_output
def apply_conditional_dropout(self, at_feat, cond_drop_prob=0.1):
"""
Apply conditional dropout during training to simulate classifier-free guidance.
Args:
at_feat: Audio features tensor
cond_drop_prob: Probability of dropping conditions (default 0.1)
Returns:
Modified audio features with some conditions replaced by null embeddings
"""
batch_size = at_feat.shape[0]
# Create dropout mask
keep_mask = torch.rand(batch_size, device=at_feat.device) > cond_drop_prob
# Create null condition embeddings
null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype)
# Apply dropout: replace dropped conditions with null embeddings
at_feat_dropped = at_feat.clone()
at_feat_dropped[~keep_mask] = null_cond_embed.unsqueeze(0).expand((~keep_mask).sum(), -1, -1)
return at_feat_dropped
def predicted_origin(self, model_output: torch.Tensor, timesteps: torch.Tensor, sample: torch.Tensor) -> tuple:
self.alphas = self.alphas.to(model_output.device)
self.sigmas = self.sigmas.to(model_output.device)
alphas = extract_into_tensor(self.alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(self.sigmas, timesteps, sample.shape)
# i will do this
if self.scheduler.config.prediction_type == "epsilon":
pred_original_sample = (sample - sigmas * model_output) / alphas
pred_epsilon = model_output
elif self.scheduler.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alphas * model_output) / sigmas
elif self.scheduler.config.prediction_type == "v_prediction":
pred_original_sample = alphas * sample - sigmas * model_output
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(f"Invalid prediction_type {self.scheduler.config.prediction_type}.")
return pred_original_sample, pred_epsilon
def forward(self, cond_: dict) -> dict:
audio = cond_['y']['audio_onset']
word = cond_['y']['word']
id = cond_['y']['id']
seed = cond_['y']['seed']
style_feature = cond_['y']['style_feature']
audio_feat = self.modality_encoder(audio, word)
bs = audio_feat.shape[0]
shape_ = (bs, self.input_dim * self.num_joints, 1, self.seq_len)
latents = torch.randn(shape_, device=audio_feat.device)
latents = self._diffusion_reverse(latents, seed, audio_feat, guidance_scale=self.guidance_scale)
return latents
def _diffusion_reverse(
self,
latents: torch.Tensor,
seed: torch.Tensor,
at_feat: torch.Tensor,
guidance_scale: float = 1,
) -> torch.Tensor:
return_dict = {}
# scale the initial noise by the standard deviation required by the scheduler, like in Stable Diffusion
# this is the initial noise need to be returned for rectified training
latents = latents * self.scheduler.init_noise_sigma
noise = latents
return_dict["init_noise"] = latents
return_dict['at_feat'] = at_feat
return_dict['seed'] = seed
# set timesteps
self.scheduler.set_timesteps(self.cfg.model.scheduler.num_inference_steps)
timesteps = self.scheduler.timesteps.to(at_feat.device)
latents = torch.zeros_like(latents)
latents = self.scheduler.add_noise(latents, noise, timesteps[0])
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (Ξ·) is only used with the DDIMScheduler, and between [0, 1]
extra_step_kwargs = {}
if "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()):
extra_step_kwargs["eta"] = self.cfg.model.scheduler.eta
for i, t in enumerate(timesteps):
latent_model_input = latents
# actually it does nothing here according to ddim scheduler
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
model_output = self.apply_classifier_free_guidance(
x=latent_model_input,
timesteps=t,
seed=seed,
at_feat=at_feat,
guidance_scale=guidance_scale)
latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
return_dict['latents'] = latents
return return_dict
def _diffusion_process(self,
latents: torch.Tensor,
audio_feat: torch.Tensor,
id: torch.Tensor,
seed: torch.Tensor,
style_feature: torch.Tensor
) -> dict:
# [batch_size, n_frame, latent_dim]
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(
0,
self.scheduler.config.num_train_timesteps,
(bsz,),
device=latents.device
)
timesteps = timesteps.long()
noisy_latents = self.scheduler.add_noise(latents.clone(), noise, timesteps)
model_output = self.denoiser(
x=noisy_latents,
timesteps=timesteps,
seed=seed,
at_feat=audio_feat,
)
latents_pred, noise_pred = self.predicted_origin(model_output, timesteps, noisy_latents)
n_set = {
"noise": noise,
"noise_pred": noise_pred,
"sample_pred": latents_pred,
"sample_gt": latents,
"timesteps": timesteps,
"model_output": model_output,
}
return n_set
def train_forward(self, cond_: dict, x0: torch.Tensor) -> dict:
audio = cond_['y']['audio_onset']
word = cond_['y']['word']
id = cond_['y']['id']
seed = cond_['y']['seed']
style_feature = cond_['y']['style_feature']
audio_feat = self.modality_encoder(audio, word)
# Apply conditional dropout during training
audio_feat = self.apply_conditional_dropout(audio_feat, cond_drop_prob=0.1)
n_set = self._diffusion_process(x0, audio_feat, id, seed, style_feature)
loss_dict = dict()
# Diffusion loss
if self.scheduler.config.prediction_type == "epsilon":
model_pred, target = n_set['noise_pred'], n_set['noise']
elif self.scheduler.config.prediction_type == "sample":
model_pred, target = n_set['sample_pred'], n_set['sample_gt']
elif self.scheduler.config.prediction_type == "v_prediction":
# For v_prediction, we need to compute the v target
# v = alpha * noise - sigma * x0
timesteps = n_set['timesteps']
self.alphas = self.alphas.to(x0.device)
self.sigmas = self.sigmas.to(x0.device)
alphas = extract_into_tensor(self.alphas, timesteps, x0.shape)
sigmas = extract_into_tensor(self.sigmas, timesteps, x0.shape)
v_target = alphas * n_set['noise'] - sigmas * n_set['sample_gt']
model_pred, target = n_set['model_output'], v_target # The model output is the v prediction
else:
raise ValueError(f"Invalid prediction_type {self.scheduler.config.prediction_type}.")
# mse loss
diff_loss = F.mse_loss(target, model_pred, reduction="mean")
loss_dict['diff_loss'] = diff_loss
total_loss = sum(loss_dict.values())
loss_dict['loss'] = total_loss
return loss_dict