jerryhai
Track binary files with Git LFS
90f7c1e
import math
import torch
from .base import BaseModule
from .modules import Mish, Upsample, Downsample, Rezero, Block, ResnetBlock
from .modules import LinearAttention, Residual, Timesteps, TimbreBlock, PitchPosEmb
from einops import rearrange
class UNetPitcher(BaseModule):
def __init__(self,
dim_base,
dim_cond,
use_ref_t,
use_embed,
dim_embed=256,
dim_mults=(1, 2, 4),
pitch_type='bins'):
super(UNetPitcher, self).__init__()
self.use_ref_t = use_ref_t
self.use_embed = use_embed
self.pitch_type = pitch_type
dim_in = 2
# time embedding
self.time_pos_emb = Timesteps(num_channels=dim_base,
flip_sin_to_cos=True,
downscale_freq_shift=0)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim_base, dim_base * 4),
Mish(), torch.nn.Linear(dim_base * 4, dim_base))
# speaker embedding
timbre_total = 0
if use_ref_t:
self.ref_block = TimbreBlock(out_dim=dim_cond)
timbre_total += dim_cond
if use_embed:
timbre_total += dim_embed
if timbre_total != 0:
self.timbre_block = torch.nn.Sequential(
torch.nn.Linear(timbre_total, 4 * dim_cond),
Mish(),
torch.nn.Linear(4 * dim_cond, dim_cond))
if use_embed or use_ref_t:
dim_in += dim_cond
self.pitch_pos_emb = PitchPosEmb(dim_cond)
self.pitch_mlp = torch.nn.Sequential(
torch.nn.Conv1d(dim_cond, dim_cond * 4, 1, stride=1),
Mish(),
torch.nn.Conv1d(dim_cond * 4, dim_cond, 1, stride=1), )
dim_in += dim_cond
# pitch embedding
# if self.pitch_type == 'bins':
# print('using mel bins for f0')
# elif self.pitch_type == 'log':
# print('using log bins f0')
dims = [dim_in, *map(lambda m: dim_base * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# blocks
self.downs = torch.nn.ModuleList([])
self.ups = torch.nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(torch.nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_emb_dim=dim_base),
ResnetBlock(dim_out, dim_out, time_emb_dim=dim_base),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else torch.nn.Identity()]))
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim_base)
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim_base)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(torch.nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim_base),
ResnetBlock(dim_in, dim_in, time_emb_dim=dim_base),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in)]))
self.final_block = Block(dim_base, dim_base)
self.final_conv = torch.nn.Conv2d(dim_base, 1, 1)
def forward(self, x, mean, f0, t, ref=None, embed=None):
if not torch.is_tensor(t):
t = torch.tensor([t], dtype=torch.long, device=x.device)
if len(t.shape) == 0:
t = t * torch.ones(x.shape[0], dtype=t.dtype, device=x.device)
t = self.time_pos_emb(t)
t = self.mlp(t)
x = torch.stack([x, mean], 1)
f0 = self.pitch_pos_emb(f0)
f0 = self.pitch_mlp(f0)
f0 = f0.unsqueeze(2)
f0 = torch.cat(x.shape[2] * [f0], 2)
timbre = None
if self.use_ref_t:
ref = torch.stack([ref], 1)
timbre = self.ref_block(ref)
if self.use_embed:
if timbre is not None:
timbre = torch.cat([timbre, embed], 1)
else:
timbre = embed
if timbre is None:
# raise Exception("at least use one timbre condition")
condition = f0
else:
timbre = self.timbre_block(timbre).unsqueeze(-1).unsqueeze(-1)
timbre = torch.cat(x.shape[2] * [timbre], 2)
timbre = torch.cat(x.shape[3] * [timbre], 3)
condition = torch.cat([f0, timbre], 1)
x = torch.cat([x, condition], 1)
hiddens = []
for resnet1, resnet2, attn, downsample in self.downs:
x = resnet1(x, t)
x = resnet2(x, t)
x = attn(x)
hiddens.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for resnet1, resnet2, attn, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = resnet1(x, t)
x = resnet2(x, t)
x = attn(x)
x = upsample(x)
x = self.final_block(x)
output = self.final_conv(x)
return output.squeeze(1)