sfi_hubert / modeling.py
Wataru's picture
add non_integer stride
8ff868a verified
from transformers import HubertModel
from transformers.models.hubert.modeling_hubert import HubertFeatureEncoder
from .configuration import SfiHuBERTConfig
from .continuous_filters import FrequencyDomainRFFImplicitFilter
from .conv_any_stride import FreqRespSampConv1d
class SfiHuBERTFeatureEncoder(HubertFeatureEncoder):
def __init__(self, config: SfiHuBERTConfig):
super().__init__(config)
out_channels = self.conv_layers[0].conv.out_channels
self.conv_layers[0].conv = FreqRespSampConv1d(
in_channels=1,
out_channels=out_channels,
ContFilterType=FrequencyDomainRFFImplicitFilter,
filter_params=config.latent_filter_params,
n_samples=640,
)
def forward(self, *args, **kwargs):
# Custom feature extraction logic can be added here
return super().forward(*args, **kwargs)
class SfiHuBERTModel(HubertModel):
config_class = SfiHuBERTConfig
def __init__(self, config: SfiHuBERTConfig):
super().__init__(config)
self.config = config
self.feature_extractor = SfiHuBERTFeatureEncoder(config)
def forward(self, *args, **kwargs):
# Custom forward pass logic can be added here
return super().forward(*args, **kwargs)
def set_sample_rate(self, sample_rate):
sample_rate = str(int(sample_rate))
if sample_rate not in self.config.sfi_conv_parameters:
raise ValueError(
f"Sample rate {sample_rate} not in the list of allowed sample rates."
)
self.feature_extractor.conv_layers[0].conv.prepare(
sample_rate=int(sample_rate), **self.config.sfi_conv_parameters[sample_rate]
)