File size: 1,344 Bytes
fefd7ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from transformers import PreTrainedModel

from .model import WavJEPA
from .configuration_wavjepa import WavJEPAConfig
from .audio_extractor import ConvFeatureExtractor
import torch 
from typing import Union 

class WavJEPAModel(PreTrainedModel):
    config_class = WavJEPAConfig

    def __init__(self, config):
        super().__init__(config)

        self.model = WavJEPA(
                feature_extractor = ConvFeatureExtractor(
                    conv_layers_spec = eval(config.extractor_config['conv_layers_spec']),
                    in_channels = config.extractor_config['in_channels'],
                    dropout = config.extractor_config['dropout'],
                    mode = config.extractor_config['mode'],
                    conv_bias = config.extractor_config['conv_bias'],
                    depthwise = config.extractor_config['depthwise'],
                    ),
                transformer_encoder_layers_cfg = config.encoder_layers_cfg,
                transformer_encoder_cfg = config.encoder_cfg,
                transformer_decoder_layers_cfg = config.decoder_layers_cfg,
                transformer_decoder_cfg = config.decoder_cfg,
                size = config.model_size,
        )

    def forward(self, tensor) -> Union[torch.Tensor, torch.Tensor]:
        return self.model.get_audio_representation(tensor)