| from transformers import PreTrainedModel | |
| from .model import WavJEPA | |
| from .configuration_wavjepa import WavJEPAConfig | |
| from .audio_extractor import ConvFeatureExtractor | |
| import torch | |
| from typing import Union | |
| class WavJEPAModel(PreTrainedModel): | |
| config_class = WavJEPAConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = WavJEPA( | |
| feature_extractor = ConvFeatureExtractor( | |
| conv_layers_spec = eval(config.extractor_config['conv_layers_spec']), | |
| in_channels = config.extractor_config['in_channels'], | |
| dropout = config.extractor_config['dropout'], | |
| mode = config.extractor_config['mode'], | |
| conv_bias = config.extractor_config['conv_bias'], | |
| depthwise = config.extractor_config['depthwise'], | |
| ), | |
| transformer_encoder_layers_cfg = config.encoder_layers_cfg, | |
| transformer_encoder_cfg = config.encoder_cfg, | |
| transformer_decoder_layers_cfg = config.decoder_layers_cfg, | |
| transformer_decoder_cfg = config.decoder_cfg, | |
| size = config.model_size, | |
| ) | |
| def forward(self, tensor) -> Union[torch.Tensor, torch.Tensor]: | |
| return self.model.get_audio_representation(tensor) | |