wavjepa-base / modeling_wavjepa.py
GokseninYuksel's picture
Upload model
fefd7ae verified
from transformers import PreTrainedModel
from .model import WavJEPA
from .configuration_wavjepa import WavJEPAConfig
from .audio_extractor import ConvFeatureExtractor
import torch
from typing import Union
class WavJEPAModel(PreTrainedModel):
config_class = WavJEPAConfig
def __init__(self, config):
super().__init__(config)
self.model = WavJEPA(
feature_extractor = ConvFeatureExtractor(
conv_layers_spec = eval(config.extractor_config['conv_layers_spec']),
in_channels = config.extractor_config['in_channels'],
dropout = config.extractor_config['dropout'],
mode = config.extractor_config['mode'],
conv_bias = config.extractor_config['conv_bias'],
depthwise = config.extractor_config['depthwise'],
),
transformer_encoder_layers_cfg = config.encoder_layers_cfg,
transformer_encoder_cfg = config.encoder_cfg,
transformer_decoder_layers_cfg = config.decoder_layers_cfg,
transformer_decoder_cfg = config.decoder_cfg,
size = config.model_size,
)
def forward(self, tensor) -> Union[torch.Tensor, torch.Tensor]:
return self.model.get_audio_representation(tensor)