import copy from typing import Optional import torch.nn as nn import torch from einops import rearrange import math import numpy as np import torch.nn.functional as F class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) # (5000, 128) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (5000, 1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): # not used in the final model x = x + self.pe[:x.shape[0], :] return self.dropout(x) class TimestepEmbedding(nn.Module): def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str, out_dim = None, post_act_fn = None, cond_proj_dim = None, zero_init_cond: bool = True) -> None: super(TimestepEmbedding, self).__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) if zero_init_cond: self.cond_proj.weight.data.fill_(0.0) else: self.cond_proj = None # gelu self.act = torch.nn.GELU() if act_fn == 'gelu' else torch.nn.SiLU() if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) if post_act_fn is None: self.post_act = None else: self.post_act = torch.nn.GELU() if post_act_fn == 'gelu' else torch.nn.SiLU() def forward(self, sample: torch.Tensor, timestep_cond = None) -> torch.Tensor: if timestep_cond is not None: sample = sample + self.cond_proj(timestep_cond) sample = self.linear_1(sample) sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample class TimestepEmbedder(nn.Module): def __init__(self, latent_dim, sequence_pos_encoder): super().__init__() self.latent_dim = latent_dim self.sequence_pos_encoder = sequence_pos_encoder time_embed_dim = self.latent_dim self.time_embed = nn.Sequential( nn.Linear(self.latent_dim, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim), ) def forward(self, timesteps): return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) class InputProcess(nn.Module): def __init__(self, input_feats, latent_dim): super().__init__() self.input_feats = input_feats self.latent_dim = latent_dim self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) def forward(self, x): x = x.permute((0, 1, 3, 2)) x = self.poseEmbedding(x) # [seqlen, bs, d] return x class OutputProcess(nn.Module): def __init__(self, input_feats, latent_dim): super().__init__() self.input_feats = input_feats self.latent_dim = latent_dim self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) def forward(self, output): bs, n_joints, nframes, d = output.shape output = self.poseFinal(output) output = output.permute(0, 1, 3, 2) # [bs, njoints, nfeats, nframes] output = output.reshape(bs, n_joints * 128, 1, nframes) return output class SinusoidalEmbeddings(nn.Module): def __init__(self, dim): super().__init__() inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) def forward(self, x): n = x.shape[-2] t = torch.arange(n, device = x.device).type_as(self.inv_freq) freqs = torch.einsum('i , j -> i j', t, self.inv_freq) return torch.cat((freqs, freqs), dim=-1) def rotate_half(x): x = rearrange(x, 'b ... (r d) -> b (...) r d', r = 2) x1, x2 = x.unbind(dim = -2) return torch.cat((-x2, x1), dim = -1) def apply_rotary_pos_emb(q, k, freqs): q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) return q, k class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float) -> None: super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift def forward(self, timesteps: torch.Tensor) -> torch.Tensor: t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift) return t_emb def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ) -> torch.Tensor: # assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def reparameterize(mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def init_weight(m): if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): nn.init.xavier_normal_(m.weight) # m.bias.data.fill_(0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def init_weight_skcnn(m): if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) # m.bias.data.fill_(0.01) if m.bias is not None: #nn.init.constant_(m.bias, 0) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) bound = 1 / math.sqrt(fan_in) nn.init.uniform_(m.bias, -bound, bound) def sample(logits, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0, sample_logits=True): logits = logits[:, -1, :] / max(temperature, 1e-5) if top_k > 0 or top_p < 1.0: logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) probs = F.softmax(logits, dim=-1) if sample_logits: idx = torch.multinomial(probs, num_samples=1) else: _, idx = torch.topk(probs, k=1, dim=-1) return idx, probs ### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html def top_k_top_p_filtering( logits, top_k: int = 0, top_p: float = 1.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, ): """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Make sure we keep at least min_tokens_to_keep per batch example in the output From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if top_k > 0: top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p if min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = filter_value return logits class FlowMatchScheduler(): def __init__(self, num_inference_steps=20, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003 / 1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False): self.num_train_timesteps = num_train_timesteps self.shift = shift self.sigma_max = sigma_max self.sigma_min = sigma_min self.inverse_timesteps = inverse_timesteps self.extra_one_step = extra_one_step self.reverse_sigmas = reverse_sigmas self.set_timesteps(num_inference_steps, training=True) def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False): sigma_start = self.sigma_min + \ (self.sigma_max - self.sigma_min) * denoising_strength if self.extra_one_step: self.sigmas = torch.linspace( sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] else: self.sigmas = torch.linspace( sigma_start, self.sigma_min, num_inference_steps) if self.inverse_timesteps: self.sigmas = torch.flip(self.sigmas, dims=[0]) self.sigmas = self.shift * self.sigmas / \ (1 + (self.shift - 1) * self.sigmas) if self.reverse_sigmas: self.sigmas = 1 - self.sigmas self.timesteps = self.sigmas * self.num_train_timesteps if training: x = self.timesteps y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) y_shifted = y - y.min() bsmntw_weighing = y_shifted * \ (num_inference_steps / y_shifted.sum()) self.linear_timesteps_weights = bsmntw_weighing def step(self, model_output, timestep, sample, to_final=False): if timestep.ndim == 2: timestep = timestep.flatten(0, 1) self.sigmas = self.sigmas.to(model_output.device) self.timesteps = self.timesteps.to(model_output.device) timestep_id = torch.argmin( (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) if to_final or (timestep_id + 1 >= len(self.timesteps)).any(): sigma_ = 1 if ( self.inverse_timesteps or self.reverse_sigmas) else 0 else: sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1) prev_sample = sample + model_output * (sigma_ - sigma) return prev_sample def add_noise(self, original_samples, noise, timestep): """ Diffusion forward corruption process. Input: - clean_latent: the clean latent with shape [B*T, C, H, W] - noise: the noise with shape [B*T, C, H, W] - timestep: the timestep with shape [B*T] Output: the corrupted latent with shape [B*T, C, H, W] """ if timestep.ndim == 2: timestep = timestep.flatten(0, 1) self.sigmas = self.sigmas.to(noise.device) self.timesteps = self.timesteps.to(noise.device) timestep_id = torch.argmin( (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) sample = (1 - sigma) * original_samples + sigma * noise return sample.type_as(noise) def training_target(self, sample, noise, timestep): target = noise - sample return target def training_weight(self, timestep): """ Input: - timestep: the timestep with shape [B*T] Output: the corresponding weighting [B*T] """ if timestep.ndim == 2: timestep = timestep.flatten(0, 1) self.linear_timesteps_weights = self.linear_timesteps_weights.to(timestep.device) timestep_id = torch.argmin( (self.timesteps.unsqueeze(1) - timestep.unsqueeze(0)).abs(), dim=0) weights = self.linear_timesteps_weights[timestep_id] return weights