| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .utils import get_activation_fn |
|
|
|
|
| def get_timestep_embedding( |
| timesteps: torch.Tensor, |
| embedding_dim: int, |
| flip_sin_to_cos: bool = False, |
| downscale_freq_shift: float = 1, |
| scale: float = 1, |
| max_period: int = 10000, |
| ) -> torch.Tensor: |
| assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" |
|
|
| half_dim = embedding_dim // 2 |
| exponent = -math.log(max_period) * torch.arange( |
| start=0, end=half_dim, dtype=torch.float32, device=timesteps.device |
| ) |
| exponent = exponent / (half_dim - downscale_freq_shift) |
|
|
| emb = torch.exp(exponent) |
| emb = timesteps[:, None].float() * emb[None, :] |
|
|
| |
| emb = scale * emb |
|
|
| |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
|
|
| |
| if flip_sin_to_cos: |
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) |
|
|
| |
| if embedding_dim % 2 == 1: |
| emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) |
| return emb |
|
|
|
|
| class TimestepEmbedding(nn.Module): |
| def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str, |
| out_dim: Optional[int] = None, post_act_fn: Optional[str] = None, |
| cond_proj_dim: Optional[int] = None, zero_init_cond: bool = True) -> None: |
| super(TimestepEmbedding, self).__init__() |
|
|
| self.linear_1 = nn.Linear(in_channels, time_embed_dim) |
|
|
| if cond_proj_dim is not None: |
| self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) |
| if zero_init_cond: |
| self.cond_proj.weight.data.fill_(0.0) |
| else: |
| self.cond_proj = None |
|
|
| self.act = get_activation_fn(act_fn) |
|
|
| if out_dim is not None: |
| time_embed_dim_out = out_dim |
| else: |
| time_embed_dim_out = time_embed_dim |
| self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) |
|
|
| if post_act_fn is None: |
| self.post_act = None |
| else: |
| self.post_act = get_activation_fn(post_act_fn) |
|
|
| def forward(self, sample: torch.Tensor, timestep_cond: Optional[torch.Tensor] = None) -> torch.Tensor: |
| if timestep_cond is not None: |
| sample = sample + self.cond_proj(timestep_cond) |
| sample = self.linear_1(sample) |
| sample = self.act(sample) |
| sample = self.linear_2(sample) |
| if self.post_act is not None: |
| sample = self.post_act(sample) |
| return sample |
|
|
|
|
| class Timesteps(nn.Module): |
| def __init__(self, num_channels: int, flip_sin_to_cos: bool, |
| downscale_freq_shift: float) -> None: |
| super().__init__() |
| self.num_channels = num_channels |
| self.flip_sin_to_cos = flip_sin_to_cos |
| self.downscale_freq_shift = downscale_freq_shift |
|
|
| def forward(self, timesteps: torch.Tensor) -> torch.Tensor: |
| t_emb = get_timestep_embedding( |
| timesteps, |
| self.num_channels, |
| flip_sin_to_cos=self.flip_sin_to_cos, |
| downscale_freq_shift=self.downscale_freq_shift) |
| return t_emb |
|
|