File size: 5,513 Bytes
90f7c1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import math
import torch

from .base import BaseModule
from .modules import Mish, Upsample, Downsample, Rezero, Block, ResnetBlock
from .modules import LinearAttention, Residual, Timesteps, TimbreBlock, PitchPosEmb

from einops import rearrange


class UNetPitcher(BaseModule):
    def __init__(self,

                 dim_base,

                 dim_cond,

                 use_ref_t,

                 use_embed,

                 dim_embed=256,

                 dim_mults=(1, 2, 4),

                 pitch_type='bins'):

        super(UNetPitcher, self).__init__()
        self.use_ref_t = use_ref_t
        self.use_embed = use_embed
        self.pitch_type = pitch_type

        dim_in = 2

        # time embedding
        self.time_pos_emb = Timesteps(num_channels=dim_base,
                                      flip_sin_to_cos=True,
                                      downscale_freq_shift=0)

        self.mlp = torch.nn.Sequential(torch.nn.Linear(dim_base, dim_base * 4),
                                       Mish(), torch.nn.Linear(dim_base * 4, dim_base))

        # speaker embedding
        timbre_total = 0
        if use_ref_t:
            self.ref_block = TimbreBlock(out_dim=dim_cond)
            timbre_total += dim_cond
        if use_embed:
            timbre_total += dim_embed

        if timbre_total != 0:
            self.timbre_block = torch.nn.Sequential(
                torch.nn.Linear(timbre_total, 4 * dim_cond),
                Mish(),
                torch.nn.Linear(4 * dim_cond, dim_cond))

        if use_embed or use_ref_t:
            dim_in += dim_cond

        self.pitch_pos_emb = PitchPosEmb(dim_cond)
        self.pitch_mlp = torch.nn.Sequential(
            torch.nn.Conv1d(dim_cond, dim_cond * 4, 1, stride=1),
            Mish(),
            torch.nn.Conv1d(dim_cond * 4, dim_cond, 1, stride=1), )
        dim_in += dim_cond

        # pitch embedding
        # if self.pitch_type == 'bins':
        #     print('using mel bins for f0')
        # elif self.pitch_type == 'log':
        #     print('using log bins f0')

        dims = [dim_in, *map(lambda m: dim_base * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        # blocks
        self.downs = torch.nn.ModuleList([])
        self.ups = torch.nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            self.downs.append(torch.nn.ModuleList([
                ResnetBlock(dim_in, dim_out, time_emb_dim=dim_base),
                ResnetBlock(dim_out, dim_out, time_emb_dim=dim_base),
                Residual(Rezero(LinearAttention(dim_out))),
                Downsample(dim_out) if not is_last else torch.nn.Identity()]))

        mid_dim = dims[-1]
        self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim_base)
        self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
        self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim_base)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            self.ups.append(torch.nn.ModuleList([
                ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim_base),
                ResnetBlock(dim_in, dim_in, time_emb_dim=dim_base),
                Residual(Rezero(LinearAttention(dim_in))),
                Upsample(dim_in)]))
        self.final_block = Block(dim_base, dim_base)
        self.final_conv = torch.nn.Conv2d(dim_base, 1, 1)

    def forward(self, x, mean, f0, t, ref=None, embed=None):
        if not torch.is_tensor(t):
            t = torch.tensor([t], dtype=torch.long, device=x.device)
        if len(t.shape) == 0:
            t = t * torch.ones(x.shape[0], dtype=t.dtype, device=x.device)

        t = self.time_pos_emb(t)
        t = self.mlp(t)

        x = torch.stack([x, mean], 1)

        f0 = self.pitch_pos_emb(f0)
        f0 = self.pitch_mlp(f0)
        f0 = f0.unsqueeze(2)
        f0 = torch.cat(x.shape[2] * [f0], 2)

        timbre = None
        if self.use_ref_t:
            ref = torch.stack([ref], 1)
            timbre = self.ref_block(ref)
        if self.use_embed:
            if timbre is not None:
                timbre = torch.cat([timbre, embed], 1)
            else:
                timbre = embed
        if timbre is None:
            # raise Exception("at least use one timbre condition")
            condition = f0
        else:
            timbre = self.timbre_block(timbre).unsqueeze(-1).unsqueeze(-1)
            timbre = torch.cat(x.shape[2] * [timbre], 2)
            timbre = torch.cat(x.shape[3] * [timbre], 3)
            condition = torch.cat([f0, timbre], 1)

        x = torch.cat([x, condition], 1)

        hiddens = []
        for resnet1, resnet2, attn, downsample in self.downs:
            x = resnet1(x, t)
            x = resnet2(x, t)
            x = attn(x)
            hiddens.append(x)
            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for resnet1, resnet2, attn, upsample in self.ups:
            x = torch.cat((x, hiddens.pop()), dim=1)
            x = resnet1(x, t)
            x = resnet2(x, t)
            x = attn(x)
            x = upsample(x)

        x = self.final_block(x)
        output = self.final_conv(x)

        return output.squeeze(1)