Shen Feiyu
add 1s
faadabf
import torch
import torch.nn.functional as F
from typing import Dict, List
from einops import pack, repeat
from .estimator_dit import DiT
from .upsample_encoder import UpsampleConformerEncoder
class DualEmbedding(torch.nn.Module):
def __init__(
self,
channels:int=512,
):
super().__init__()
self.codebook_size = 128
self.codebook_dim = 128
self.codebook = torch.nn.ModuleList([
torch.nn.Embedding(self.codebook_size, self.codebook_dim),
torch.nn.Embedding(self.codebook_size, self.codebook_dim),
])
self.out_proj = torch.nn.Linear(self.codebook_dim * 2, channels)
def forward(self, tokens):
"""
Args:
tokens: shape (b, t)
Returns:
token_embs: shape (b, t, c)
"""
token_embs = torch.cat([
self.codebook[0](tokens % self.codebook_size),
self.codebook[1](tokens // self.codebook_size)
], dim=-1)
token_embs = self.out_proj(token_embs)
return token_embs
class CausalFmWithSpkCtx(torch.nn.Module):
def __init__(
self,
# Basic in-out
spk_channels: int,
spk_enc_channels: int, # out channels of spk & encoder projection
# Module
token_emb: DualEmbedding,
encoder: UpsampleConformerEncoder,
estimator: DiT,
# Flow cfg
infer_cfg_rate: float = 0.7,
):
super().__init__()
# Variants
self.up_stride = encoder.up_stride
self.infer_cfg_rate = infer_cfg_rate
# Module
self.spk_proj = torch.nn.Linear(spk_channels, spk_enc_channels)
self.token_emb = token_emb
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(encoder.output_size, spk_enc_channels)
self.estimator = estimator
# Initial noise, maximum of 600s
self.register_buffer(
"x0",
torch.randn([1, self.estimator.out_channels, 50 * 600]),
persistent=False,
)
def _euler(
self,
x0: torch.Tensor,
c: torch.Tensor,
n_timesteps: int = 10,
):
# time steps
t_span = torch.linspace(0, 1, n_timesteps + 1).to(x0)
# cosine time schduling
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
# euler solver
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)
xt = x0
for step in range(1, len(t_span)):
# pack input
x_in = torch.cat([xt, xt], dim=0)
c_in = torch.cat([c, torch.zeros_like(c)], dim=0)
t_in = torch.cat([t, t], dim=0)
# model call
with torch.no_grad():
vt = self.estimator.forward(x_in, c_in, t_in)
# cfg
vt_cond, vt_cfg = vt.chunk(2, dim=0)
vt = (1.0 + self.infer_cfg_rate) * vt_cond - self.infer_cfg_rate * vt_cfg
xt = xt + dt * vt
t = t + dt
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return xt
def inference(
self,
prompt_token: torch.Tensor,
prompt_xvec: torch.Tensor,
prompt_feat: torch.Tensor,
token: torch.Tensor,
):
# NOTE align prompt_token, prompt_feat in advance
# Spk condition
embedding = F.normalize(prompt_xvec, dim=1)
spks = self.spk_proj(embedding)
# Token condition
token = torch.concat([prompt_token, token], dim=1)
xs = self.token_emb(token)
xs_lens = torch.tensor([xs.shape[1]]).to(token)
xs = self.encoder(xs, xs_lens)
mu = self.encoder_proj(xs)
# Mel context
ctx = torch.zeros_like(mu)
ctx[:, : prompt_feat.shape[1]] = prompt_feat
# Compose condition
cond = mu.transpose(1, 2)
ctx = ctx.transpose(1, 2)
spks = repeat(spks, "b c -> b c t", t=cond.shape[-1])
cond = pack([cond, spks, ctx], "b * t")[0]
# FM inference
x0 = self.x0[..., : mu.shape[1]]
x1 = self._euler(x0, cond, n_timesteps=10)
feat = x1.transpose(1, 2)[:, prompt_feat.shape[1] :]
return feat