|
|
from transformers import HubertModel |
|
|
from transformers.models.hubert.modeling_hubert import HubertFeatureEncoder |
|
|
|
|
|
from .configuration import SfiHuBERTConfig |
|
|
from .continuous_filters import FrequencyDomainRFFImplicitFilter |
|
|
from .conv_any_stride import FreqRespSampConv1d |
|
|
|
|
|
|
|
|
class SfiHuBERTFeatureEncoder(HubertFeatureEncoder): |
|
|
def __init__(self, config: SfiHuBERTConfig): |
|
|
super().__init__(config) |
|
|
out_channels = self.conv_layers[0].conv.out_channels |
|
|
self.conv_layers[0].conv = FreqRespSampConv1d( |
|
|
in_channels=1, |
|
|
out_channels=out_channels, |
|
|
ContFilterType=FrequencyDomainRFFImplicitFilter, |
|
|
filter_params=config.latent_filter_params, |
|
|
n_samples=640, |
|
|
) |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
class SfiHuBERTModel(HubertModel): |
|
|
config_class = SfiHuBERTConfig |
|
|
|
|
|
def __init__(self, config: SfiHuBERTConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.feature_extractor = SfiHuBERTFeatureEncoder(config) |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
|
|
def set_sample_rate(self, sample_rate): |
|
|
sample_rate = str(int(sample_rate)) |
|
|
if sample_rate not in self.config.sfi_conv_parameters: |
|
|
raise ValueError( |
|
|
f"Sample rate {sample_rate} not in the list of allowed sample rates." |
|
|
) |
|
|
self.feature_extractor.conv_layers[0].conv.prepare( |
|
|
sample_rate=int(sample_rate), **self.config.sfi_conv_parameters[sample_rate] |
|
|
) |
|
|
|