Spaces:
Runtime error
Runtime error
| import logging | |
| from functools import partial | |
| from typing import Dict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models.config import instantiate_from_config | |
| from models.utils.utils import count_parameters | |
| logger = logging.getLogger(__name__) | |
| def print_memory_usage(location: str, device: torch.device = None): | |
| """Print current GPU memory usage.""" | |
| if device is None: | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| if device.type == 'cuda': | |
| allocated = torch.cuda.memory_allocated(device) / 1024**3 # GB | |
| reserved = torch.cuda.memory_reserved(device) / 1024**3 # GB | |
| max_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 # GB | |
| print(f"[{location}] GPU Memory - Allocated: {allocated:.3f}GB, Reserved: {reserved:.3f}GB, Max: {max_allocated:.3f}GB") | |
| else: | |
| print(f"[{location}] Using CPU device") | |
| def clear_gpu_cache(): | |
| """Clear GPU cache to free memory.""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print("GPU cache cleared") | |
| def find_attention_modules(module, attention_modules=None): | |
| """Recursively find all attention modules in a model.""" | |
| if attention_modules is None: | |
| attention_modules = [] | |
| for name, child in module.named_children(): | |
| if hasattr(child, 'set_force_no_fused_attn'): | |
| attention_modules.append(child) | |
| find_attention_modules(child, attention_modules) | |
| return attention_modules | |
| def mean_flat(x): | |
| """ | |
| Take the mean over all non-batch dimensions. | |
| """ | |
| return torch.mean(x, dim=list(range(1, len(x.size())))) | |
| def reshape_coefs(t): | |
| """Reshape coefficients for broadcasting.""" | |
| return t.reshape((t.shape[0], 1, 1, 1)) | |
| class GestureMF(torch.nn.Module): | |
| """ | |
| MeanFlow loss calculator for gesture generation, designed to be similar to GestureLSM. | |
| """ | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| # Initialize model components | |
| self.modality_encoder = instantiate_from_config(cfg.model.modality_encoder) | |
| self.denoiser = instantiate_from_config(cfg.model.denoiser) | |
| # Model hyperparameters | |
| self.do_classifier_free_guidance = cfg.model.do_classifier_free_guidance | |
| self.guidance_scale = cfg.model.guidance_scale | |
| self.num_inference_steps = cfg.model.n_steps | |
| # meanflow args | |
| self.weighting = cfg.model.weighting | |
| self.path_type = cfg.model.path_type | |
| self.noise_dist = cfg.model.noise_dist | |
| self.data_proportion = cfg.model.data_proportion | |
| self.cfg_min_t = cfg.model.cfg_min_t | |
| self.cfg_max_t = cfg.model.cfg_max_t | |
| self.time_mu = cfg.model.time_mu | |
| self.time_sigma = cfg.model.time_sigma | |
| self.time_min = cfg.model.time_min | |
| self.time_max = cfg.model.time_max | |
| # CFG parameters | |
| self.cfg_omega = cfg.model.get("cfg_omega", 0.5) | |
| self.cfg_kappa = cfg.model.get("cfg_kappa", 0.5) | |
| self.adaptive_p = cfg.model.get("adaptive_p", 0.5) | |
| self.num_joints = self.denoiser.joint_num | |
| self.seq_len = self.denoiser.seq_len | |
| self.input_dim = self.denoiser.input_dim | |
| self.latent_dim = self.denoiser.latent_dim | |
| # Flow matching mode: 'v' for velocity prediction, 'x1' for direct position prediction | |
| self.flow_mode = cfg.model.get("flow_mode", "v") | |
| assert self.flow_mode in [ | |
| "v", | |
| "x1", | |
| ], f"Flow mode must be 'v' or 'x1', got {self.flow_mode}" | |
| logger.info(f"Using flow mode: {self.flow_mode}") | |
| # Set up JVP function for computing derivatives | |
| self.jvp_fn = torch.func.jvp | |
| def summarize_parameters(self) -> None: | |
| logger.info(f'Denoiser: {count_parameters(self.denoiser)}M') | |
| logger.info(f'Encoder: {count_parameters(self.modality_encoder)}M') | |
| def _disable_fused_attn_for_jvp(self): | |
| """Temporarily disable fused attention to avoid forward AD issues.""" | |
| # Find all attention modules in the denoiser | |
| attention_modules = find_attention_modules(self.denoiser) | |
| if attention_modules: | |
| # Disable fused attention for all found modules | |
| for attn_module in attention_modules: | |
| attn_module.set_force_no_fused_attn(True) | |
| return attention_modules, False | |
| else: | |
| # Fallback: check if denoiser itself has the method | |
| if hasattr(self.denoiser, 'set_force_no_fused_attn'): | |
| self.denoiser.set_force_no_fused_attn(True) | |
| return self.denoiser, True | |
| return None, None | |
| def _restore_fused_attn(self, original_state, is_simple): | |
| """Restore original fused attention setting.""" | |
| if original_state is None: | |
| return | |
| if is_simple: | |
| # Restore for denoiser itself | |
| if hasattr(original_state, 'set_force_no_fused_attn'): | |
| original_state.set_force_no_fused_attn(False) | |
| else: | |
| # Restore for each block | |
| for attn in original_state: | |
| if hasattr(attn, 'set_force_no_fused_attn'): | |
| attn.set_force_no_fused_attn(False) | |
| def _logit_normal_dist(self, bz, device): | |
| rnd_normal = torch.randn((bz, 1, 1, 1), device=device) | |
| return torch.sigmoid(rnd_normal * self.time_sigma + self.time_mu) | |
| def _uniform_dist(self, bz, device): | |
| return torch.rand((bz, 1, 1, 1), device=device) | |
| def interpolate(self, t): | |
| """Define interpolation function""" | |
| if self.path_type == "linear": | |
| alpha_t = 1 - t | |
| sigma_t = t | |
| d_alpha_t = -1 | |
| d_sigma_t = 1 | |
| elif self.path_type == "cosine": | |
| alpha_t = torch.cos(t * np.pi / 2) | |
| sigma_t = torch.sin(t * np.pi / 2) | |
| d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2) | |
| d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2) | |
| else: | |
| raise NotImplementedError() | |
| return alpha_t, sigma_t, d_alpha_t, d_sigma_t | |
| def sample_tr(self, bz, device): | |
| """Sample time parameters t and r.""" | |
| if self.noise_dist == "logit_normal": | |
| t = self._logit_normal_dist(bz, device) | |
| r = self._logit_normal_dist(bz, device) | |
| elif self.noise_dist == "uniform": | |
| t = self._uniform_dist(bz, device) | |
| r = self._uniform_dist(bz, device) | |
| else: | |
| raise ValueError(f"Unknown noise distribution: {self.noise_dist}") | |
| t, r = torch.maximum(t, r), torch.minimum(t, r) | |
| data_size = int(bz * self.data_proportion) | |
| zero_mask = (torch.arange(bz, device=t.device) < data_size).view(bz, 1, 1, 1) | |
| r = torch.where(zero_mask, t, r) | |
| return t, r | |
| def apply_classifier_free_guidance(self, x, timesteps, seed, at_feat, cond_time=None, guidance_scale=1.0): | |
| """ | |
| Apply classifier-free guidance by running both conditional and unconditional predictions. | |
| Args: | |
| x: Input tensor | |
| timesteps: Timestep tensor | |
| seed: Seed vectors | |
| at_feat: Audio features | |
| cond_time: Conditional time tensor | |
| guidance_scale: Guidance scale (1.0 means no guidance) | |
| Returns: | |
| Guided output tensor | |
| """ | |
| if guidance_scale <= 1.0: | |
| # No guidance needed, run normal forward pass | |
| return self.denoiser( | |
| x=x, | |
| timesteps=timesteps, | |
| seed=seed, | |
| at_feat=at_feat, | |
| cond_time=cond_time, | |
| ) | |
| # Double the batch for classifier free guidance | |
| x_doubled = torch.cat([x] * 2, dim=0) | |
| seed_doubled = torch.cat([seed] * 2, dim=0) | |
| # Properly expand timesteps to match doubled batch size | |
| batch_size = x.shape[0] | |
| timesteps_doubled = timesteps.expand(batch_size * 2) | |
| if cond_time is not None: | |
| cond_time_doubled = cond_time.expand(batch_size * 2) | |
| else: | |
| cond_time_doubled = None | |
| # Create conditional and unconditional audio features | |
| batch_size = at_feat.shape[0] | |
| null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) | |
| at_feat_uncond = null_cond_embed.unsqueeze(0).expand(batch_size, -1, -1) | |
| at_feat_combined = torch.cat([at_feat, at_feat_uncond], dim=0) | |
| # Run both conditional and unconditional predictions | |
| output = self.denoiser( | |
| x=x_doubled, | |
| timesteps=timesteps_doubled, | |
| seed=seed_doubled, | |
| at_feat=at_feat_combined, | |
| cond_time=cond_time_doubled, | |
| ) | |
| # Split predictions and apply guidance | |
| pred_cond, pred_uncond = output.chunk(2, dim=0) | |
| guided_output = pred_uncond + guidance_scale * (pred_cond - pred_uncond) | |
| return guided_output | |
| def apply_conditional_dropout(self, at_feat, cond_drop_prob=0.1): | |
| """ | |
| Apply conditional dropout during training to simulate classifier-free guidance. | |
| Args: | |
| at_feat: Audio features tensor | |
| cond_drop_prob: Probability of dropping conditions (default 0.1) | |
| Returns: | |
| Modified audio features with some conditions replaced by null embeddings | |
| """ | |
| batch_size = at_feat.shape[0] | |
| # Create dropout mask | |
| keep_mask = torch.rand(batch_size, device=at_feat.device) > cond_drop_prob | |
| # Create null condition embeddings | |
| null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) | |
| # Apply dropout: replace dropped conditions with null embeddings | |
| at_feat_dropped = at_feat.clone() | |
| at_feat_dropped[~keep_mask] = null_cond_embed.unsqueeze(0).expand((~keep_mask).sum(), -1, -1) | |
| return at_feat_dropped, keep_mask | |
| def forward(self, condition_dict: Dict[str, Dict]) -> Dict[str, torch.Tensor]: | |
| """Forward pass for inference. | |
| Args: | |
| condition_dict: Dictionary containing input conditions including audio, word tokens, | |
| and other features | |
| Returns: | |
| Dictionary containing generated latents | |
| """ | |
| # Extract input features | |
| audio = condition_dict['y']['audio_onset'] | |
| word_tokens = condition_dict['y']['word'] | |
| ids = condition_dict['y']['id'] | |
| seed_vectors = condition_dict['y']['seed'] | |
| style_features = condition_dict['y']['style_feature'] | |
| if 'wavlm' in condition_dict['y']: | |
| wavlm_features = condition_dict['y']['wavlm'] | |
| else: | |
| wavlm_features = None | |
| return_dict = {} | |
| return_dict['seed'] = seed_vectors | |
| # Encode input modalities | |
| audio_features = self.modality_encoder(audio, word_tokens, wavlm_features) | |
| return_dict['at_feat'] = audio_features | |
| # Initialize generation | |
| batch_size = audio_features.shape[0] | |
| latent_shape = (batch_size, self.input_dim * self.num_joints, 1, self.seq_len) | |
| # Sampling parameters | |
| x_t = torch.randn(latent_shape, device=audio_features.device) | |
| return_dict['init_noise'] = x_t | |
| if self.num_inference_steps == 1: | |
| cond_time = torch.zeros(1, device=audio_features.device) | |
| timestep = torch.ones(1, device=audio_features.device) | |
| model_output = self.apply_classifier_free_guidance( | |
| x=x_t, | |
| timesteps=timestep, | |
| seed=seed_vectors, | |
| at_feat=audio_features, | |
| cond_time=cond_time, | |
| guidance_scale=self.guidance_scale | |
| ) | |
| # one-step meanflow | |
| x_t = x_t - model_output | |
| else: | |
| epsilon = 1e-8 | |
| timesteps = torch.linspace(1 - epsilon, 0, self.num_inference_steps + 1).to(audio_features.device) | |
| # Generation loop | |
| for step in range(len(timesteps) - 1): | |
| current_t = timesteps[step].unsqueeze(0) | |
| current_r = timesteps[step + 1].unsqueeze(0) | |
| model_output = self.apply_classifier_free_guidance( | |
| x=x_t, | |
| timesteps=current_t, | |
| cond_time=current_r, | |
| seed=seed_vectors, | |
| at_feat=audio_features, | |
| guidance_scale=self.guidance_scale | |
| ) | |
| # only support v-prediction mode for now | |
| # Update x_t using the predicted meanflow velocity field | |
| x_t = x_t - (current_t - current_r) * model_output | |
| return_dict['latents'] = x_t | |
| return return_dict |