import time import inspect import logging from typing import Optional import scipy.stats as stats import tqdm import numpy as np from omegaconf import DictConfig from typing import Dict import math import torch import torch.distributions as dist import torch.nn as nn import torch import torch.nn.functional as F from models.config import instantiate_from_config from models.utils.utils import count_parameters, extract_into_tensor, sum_flat logger = logging.getLogger(__name__) def exponential_pdf(x, a): C = a / (np.exp(a) - 1) return C * np.exp(a * x) # Define a custom probability density function class ExponentialPDF(stats.rv_continuous): def _pdf(self, x, a): return exponential_pdf(x, a) def sample_t(exponential_pdf, num_samples, a=2): t = exponential_pdf.rvs(size=num_samples, a=a) t = torch.from_numpy(t).float() t = torch.cat([t, 1 - t], dim=0) t = t[torch.randperm(t.shape[0])] t = t[:num_samples] t_min = 1e-5 t_max = 1-1e-5 # Scale t to [t_min, t_max] t = t * (t_max - t_min) + t_min return t def sample_beta_distribution(num_samples, alpha=2, beta=0.8, t_min=1e-5, t_max=1-1e-5): """ Samples from a Beta distribution with the specified parameters. Args: num_samples (int): Number of samples to generate. alpha (float): Alpha parameter of the Beta distribution (shape1). beta (float): Beta parameter of the Beta distribution (shape2). t_min (float): Minimum value for scaling the samples (default is near 0). t_max (float): Maximum value for scaling the samples (default is near 1). Returns: torch.Tensor: Tensor of sampled values. """ # Define the Beta distribution beta_dist = dist.Beta(alpha, beta) # Sample values from the Beta distribution samples = beta_dist.sample((num_samples,)) # Scale the samples to the range [t_min, t_max] scaled_samples = samples * (t_max - t_min) + t_min return scaled_samples def sample_t_fast(num_samples, a=2, t_min=1e-5, t_max=1-1e-5): # Direct inverse sampling for exponential distribution C = a / (np.exp(a) - 1) # Generate uniform samples u = torch.rand(num_samples * 2) # Inverse transform sampling formula for the exponential PDF # F^(-1)(u) = (1/a) * ln(1 + u*(exp(a) - 1)) t = (1/a) * torch.log(1 + u * (np.exp(a) - 1)) # Combine t and 1-t t = torch.cat([t, 1 - t]) # Random permutation and slice t = t[torch.randperm(t.shape[0])][:num_samples] # Scale to [t_min, t_max] t = t * (t_max - t_min) + t_min return t def sample_cosmap(num_samples, t_min=1e-5, t_max=1-1e-5, device='cpu'): """ CosMap sampling. Args: num_samples: Number of samples to generate t_min, t_max: Range limits to avoid numerical issues """ # Generate uniform samples u = torch.rand(num_samples, device=device) # Apply the cosine mapping pi_half = torch.pi / 2 t = 1 - 1 / (torch.tan(pi_half * u) + 1) # Scale to [t_min, t_max] t = t * (t_max - t_min) + t_min return t def reshape_coefs(t): return t.reshape((t.shape[0], 1, 1, 1)) class GestureLSM(torch.nn.Module): def __init__(self, cfg) -> None: super().__init__() self.cfg = cfg # Initialize model components self.modality_encoder = instantiate_from_config(cfg.model.modality_encoder) self.denoiser = instantiate_from_config(cfg.model.denoiser) # Model hyperparameters self.do_classifier_free_guidance = cfg.model.do_classifier_free_guidance self.guidance_scale = cfg.model.guidance_scale self.num_inference_steps = cfg.model.n_steps # Loss functions self.smooth_l1_loss = torch.nn.SmoothL1Loss(reduction='none') self.num_joints = self.denoiser.joint_num self.seq_len = self.denoiser.seq_len self.input_dim = self.denoiser.input_dim # Flow matching mode: 'v' for velocity prediction, 'x1' for direct position prediction self.flow_mode = cfg.model.get("flow_mode", "v") assert self.flow_mode in [ "v", "x1", ], f"Flow mode must be 'v' or 'x1', got {self.flow_mode}" logger.info(f"Using flow mode: {self.flow_mode}") def summarize_parameters(self) -> None: logger.info(f'Denoiser: {count_parameters(self.denoiser)}M') logger.info(f'Encoder: {count_parameters(self.modality_encoder)}M') def apply_classifier_free_guidance(self, x, timesteps, seed, at_feat, cond_time=None, guidance_scale=1.0): """ Apply classifier-free guidance by running both conditional and unconditional predictions. Args: x: Input tensor timesteps: Timestep tensor seed: Seed vectors at_feat: Audio features cond_time: Conditional time tensor guidance_scale: Guidance scale (1.0 means no guidance) Returns: Guided output tensor """ if guidance_scale <= 1.0: # No guidance needed, run normal forward pass return self.denoiser( x=x, timesteps=timesteps, seed=seed, at_feat=at_feat, cond_time=cond_time, ) # Double the batch for classifier free guidance x_doubled = torch.cat([x] * 2, dim=0) seed_doubled = torch.cat([seed] * 2, dim=0) at_feat_doubled = torch.cat([at_feat] * 2, dim=0) # Properly expand timesteps to match doubled batch size batch_size = x.shape[0] timesteps_doubled = timesteps.expand(batch_size * 2) if cond_time is not None: cond_time_doubled = cond_time.expand(batch_size * 2) else: cond_time_doubled = None # Create conditional and unconditional audio features batch_size = at_feat.shape[0] seq_len = self.denoiser.null_cond_embed.shape[0] if at_feat.shape[1] != seq_len: at_feat = F.interpolate( at_feat.transpose(1, 2), size=seq_len, mode="linear", align_corners=False, ).transpose(1, 2) logger.warning( "Adjusted conditional feature length to match denoiser (got=%d, expected=%d)", at_feat.shape[1], seq_len, ) null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) at_feat_uncond = null_cond_embed.unsqueeze(0).expand(batch_size, -1, -1) at_feat_combined = torch.cat([at_feat, at_feat_uncond], dim=0) # Run both conditional and unconditional predictions output = self.denoiser( x=x_doubled, timesteps=timesteps_doubled, seed=seed_doubled, at_feat=at_feat_combined, cond_time=cond_time_doubled, ) # Split predictions and apply guidance pred_cond, pred_uncond = output.chunk(2, dim=0) guided_output = pred_uncond + guidance_scale * (pred_cond - pred_uncond) return guided_output def apply_conditional_dropout(self, at_feat, cond_drop_prob=0.1): """ Apply conditional dropout during training to simulate classifier-free guidance. Args: at_feat: Audio features tensor cond_drop_prob: Probability of dropping conditions (default 0.1) Returns: Modified audio features with some conditions replaced by null embeddings """ batch_size = at_feat.shape[0] # Create dropout mask keep_mask = torch.rand(batch_size, device=at_feat.device) > cond_drop_prob # Create null condition embeddings null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) # Apply dropout: replace dropped conditions with null embeddings at_feat_dropped = at_feat.clone() at_feat_dropped[~keep_mask] = null_cond_embed.unsqueeze(0).expand((~keep_mask).sum(), -1, -1) return at_feat_dropped def apply_force_cfg(self, at_feat, force_cfg): """ Apply forced conditional dropout based on the force_cfg mask. Args: at_feat: Audio features tensor force_cfg: Boolean mask indicating which samples should use null conditions Returns: Modified audio features with forced conditions replaced by null embeddings """ batch_size = at_feat.shape[0] # Create null condition embeddings null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) # Apply forced dropout: replace forced conditions with null embeddings at_feat_forced = at_feat.clone() force_cfg_tensor = torch.tensor(force_cfg, device=at_feat.device) at_feat_forced[force_cfg_tensor] = null_cond_embed.unsqueeze(0).expand(force_cfg_tensor.sum(), -1, -1) return at_feat_forced def forward(self, condition_dict: Dict[str, Dict]) -> Dict[str, torch.Tensor]: """Forward pass for inference. Args: condition_dict: Dictionary containing input conditions including audio, word tokens, and other features Returns: Dictionary containing generated latents """ # Extract input features audio = condition_dict['y']['audio_onset'] word_tokens = condition_dict['y']['word'] ids = condition_dict['y']['id'] seed_vectors = condition_dict['y']['seed'] style_features = condition_dict['y']['style_feature'] if 'wavlm' in condition_dict['y']: wavlm_features = condition_dict['y']['wavlm'] else: wavlm_features = None return_dict = {} return_dict['seed'] = seed_vectors # Encode input modalities audio_features = self.modality_encoder(audio, word_tokens, wavlm_features) return_dict['at_feat'] = audio_features # Initialize generation batch_size = audio_features.shape[0] latent_shape = (batch_size, self.input_dim * self.num_joints, 1, self.seq_len) # Sampling parameters x_t = torch.randn(latent_shape, device=audio_features.device) return_dict['init_noise'] = x_t epsilon = 1e-8 delta_t = torch.tensor(1 / self.num_inference_steps).to(audio_features.device) timesteps = torch.linspace(epsilon, 1 - epsilon, self.num_inference_steps + 1).to(audio_features.device) # Generation loop for step in range(1, len(timesteps)): current_t = timesteps[step - 1].unsqueeze(0) current_delta = delta_t.unsqueeze(0) with torch.no_grad(): model_output = self.apply_classifier_free_guidance( x=x_t, timesteps=current_t, seed=seed_vectors, at_feat=audio_features, cond_time=current_delta, guidance_scale=self.guidance_scale ) if self.flow_mode == "v": # Velocity prediction mode (original) # Update x_t using the predicted velocity field x_t = x_t + (timesteps[step] - timesteps[step - 1]) * model_output else: # 'x1' mode # Direct position prediction mode x_t = x_t + (timesteps[step] - timesteps[step - 1]) * (model_output - return_dict['init_noise']) return_dict['latents'] = x_t return return_dict def train_forward(self, condition_dict: Dict[str, Dict], latents: torch.Tensor, train_consistency=False) -> Dict[str, torch.Tensor]: """Compute training losses for both flow matching and consistency. Args: condition_dict: Dictionary containing training conditions latents: Target latent vectors Returns: Dictionary containing individual and total losses """ # Extract input features audio = condition_dict['y']['audio_onset'] word_tokens = condition_dict['y']['word'] instance_ids = condition_dict['y']['id'] seed_vectors = condition_dict['y']['seed'] style_features = condition_dict['y']['style_feature'] # Encode input modalities audio_features = self.modality_encoder(audio, word_tokens) # Initialize noise x0_noise = torch.randn_like(latents) # Sample timesteps and deltas deltas = 1 / torch.tensor([2 ** i for i in range(1, 8)]).to(latents.device) delta_probs = torch.ones((deltas.shape[0],)).to(latents.device) / deltas.shape[0] batch_size = latents.shape[0] flow_batch_size = int(batch_size * 3/4) # Apply conditional dropout during training for flow matching loss audio_features_flow = self.apply_conditional_dropout(audio_features[:flow_batch_size], cond_drop_prob=0.1) # Sample random coefficients t = sample_beta_distribution(batch_size, alpha=2, beta=1.2).to(latents.device) # t = sample_beta_distribution(batch_size, alpha=2, beta=0.8).to(latents.device) d = deltas[delta_probs.multinomial(batch_size, replacement=True)] d[:flow_batch_size] = 0 # Prepare inputs t_coef = reshape_coefs(t) x_t = t_coef * latents + (1 - t_coef) * x0_noise t = t_coef.flatten() # Flow matching loss model_output = self.denoiser( x=x_t[:flow_batch_size], timesteps=t[:flow_batch_size], seed=seed_vectors[:flow_batch_size], at_feat=audio_features_flow, cond_time=d[:flow_batch_size], ) losses = {} if self.flow_mode == "v": # Velocity prediction mode (original) flow_target = latents[:flow_batch_size] - x0_noise[:flow_batch_size] flow_loss = ( F.mse_loss(flow_target, model_output) / t[:flow_batch_size] ).mean() else: # 'x1' mode # Direct position prediction mode flow_target = latents[:flow_batch_size] flow_loss = (F.mse_loss(flow_target, model_output) / t[:flow_batch_size]).mean() losses["flow_loss"] = flow_loss # Consistency loss computation # Jan 11, perform cfg at the same time, 50% true and 50% false force_cfg = np.random.choice( [True, False], size=batch_size - flow_batch_size, p=[0.8, 0.2] ) # Apply force_cfg externally audio_features_consistency = self.apply_force_cfg(audio_features[flow_batch_size:], force_cfg) with torch.no_grad(): pred_t = self.denoiser( x=x_t[flow_batch_size:], timesteps=t[flow_batch_size:], seed=seed_vectors[flow_batch_size:], at_feat=audio_features_consistency, cond_time=d[flow_batch_size:], ) d_coef = reshape_coefs(d) if self.flow_mode == "v": speed_t = pred_t else: speed_t = speed_t - x0_noise x_td = x_t[flow_batch_size:] + d_coef[flow_batch_size:] * speed_t d = d_coef.flatten() pred_td = self.denoiser( x=x_td, timesteps=t[flow_batch_size:] + d[flow_batch_size:], seed=seed_vectors[flow_batch_size:], at_feat=audio_features_consistency, cond_time=d[flow_batch_size:], ) if self.flow_mode == "v": speed_td = pred_td else: speed_td = speed_t - x0_noise speed_target = (speed_t + speed_td) / 2 model_pred = self.denoiser( x=x_t[flow_batch_size:], timesteps=t[flow_batch_size:], seed=seed_vectors[flow_batch_size:], at_feat=audio_features_consistency, cond_time=2 * d[flow_batch_size:], ) if self.flow_mode == "v": speed_pred = model_pred else: speed_pred = model_pred - x0_noise consistency_loss = F.mse_loss(speed_pred, speed_target, reduction="mean") losses["consistency_loss"] = consistency_loss losses["loss"] = sum(losses.values()) return losses def train_reflow(self, latents, audio_features, x0_noise, seed_vectors) -> Dict[str, torch.Tensor]: """Compute training losses for both flow matching and consistency. Args: condition_dict: Dictionary containing training conditions latents: Target latent vectors Returns: Dictionary containing individual and total losses """ # Sample timesteps and deltas deltas = 1 / torch.tensor([2 ** i for i in range(1, 8)]).to(latents.device) delta_probs = torch.ones((deltas.shape[0],)).to(latents.device) / deltas.shape[0] batch_size = latents.shape[0] flow_batch_size = int(batch_size * 3/4) # Sample random coefficients t = sample_beta_distribution(batch_size, alpha=2, beta=1.2).to(latents.device) # t = sample_beta_distribution(batch_size, alpha=2, beta=0.8).to(latents.device) d = deltas[delta_probs.multinomial(batch_size, replacement=True)] d[:flow_batch_size] = 0 # Prepare inputs t_coef = reshape_coefs(t) x_t = t_coef * latents + (1 - t_coef) * x0_noise t = t_coef.flatten() # Flow matching loss flow_pred = self.denoiser( x=x_t[:flow_batch_size], timesteps=t[:flow_batch_size], seed=seed_vectors[:flow_batch_size], at_feat=audio_features[:flow_batch_size], cond_time=d[:flow_batch_size], ) flow_target = latents[:flow_batch_size] - x0_noise[:flow_batch_size] losses = {} flow_loss = (F.mse_loss(flow_target, flow_pred) / t).mean() losses['flow_loss'] = flow_loss # Consistency loss computation # Jan 11, perform cfg at the same time, 50% true and 50% false force_cfg = np.random.choice([True, False], size=batch_size-flow_batch_size, p=[0.8, 0.2]) with torch.no_grad(): speed_t = self.denoiser( x=x_t[flow_batch_size:], timesteps=t[flow_batch_size:], seed=seed_vectors[flow_batch_size:], at_feat=audio_features[flow_batch_size:], cond_time=d[flow_batch_size:], ) d_coef = reshape_coefs(d) x_td = x_t[flow_batch_size:] + d_coef[flow_batch_size:] * speed_t d = d_coef.flatten() speed_td = self.denoiser( x=x_td, timesteps=t[flow_batch_size:] + d[flow_batch_size:], seed=seed_vectors[flow_batch_size:], at_feat=audio_features[flow_batch_size:], cond_time=d[flow_batch_size:], ) speed_target = (speed_t + speed_td) / 2 speed_pred = self.denoiser( x=x_t[flow_batch_size:], timesteps=t[flow_batch_size:], seed=seed_vectors[flow_batch_size:], at_feat=audio_features[flow_batch_size:], cond_time=2 * d[flow_batch_size:], ) consistency_loss = F.mse_loss(speed_pred, speed_target, reduction="mean") losses['consistency_loss'] = consistency_loss losses['loss'] = sum(losses.values()) return losses