yoyolicoris's picture
refactor: use modules from the dataset repo
53e45f1
import torch
from torch import nn
import torch.nn.functional as F
from functools import partial, reduce
from typing import Optional, List
from torchaudio.transforms import MelSpectrogram, MFCC
class LogMelSpectrogram(MelSpectrogram):
def forward(self, waveform):
return super().forward(waveform).add(1e-8).log()
class LogMFCC(MFCC):
def __init__(self, *args, **kwargs):
super().__init__(*args, log_mels=True, **kwargs)
class LightningSequential(nn.Sequential):
def __init__(self, modules: List[nn.Module]):
super().__init__(*modules)
def forward(self, *args):
return reduce(lambda x, f: f(*x) if isinstance(x, tuple) else f(x), self, args)
class ResidualWrapper(nn.Module):
def __init__(self, m: nn.Module):
super().__init__()
self.m = m
def forward(self, x):
return x + self.m(x)