import torch from torch import nn import torch.nn.functional as F from typing import List class LogSpectralCentroid(nn.Module): def forward(self, spec): # assume spec is of shape (..., freq, time) freqs = torch.linspace(0, 1, spec.size(-2), device=spec.device) spec_T = spec.transpose(-1, -2) normalised_spec = spec_T / spec_T.sum(-1, keepdim=True).clamp_min(1e-8) return torch.log(normalised_spec @ freqs + 1e-8).unsqueeze(-2) class LogSpectralFlatness(nn.Module): def forward(self, spec): # assume spec is of shape (..., freq, time) spec_pow = spec.clamp(1e-8).square() log_gmean = spec_pow.log().mean(-2, keepdim=True) log_amean = spec_pow.mean(-2, keepdim=True).log() return log_gmean - log_amean class LogSpectralBandwidth(nn.Module): def __init__(self): super().__init__() self.centroid = LogSpectralCentroid() def forward(self, spec): # assume spec is of shape (..., freq, time) freqs = torch.linspace(0, 1, spec.size(-2), device=spec.device) centroid = self.centroid(spec).exp() normalised_spec = spec / spec.sum(-2, keepdim=True).clamp_min(1e-8) return ( torch.log( (normalised_spec * (freqs[:, None] - centroid).square()).sum( -2, keepdim=True ) + 1e-8 ) * 0.5 ) class LogRMS(nn.Module): def forward(self, frame): return torch.log(frame.square().mean(-2, keepdim=True).sqrt() + 1e-8) class LogCrest(nn.Module): def __init__(self): super().__init__() self.rms = LogRMS() def forward(self, frame): log_rms = self.rms(frame) return frame.abs().amax(-2, keepdim=True).add(1e-8).log() - log_rms class LogSpread(nn.Module): def __init__(self): super().__init__() self.rms = LogRMS() def forward(self, frame): log_rms = self.rms(frame) return (frame.abs().add(1e-8).log() - log_rms).mean(-2, keepdim=True) class MapAndMerge(nn.Module): def __init__(self, funcs: List[nn.Module], dim=-1): super().__init__() self.funcs = nn.ModuleList(funcs) self.dim = dim def forward(self, frame): return torch.cat([f(frame) for f in self.funcs], dim=self.dim) class Frame(nn.Module): def __init__(self, frame_length, hop_length, center=False): super().__init__() self.frame_length = frame_length self.hop_length = hop_length self.center = center def forward(self, waveform): if self.center: waveform = F.pad(waveform, (self.frame_length // 2, self.frame_length // 2)) return waveform.unfold(-1, self.frame_length, self.hop_length).transpose(-1, -2) class StatisticReduction(nn.Module): def __init__(self, dim=-1): super().__init__() self.dim = dim def forward(self, x): mu = x.mean(self.dim, keepdim=True) diffs = x - mu std = diffs.square().mean(self.dim, keepdim=True).sqrt() zscores = diffs / std.clamp_min(1e-8) skews = zscores.pow(3).mean(self.dim, keepdim=True) kurts = zscores.pow(4).mean(self.dim, keepdim=True) - 3 return torch.cat([mu, std, skews, kurts], dim=self.dim)