cronos3k commited on
Commit
a63aa33
·
verified ·
1 Parent(s): a607236

Upload audiodit/configuration_audiodit.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. audiodit/configuration_audiodit.py +225 -0
audiodit/configuration_audiodit.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AudioDiT model configuration"""
2
+
3
+ from transformers import PreTrainedConfig, logging
4
+ from transformers.models.umt5.configuration_umt5 import UMT5Config
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class AudioDiTVaeConfig(PreTrainedConfig):
11
+ r"""
12
+ Configuration class for the AudioDiT WAV-VAE audio autoencoder.
13
+
14
+ Args:
15
+ in_channels (`int`, *optional*, defaults to 1):
16
+ Number of input audio channels (mono=1).
17
+ channels (`int`, *optional*, defaults to 128):
18
+ Base channel count for encoder/decoder.
19
+ c_mults (`list[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
20
+ Channel multipliers for each encoder/decoder stage.
21
+ strides (`list[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
22
+ Downsampling strides for each encoder stage.
23
+ latent_dim (`int`, *optional*, defaults to 64):
24
+ Dimensionality of the latent space (after VAE bottleneck: encoder outputs 128, split to mean+scale → 64).
25
+ encoder_latent_dim (`int`, *optional*, defaults to 128):
26
+ Dimensionality of the encoder output before VAE bottleneck.
27
+ use_snake (`bool`, *optional*, defaults to `True`):
28
+ Whether to use Snake activation instead of ELU.
29
+ downsample_shortcut (`str`, *optional*, defaults to `"averaging"`):
30
+ Shortcut type for encoder downsampling blocks.
31
+ upsample_shortcut (`str`, *optional*, defaults to `"duplicating"`):
32
+ Shortcut type for decoder upsampling blocks.
33
+ out_shortcut (`str`, *optional*, defaults to `"averaging"`):
34
+ Shortcut type for encoder output projection.
35
+ in_shortcut (`str`, *optional*, defaults to `"duplicating"`):
36
+ Shortcut type for decoder input projection.
37
+ final_tanh (`bool`, *optional*, defaults to `False`):
38
+ Whether to apply tanh to the decoder output.
39
+ downsampling_ratio (`int`, *optional*, defaults to 2048):
40
+ Total downsampling ratio from audio samples to latent frames.
41
+ sample_rate (`int`, *optional*, defaults to 24000):
42
+ Audio sample rate.
43
+ scale (`float`, *optional*, defaults to 0.71):
44
+ Scale factor for the latent space.
45
+ """
46
+
47
+ model_type = "audiodit_vae"
48
+
49
+ def __init__(
50
+ self,
51
+ in_channels: int = 1,
52
+ channels: int = 128,
53
+ c_mults: list[int] | None = None,
54
+ strides: list[int] | None = None,
55
+ latent_dim: int = 64,
56
+ encoder_latent_dim: int = 128,
57
+ use_snake: bool = True,
58
+ downsample_shortcut: str = "averaging",
59
+ upsample_shortcut: str = "duplicating",
60
+ out_shortcut: str = "averaging",
61
+ in_shortcut: str = "duplicating",
62
+ final_tanh: bool = False,
63
+ downsampling_ratio: int = 2048,
64
+ sample_rate: int = 24000,
65
+ scale: float = 0.71,
66
+ **kwargs,
67
+ ):
68
+ super().__init__(**kwargs)
69
+ self.in_channels = in_channels
70
+ self.channels = channels
71
+ self.c_mults = c_mults if c_mults is not None else [1, 2, 4, 8, 16]
72
+ self.strides = strides if strides is not None else [2, 4, 4, 8, 8]
73
+ self.latent_dim = latent_dim
74
+ self.encoder_latent_dim = encoder_latent_dim
75
+ self.use_snake = use_snake
76
+ self.downsample_shortcut = downsample_shortcut
77
+ self.upsample_shortcut = upsample_shortcut
78
+ self.out_shortcut = out_shortcut
79
+ self.in_shortcut = in_shortcut
80
+ self.final_tanh = final_tanh
81
+ self.downsampling_ratio = downsampling_ratio
82
+ self.sample_rate = sample_rate
83
+ self.scale = scale
84
+
85
+
86
+ class AudioDiTConfig(PreTrainedConfig):
87
+ r"""
88
+ Configuration class for AudioDiT, a Conditional Flow Matching TTS model based on DiT architecture.
89
+
90
+ Args:
91
+ dit_dim (`int`, *optional*, defaults to 1536):
92
+ Hidden dimension of the DiT transformer.
93
+ dit_depth (`int`, *optional*, defaults to 24):
94
+ Number of transformer layers.
95
+ dit_heads (`int`, *optional*, defaults to 24):
96
+ Number of attention heads.
97
+ dit_ff_mult (`float`, *optional*, defaults to 4.0):
98
+ Feed-forward network multiplier.
99
+ dit_text_dim (`int`, *optional*, defaults to 768):
100
+ Dimension of the text encoder output (UMT5-base).
101
+ dit_dropout (`float`, *optional*, defaults to 0.0):
102
+ Dropout rate.
103
+ dit_bias (`bool`, *optional*, defaults to `True`):
104
+ Whether to use bias in linear layers.
105
+ dit_cross_attn (`bool`, *optional*, defaults to `True`):
106
+ Whether to use cross-attention layers.
107
+ dit_adaln_type (`str`, *optional*, defaults to `"global"`):
108
+ Type of adaptive layer norm (`"global"` or `"local"`).
109
+ dit_adaln_use_text_cond (`bool`, *optional*, defaults to `True`):
110
+ Whether to condition AdaLN on text embeddings.
111
+ dit_long_skip (`bool`, *optional*, defaults to `True`):
112
+ Whether to use long skip connection (input added to output).
113
+ dit_text_conv (`bool`, *optional*, defaults to `True`):
114
+ Whether to apply ConvNeXt blocks on text embeddings.
115
+ dit_qk_norm (`bool`, *optional*, defaults to `True`):
116
+ Whether to apply RMS normalization to Q and K.
117
+ dit_cross_attn_norm (`bool`, *optional*, defaults to `False`):
118
+ Whether to apply layer normalization in cross-attention.
119
+ dit_eps (`float`, *optional*, defaults to 1e-6):
120
+ Epsilon for normalization layers.
121
+ dit_use_latent_condition (`bool`, *optional*, defaults to `True`):
122
+ Whether to use latent conditioning (for prompt audio).
123
+ repa_dit_layer (`int`, *optional*, defaults to 8):
124
+ Layer index for representation alignment.
125
+ latent_dim (`int`, *optional*, defaults to 64):
126
+ Dimensionality of the audio latent space.
127
+ sigma (`float`, *optional*, defaults to 0.0):
128
+ Noise level for conditional flow matching.
129
+ sampling_rate (`int`, *optional*, defaults to 24000):
130
+ Audio sample rate.
131
+ latent_hop (`int`, *optional*, defaults to 2048):
132
+ Hop size in audio samples per latent frame.
133
+ max_wav_duration (`float`, *optional*, defaults to 30.0):
134
+ Maximum audio duration in seconds.
135
+ text_encoder_model (`str`, *optional*, defaults to `"google/umt5-base"`):
136
+ HuggingFace model identifier for the text encoder.
137
+ text_add_embed (`bool`, *optional*, defaults to `True`):
138
+ Whether to add the first hidden state to the last hidden state in text encoding.
139
+ text_norm_feat (`bool`, *optional*, defaults to `True`):
140
+ Whether to apply layer normalization to text features.
141
+ vae_config (`AudioDiTVaeConfig` or `dict`, *optional*):
142
+ Configuration for the WAV-VAE audio autoencoder.
143
+
144
+ Example:
145
+
146
+ ```python
147
+ >>> from transformers import AudioDiTConfig, AudioDiTModel
148
+
149
+ >>> configuration = AudioDiTConfig()
150
+ >>> model = AudioDiTModel(configuration)
151
+ >>> configuration = model.config
152
+ ```
153
+ """
154
+
155
+ model_type = "audiodit"
156
+ sub_configs = {"vae_config": AudioDiTVaeConfig, "text_encoder_config": UMT5Config}
157
+
158
+ def __init__(
159
+ self,
160
+ dit_dim: int = 1536,
161
+ dit_depth: int = 24,
162
+ dit_heads: int = 24,
163
+ dit_ff_mult: float = 4.0,
164
+ dit_text_dim: int = 768,
165
+ dit_dropout: float = 0.0,
166
+ dit_bias: bool = True,
167
+ dit_cross_attn: bool = True,
168
+ dit_adaln_type: str = "global",
169
+ dit_adaln_use_text_cond: bool = True,
170
+ dit_long_skip: bool = True,
171
+ dit_text_conv: bool = True,
172
+ dit_qk_norm: bool = True,
173
+ dit_cross_attn_norm: bool = False,
174
+ dit_eps: float = 1e-6,
175
+ dit_use_latent_condition: bool = True,
176
+ repa_dit_layer: int = 8,
177
+ latent_dim: int = 64,
178
+ sigma: float = 0.0,
179
+ sampling_rate: int = 24000,
180
+ latent_hop: int = 2048,
181
+ max_wav_duration: float = 30.0,
182
+ text_encoder_model: str = "google/umt5-base",
183
+ text_add_embed: bool = True,
184
+ text_norm_feat: bool = True,
185
+ vae_config: AudioDiTVaeConfig | dict | None = None,
186
+ text_encoder_config: UMT5Config | dict | None = None,
187
+ **kwargs,
188
+ ):
189
+ super().__init__(**kwargs)
190
+ self.dit_dim = dit_dim
191
+ self.dit_depth = dit_depth
192
+ self.dit_heads = dit_heads
193
+ self.dit_ff_mult = dit_ff_mult
194
+ self.dit_text_dim = dit_text_dim
195
+ self.dit_dropout = dit_dropout
196
+ self.dit_bias = dit_bias
197
+ self.dit_cross_attn = dit_cross_attn
198
+ self.dit_adaln_type = dit_adaln_type
199
+ self.dit_adaln_use_text_cond = dit_adaln_use_text_cond
200
+ self.dit_long_skip = dit_long_skip
201
+ self.dit_text_conv = dit_text_conv
202
+ self.dit_qk_norm = dit_qk_norm
203
+ self.dit_cross_attn_norm = dit_cross_attn_norm
204
+ self.dit_eps = dit_eps
205
+ self.dit_use_latent_condition = dit_use_latent_condition
206
+ self.repa_dit_layer = repa_dit_layer
207
+ self.latent_dim = latent_dim
208
+ self.sigma = sigma
209
+ self.sampling_rate = sampling_rate
210
+ self.latent_hop = latent_hop
211
+ self.max_wav_duration = max_wav_duration
212
+ self.text_encoder_model = text_encoder_model
213
+ self.text_add_embed = text_add_embed
214
+ self.text_norm_feat = text_norm_feat
215
+
216
+ if isinstance(vae_config, dict):
217
+ vae_config = AudioDiTVaeConfig(**vae_config)
218
+ self.vae_config = vae_config if vae_config is not None else AudioDiTVaeConfig()
219
+
220
+ if isinstance(text_encoder_config, dict):
221
+ text_encoder_config = UMT5Config(**text_encoder_config)
222
+ self.text_encoder_config = text_encoder_config
223
+
224
+
225
+ __all__ = ["AudioDiTConfig", "AudioDiTVaeConfig"]