from transformers import HubertModel from transformers.models.hubert.modeling_hubert import HubertFeatureEncoder from .configuration import SfiHuBERTConfig from .continuous_filters import FrequencyDomainRFFImplicitFilter from .conv_any_stride import FreqRespSampConv1d class SfiHuBERTFeatureEncoder(HubertFeatureEncoder): def __init__(self, config: SfiHuBERTConfig): super().__init__(config) out_channels = self.conv_layers[0].conv.out_channels self.conv_layers[0].conv = FreqRespSampConv1d( in_channels=1, out_channels=out_channels, ContFilterType=FrequencyDomainRFFImplicitFilter, filter_params=config.latent_filter_params, n_samples=640, ) def forward(self, *args, **kwargs): # Custom feature extraction logic can be added here return super().forward(*args, **kwargs) class SfiHuBERTModel(HubertModel): config_class = SfiHuBERTConfig def __init__(self, config: SfiHuBERTConfig): super().__init__(config) self.config = config self.feature_extractor = SfiHuBERTFeatureEncoder(config) def forward(self, *args, **kwargs): # Custom forward pass logic can be added here return super().forward(*args, **kwargs) def set_sample_rate(self, sample_rate): sample_rate = str(int(sample_rate)) if sample_rate not in self.config.sfi_conv_parameters: raise ValueError( f"Sample rate {sample_rate} not in the list of allowed sample rates." ) self.feature_extractor.conv_layers[0].conv.prepare( sample_rate=int(sample_rate), **self.config.sfi_conv_parameters[sample_rate] )