File size: 1,562 Bytes
3d43b81 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | import torch
from torch import Tensor
from transformers import PreTrainedModel
from audio_encoders_pytorch import MelE1d, TanhBottleneck
from audio_diffusion_pytorch import DiffusionAE, UNetV0, LTPlugin, VDiffusion, VSampler
from .config import DMAE1dConfig
class DMAE1d(PreTrainedModel):
config_class = DMAE1dConfig
def __init__(self, config: DMAE1dConfig):
super().__init__(config)
UNet = LTPlugin(
UNetV0,
num_filters=128,
window_length=64,
stride=32,
)
self.model = DiffusionAE(
net_t=UNet,
dim=1,
in_channels=2,
channels=[256, 512, 512, 512, 512],
factors=[1, 2, 2, 2, 2],
items=[1, 2, 2, 2, 4],
inject_depth=4,
encoder=MelE1d(
in_channels=2,
channels=512,
multipliers=[1, 1],
factors=[2],
num_blocks=[12],
out_channels=32,
mel_channels=80,
mel_sample_rate=48000,
mel_normalize_log=True,
bottleneck=TanhBottleneck()
),
diffusion_t=VDiffusion,
sampler_t=VSampler
)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def encode(self, *args, **kwargs):
return self.model.encode(*args, **kwargs)
@torch.no_grad()
def decode(self, *args, **kwargs):
return self.model.decode(*args, **kwargs)
|