| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import math |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
|
|
| def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: |
| """ |
| Different from the original ROPE used for flux. |
| Megatron attention takes the out product and calculate sin/cos inside, so we only need to get the freqs here |
| in the shape of [seq, ..., dim] |
| """ |
| assert dim % 2 == 0, "The dimension must be even." |
|
|
| scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim |
| omega = 1.0 / (theta**scale) |
|
|
| out = torch.einsum("...n,d->...nd", pos, omega) |
|
|
| return out.float() |
|
|
|
|
| class EmbedND(nn.Module): |
| ''' |
| Generate Rope matrix with preset axes dimensions. |
| ''' |
|
|
| def __init__(self, dim: int, theta: int, axes_dim: list[int]): |
| |
| super().__init__() |
| self.dim = dim |
| self.theta = theta |
| self.axes_dim = axes_dim |
|
|
| def forward(self, ids: torch.Tensor) -> torch.Tensor: |
| |
| n_axes = ids.shape[-1] |
| emb = torch.cat( |
| [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], |
| dim=-1, |
| ) |
| emb = emb.unsqueeze(1).permute(2, 0, 1, 3) |
| return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) |
|
|
|
|
| class MLPEmbedder(nn.Module): |
| ''' |
| MLP embedder with two projection layers and Silu in between. |
| ''' |
|
|
| def __init__(self, in_dim: int, hidden_dim: int): |
| |
| super().__init__() |
| self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) |
| self.silu = nn.SiLU() |
| self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| return self.out_layer(self.silu(self.in_layer(x))) |
|
|
|
|
| def get_timestep_embedding( |
| timesteps: torch.Tensor, |
| embedding_dim: int, |
| flip_sin_to_cos: bool = True, |
| downscale_freq_shift: float = 0, |
| scale: float = 1, |
| max_period: int = 10000, |
| ): |
| """ |
| This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. |
| |
| Args |
| timesteps (torch.Tensor): |
| a 1-D Tensor of N indices, one per batch element. These may be fractional. |
| embedding_dim (int): |
| the dimension of the output. |
| flip_sin_to_cos (bool): |
| Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) |
| downscale_freq_shift (float): |
| Controls the delta between frequencies between dimensions |
| scale (float): |
| Scaling factor applied to the embeddings. |
| max_period (int): |
| Controls the maximum frequency of the embeddings |
| Returns |
| torch.Tensor: an [N x dim] Tensor of positional embeddings. |
| """ |
| assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" |
|
|
| half_dim = embedding_dim // 2 |
| exponent = -math.log(max_period) * torch.arange( |
| start=0, end=half_dim, dtype=torch.float32, device=timesteps.device |
| ) |
| exponent = exponent / (half_dim - downscale_freq_shift) |
|
|
| emb = torch.exp(exponent) |
| emb = timesteps[:, None].float() * emb[None, :] |
|
|
| |
| emb = scale * emb |
|
|
| |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
|
|
| |
| if flip_sin_to_cos: |
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) |
|
|
| |
| if embedding_dim % 2 == 1: |
| emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) |
| return emb |
|
|
|
|
| |
| class Timesteps(nn.Module): |
| def __init__( |
| self, |
| embedding_dim: int, |
| flip_sin_to_cos: bool = True, |
| downscale_freq_shift: float = 0, |
| scale: float = 1, |
| max_period: int = 10000, |
| ): |
| super().__init__() |
| self.embedding_dim = embedding_dim |
| self.flip_sin_to_cos = flip_sin_to_cos |
| self.downscale_freq_shift = downscale_freq_shift |
| self.scale = scale |
| self.max_period = max_period |
|
|
| def forward(self, timesteps: torch.Tensor) -> torch.Tensor: |
| t_emb = get_timestep_embedding( |
| timesteps, |
| self.embedding_dim, |
| flip_sin_to_cos=self.flip_sin_to_cos, |
| downscale_freq_shift=self.downscale_freq_shift, |
| scale=self.scale, |
| max_period=self.max_period, |
| ) |
| return t_emb |
|
|
|
|
| |
|
|
|
|
| class TimeStepEmbedder(nn.Module): |
| """ |
| A neural network module that embeds timesteps for use in models such as diffusion models. |
| It projects the input timesteps to a higher-dimensional space and then embeds them using |
| an MLP (Multilayer Perceptron). The projection and embedding provide a learned representation |
| of the timestep that can be used in further computations. |
| |
| Args: |
| embedding_dim (int): |
| The dimensionality of the timestep embedding space. |
| hidden_dim (int): |
| The dimensionality of the hidden layer in the MLPEmbedder. |
| flip_sin_to_cos (bool, optional): |
| Whether to flip the sine and cosine components during the projection (default is True). |
| downscale_freq_shift (float, optional): |
| A scaling factor for the frequency shift during the projection (default is 0). |
| scale (float, optional): |
| A scaling factor applied to the timestep projections (default is 1). |
| max_period (int, optional): |
| The maximum period for the sine and cosine functions used in projection (default is 10000). |
| |
| Methods: |
| forward: Takes a tensor of timesteps and returns their embedded representation. |
| """ |
|
|
| def __init__( |
| self, |
| embedding_dim: int, |
| hidden_dim: int, |
| flip_sin_to_cos: bool = True, |
| downscale_freq_shift: float = 0, |
| scale: float = 1, |
| max_period: int = 10000, |
| ): |
|
|
| super().__init__() |
|
|
| self.time_proj = Timesteps( |
| embedding_dim=embedding_dim, |
| flip_sin_to_cos=flip_sin_to_cos, |
| downscale_freq_shift=downscale_freq_shift, |
| scale=scale, |
| max_period=max_period, |
| ) |
| self.time_embedder = MLPEmbedder(in_dim=embedding_dim, hidden_dim=hidden_dim) |
|
|
| def forward(self, timesteps: torch.Tensor) -> torch.Tensor: |
| |
| timesteps_proj = self.time_proj(timesteps) |
| timesteps_emb = self.time_embedder(timesteps_proj) |
|
|
| return timesteps_emb |
|
|