MusicFMInference / configuration_musicfm.py
tky823's picture
Upload configuration_musicfm.py with huggingface_hub
5b5d0fc verified
from transformers import PretrainedConfig
class MusicFMConfig(PretrainedConfig):
model_type = "musicfm"
def __init__(
self,
num_codebooks: int = 1,
codebook_dim: int = 16,
codebook_size: int = 4096,
features: list[str] = ["melspec_2048"],
hop_length: int = 240,
n_mels: int = 128,
conv_dim: int = 512,
encoder_dim: int = 1024,
encoder_depth: int = 12,
mask_hop: float = 0.4,
mask_prob: float = 0.6,
is_flash: bool = False,
stat: dict[str, float] = {},
**kwargs,
) -> None:
super().__init__(**kwargs)
self.num_codebooks = num_codebooks
self.codebook_dim = codebook_dim
self.codebook_size = codebook_size
self.features = features
self.hop_length = hop_length
self.n_mels = n_mels
self.conv_dim = conv_dim
self.encoder_dim = encoder_dim
self.encoder_depth = encoder_depth
self.mask_hop = mask_hop
self.mask_prob = mask_prob
self.is_flash = is_flash
self.stat = stat
class MusicFMInferenceConfig(MusicFMConfig):
model_type = "musicfm_inference"
def __init__(
self,
num_codebooks: int = 1,
codebook_dim: int = 16,
codebook_size: int = 4096,
features: list[str] = ["melspec_2048"],
hop_length: int = 240,
n_mels: int = 128,
conv_dim: int = 512,
encoder_dim: int = 1024,
encoder_depth: int = 12,
mask_hop: float = 0.4,
mask_prob: float = 0.6,
is_flash: bool = False,
layer_index: int = 9,
stat: dict[str, float] = {},
**kwargs,
) -> None:
super().__init__(
num_codebooks=num_codebooks,
codebook_dim=codebook_dim,
codebook_size=codebook_size,
features=features,
hop_length=hop_length,
n_mels=n_mels,
conv_dim=conv_dim,
encoder_dim=encoder_dim,
encoder_depth=encoder_depth,
mask_hop=mask_hop,
mask_prob=mask_prob,
is_flash=is_flash,
stat=stat,
**kwargs,
)
self.layer_index = layer_index