Spaces:
Configuration error
Configuration error
| import math | |
| import torch | |
| from einops import rearrange | |
| from visualizr.model.base import BaseModule | |
| class Mish(BaseModule): | |
| def forward(self, x): | |
| return x * torch.tanh(torch.nn.functional.softplus(x)) | |
| class Upsample(BaseModule): | |
| def __init__(self, dim): | |
| super(Upsample, self).__init__() | |
| self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class Downsample(BaseModule): | |
| def __init__(self, dim): | |
| super(Downsample, self).__init__() | |
| self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class Rezero(BaseModule): | |
| def __init__(self, fn): | |
| super(Rezero, self).__init__() | |
| self.fn = fn | |
| self.g = torch.nn.Parameter(torch.zeros(1)) | |
| def forward(self, x): | |
| return self.fn(x) * self.g | |
| class Block(BaseModule): | |
| def __init__(self, dim, dim_out, groups=8): | |
| super(Block, self).__init__() | |
| self.block = torch.nn.Sequential( | |
| torch.nn.Conv2d(dim, dim_out, 3, padding=1), | |
| torch.nn.GroupNorm(groups, dim_out), | |
| Mish(), | |
| ) | |
| def forward(self, x, mask): | |
| output = self.block(x * mask) | |
| return output * mask | |
| class ResnetBlock(BaseModule): | |
| def __init__(self, dim, dim_out, time_emb_dim, groups=8): | |
| super(ResnetBlock, self).__init__() | |
| self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) | |
| self.block1 = Block(dim, dim_out, groups=groups) | |
| self.block2 = Block(dim_out, dim_out, groups=groups) | |
| if dim != dim_out: | |
| self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) | |
| else: | |
| self.res_conv = torch.nn.Identity() | |
| def forward(self, x, mask, time_emb): | |
| h = self.block1(x, mask) | |
| h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) | |
| h = self.block2(h, mask) | |
| output = h + self.res_conv(x * mask) | |
| return output | |
| class LinearAttention(BaseModule): | |
| def __init__(self, dim, heads=4, dim_head=32): | |
| super(LinearAttention, self).__init__() | |
| self.heads = heads | |
| hidden_dim = dim_head * heads | |
| self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
| self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| qkv = self.to_qkv(x) | |
| q, k, v = rearrange( | |
| qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 | |
| ) | |
| k = k.softmax(dim=-1) | |
| context = torch.einsum("bhdn,bhen->bhde", k, v) | |
| out = torch.einsum("bhde,bhdn->bhen", context, q) | |
| out = rearrange( | |
| out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w | |
| ) | |
| return self.to_out(out) | |
| class Residual(BaseModule): | |
| def __init__(self, fn): | |
| super(Residual, self).__init__() | |
| self.fn = fn | |
| def forward(self, x, *args, **kwargs): | |
| output = self.fn(x, *args, **kwargs) + x | |
| return output | |
| class SinusoidalPosEmb(BaseModule): | |
| def __init__(self, dim): | |
| super(SinusoidalPosEmb, self).__init__() | |
| self.dim = dim | |
| def forward(self, x, scale=1000): | |
| device = x.device | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) | |
| emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| return emb | |
| class GradLogPEstimator2d(BaseModule): | |
| def __init__( | |
| self, | |
| dim, | |
| dim_mults=(1, 2, 4), | |
| groups=8, | |
| n_spks=None, | |
| spk_emb_dim=64, | |
| n_feats=80, | |
| pe_scale=1000, | |
| ): | |
| super(GradLogPEstimator2d, self).__init__() | |
| self.dim = dim | |
| self.dim_mults = dim_mults | |
| self.groups = groups | |
| self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1 | |
| self.spk_emb_dim = spk_emb_dim | |
| self.pe_scale = pe_scale | |
| if n_spks > 1: | |
| self.spk_mlp = torch.nn.Sequential( | |
| torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), | |
| Mish(), | |
| torch.nn.Linear(spk_emb_dim * 4, n_feats), | |
| ) | |
| self.time_pos_emb = SinusoidalPosEmb(dim) | |
| self.mlp = torch.nn.Sequential( | |
| torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim) | |
| ) | |
| dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] | |
| in_out = list(zip(dims[:-1], dims[1:])) | |
| self.downs = torch.nn.ModuleList([]) | |
| self.ups = torch.nn.ModuleList([]) | |
| num_resolutions = len(in_out) | |
| for ind, (dim_in, dim_out) in enumerate(in_out): | |
| is_last = ind >= (num_resolutions - 1) | |
| self.downs.append( | |
| torch.nn.ModuleList( | |
| [ | |
| ResnetBlock(dim_in, dim_out, time_emb_dim=dim), | |
| ResnetBlock(dim_out, dim_out, time_emb_dim=dim), | |
| Residual(Rezero(LinearAttention(dim_out))), | |
| Downsample(dim_out) if not is_last else torch.nn.Identity(), | |
| ] | |
| ) | |
| ) | |
| mid_dim = dims[-1] | |
| self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) | |
| self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) | |
| self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) | |
| for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): | |
| self.ups.append( | |
| torch.nn.ModuleList( | |
| [ | |
| ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), | |
| ResnetBlock(dim_in, dim_in, time_emb_dim=dim), | |
| Residual(Rezero(LinearAttention(dim_in))), | |
| Upsample(dim_in), | |
| ] | |
| ) | |
| ) | |
| self.final_block = Block(dim, dim) | |
| self.final_conv = torch.nn.Conv2d(dim, 1, 1) | |
| def forward(self, x, mask, mu, t, spk=None): | |
| global s | |
| if not isinstance(spk, type(None)): | |
| s = self.spk_mlp(spk) | |
| t = self.time_pos_emb(t, scale=self.pe_scale) | |
| t = self.mlp(t) | |
| if self.n_spks < 2: | |
| x = torch.stack([mu, x], 1) | |
| else: | |
| s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1]) | |
| x = torch.stack([mu, x, s], 1) | |
| mask = mask.unsqueeze(1) | |
| hiddens = [] | |
| masks = [mask] | |
| for resnet1, resnet2, attn, downsample in self.downs: | |
| mask_down = masks[-1] | |
| x = resnet1(x, mask_down, t) | |
| x = resnet2(x, mask_down, t) | |
| x = attn(x) | |
| hiddens.append(x) | |
| x = downsample(x * mask_down) | |
| masks.append(mask_down[:, :, :, ::2]) | |
| masks = masks[:-1] | |
| mask_mid = masks[-1] | |
| x = self.mid_block1(x, mask_mid, t) | |
| x = self.mid_attn(x) | |
| x = self.mid_block2(x, mask_mid, t) | |
| for resnet1, resnet2, attn, upsample in self.ups: | |
| mask_up = masks.pop() | |
| x = torch.cat((x, hiddens.pop()), dim=1) | |
| x = resnet1(x, mask_up, t) | |
| x = resnet2(x, mask_up, t) | |
| x = attn(x) | |
| x = upsample(x * mask_up) | |
| x = self.final_block(x, mask) | |
| output = self.final_conv(x * mask) | |
| return (output * mask).squeeze(1) | |
| def get_noise(t, beta_init, beta_term, cumulative=False): | |
| if cumulative: | |
| noise = beta_init * t + 0.5 * (beta_term - beta_init) * (t**2) | |
| else: | |
| noise = beta_init + (beta_term - beta_init) * t | |
| return noise | |
| class Diffusion(BaseModule): | |
| def __init__( | |
| self, | |
| n_feats, | |
| dim, | |
| n_spks=1, | |
| spk_emb_dim=64, | |
| beta_min=0.05, | |
| beta_max=20, | |
| pe_scale=1000, | |
| ): | |
| super(Diffusion, self).__init__() | |
| self.n_feats = n_feats | |
| self.dim = dim | |
| self.n_spks = n_spks | |
| self.spk_emb_dim = spk_emb_dim | |
| self.beta_min = beta_min | |
| self.beta_max = beta_max | |
| self.pe_scale = pe_scale | |
| self.estimator = GradLogPEstimator2d( | |
| dim, n_spks=n_spks, spk_emb_dim=spk_emb_dim, pe_scale=pe_scale | |
| ) | |
| def forward_diffusion(self, x0, mask, mu, t): | |
| time = t.unsqueeze(-1).unsqueeze(-1) | |
| cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) | |
| mean = x0 * torch.exp(-0.5 * cum_noise) + mu * ( | |
| 1.0 - torch.exp(-0.5 * cum_noise) | |
| ) | |
| variance = 1.0 - torch.exp(-cum_noise) | |
| z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False) | |
| xt = mean + z * torch.sqrt(variance) | |
| return xt * mask, z * mask | |
| def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None): | |
| h = 1.0 / n_timesteps | |
| xt = z * mask | |
| for i in range(n_timesteps): | |
| t = (1.0 - (i + 0.5) * h) * torch.ones( | |
| z.shape[0], dtype=z.dtype, device=z.device | |
| ) | |
| time = t.unsqueeze(-1).unsqueeze(-1) | |
| noise_t = get_noise(time, self.beta_min, self.beta_max, cumulative=False) | |
| if stoc: # adds stochastic term | |
| dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk) | |
| dxt_det = dxt_det * noise_t * h | |
| dxt_stoc = torch.randn( | |
| z.shape, dtype=z.dtype, device=z.device, requires_grad=False | |
| ) | |
| dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h) | |
| dxt = dxt_det + dxt_stoc | |
| else: | |
| dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk)) | |
| dxt = dxt * noise_t * h | |
| xt = (xt - dxt) * mask | |
| return xt | |
| def forward(self, z, mask, mu, n_timesteps, stoc=False, spk=None): | |
| return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk) | |
| def loss_t(self, x0, mask, mu, t, spk=None): | |
| xt, z = self.forward_diffusion(x0, mask, mu, t) | |
| time = t.unsqueeze(-1).unsqueeze(-1) | |
| cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) | |
| noise_estimation = self.estimator(xt, mask, mu, t, spk) | |
| noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise)) | |
| loss = torch.sum((noise_estimation + z) ** 2) / (torch.sum(mask) * self.n_feats) | |
| return loss, xt | |
| def compute_loss(self, x0, mask, mu, spk=None, offset=1e-5): | |
| t = torch.rand( | |
| x0.shape[0], dtype=x0.dtype, device=x0.device, requires_grad=False | |
| ) | |
| t = torch.clamp(t, offset, 1.0 - offset) | |
| return self.loss_t(x0, mask, mu, t, spk) | |