Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,251 Bytes
faadabf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
import torch.nn.functional as F
from typing import Dict, List
from einops import pack, repeat
from .estimator_dit import DiT
from .upsample_encoder import UpsampleConformerEncoder
class DualEmbedding(torch.nn.Module):
def __init__(
self,
channels:int=512,
):
super().__init__()
self.codebook_size = 128
self.codebook_dim = 128
self.codebook = torch.nn.ModuleList([
torch.nn.Embedding(self.codebook_size, self.codebook_dim),
torch.nn.Embedding(self.codebook_size, self.codebook_dim),
])
self.out_proj = torch.nn.Linear(self.codebook_dim * 2, channels)
def forward(self, tokens):
"""
Args:
tokens: shape (b, t)
Returns:
token_embs: shape (b, t, c)
"""
token_embs = torch.cat([
self.codebook[0](tokens % self.codebook_size),
self.codebook[1](tokens // self.codebook_size)
], dim=-1)
token_embs = self.out_proj(token_embs)
return token_embs
class CausalFmWithSpkCtx(torch.nn.Module):
def __init__(
self,
# Basic in-out
spk_channels: int,
spk_enc_channels: int, # out channels of spk & encoder projection
# Module
token_emb: DualEmbedding,
encoder: UpsampleConformerEncoder,
estimator: DiT,
# Flow cfg
infer_cfg_rate: float = 0.7,
):
super().__init__()
# Variants
self.up_stride = encoder.up_stride
self.infer_cfg_rate = infer_cfg_rate
# Module
self.spk_proj = torch.nn.Linear(spk_channels, spk_enc_channels)
self.token_emb = token_emb
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(encoder.output_size, spk_enc_channels)
self.estimator = estimator
# Initial noise, maximum of 600s
self.register_buffer(
"x0",
torch.randn([1, self.estimator.out_channels, 50 * 600]),
persistent=False,
)
def _euler(
self,
x0: torch.Tensor,
c: torch.Tensor,
n_timesteps: int = 10,
):
# time steps
t_span = torch.linspace(0, 1, n_timesteps + 1).to(x0)
# cosine time schduling
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
# euler solver
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)
xt = x0
for step in range(1, len(t_span)):
# pack input
x_in = torch.cat([xt, xt], dim=0)
c_in = torch.cat([c, torch.zeros_like(c)], dim=0)
t_in = torch.cat([t, t], dim=0)
# model call
with torch.no_grad():
vt = self.estimator.forward(x_in, c_in, t_in)
# cfg
vt_cond, vt_cfg = vt.chunk(2, dim=0)
vt = (1.0 + self.infer_cfg_rate) * vt_cond - self.infer_cfg_rate * vt_cfg
xt = xt + dt * vt
t = t + dt
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return xt
def inference(
self,
prompt_token: torch.Tensor,
prompt_xvec: torch.Tensor,
prompt_feat: torch.Tensor,
token: torch.Tensor,
):
# NOTE align prompt_token, prompt_feat in advance
# Spk condition
embedding = F.normalize(prompt_xvec, dim=1)
spks = self.spk_proj(embedding)
# Token condition
token = torch.concat([prompt_token, token], dim=1)
xs = self.token_emb(token)
xs_lens = torch.tensor([xs.shape[1]]).to(token)
xs = self.encoder(xs, xs_lens)
mu = self.encoder_proj(xs)
# Mel context
ctx = torch.zeros_like(mu)
ctx[:, : prompt_feat.shape[1]] = prompt_feat
# Compose condition
cond = mu.transpose(1, 2)
ctx = ctx.transpose(1, 2)
spks = repeat(spks, "b c -> b c t", t=cond.shape[-1])
cond = pack([cond, spks, ctx], "b * t")[0]
# FM inference
x0 = self.x0[..., : mu.shape[1]]
x1 = self._euler(x0, cond, n_timesteps=10)
feat = x1.transpose(1, 2)[:, prompt_feat.shape[1] :]
return feat
|