| from transformers import PreTrainedModel |
| from .model import GRAMT |
| from .configuration_gramt_mono import GRAMTMonoConfig |
|
|
|
|
| class GRAMTMonoModel(PreTrainedModel): |
| config_class = GRAMTMonoConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = GRAMT( |
| in_channels = config.in_channels, |
| decoder_mlp_ratio = config.decoder_mlp_ratio, |
| decoder_depth = config.decoder_depth, |
| decoder_num_heads = config.decoder_num_heads, |
| decoder_embedding_dim = config.decoder_embedding_dim, |
| decoder_window_sizes = config.decoder_window_sizes, |
| encoder_num_layers = config.encoder_num_layers, |
| encoder_num_heads = config.encoder_num_heads, |
| encoder_hidden_dim = config.encoder_hidden_dim, |
| encoder_mlp_ratio = config.encoder_mlp_ratio, |
| encoder_dropout = config.encoder_dropout, |
| encoder_attention_dropout = config.encoder_attention_dropout, |
| encoder_norm_layer_eps = config.encoder_norm_layer_eps, |
| patch_size = config.patch_size, |
| frequency_stride = config.frequency_stride, |
| time_stride = config.time_stride, |
| max_length = config.max_length, |
| num_mel_bins = config.num_mel_bins |
| ) |
|
|
| def forward(self, tensor, strategy = "raw"): |
| return self.model.get_audio_representation(tensor, strategy = strategy) |
|
|
|
|