tky823 commited on
Commit
5b5d0fc
·
verified ·
1 Parent(s): 610b667

Upload configuration_musicfm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_musicfm.py +79 -0
configuration_musicfm.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class MusicFMConfig(PretrainedConfig):
5
+ model_type = "musicfm"
6
+
7
+ def __init__(
8
+ self,
9
+ num_codebooks: int = 1,
10
+ codebook_dim: int = 16,
11
+ codebook_size: int = 4096,
12
+ features: list[str] = ["melspec_2048"],
13
+ hop_length: int = 240,
14
+ n_mels: int = 128,
15
+ conv_dim: int = 512,
16
+ encoder_dim: int = 1024,
17
+ encoder_depth: int = 12,
18
+ mask_hop: float = 0.4,
19
+ mask_prob: float = 0.6,
20
+ is_flash: bool = False,
21
+ stat: dict[str, float] = {},
22
+ **kwargs,
23
+ ) -> None:
24
+ super().__init__(**kwargs)
25
+
26
+ self.num_codebooks = num_codebooks
27
+ self.codebook_dim = codebook_dim
28
+ self.codebook_size = codebook_size
29
+ self.features = features
30
+ self.hop_length = hop_length
31
+ self.n_mels = n_mels
32
+ self.conv_dim = conv_dim
33
+ self.encoder_dim = encoder_dim
34
+ self.encoder_depth = encoder_depth
35
+ self.mask_hop = mask_hop
36
+ self.mask_prob = mask_prob
37
+ self.is_flash = is_flash
38
+ self.stat = stat
39
+
40
+
41
+ class MusicFMInferenceConfig(MusicFMConfig):
42
+ model_type = "musicfm_inference"
43
+
44
+ def __init__(
45
+ self,
46
+ num_codebooks: int = 1,
47
+ codebook_dim: int = 16,
48
+ codebook_size: int = 4096,
49
+ features: list[str] = ["melspec_2048"],
50
+ hop_length: int = 240,
51
+ n_mels: int = 128,
52
+ conv_dim: int = 512,
53
+ encoder_dim: int = 1024,
54
+ encoder_depth: int = 12,
55
+ mask_hop: float = 0.4,
56
+ mask_prob: float = 0.6,
57
+ is_flash: bool = False,
58
+ layer_index: int = 9,
59
+ stat: dict[str, float] = {},
60
+ **kwargs,
61
+ ) -> None:
62
+ super().__init__(
63
+ num_codebooks=num_codebooks,
64
+ codebook_dim=codebook_dim,
65
+ codebook_size=codebook_size,
66
+ features=features,
67
+ hop_length=hop_length,
68
+ n_mels=n_mels,
69
+ conv_dim=conv_dim,
70
+ encoder_dim=encoder_dim,
71
+ encoder_depth=encoder_depth,
72
+ mask_hop=mask_hop,
73
+ mask_prob=mask_prob,
74
+ is_flash=is_flash,
75
+ stat=stat,
76
+ **kwargs,
77
+ )
78
+
79
+ self.layer_index = layer_index