|
|
import math |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from einops import rearrange, reduce, repeat |
|
|
from einops.layers.torch import Rearrange |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
def __init__(self, d_model, dropout=0.1, max_len=500, batch_first=False): |
|
|
super().__init__() |
|
|
self.batch_first = batch_first |
|
|
|
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
|
|
pe = torch.zeros(max_len, d_model) |
|
|
position = torch.arange(0, max_len).unsqueeze(1) |
|
|
div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) |
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
pe = pe.unsqueeze(0).transpose(0, 1) |
|
|
|
|
|
self.register_buffer("pe", pe) |
|
|
|
|
|
def forward(self, x): |
|
|
if self.batch_first: |
|
|
x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] |
|
|
else: |
|
|
x = x + self.pe[: x.shape[0], :] |
|
|
return self.dropout(x) |
|
|
|
|
|
|
|
|
|
|
|
class SinusoidalPosEmb(nn.Module): |
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, x): |
|
|
device = x.device |
|
|
half_dim = self.dim // 2 |
|
|
emb = math.log(10000) / (half_dim - 1) |
|
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
|
|
emb = x[:, None] * emb[None, :] |
|
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
|
return emb |
|
|
|
|
|
|
|
|
|
|
|
def prob_mask_like(shape, prob, device): |
|
|
if prob == 1: |
|
|
return torch.ones(shape, device=device, dtype=torch.bool) |
|
|
elif prob == 0: |
|
|
return torch.zeros(shape, device=device, dtype=torch.bool) |
|
|
else: |
|
|
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob |
|
|
|
|
|
|
|
|
def extract(a, t, x_shape): |
|
|
b, *_ = t.shape |
|
|
out = a.gather(-1, t) |
|
|
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
|
|
|
|
|
def make_beta_schedule( |
|
|
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 |
|
|
): |
|
|
if schedule == "linear": |
|
|
betas = ( |
|
|
torch.linspace( |
|
|
linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64 |
|
|
) |
|
|
** 2 |
|
|
) |
|
|
|
|
|
elif schedule == "cosine": |
|
|
timesteps = ( |
|
|
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s |
|
|
) |
|
|
alphas = timesteps / (1 + cosine_s) * np.pi / 2 |
|
|
alphas = torch.cos(alphas).pow(2) |
|
|
alphas = alphas / alphas[0] |
|
|
betas = 1 - alphas[1:] / alphas[:-1] |
|
|
betas = np.clip(betas, a_min=0, a_max=0.999) |
|
|
|
|
|
elif schedule == "sqrt_linear": |
|
|
betas = torch.linspace( |
|
|
linear_start, linear_end, n_timestep, dtype=torch.float64 |
|
|
) |
|
|
elif schedule == "sqrt": |
|
|
betas = ( |
|
|
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) |
|
|
** 0.5 |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"schedule '{schedule}' unknown.") |
|
|
return betas.numpy() |
|
|
|