| from transformers import PretrainedConfig | |
| class MusicFMConfig(PretrainedConfig): | |
| model_type = "musicfm" | |
| def __init__( | |
| self, | |
| num_codebooks: int = 1, | |
| codebook_dim: int = 16, | |
| codebook_size: int = 4096, | |
| features: list[str] = ["melspec_2048"], | |
| hop_length: int = 240, | |
| n_mels: int = 128, | |
| conv_dim: int = 512, | |
| encoder_dim: int = 1024, | |
| encoder_depth: int = 12, | |
| mask_hop: float = 0.4, | |
| mask_prob: float = 0.6, | |
| is_flash: bool = False, | |
| stat: dict[str, float] = {}, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.num_codebooks = num_codebooks | |
| self.codebook_dim = codebook_dim | |
| self.codebook_size = codebook_size | |
| self.features = features | |
| self.hop_length = hop_length | |
| self.n_mels = n_mels | |
| self.conv_dim = conv_dim | |
| self.encoder_dim = encoder_dim | |
| self.encoder_depth = encoder_depth | |
| self.mask_hop = mask_hop | |
| self.mask_prob = mask_prob | |
| self.is_flash = is_flash | |
| self.stat = stat | |
| class MusicFMInferenceConfig(MusicFMConfig): | |
| model_type = "musicfm_inference" | |
| def __init__( | |
| self, | |
| num_codebooks: int = 1, | |
| codebook_dim: int = 16, | |
| codebook_size: int = 4096, | |
| features: list[str] = ["melspec_2048"], | |
| hop_length: int = 240, | |
| n_mels: int = 128, | |
| conv_dim: int = 512, | |
| encoder_dim: int = 1024, | |
| encoder_depth: int = 12, | |
| mask_hop: float = 0.4, | |
| mask_prob: float = 0.6, | |
| is_flash: bool = False, | |
| layer_index: int = 9, | |
| stat: dict[str, float] = {}, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__( | |
| num_codebooks=num_codebooks, | |
| codebook_dim=codebook_dim, | |
| codebook_size=codebook_size, | |
| features=features, | |
| hop_length=hop_length, | |
| n_mels=n_mels, | |
| conv_dim=conv_dim, | |
| encoder_dim=encoder_dim, | |
| encoder_depth=encoder_depth, | |
| mask_hop=mask_hop, | |
| mask_prob=mask_prob, | |
| is_flash=is_flash, | |
| stat=stat, | |
| **kwargs, | |
| ) | |
| self.layer_index = layer_index | |