File size: 1,724 Bytes
0e7a38a
 
 
4ad33fa
16293fd
f634d4b
0e7a38a
 
 
 
 
 
f634d4b
 
 
 
0e7a38a
f634d4b
0e7a38a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ff868a
0e7a38a
 
 
 
 
8ff868a
0e7a38a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from transformers import HubertModel
from transformers.models.hubert.modeling_hubert import HubertFeatureEncoder

from .configuration import SfiHuBERTConfig
from .continuous_filters import FrequencyDomainRFFImplicitFilter
from .conv_any_stride import FreqRespSampConv1d


class SfiHuBERTFeatureEncoder(HubertFeatureEncoder):
    def __init__(self, config: SfiHuBERTConfig):
        super().__init__(config)
        out_channels = self.conv_layers[0].conv.out_channels
        self.conv_layers[0].conv = FreqRespSampConv1d(
            in_channels=1,
            out_channels=out_channels,
            ContFilterType=FrequencyDomainRFFImplicitFilter,
            filter_params=config.latent_filter_params,
            n_samples=640,
        )

    def forward(self, *args, **kwargs):
        # Custom feature extraction logic can be added here
        return super().forward(*args, **kwargs)


class SfiHuBERTModel(HubertModel):
    config_class = SfiHuBERTConfig

    def __init__(self, config: SfiHuBERTConfig):
        super().__init__(config)
        self.config = config
        self.feature_extractor = SfiHuBERTFeatureEncoder(config)

    def forward(self, *args, **kwargs):
        # Custom forward pass logic can be added here
        return super().forward(*args, **kwargs)

    def set_sample_rate(self, sample_rate):
        sample_rate = str(int(sample_rate))
        if sample_rate not in self.config.sfi_conv_parameters:
            raise ValueError(
                f"Sample rate {sample_rate} not in the list of allowed sample rates."
            )
        self.feature_extractor.conv_layers[0].conv.prepare(
            sample_rate=int(sample_rate), **self.config.sfi_conv_parameters[sample_rate]
        )