wavjepa-base / configuration_wavjepa.py
GokseninYuksel's picture
Upload model
18f0022 verified
from transformers import PretrainedConfig
from torch import nn
from .types import TransformerLayerCFG, TransformerEncoderCFG
class WavJEPAConfig(PretrainedConfig):
model_type = "wavjepa-base"
model_size = "base"
in_channels: int = 1
def __init__(
self,
extractor_layers_spec: str = "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)]",
extractor_dropout : float = 0.0,
extractor_mode : str = "default",
extractor_conv_bias : bool = False,
extractor_depthwise: bool = False,
encoder_d_model: int = 768,
encoder_nhead : int = 12,
encoder_batch_first = True,
encoder_norm_first = False,
encoder_bias = True,
encoder_mlp_ratio = 4.0,
encoder_dropout = 0.0,
encoder_num_layers: int = 12,
encoder_enable_nested_tensor = False,
encoder_mask_check = True,
decoder_d_model: int = 384,
decoder_nhead : int = 12,
decoder_batch_first = True,
decoder_norm_first = False,
decoder_bias = True,
decoder_mlp_ratio = 4.0,
decoder_dropout = 0.0,
decoder_num_layers: int = 12,
decoder_enable_nested_tensor = False,
decoder_mask_check = True,
**kwargs
):
self.encoder_cfg = TransformerEncoderCFG.create(
num_layers = encoder_num_layers,
enable_nested_tensor = encoder_enable_nested_tensor,
mask_check = encoder_mask_check,
)
self.decoder_cfg = TransformerEncoderCFG.create(
num_layers = decoder_num_layers,
enable_nested_tensor = decoder_enable_nested_tensor,
mask_check = decoder_mask_check,
)
self.encoder_layers_cfg = TransformerLayerCFG.create(
d_model = encoder_d_model,
nhead = encoder_nhead,
batch_first = encoder_batch_first,
norm_first = encoder_norm_first,
bias = encoder_bias,
mlp_ratio = encoder_mlp_ratio,
dropout = encoder_dropout,
layer_norm_eps = 1e-6
)
self.decoder_layers_cfg = TransformerLayerCFG.create(
d_model = decoder_d_model,
nhead = decoder_nhead,
batch_first = decoder_batch_first,
norm_first = decoder_norm_first,
bias = decoder_bias,
mlp_ratio = decoder_mlp_ratio,
dropout = decoder_dropout,
layer_norm_eps = 1e-6
)
self.extractor_config = dict(
conv_layers_spec = extractor_layers_spec,
in_channels = self.in_channels,
dropout = extractor_dropout,
mode = extractor_mode,
conv_bias = extractor_conv_bias,
depthwise = extractor_depthwise)
super().__init__(**kwargs)