| from transformers import PretrainedConfig |
| from typing import List |
|
|
|
|
| class MoonshineConfig(PretrainedConfig): |
| model_type = "moonshine" |
|
|
| def __init__( |
| self, |
| dim: int = 288, |
| inner_dim: int = None, |
| enc_depth: int = 8, |
| dec_depth: int = 8, |
| n_head: int = 8, |
| dec_voc_size: int = 32768, |
| enc_ff_swiglu: bool = False, |
| dec_ff_swiglu: bool = True, |
| **kwargs |
| ): |
| if inner_dim is None: |
| inner_dim = dim |
| if inner_dim % n_head != 0: |
| raise ValueError("`inner dim` must be divisible by `n_head`") |
| self.dim = dim |
| self.inner_dim = inner_dim |
| self.enc_depth = enc_depth |
| self.dec_depth = dec_depth |
| self.n_head = n_head |
| self.dec_voc_size = dec_voc_size |
| self.enc_ff_swiglu = enc_ff_swiglu |
| self.dec_ff_swiglu = dec_ff_swiglu |
| super().__init__(**kwargs) |
|
|