| |
| from transformers import PretrainedConfig |
|
|
| class ModelConfig(PretrainedConfig): |
| model_type = "SongFormer" |
| |
| def __init__( |
| self, |
| input_dim=2048, |
| input_dim_raw=4096, |
| transformer_encoder_input_dim=1024, |
| transformer_input_dim=512, |
| num_transformer_layers=4, |
| transformer_nhead=8, |
| transformer_dropout=0.1, |
| num_classes=128, |
| num_dataset_classes=64, |
| down_sample_conv_kernel_size=3, |
| down_sample_conv_stride=3, |
| down_sample_conv_dropout=0.1, |
| down_sample_conv_padding=0, |
| boundary_tv_loss_beta=0.6, |
| boundary_tv_loss_lambda=0.4, |
| boundary_tv_loss_boundary_threshold=0.01, |
| boundary_tv_loss_reduction_weight=0.1, |
| boundary_tvloss_weight=0.05, |
| label_focal_loss_alpha=0.25, |
| label_focal_loss_gamma=2.0, |
| label_focal_loss_weight=0.2, |
| loss_weight_section=0.2, |
| loss_weight_function=0.8, |
| learn_label=True, |
| learn_segment=True, |
| local_maxima_filter_size=3, |
| frame_rates=8.333, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.input_dim = input_dim |
| self.input_dim_raw = input_dim_raw |
| self.transformer_encoder_input_dim = transformer_encoder_input_dim |
| self.transformer_input_dim = transformer_input_dim |
| self.num_transformer_layers = num_transformer_layers |
| self.transformer_nhead = transformer_nhead |
| self.transformer_dropout = transformer_dropout |
| self.num_classes = num_classes |
| self.num_dataset_classes = num_dataset_classes |
| self.down_sample_conv_kernel_size = down_sample_conv_kernel_size |
| self.down_sample_conv_stride = down_sample_conv_stride |
| self.down_sample_conv_dropout = down_sample_conv_dropout |
| self.down_sample_conv_padding = down_sample_conv_padding |
| self.boundary_tv_loss_beta = boundary_tv_loss_beta |
| self.boundary_tv_loss_lambda = boundary_tv_loss_lambda |
| self.boundary_tv_loss_boundary_threshold = boundary_tv_loss_boundary_threshold |
| self.boundary_tv_loss_reduction_weight = boundary_tv_loss_reduction_weight |
| self.boundary_tvloss_weight = boundary_tvloss_weight |
| self.label_focal_loss_alpha = label_focal_loss_alpha |
| self.label_focal_loss_gamma = label_focal_loss_gamma |
| self.label_focal_loss_weight = label_focal_loss_weight |
| self.loss_weight_section = loss_weight_section |
| self.loss_weight_function = loss_weight_function |
| self.learn_label = learn_label |
| self.learn_segment = learn_segment |
| self.local_maxima_filter_size = local_maxima_filter_size |
| self.frame_rates = frame_rates |