Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| from typing import Dict, List | |
| from einops import pack, repeat | |
| from .estimator_dit import DiT | |
| from .upsample_encoder import UpsampleConformerEncoder | |
| class DualEmbedding(torch.nn.Module): | |
| def __init__( | |
| self, | |
| channels:int=512, | |
| ): | |
| super().__init__() | |
| self.codebook_size = 128 | |
| self.codebook_dim = 128 | |
| self.codebook = torch.nn.ModuleList([ | |
| torch.nn.Embedding(self.codebook_size, self.codebook_dim), | |
| torch.nn.Embedding(self.codebook_size, self.codebook_dim), | |
| ]) | |
| self.out_proj = torch.nn.Linear(self.codebook_dim * 2, channels) | |
| def forward(self, tokens): | |
| """ | |
| Args: | |
| tokens: shape (b, t) | |
| Returns: | |
| token_embs: shape (b, t, c) | |
| """ | |
| token_embs = torch.cat([ | |
| self.codebook[0](tokens % self.codebook_size), | |
| self.codebook[1](tokens // self.codebook_size) | |
| ], dim=-1) | |
| token_embs = self.out_proj(token_embs) | |
| return token_embs | |
| class CausalFmWithSpkCtx(torch.nn.Module): | |
| def __init__( | |
| self, | |
| # Basic in-out | |
| spk_channels: int, | |
| spk_enc_channels: int, # out channels of spk & encoder projection | |
| # Module | |
| token_emb: DualEmbedding, | |
| encoder: UpsampleConformerEncoder, | |
| estimator: DiT, | |
| # Flow cfg | |
| infer_cfg_rate: float = 0.7, | |
| ): | |
| super().__init__() | |
| # Variants | |
| self.up_stride = encoder.up_stride | |
| self.infer_cfg_rate = infer_cfg_rate | |
| # Module | |
| self.spk_proj = torch.nn.Linear(spk_channels, spk_enc_channels) | |
| self.token_emb = token_emb | |
| self.encoder = encoder | |
| self.encoder_proj = torch.nn.Linear(encoder.output_size, spk_enc_channels) | |
| self.estimator = estimator | |
| # Initial noise, maximum of 600s | |
| self.register_buffer( | |
| "x0", | |
| torch.randn([1, self.estimator.out_channels, 50 * 600]), | |
| persistent=False, | |
| ) | |
| def _euler( | |
| self, | |
| x0: torch.Tensor, | |
| c: torch.Tensor, | |
| n_timesteps: int = 10, | |
| ): | |
| # time steps | |
| t_span = torch.linspace(0, 1, n_timesteps + 1).to(x0) | |
| # cosine time schduling | |
| t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) | |
| # euler solver | |
| t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] | |
| t = t.unsqueeze(dim=0) | |
| xt = x0 | |
| for step in range(1, len(t_span)): | |
| # pack input | |
| x_in = torch.cat([xt, xt], dim=0) | |
| c_in = torch.cat([c, torch.zeros_like(c)], dim=0) | |
| t_in = torch.cat([t, t], dim=0) | |
| # model call | |
| with torch.no_grad(): | |
| vt = self.estimator.forward(x_in, c_in, t_in) | |
| # cfg | |
| vt_cond, vt_cfg = vt.chunk(2, dim=0) | |
| vt = (1.0 + self.infer_cfg_rate) * vt_cond - self.infer_cfg_rate * vt_cfg | |
| xt = xt + dt * vt | |
| t = t + dt | |
| if step < len(t_span) - 1: | |
| dt = t_span[step + 1] - t | |
| return xt | |
| def inference( | |
| self, | |
| prompt_token: torch.Tensor, | |
| prompt_xvec: torch.Tensor, | |
| prompt_feat: torch.Tensor, | |
| token: torch.Tensor, | |
| ): | |
| # NOTE align prompt_token, prompt_feat in advance | |
| # Spk condition | |
| embedding = F.normalize(prompt_xvec, dim=1) | |
| spks = self.spk_proj(embedding) | |
| # Token condition | |
| token = torch.concat([prompt_token, token], dim=1) | |
| xs = self.token_emb(token) | |
| xs_lens = torch.tensor([xs.shape[1]]).to(token) | |
| xs = self.encoder(xs, xs_lens) | |
| mu = self.encoder_proj(xs) | |
| # Mel context | |
| ctx = torch.zeros_like(mu) | |
| ctx[:, : prompt_feat.shape[1]] = prompt_feat | |
| # Compose condition | |
| cond = mu.transpose(1, 2) | |
| ctx = ctx.transpose(1, 2) | |
| spks = repeat(spks, "b c -> b c t", t=cond.shape[-1]) | |
| cond = pack([cond, spks, ctx], "b * t")[0] | |
| # FM inference | |
| x0 = self.x0[..., : mu.shape[1]] | |
| x1 = self._euler(x0, cond, n_timesteps=10) | |
| feat = x1.transpose(1, 2)[:, prompt_feat.shape[1] :] | |
| return feat | |