File size: 1,562 Bytes
3d43b81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch 
from torch import Tensor
from transformers import PreTrainedModel
from audio_encoders_pytorch import MelE1d, TanhBottleneck
from audio_diffusion_pytorch import DiffusionAE, UNetV0, LTPlugin, VDiffusion, VSampler
from .config import DMAE1dConfig 


class DMAE1d(PreTrainedModel):

    config_class = DMAE1dConfig

    def __init__(self, config: DMAE1dConfig):
        super().__init__(config)

        UNet = LTPlugin(
            UNetV0,
            num_filters=128,
            window_length=64,
            stride=32,
        )

        self.model = DiffusionAE(
            net_t=UNet, 
            dim=1, 
            in_channels=2,
            channels=[256, 512, 512, 512, 512],
            factors=[1, 2, 2, 2, 2],
            items=[1, 2, 2, 2, 4],
            inject_depth=4,
            encoder=MelE1d(
                in_channels=2,
                channels=512,
                multipliers=[1, 1],
                factors=[2],
                num_blocks=[12],
                out_channels=32,
                mel_channels=80,
                mel_sample_rate=48000,
                mel_normalize_log=True,
                bottleneck=TanhBottleneck()
            ),
            diffusion_t=VDiffusion, 
            sampler_t=VSampler
        )
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def encode(self, *args, **kwargs):
        return self.model.encode(*args, **kwargs)
    
    @torch.no_grad()
    def decode(self, *args, **kwargs):
        return self.model.decode(*args, **kwargs)