Spaces:
Sleeping
Sleeping
File size: 2,070 Bytes
957e2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
import torch.nn.functional as F
from src.simulation.component import Component
################################################################################
# Pre-emphasis filter
################################################################################
class PreEmphasis(Component):
"""
Apply pre-emphasis filter via waveform convolution. Adapted from
https://github.com/clovaai/voxceleb_trainer/blob/master/utils.py
"""
def __init__(self, coef: float = 0.97, method: str = 'shift'):
"""
Initialize filter
:param coef: pre-emphasis coefficient
:param method: implementation; must be one of `conv` or `shift`
"""
super().__init__()
self.coef = coef
if method not in ['conv', 'shift', None]:
raise ValueError(f'Invalid method {method}')
self.method = method
# flip filter (cross-correlation --> convolution)
self.register_buffer(
'flipped_filter',
torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
)
def forward(self, x: torch.Tensor):
"""
Apply pre-emphasis filter via waveform convolution
"""
assert x.ndim >= 2 # require batch dimension
n_batch, signal_length = x.shape[0], x.shape[-1]
# require channel dimension for convolution
x = x.reshape(n_batch, -1, signal_length)
in_channels = x.shape[1]
if self.method == 'conv':
# reflect padding to match lengths of in/out
x = F.pad(x, (1, 0), 'reflect')
return F.conv1d(
x,
self.flipped_filter.repeat(in_channels, 1, 1),
groups=in_channels
)
elif self.method == 'shift':
return torch.cat(
[
x[..., 0:1],
x[..., 1:] - self.coef*x[..., :-1]
], dim=-1)
else:
return x
|