| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ Dasheng model configuration""" |
|
|
| from transformers import PretrainedConfig |
|
|
| DASHENG_PRETRAINED_CONFIG_ARCHIVE_MAP = { |
| "mispeech/dasheng-base": "https://huggingface.co/mispeech/dasheng-base/resolve/main/config.json", |
| "mispeech/dasheng-0.6B": "https://huggingface.co/mispeech/dasheng-0.6B/resolve/main/config.json", |
| "mispeech/dasheng-1.2B": "https://huggingface.co/mispeech/dasheng-1.2B/resolve/main/config.json", |
| } |
|
|
|
|
| class DashengConfig(PretrainedConfig): |
| model_type = "dasheng" |
|
|
| def __init__( |
| self, |
| name: str = "dasheng-base", |
| loss: str = "BCELoss", |
| **kwargs, |
| ): |
| r""" |
| Configuration class for the Dasheng model. |
| |
| Args: |
| name (str, *optional*): |
| Can be "dasheng-base", "dasheng-0.6B", or "dasheng-1.2B". Default to "dasheng-base". |
| loss (str, *optional*): |
| Name of the loss function to use. Can be any loss in `nn.modules.loss`. Default to "BCELoss". |
| kwargs (dict, *optional*): |
| Additional keyword arguments, see `dasheng_model.modeling_dasheng.DashengFeatureExtractor` and `dasheng_model.modeling_dasheng.AudioTransformerMAE_Encoder` for more details. |
| """ |
|
|
| super().__init__(**kwargs) |
|
|
| encoder_kwargs = dict(target_length=1008, patch_size=[64, 4], patch_stride=[64, 4]) |
|
|
| if name == "dasheng-1.2B": |
| encoder_kwargs["embed_dim"] = 1536 |
| encoder_kwargs["depth"] = 40 |
| encoder_kwargs["num_heads"] = 24 |
| elif name == "dasheng-0.6B": |
| encoder_kwargs["embed_dim"] = 1280 |
| encoder_kwargs["depth"] = 32 |
| encoder_kwargs["num_heads"] = 16 |
| elif name == "dasheng-base": |
| encoder_kwargs["embed_dim"] = 768 |
| encoder_kwargs["depth"] = 12 |
| encoder_kwargs["num_heads"] = 12 |
| else: |
| raise ValueError(f"Unrecognized model name: {name}") |
| self.name = name |
|
|
| encoder_kwargs.update((k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs)) |
| self.encoder_kwargs = {**encoder_kwargs, **kwargs} |
|
|
| self.loss = loss |
|
|