cosrigel commited on
Commit
6d342f3
·
verified ·
1 Parent(s): 64b4d26

chore: upload folder dia for inference

Browse files
dia/__init__.py ADDED
File without changes
dia/audio.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+ from .config import DataConfig
6
+
7
+
8
+ def build_delay_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
9
+ """
10
+ Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
11
+ Negative t_idx => BOS; t_idx >= T => PAD.
12
+ """
13
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
14
+
15
+ t_idx_BxT = torch.broadcast_to(
16
+ torch.arange(T, dtype=torch.int32)[None, :],
17
+ [B, T],
18
+ )
19
+ t_idx_BxTx1 = t_idx_BxT[..., None]
20
+ t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
21
+
22
+ b_idx_BxTxC = torch.broadcast_to(
23
+ torch.arange(B, dtype=torch.int32).view(B, 1, 1),
24
+ [B, T, C],
25
+ )
26
+ c_idx_BxTxC = torch.broadcast_to(
27
+ torch.arange(C, dtype=torch.int32).view(1, 1, C),
28
+ [B, T, C],
29
+ )
30
+
31
+ # We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
32
+ t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
33
+
34
+ indices_BTCx3 = torch.stack(
35
+ [
36
+ b_idx_BxTxC.reshape(-1),
37
+ t_clamped_BxTxC.reshape(-1),
38
+ c_idx_BxTxC.reshape(-1),
39
+ ],
40
+ dim=1,
41
+ ).long() # Ensure indices are long type for indexing
42
+
43
+ return t_idx_BxTxC, indices_BTCx3
44
+
45
+
46
+ def apply_audio_delay(
47
+ audio_BxTxC: torch.Tensor,
48
+ pad_value: int,
49
+ bos_value: int,
50
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
51
+ ) -> torch.Tensor:
52
+ """
53
+ Applies the delay pattern to batched audio tokens using precomputed indices,
54
+ inserting BOS where t_idx < 0 and PAD where t_idx >= T.
55
+
56
+ Args:
57
+ audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
58
+ pad_value: the padding token
59
+ bos_value: the BOS token
60
+ precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
61
+
62
+ Returns:
63
+ result_BxTxC: [B, T, C] delayed audio tokens
64
+ """
65
+ device = audio_BxTxC.device # Get device from input tensor
66
+ t_idx_BxTxC, indices_BTCx3 = precomp
67
+ t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device
68
+ indices_BTCx3 = indices_BTCx3.to(device)
69
+
70
+ # Equivalent of tf.gather_nd using advanced indexing
71
+ # Ensure indices are long type if not already (build_delay_indices should handle this)
72
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
73
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
74
+
75
+ # Create masks on the correct device
76
+ mask_bos = t_idx_BxTxC < 0 # => place bos_value
77
+ mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
78
+
79
+ # Create scalar tensors on the correct device
80
+ bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
81
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
82
+
83
+ # If mask_bos, BOS; else if mask_pad, PAD; else original gather
84
+ # All tensors should now be on the same device
85
+ result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
86
+
87
+ return result_BxTxC
88
+
89
+
90
+ @torch.no_grad()
91
+ @torch.inference_mode()
92
+ def audio_to_codebook(
93
+ model,
94
+ input_values,
95
+ data_config: DataConfig,
96
+ padding_mask=None,
97
+ sample_rate=44100,
98
+ ):
99
+ """
100
+ Encodes the input audio waveform into discrete codes.
101
+
102
+ Args:
103
+ model: The model to use for encoding.
104
+ input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
105
+ Float values of the input audio waveform.
106
+ padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
107
+ Padding mask used to pad the `input_values`.
108
+ sample_rate (`int`, *optional*) :
109
+ Signal sampling_rate
110
+
111
+ Returns:
112
+ A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
113
+ factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
114
+ `codebook` of shape `[batch_size, num_codebooks, frames]`.
115
+ Scale is not used here.
116
+
117
+ """
118
+ audio_data = model.preprocess(input_values, sample_rate)
119
+
120
+ if padding_mask is None:
121
+ padding_mask = torch.ones_like(input_values).bool()
122
+
123
+ _, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None) # 1, C, T
124
+ seq_length = encoded_frame.shape[2]
125
+
126
+ t_idx_BxTxC, indices_BTCx3 = build_delay_indices(
127
+ B=1,
128
+ T=seq_length,
129
+ C=data_config.channels,
130
+ delay_pattern=data_config.delay_pattern,
131
+ )
132
+
133
+ encoded_frame = apply_audio_delay(
134
+ audio_BxTxC=encoded_frame.transpose(1, 2), # 1, T, C
135
+ pad_value=data_config.audio_pad_value,
136
+ bos_value=data_config.audio_bos_value,
137
+ precomp=(t_idx_BxTxC, indices_BTCx3),
138
+ )
139
+
140
+ return encoded_frame
141
+
142
+
143
+ def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
144
+ """
145
+ Precompute indices for the revert operation using PyTorch.
146
+
147
+ Returns:
148
+ A tuple (t_idx_BxTxC, indices_BTCx3) where:
149
+ - t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
150
+ - indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
151
+ batch indices, clamped time indices, and channel indices.
152
+ """
153
+ # Use default device unless specified otherwise; assumes inputs might define device later
154
+ device = None # Or determine dynamically if needed, e.g., from a model parameter
155
+
156
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
157
+
158
+ t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
159
+ t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
160
+
161
+ t_idx_BxTxC = torch.minimum(
162
+ t_idx_BT1 + delay_arr.view(1, 1, C),
163
+ torch.tensor(T - 1, device=device),
164
+ )
165
+ b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
166
+ c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
167
+
168
+ indices_BTCx3 = torch.stack(
169
+ [
170
+ b_idx_BxTxC.reshape(-1),
171
+ t_idx_BxTxC.reshape(-1),
172
+ c_idx_BxTxC.reshape(-1),
173
+ ],
174
+ axis=1,
175
+ ).long() # Ensure indices are long type
176
+
177
+ return t_idx_BxTxC, indices_BTCx3
178
+
179
+
180
+ def revert_audio_delay(
181
+ audio_BxTxC: torch.Tensor,
182
+ pad_value: int,
183
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
184
+ T: int,
185
+ ) -> torch.Tensor:
186
+ """
187
+ Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version).
188
+
189
+ Args:
190
+ audio_BxTxC: Input delayed audio tensor
191
+ pad_value: Padding value for out-of-bounds indices
192
+ precomp: Precomputed revert indices tuple containing:
193
+ - t_idx_BxTxC: Time offset indices tensor
194
+ - indices_BTCx3: Gather indices tensor for original audio
195
+ T: Original sequence length before padding
196
+
197
+ Returns:
198
+ Reverted audio tensor with same shape as input
199
+ """
200
+ t_idx_BxTxC, indices_BTCx3 = precomp
201
+ device = audio_BxTxC.device # Get device from input tensor
202
+
203
+ # Move precomputed indices to the same device as audio_BxTxC if they aren't already
204
+ t_idx_BxTxC = t_idx_BxTxC.to(device)
205
+ indices_BTCx3 = indices_BTCx3.to(device)
206
+
207
+ # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
208
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
209
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.size()) # Use .size() for robust reshaping
210
+
211
+ # Create pad_tensor on the correct device
212
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
213
+ # Create T tensor on the correct device for comparison
214
+ T_tensor = torch.tensor(T, device=device)
215
+
216
+ result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC) # Changed np.where to torch.where
217
+
218
+ return result_BxTxC
219
+
220
+
221
+ @torch.no_grad()
222
+ @torch.inference_mode()
223
+ def decode(
224
+ model,
225
+ audio_codes,
226
+ ):
227
+ """
228
+ Decodes the given frames into an output audio waveform
229
+ """
230
+ if len(audio_codes) != 1:
231
+ raise ValueError(f"Expected one frame, got {len(audio_codes)}")
232
+
233
+ try:
234
+ audio_values = model.quantizer.from_codes(audio_codes)
235
+ audio_values = model.decode(audio_values[0])
236
+
237
+ return audio_values
238
+ except Exception as e:
239
+ print(f"Error in decode method: {str(e)}")
240
+ raise
241
+
242
+
243
+ def codebook_to_audio(generated_codes: torch.Tensor, model, delay_pattern, B=1, T=2600, C=9):
244
+ """Process a single codebook file to generate audio"""
245
+ # Remove BOS token
246
+ generated_codes = generated_codes[:, 1:]
247
+
248
+ if generated_codes.shape[1] > T:
249
+ generated_codes = generated_codes[:, :T]
250
+
251
+ seq_length = generated_codes.shape[1]
252
+
253
+ # Build revert indices
254
+ t_idx_BxTxC, indices_BTCx3 = build_revert_indices(B=B, T=seq_length, C=C, delay_pattern=delay_pattern)
255
+
256
+ # Transpose and add batch dimension
257
+ audio_BxTxC = generated_codes.transpose(1, 0).unsqueeze(0)
258
+ reverted_codebook = revert_audio_delay(
259
+ audio_BxTxC=audio_BxTxC,
260
+ pad_value=0,
261
+ precomp=(t_idx_BxTxC, indices_BTCx3),
262
+ T=seq_length,
263
+ )
264
+ # Chỉ cắt bỏ 'delay' frame cuối nếu total frame > delay
265
+ delay = 30
266
+ num_frames = reverted_codebook.size(1) # chiều T ban đầu
267
+
268
+ if num_frames > delay:
269
+ reverted_codebook = reverted_codebook[:, :-delay, :]
270
+ # Nếu num_frames <= delay thì không cắt để tránh tạo chiều T = 0
271
+
272
+ codebook = reverted_codebook.transpose(1, 2) # (B x T x C) -> (B x C x T)
273
+
274
+ min_valid_index = 0
275
+ max_valid_index = 1023
276
+ invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
277
+
278
+ num_invalid = torch.sum(invalid_mask).item()
279
+ if num_invalid > 0:
280
+ print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.")
281
+
282
+ # Set invalid values to 0 (modify the tensor in-place)
283
+ codebook[invalid_mask] = 0
284
+ audio_array = decode(model, codebook)
285
+
286
+ return audio_array
dia/config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.1",
3
+ "model": {
4
+ "encoder": {
5
+ "n_layer": 12,
6
+ "n_embd": 1024,
7
+ "n_hidden": 4096,
8
+ "n_head": 16,
9
+ "head_dim": 128
10
+ },
11
+ "decoder": {
12
+ "n_layer": 18,
13
+ "n_embd": 2048,
14
+ "n_hidden": 8192,
15
+ "gqa_query_heads": 16,
16
+ "cross_query_heads": 16,
17
+ "kv_heads": 4,
18
+ "gqa_head_dim": 128,
19
+ "cross_head_dim": 128,
20
+ "d_model" : 256
21
+ },
22
+ "src_vocab_size": 256,
23
+ "tgt_vocab_size": 1028,
24
+ "dropout": 0.0
25
+ },
26
+ "training": {
27
+ "dtype": "bfloat16"
28
+ },
29
+ "data": {
30
+ "text_length": 512,
31
+ "audio_length": 1536,
32
+ "channels": 9,
33
+ "text_pad_value": 0,
34
+ "audio_eos_value": 1024,
35
+ "audio_pad_value": 1025,
36
+ "audio_bos_value": 1026,
37
+ "delay_pattern": [
38
+ 0,
39
+ 8,
40
+ 9,
41
+ 10,
42
+ 11,
43
+ 12,
44
+ 13,
45
+ 14,
46
+ 15
47
+ ]
48
+ }
49
+ }
dia/config.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management module for the Dia model.
2
+
3
+ This module provides comprehensive configuration management for the Dia model,
4
+ utilizing Pydantic for validation. It defines configurations for data processing,
5
+ model architecture (encoder and decoder), and training settings.
6
+
7
+ Key components:
8
+ - DataConfig: Parameters for data loading and preprocessing.
9
+ - EncoderConfig: Architecture details for the encoder module.
10
+ - DecoderConfig: Architecture details for the decoder module.
11
+ - ModelConfig: Combined model architecture settings.
12
+ - TrainingConfig: Training hyperparameters and settings.
13
+ - DiaConfig: Master configuration combining all components.
14
+ """
15
+
16
+ import os
17
+ from typing import Annotated
18
+
19
+ from pydantic import BaseModel, BeforeValidator, Field
20
+
21
+
22
+ class DataConfig(BaseModel, frozen=True):
23
+ """Configuration for data loading and preprocessing.
24
+
25
+ Attributes:
26
+ text_length: Maximum length of text sequences (must be multiple of 128).
27
+ audio_length: Maximum length of audio sequences (must be multiple of 128).
28
+ channels: Number of audio channels.
29
+ text_pad_value: Value used for padding text sequences.
30
+ audio_eos_value: Value representing the end of audio sequences.
31
+ audio_bos_value: Value representing the beginning of audio sequences.
32
+ audio_pad_value: Value used for padding audio sequences.
33
+ delay_pattern: List of delay values for each audio channel.
34
+ """
35
+
36
+ text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
37
+ audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
38
+ channels: int = Field(default=9, gt=0, multiple_of=1)
39
+ text_pad_value: int = Field(default=0)
40
+ audio_eos_value: int = Field(default=1024)
41
+ audio_pad_value: int = Field(default=1025)
42
+ audio_bos_value: int = Field(default=1026)
43
+ delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15])
44
+
45
+ def __hash__(self) -> int:
46
+ """Generate a hash based on all fields of the config."""
47
+ return hash(
48
+ (
49
+ self.text_length,
50
+ self.audio_length,
51
+ self.channels,
52
+ self.text_pad_value,
53
+ self.audio_pad_value,
54
+ self.audio_bos_value,
55
+ self.audio_eos_value,
56
+ tuple(self.delay_pattern),
57
+ )
58
+ )
59
+
60
+
61
+ class EncoderConfig(BaseModel, frozen=True):
62
+ """Configuration for the encoder component of the Dia model.
63
+
64
+ Attributes:
65
+ n_layer: Number of transformer layers.
66
+ n_embd: Embedding dimension.
67
+ n_hidden: Hidden dimension size in the MLP layers.
68
+ n_head: Number of attention heads.
69
+ head_dim: Dimension per attention head.
70
+ mlp_activations: List of activation functions for the MLP layers.
71
+ use_pre_norm: Whether to use pre-normalization (LayerNorm before attention/MLP).
72
+ """
73
+
74
+ n_layer: int = Field(gt=0)
75
+ n_embd: int = Field(gt=0)
76
+ n_hidden: int = Field(gt=0)
77
+ n_head: int = Field(gt=0)
78
+ head_dim: int = Field(gt=0)
79
+ mlp_activations: list[str] = Field(default=["silu", "linear"])
80
+ use_pre_norm: bool = Field(default=False)
81
+
82
+
83
+ class DecoderConfig(BaseModel, frozen=True):
84
+ """Configuration for the decoder component of the Dia model.
85
+
86
+ Attributes:
87
+ n_layer: Number of transformer layers.
88
+ n_embd: Embedding dimension.
89
+ n_hidden: Hidden dimension size in the MLP layers.
90
+ gqa_query_heads: Number of query heads for grouped-query self-attention.
91
+ kv_heads: Number of key/value heads for grouped-query self-attention.
92
+ gqa_head_dim: Dimension per query head for grouped-query self-attention.
93
+ cross_query_heads: Number of query heads for cross-attention.
94
+ cross_head_dim: Dimension per cross-attention head.
95
+ mlp_activations: List of activation functions for the MLP layers.
96
+ use_pre_norm: Whether to use pre-normalization.
97
+ """
98
+
99
+ n_layer: int = Field(gt=0)
100
+ n_embd: int = Field(gt=0)
101
+ n_hidden: int = Field(gt=0)
102
+ gqa_query_heads: int = Field(gt=0)
103
+ kv_heads: int = Field(gt=0)
104
+ gqa_head_dim: int = Field(gt=0)
105
+ cross_query_heads: int = Field(gt=0)
106
+ cross_head_dim: int = Field(gt=0)
107
+ mlp_activations: list[str] = Field(default=["silu", "linear"])
108
+ use_pre_norm: bool = Field(default=False)
109
+
110
+
111
+ class ModelConfig(BaseModel, frozen=True):
112
+ """Main configuration container for the Dia model architecture.
113
+
114
+ Attributes:
115
+ encoder: Configuration for the encoder component.
116
+ decoder: Configuration for the decoder component.
117
+ src_vocab_size: Size of the source (text) vocabulary.
118
+ tgt_vocab_size: Size of the target (audio code) vocabulary.
119
+ dropout: Dropout probability applied within the model.
120
+ normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
121
+ weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
122
+ rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
123
+ rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
124
+ """
125
+
126
+ encoder: EncoderConfig
127
+ decoder: DecoderConfig
128
+ src_vocab_size: int = Field(default=128, gt=0)
129
+ tgt_vocab_size: int = Field(default=1028, gt=0)
130
+ dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
131
+ normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
132
+ weight_dtype: str = Field(default="float32", description="Weight precision")
133
+ rope_min_timescale: int = Field(default=1, description="Timescale For global Attention")
134
+ rope_max_timescale: int = Field(default=10_000, description="Timescale For global Attention")
135
+
136
+
137
+ class TrainingConfig(BaseModel, frozen=True):
138
+ """Training process configuration and hyperparameters.
139
+
140
+ Note: This configuration currently only includes precision settings.
141
+ Other training parameters (like batch size, learning rate, optimizer settings)
142
+ are assumed to be handled externally.
143
+
144
+ Attributes:
145
+ dtype: Data type for activations during training (e.g., "bfloat16", "float32").
146
+ logits_dot_in_fp32: Whether to compute the final logits dot product in fp32 for stability.
147
+ """
148
+
149
+ dtype: str = Field(default="bfloat16", description="Activation precision")
150
+ logits_dot_in_fp32: bool = Field(default=False)
151
+
152
+
153
+ class DiaConfig(BaseModel, frozen=True):
154
+ """Master configuration for the Dia model.
155
+
156
+ Combines all sub-configurations into a single validated object.
157
+
158
+ Attributes:
159
+ version: Configuration version string.
160
+ model: Model architecture configuration.
161
+ training: Training process configuration (precision settings).
162
+ data: Data loading and processing configuration.
163
+ """
164
+
165
+ version: str = Field(default="1.0")
166
+ model: ModelConfig
167
+ training: TrainingConfig
168
+ data: DataConfig
169
+
170
+ def save(self, path: str) -> None:
171
+ """Save the current configuration instance to a JSON file.
172
+
173
+ Ensures the parent directory exists and the file has a .json extension.
174
+
175
+ Args:
176
+ path: The target file path to save the configuration.
177
+
178
+ Raises:
179
+ ValueError: If the path is not a file with a .json extension.
180
+ """
181
+ os.makedirs(os.path.dirname(path), exist_ok=True)
182
+ config_json = self.model_dump_json(indent=2)
183
+ with open(path, "w") as f:
184
+ f.write(config_json)
185
+
186
+ @classmethod
187
+ def load(cls, path: str) -> "DiaConfig | None":
188
+ """Load and validate a Dia configuration from a JSON file.
189
+
190
+ Args:
191
+ path: The path to the configuration file.
192
+
193
+ Returns:
194
+ A validated DiaConfig instance if the file exists and is valid,
195
+ otherwise None if the file is not found.
196
+
197
+ Raises:
198
+ ValueError: If the path does not point to an existing .json file.
199
+ pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
200
+ """
201
+ try:
202
+ with open(path, "r") as f:
203
+ content = f.read()
204
+ return cls.model_validate_json(content)
205
+ except FileNotFoundError:
206
+ return None
dia/config_inference.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.1",
3
+ "model": {
4
+ "encoder": {
5
+ "n_layer": 12,
6
+ "n_embd": 1024,
7
+ "n_hidden": 4096,
8
+ "n_head": 16,
9
+ "head_dim": 128
10
+ },
11
+ "decoder": {
12
+ "n_layer": 18,
13
+ "n_embd": 2048,
14
+ "n_hidden": 8192,
15
+ "gqa_query_heads": 16,
16
+ "cross_query_heads": 16,
17
+ "kv_heads": 4,
18
+ "gqa_head_dim": 128,
19
+ "cross_head_dim": 128,
20
+ "d_model" : 256
21
+ },
22
+ "src_vocab_size": 256,
23
+ "tgt_vocab_size": 1028,
24
+ "dropout": 0.0
25
+ },
26
+ "training": {
27
+ "dtype": "float32"
28
+ },
29
+ "data": {
30
+ "text_length": 512,
31
+ "audio_length": 1536,
32
+ "channels": 9,
33
+ "text_pad_value": 0,
34
+ "audio_eos_value": 1024,
35
+ "audio_pad_value": 1025,
36
+ "audio_bos_value": 1026,
37
+ "delay_pattern": [
38
+ 0,
39
+ 8,
40
+ 9,
41
+ 10,
42
+ 11,
43
+ 12,
44
+ 13,
45
+ 14,
46
+ 15
47
+ ]
48
+ }
49
+ }
dia/convert_ckpt.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from dia.layers import DiaModel # adjust your import if needed
4
+ from dia.config import DiaConfig
5
+
6
+ def convert_checkpoint(input_ckpt: str, output_ckpt: str, config_path: str):
7
+ # select device
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # 1) Reconstruct exactly the same compiled model you saved
11
+ dia_cfg = DiaConfig.load(config_path)
12
+ model = DiaModel(dia_cfg).to(device)
13
+ model = model.half()
14
+ model = torch.compile(model, backend="inductor")
15
+
16
+ # 2) Load your compiled/half checkpoint
17
+ state = torch.load(input_ckpt, map_location=device)
18
+ model.load_state_dict(state)
19
+
20
+ # 3) Un-wrap to the original nn.Module
21
+ orig = getattr(model, "_orig_mod", None) or getattr(model, "__wrapped__", None) or model
22
+
23
+ # 4) Cast all params & buffers back to float32
24
+ orig.float()
25
+
26
+ # 5) Save its clean, float32 state_dict
27
+ torch.save(orig.state_dict(), output_ckpt)
28
+ print(f"Saved normal FP32 checkpoint to {output_ckpt}")
29
+
30
+ def main():
31
+ parser = argparse.ArgumentParser(
32
+ description="Convert a compiled/half-precision checkpoint back to a standard FP32 state_dict."
33
+ )
34
+ parser.add_argument(
35
+ "--input-ckpt", "-i",
36
+ required=True,
37
+ help="Path to the half-precision compiled checkpoint (.pth) to load"
38
+ )
39
+ parser.add_argument(
40
+ "--output-ckpt", "-o",
41
+ required=True,
42
+ help="Path where the FP32 state_dict will be saved"
43
+ )
44
+ parser.add_argument(
45
+ "--config", "-c",
46
+ required=True,
47
+ help="Path to your DiaConfig JSON file"
48
+ )
49
+
50
+ args = parser.parse_args()
51
+ convert_checkpoint(args.input_ckpt, args.output_ckpt, args.config)
52
+
53
+ if __name__ == "__main__":
54
+ main()
dia/dataset.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ import torchaudio
5
+ import pandas as pd
6
+ from torch.utils.data import Dataset
7
+
8
+ import dac
9
+ from .config import DiaConfig
10
+
11
+
12
+
13
+
14
+ class LocalDiaDataset(Dataset):
15
+ """Dataset loader from a local CSV (pipe-separated) and audio folder."""
16
+ def __init__(self, csv_path: Path, audio_root: Path, config: DiaConfig, dac_model: dac.DAC):
17
+ self.df = pd.read_csv(csv_path, sep=r"\s*\|\s*", engine="python", names=["audio", "text", "channel"])
18
+ self.audio_root = audio_root
19
+ self.config = config
20
+ self.dac_model = dac_model
21
+
22
+ def __len__(self) -> int:
23
+ return len(self.df)
24
+
25
+ def __getitem__(self, idx: int):
26
+ row = self.df.iloc[idx]
27
+ text = row["text"]
28
+ channel = row.get("channel", None)
29
+ if channel and pd.notna(channel):
30
+ text = f"[{channel}]{text}"
31
+
32
+ audio_path = self.audio_root / row["audio"]
33
+ waveform, sr = torchaudio.load(audio_path)
34
+
35
+ if sr != 44100:
36
+ waveform = torchaudio.functional.resample(waveform, sr, 44100)
37
+
38
+ if waveform.ndim == 1:
39
+ waveform = waveform.unsqueeze(0)
40
+ elif waveform.ndim == 2:
41
+ waveform = waveform[:1] # only take 1 channel if stereo
42
+
43
+ waveform = waveform.unsqueeze(0) # (1, 1, T)
44
+ with torch.no_grad():
45
+ audio_tensor = self.dac_model.preprocess(waveform, 44100).to(
46
+ next(self.dac_model.parameters()).device
47
+ )
48
+ _, encoded, *_ = self.dac_model.encode(audio_tensor, n_quantizers=None)
49
+ encoded = encoded.squeeze(0).transpose(0, 1) # (T, C)
50
+
51
+ return text, encoded, waveform
52
+
53
+
54
+ class HFDiaDataset(Dataset):
55
+ """Dataset loader từ Hugging Face Datasets object."""
56
+
57
+ def __init__(self, hf_dataset, config: DiaConfig, dac_model: dac.DAC):
58
+ self.dataset = hf_dataset
59
+ self.config = config
60
+ self.dac_model = dac_model
61
+
62
+ def __len__(self) -> int:
63
+ return len(self.dataset)
64
+
65
+ def __getitem__(self, idx: int):
66
+ sample = self.dataset[idx]
67
+
68
+ # Xử lý text tag
69
+ text = sample["text"]
70
+ channel = sample.get("channel", None)
71
+ lang = sample.get("language", None)
72
+
73
+ if channel and isinstance(channel, str) and channel.strip():
74
+ text = f"[{channel}]{text}"
75
+ elif lang and isinstance(lang, str):
76
+ text = f"[{lang}]{text}"
77
+
78
+ # Xử lý audio
79
+ audio_info = sample["audio"]
80
+ waveform = torch.tensor(audio_info["array"], dtype=torch.float32)
81
+
82
+ # Đảm bảo waveform shape (1, 1, T)
83
+ if waveform.ndim == 1:
84
+ waveform = waveform.unsqueeze(0).unsqueeze(0)
85
+ elif waveform.ndim == 2:
86
+ waveform = waveform[:1].unsqueeze(0) # lấy 1 channel đầu
87
+
88
+ # Resample nếu không phải 44100 Hz
89
+ sr = audio_info.get("sampling_rate", 44100)
90
+ if sr != 44100:
91
+ waveform = torchaudio.functional.resample(waveform, sr, 44100)
92
+
93
+ with torch.no_grad():
94
+ audio_tensor = self.dac_model.preprocess(waveform, 44100).to(next(self.dac_model.parameters()).device)
95
+ _, encoded, *_ = self.dac_model.encode(audio_tensor, n_quantizers=None)
96
+ encoded = encoded.squeeze(0).transpose(0, 1) # (T, C)
97
+
98
+ return text, encoded, waveform
99
+
100
+
101
+
102
+ class HFDiaIterDataset(torch.utils.data.IterableDataset):
103
+ """Iterable wrapper for a HF streaming Dataset that has `audio.array` & `text`."""
104
+ def __init__(self, hf_iterable, config: DiaConfig, dac_model: dac.DAC):
105
+ super().__init__()
106
+ self.dataset = hf_iterable
107
+ self.config = config
108
+ self.dac_model = dac_model
109
+
110
+ def __iter__(self):
111
+ for sample in self.dataset:
112
+ lang = sample.get("language", None)
113
+ # Lấy thông tin channel và chuẩn hóa
114
+ channel = sample.get("channel", "").replace("@", "").lower()
115
+ speaker_tag = f"[{channel}]" if channel else "[unk]"
116
+ # Ghép tag speaker + text
117
+ text = speaker_tag + sample["text"]
118
+ audio_info = sample['audio']
119
+ waveform = torch.tensor(audio_info['array'], dtype=torch.float32)
120
+ if waveform.ndim == 1:
121
+ waveform = waveform.unsqueeze(0).unsqueeze(0)
122
+ elif waveform.ndim == 2:
123
+ waveform = waveform.unsqueeze(0)
124
+ sr = audio_info.get('sampling_rate', 44100)
125
+ if sr != 44100:
126
+ waveform = torchaudio.functional.resample(waveform, sr, 44100)
127
+ with torch.no_grad():
128
+ audio_tensor = (
129
+ self.dac_model.preprocess(waveform, 44100)
130
+ .to(next(self.dac_model.parameters()).device)
131
+ )
132
+ _, encoded, *_ = self.dac_model.encode(audio_tensor, n_quantizers=None)
133
+ encoded = encoded.squeeze(0).transpose(0, 1)
134
+ yield text, encoded, waveform
135
+
136
+ from .dataset import HFDiaIterDataset
137
+
138
+ class VietnameseDiaDataset(HFDiaIterDataset):
139
+ def __init__(self, dataset, dia_cfg, dac_model):
140
+ super().__init__(dataset, dia_cfg, dac_model)
141
+
142
+ def __getitem__(self, idx):
143
+ item = self.dataset[idx]
144
+ # 1) Thêm tag ngôn ngữ [vi]
145
+ text = item["text"]
146
+ if not text.startswith("[vi]"):
147
+ text = f"[vi]{text}"
148
+
149
+ # 2) Xử lý audio về 44.1 kHz
150
+ audio_array = item["audio"]["array"]
151
+ sr = item["audio"]["sampling_rate"]
152
+ if sr != 44100:
153
+ audio_array = torchaudio.functional.resample(
154
+ torch.tensor(audio_array),
155
+ orig_freq=sr,
156
+ new_freq=44100
157
+ ).numpy()
158
+
159
+ # 3) Mã hoá DAC (tần số codec) từ đoạn audio
160
+ encoding = self.get_dac_encoding(audio_array)
161
+
162
+ return text, encoding, audio_array
dia/finetune.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import random
5
+ import tempfile
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torchaudio
11
+ import pandas as pd
12
+ from torch.utils.data import Dataset, DataLoader, random_split
13
+ from torch.cuda.amp import autocast
14
+ from torch.utils.tensorboard import SummaryWriter
15
+ from torch.nn.utils import clip_grad_norm_
16
+ from transformers import get_scheduler
17
+ import torch.nn.functional as F
18
+ import bitsandbytes as bnb
19
+ from tqdm import tqdm
20
+ from datasets import load_dataset, interleave_datasets, get_dataset_config_names, DatasetDict
21
+ from huggingface_hub import hf_hub_download
22
+ import math
23
+ import gc
24
+ from torch.cuda.amp import GradScaler
25
+
26
+ import dac
27
+ from .config import DiaConfig
28
+ from .layers import DiaModel
29
+ from .model import Dia
30
+ from .audio import build_delay_indices, apply_audio_delay
31
+ from .dataset import *
32
+ from .interleaved_datasets import load_cml_tts_streamed, load_common_voice17_streamed
33
+ from datasets import load_from_disk
34
+ from .dataset import HFDiaDataset
35
+ from tqdm import tqdm
36
+
37
+ # Configure logging
38
+ logging.basicConfig(
39
+ level=logging.INFO,
40
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
41
+ )
42
+ logger = logging.getLogger(__name__)
43
+
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ torch.backends.cudnn.benchmark = True
46
+
47
+ #bytes for language tag replacement
48
+ LANG2BYTE = {
49
+ "en": 3,
50
+ "vi": 19,
51
+ }
52
+
53
+ CHANNELS = [
54
+ "5phutcrypto",
55
+ "anhbanthan",
56
+ "anhthamtu",
57
+ "animerewind.official",
58
+ "bibitv8888",
59
+ "btvgo",
60
+ "baclieutv",
61
+ "bachhoaxanhcom",
62
+ "baodientuvov",
63
+ "blvckvines",
64
+ "boringppl",
65
+ "bronub",
66
+ "cdteam-why",
67
+ "cobabinhduong",
68
+ "cosmicwriter",
69
+ "cuthongthai",
70
+ "daiphatthanhtruyenhinhsonla",
71
+ "day-be-thong-minh-tv",
72
+ "danangtv",
73
+ "daihanoi-htv",
74
+ "daiptththainguyentntv",
75
+ "dongmauviet",
76
+ "dongthaptv",
77
+ "fptbongdaofficial",
78
+ "fonosvietnam",
79
+ "hieurotrong5phut-ntkt",
80
+ "htvtintuc",
81
+ "happyhidari",
82
+ "hoabinhtvgo",
83
+ "hocenglishonline",
84
+ "hocvienbovagau",
85
+ "hungyentvvngo",
86
+ "huynhduykhuongofficial",
87
+ "huynhlapofficial",
88
+ "jvevermind",
89
+ "kenhvtc16",
90
+ "kiengiangtv",
91
+ "khanhvyofficial",
92
+ "kienthucquansu",
93
+ "lamdongtv1",
94
+ "lamvlog",
95
+ "longantv-la34",
96
+ "mangovid",
97
+ "mensbay",
98
+ "meovatcuocsonglnv",
99
+ "meuchannel",
100
+ "ntnvlogsnguyenthanhnam",
101
+ "ngamradio",
102
+ "nhanhac555",
103
+ "nhantaidaiviet",
104
+ "ptth-trt",
105
+ "ptvtruyenhinhphutho",
106
+ "phantichgame",
107
+ "phephim",
108
+ "phimhottk-l",
109
+ "riwaylegal",
110
+ "ruangao",
111
+ "suckhoetamsinh",
112
+ "sachbiquyethanhcong",
113
+ "soisangbrightsidevietnamese",
114
+ "spiderum",
115
+ "spiderumbooks",
116
+ "sukieskitchen",
117
+ "tin3phut",
118
+ "tranthanhtown",
119
+ "tulemientay",
120
+ "tayninhtv",
121
+ "thainhitv",
122
+ "thanhpahm",
123
+ "thegioilaptop",
124
+ "thepresentwriter",
125
+ "tiengiangtivi",
126
+ "tieubaobaothom",
127
+ "tintucbitcoin247",
128
+ "truyenhinhbinhphuoc-bptv",
129
+ "truyenhinhyenbaiytv",
130
+ "truyenhinhcaobang",
131
+ "truyenhinhdaklakdrt",
132
+ "truyenhinhdaknong1",
133
+ "truyenhinhdienbien23.9",
134
+ "truyenhinhkhanhhoa",
135
+ "truyenhinhkontumkrt",
136
+ "truyenhinhnaminhntv",
137
+ "truyenhinhninhthuan",
138
+ "truyenhinhquangngai",
139
+ "tuantienti2911",
140
+ "tuyenquangttv",
141
+ "vovlivedoctruyen",
142
+ "vietcetera",
143
+ "vinhlongtv",
144
+ "voizfm",
145
+ "vutrunguyenthuy",
146
+ "vuive",
147
+ "w2wanime",
148
+ "w2wcartoon",
149
+ "w2whorror",
150
+ "w2wmovie",
151
+ "web5ngay",
152
+ "xanh24h",
153
+ "aiphatthanhtruyenhinhquangtri",
154
+ "aiphatthanhvatruyenhinhhai1908",
155
+ "altonghop",
156
+ "antvtruyenhinhcongannhandan",
157
+ "baihoc10phut",
158
+ "battlecry.khampha",
159
+ "betterversionvn",
160
+ "blogkhoinghiep",
161
+ "bumcn",
162
+ "caikinhdi_vn",
163
+ "canthitg",
164
+ "chanthienmybachnien",
165
+ "chauanhchao",
166
+ "cosu",
167
+ "cungmaivaobep-monan-amthuc",
168
+ "daiptthphuyen",
169
+ "daiptthtv",
170
+ "daitruyenhinhangiang",
171
+ "daitruyenhinhbacgiang",
172
+ "dannytran2375",
173
+ "daybehoc5489",
174
+ "daylaphegame",
175
+ "dienmay",
176
+ "ducisreal",
177
+ "duongfg",
178
+ "duyluandethuong",
179
+ "duythanhish",
180
+ "elroydevops",
181
+ "gc.gamelab",
182
+ "hacthaybachthay",
183
+ "hagiangtv475",
184
+ "haiduongtv247",
185
+ "hanamtv8831",
186
+ "hangphimtailieudienanhnd",
187
+ "haugiangtv",
188
+ "haunauday",
189
+ "hieu-tv",
190
+ "hoshiphan",
191
+ "jakinatsumi2915",
192
+ "kechuyentieuhoc1719",
193
+ "kenhcovan",
194
+ "khalid_dinh",
195
+ "kiaralah",
196
+ "laichautv",
197
+ "langsontvtube",
198
+ "megame_official",
199
+ "minvestvn",
200
+ "nguoithanhcong1991",
201
+ "nhatkycuocsong.",
202
+ "ntcanima",
203
+ "ptthbentre",
204
+ "ptthquangbinh",
205
+ "qrt",
206
+ "quangninhtv",
207
+ "snewsvn",
208
+ "soctrangtv",
209
+ "sunhuynpodcast",
210
+ "tamhonanuong",
211
+ "tgddreview",
212
+ "thaibinhtv",
213
+ "thanhnamedu",
214
+ "thanhnientvnews",
215
+ "thbrt",
216
+ "thieunhitv3630",
217
+ "thtpct",
218
+ "tinnhanh3phut868",
219
+ "toansam",
220
+ "toidicodedaoblog",
221
+ "tranquochuywecommit",
222
+ "tranvyvy",
223
+ "truyenhinh4k",
224
+ "truyenhinhbinhthuan",
225
+ "truyenhinhcamau69",
226
+ "truyenhinhdongnai_dnrtv",
227
+ "truyenhinhgialai",
228
+ "truyenhinhlaocai",
229
+ "truyenhinhnghean",
230
+ "truyenhinhvinhphuc",
231
+ "txtofficial8798",
232
+ "vanhkhuyenle",
233
+ "vietnh1009",
234
+ "visaothenhipodcast",
235
+ "vtc14",
236
+ "vtcnow",
237
+ "vtv24",
238
+ "vuive123",
239
+ "zombiev4",
240
+ ]
241
+
242
+ # Tự động ánh xạ channel → token (bắt đầu từ 30)
243
+ for i, ch in enumerate(CHANNELS):
244
+ LANG2BYTE[ch] = 30 + i
245
+
246
+ test_sentences = {
247
+ "en": "In order to fully assess performance and the accuracy of language tags, this test sentence contains multiple subordinate clauses, varied punctuation, and a sufficient word count.",
248
+ "vi": "Để đánh giá toàn diện hiệu suất và độ chính xác của các thẻ ngôn ngữ, câu kiểm tra này chứa nhiều mệnh đề phụ, dấu câu đa dạng và số lượng từ đầy đủ."
249
+ }
250
+
251
+ @dataclass
252
+ class TrainConfig:
253
+ epochs: int = 1
254
+ batch_size: int = 2
255
+ grad_accum_steps: int = 2
256
+ learning_rate: float = 1e-5
257
+ warmup_steps: int = 500
258
+ unconditional_frac: float = 0.15
259
+ eval_step: int = 200
260
+ save_step: int = 2000
261
+ split_ratio: float = 0.997
262
+ shuffle_buffer_size: int = None # for streaming shuffle
263
+ seed: int = 42 # seed for reproducibility
264
+ runs_dir: Path = Path("runs")
265
+ run_name: str = "dia_finetune_cv"
266
+ output_dir: Path = Path(".cpkts/dia_finetune_cv ")
267
+ resume_from: Path = None
268
+ total_steps: int = 290007
269
+
270
+
271
+
272
+ def get_args() -> argparse.Namespace:
273
+ parser = argparse.ArgumentParser(description="Train the Dia audio model")
274
+ parser.add_argument("--config", type=Path, default=Path("dia/config.json"))
275
+ parser.add_argument("--dataset", type=str, default="Paradoxia/opendata-iisys-hui",
276
+ help="HuggingFace dataset name (if not using --csv_path).")
277
+ parser.add_argument("--dataset2", type=str, default=None,
278
+ help="(Optional) second HF dataset to interleave (streaming)")
279
+ parser.add_argument("--streaming",action="store_true",
280
+ help="Enable HuggingFace dataset streaming")
281
+ parser.add_argument("--hub_model", type=str, default="nari-labs/Dia-1.6B")
282
+ parser.add_argument("--local_ckpt", type=str, default=None)
283
+ parser.add_argument("--csv_path", type=Path, default=None,
284
+ help="Path to local CSV/TSV file with `audio|text` (if you want to train locally).")
285
+ parser.add_argument("--audio_root",type=Path, default=None,
286
+ help="Root directory for local audio files (required if --csv_path is set).")
287
+ parser.add_argument("--run_name", type=str, default=None)
288
+ parser.add_argument("--output_dir",type=Path, default=None)
289
+ parser.add_argument("--shuffle_buffer_size", type=int, default=None,
290
+ help="Buffer size for streaming dataset shuffle.")
291
+ parser.add_argument("--seed", type=int, default=42,
292
+ help="Random seed for reproducibility.")
293
+ parser.add_argument("--half", action="store_true", help="load model in fp16")
294
+ parser.add_argument("--compile", action="store_true", help="torch compile model")
295
+ parser.add_argument('--use_amp', action='store_true', help='Enable mixed precision')
296
+ parser.add_argument("--resume_from", type=str, default=None)
297
+ return parser.parse_args()
298
+
299
+
300
+
301
+ def collate_fn(batch, config: DiaConfig, device: torch.device):
302
+ from torch.nn.functional import pad
303
+
304
+ texts, encodings, waveforms = zip(*batch)
305
+
306
+ # -- Text inputs ---------------------------------------------------------
307
+
308
+ max_text = config.data.text_length
309
+ pad_tok = config.data.text_pad_value
310
+ text_ids = []
311
+ for txt in texts:
312
+ b_full = txt.encode('utf-8')
313
+ # replace leading "[lang]" prefix
314
+ for code, val in LANG2BYTE.items():
315
+ prefix = f"[{code}]".encode('utf-8')
316
+ if b_full.startswith(prefix):
317
+ b_full = bytes([val]) + b_full[len(prefix):]
318
+ break
319
+ bts = b_full[:max_text]
320
+ arr = list(bts) + [pad_tok] * (max_text - len(bts))
321
+ text_ids.append(torch.tensor(arr, dtype=torch.long))
322
+ src = torch.stack(text_ids).to(device)
323
+ src_pos = torch.arange(max_text, device=device).unsqueeze(0).expand(src.size(0), -1)
324
+ src_pad = src.ne(pad_tok)
325
+ enc_self_attn_mask = (src_pad.unsqueeze(2) & src_pad.unsqueeze(1)).unsqueeze(1)
326
+
327
+ # -- Audio codes --------------------------------------------------------
328
+
329
+ max_audio = config.data.audio_length
330
+ # per-sample lengths (clipped to max_audio)
331
+ seq_lens = [min(e.size(0), max_audio) for e in encodings]
332
+ batch_max = max(seq_lens)
333
+
334
+ # pad or trim each encoding to the batch max length
335
+ padded = [pad(e, (0, 0, 0, batch_max - e.size(0))) if e.size(0) < batch_max else e[:batch_max]
336
+ for e in encodings]
337
+ codes = torch.stack(padded).to(device) # (B, T=batch_max, C)
338
+
339
+ B, T, C = codes.shape
340
+ t_idx, idxs = build_delay_indices(B, T, C, config.data.delay_pattern)
341
+ delayed = apply_audio_delay(
342
+ codes,
343
+ config.data.audio_pad_value,
344
+ config.data.audio_bos_value,
345
+ (t_idx, idxs)
346
+ )
347
+ # ensure no longer than max_audio
348
+ delayed = delayed[:, :max_audio, :]
349
+
350
+ # -- Targets with per-sample EOS ----------------------------------------
351
+
352
+ max_tgt_len = max_audio + 2
353
+ pad_val = config.data.audio_pad_value
354
+ bos_val = config.data.audio_bos_value
355
+ eos_val = config.data.audio_eos_value
356
+
357
+ tgt = torch.full((B, max_tgt_len, C), pad_val, dtype=torch.long, device=device)
358
+ tgt[:, 0, :] = bos_val
359
+ tgt_lens = []
360
+ for i, L in enumerate(seq_lens):
361
+ tgt[i, 1:1 + L, :] = delayed[i, :L, :]
362
+ tgt[i, 1 + L, :] = eos_val
363
+ tgt_lens.append(1 + L + 1)
364
+
365
+ tgt_pos = torch.arange(max_tgt_len, device=device).unsqueeze(0).expand(B, -1)
366
+ tgt_pad = tgt.ne(pad_val).any(-1)
367
+
368
+ causal = torch.tril(torch.ones((max_tgt_len, max_tgt_len),
369
+ dtype=torch.bool,
370
+ device=device))
371
+ dec_self_attn_mask = (tgt_pad.unsqueeze(2) & tgt_pad.unsqueeze(1) & causal).unsqueeze(1)
372
+ dec_cross_attn_mask = (tgt_pad.unsqueeze(2) & src_pad.unsqueeze(1)).unsqueeze(1)
373
+
374
+ return {
375
+ 'src_tokens': src,
376
+ 'src_positions': src_pos,
377
+ 'enc_self_attn_mask': enc_self_attn_mask,
378
+ 'tgt_tokens': tgt,
379
+ 'tgt_positions': tgt_pos,
380
+ 'dec_self_attn_mask': dec_self_attn_mask,
381
+ 'dec_cross_attn_mask': dec_cross_attn_mask,
382
+ 'waveforms': waveforms,
383
+ 'raw_text': texts[0],
384
+ 'tgt_lens': torch.tensor(tgt_lens, dtype=torch.long, device=device),
385
+ }
386
+
387
+ def setup_loaders(dataset, dia_cfg: DiaConfig, train_cfg: TrainConfig, device):
388
+ collate = lambda b: collate_fn(b, dia_cfg, device)
389
+ if isinstance(dataset, HFDiaIterDataset):
390
+ total = getattr(dataset, "total_examples", None)
391
+ if total is None:
392
+ total = dataset.dataset.info.splits["train"].num_examples
393
+ n_train = int(train_cfg.split_ratio * total)
394
+ n_val = total - n_train
395
+ if n_val <= 0:
396
+ raise RuntimeError(f"No validation samples: total={total}, split_ratio={train_cfg.split_ratio}")
397
+ base = dataset.dataset.shuffle(buffer_size=train_cfg.shuffle_buffer_size, seed=train_cfg.seed) if train_cfg.shuffle_buffer_size else dataset.dataset
398
+ val_stream = base.take(n_val)
399
+ train_stream = base.skip(n_val)
400
+ train_ds = HFDiaIterDataset(train_stream, dia_cfg, dataset.dac_model)
401
+ val_ds = HFDiaIterDataset(val_stream, dia_cfg, dataset.dac_model)
402
+ train_loader = DataLoader(train_ds, batch_size=train_cfg.batch_size, shuffle=False, collate_fn=collate)
403
+ train_loader.steps_per_epoch = math.ceil(n_train / train_cfg.batch_size)
404
+ val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate)
405
+ return train_loader, val_loader
406
+ ds_len = len(dataset)
407
+ n_train = int(train_cfg.split_ratio * ds_len)
408
+ train_ds, val_ds = random_split(dataset, [n_train, ds_len - n_train])
409
+ train_loader = DataLoader(train_ds, batch_size=train_cfg.batch_size, shuffle=True, collate_fn=collate)
410
+ val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate)
411
+ return train_loader, val_loader
412
+
413
+
414
+
415
+ def setup_optimizer_and_scheduler(model, train_loader, train_cfg):
416
+ opt = bnb.optim.AdamW8bit(model.parameters(), lr=train_cfg.learning_rate)
417
+ # Determine steps per epoch: prefer len(), else use attached attribute
418
+ try:
419
+ steps_per_epoch = len(train_loader)
420
+ except TypeError:
421
+ if hasattr(train_loader, 'steps_per_epoch'):
422
+ steps_per_epoch = train_loader.steps_per_epoch
423
+ else:
424
+ raise RuntimeError("Cannot determine steps_per_epoch for streaming loader")
425
+ total_training_steps = steps_per_epoch * train_cfg.epochs
426
+ sched = get_scheduler(
427
+ 'cosine', opt,
428
+ num_warmup_steps=train_cfg.warmup_steps / train_cfg.grad_accum_steps,
429
+ num_training_steps=total_training_steps / train_cfg.grad_accum_steps
430
+ )
431
+ return opt, sched
432
+
433
+
434
+
435
+ def train_step(model, batch, dia_cfg, train_cfg, opt, sched, writer, step_in_epoch, global_step,scaler):
436
+ """
437
+ Perform a single training step: forward, loss, backward, update, log.
438
+ Now uses per‑sample tgt_lens to mask out padding after each EOS,
439
+ and applies 4× loss weight on the first channel.
440
+ """
441
+ # (optional) unconditional conditioning
442
+ if random.random() < train_cfg.unconditional_frac:
443
+ pad_tok = dia_cfg.data.text_pad_value
444
+ batch['src_tokens'] = torch.zeros_like(batch['src_tokens'])
445
+ batch['enc_self_attn_mask'] = torch.zeros_like(batch['enc_self_attn_mask'])
446
+ batch['dec_cross_attn_mask'] = torch.zeros_like(batch['dec_cross_attn_mask'])
447
+
448
+ with autocast(dtype=torch.float16):
449
+ # forward pass
450
+ logits = model(
451
+ src_BxS=batch['src_tokens'],
452
+ tgt_BxTxC=batch['tgt_tokens'],
453
+ src_positions=batch['src_positions'],
454
+ tgt_positions=batch['tgt_positions'],
455
+ enc_self_attn_mask=batch['enc_self_attn_mask'],
456
+ dec_self_attn_mask=batch['dec_self_attn_mask'],
457
+ dec_cross_attn_mask=batch['dec_cross_attn_mask'],
458
+ enable_dropout=True,
459
+ )
460
+ # fetch per-sample target‑lengths (including BOS+frames+EOS)
461
+ lens = batch['tgt_lens'] # shape: (B,)
462
+ max_L = int(lens.max().item()) # maximum over batch
463
+
464
+ # keep only up through the last possible EOS slot
465
+ # logits: (B, T, C, V) -> (B, max_L-1, C, V)
466
+ logits = logits[:, : max_L - 1]
467
+
468
+ # targets: shift off the BOS so 0..<max_L-1> align with logits
469
+ # target: (B, T, C) -> (B, max_L-1, C)
470
+ target = batch['tgt_tokens'][:, 1:max_L, :]
471
+
472
+ B, Tm1, C = target.shape
473
+ pad_val = dia_cfg.data.audio_pad_value
474
+
475
+ # build a mask [B x (max_L-1)] that is True for t < (lens[i]-1)
476
+ time_idx = torch.arange(Tm1, device=lens.device).unsqueeze(0) # (1, Tm1)
477
+ valid_time = time_idx < (lens.unsqueeze(1) - 1) # (B, Tm1)
478
+ mask = valid_time.unsqueeze(-1).expand(-1, -1, C) # (B, Tm1, C)
479
+
480
+ # apply 4× weight on first channel, 1× on others
481
+ channel_weights = [4.0] + [1.0] * (C - 1)
482
+ loss_c = 0.0
483
+ _, _, _, V = logits.size()
484
+
485
+ for c, w in enumerate(channel_weights):
486
+ # flatten this channel
487
+ lc = logits[:, :, c, :].reshape(-1, V) # (B*Tm1, V)
488
+ tc = target[:, :, c].reshape(-1) # (B*Tm1,)
489
+ mc = mask[:, :, c].reshape(-1) # (B*Tm1,)
490
+
491
+ # mask out padding and compute cross-entropy
492
+ lc_valid = lc[mc]
493
+ tc_valid = tc[mc]
494
+ loss_c += w * F.cross_entropy(
495
+ lc_valid, tc_valid,
496
+ ignore_index=pad_val
497
+ )
498
+
499
+ # normalize by sum of weights
500
+ loss = loss_c / sum(channel_weights)
501
+
502
+ # scale + backward
503
+ loss = loss / train_cfg.grad_accum_steps
504
+ scaler.scale(loss).backward()
505
+
506
+
507
+ # step & log
508
+
509
+ if (step_in_epoch + 1) % train_cfg.grad_accum_steps == 0:
510
+ # Unscale before clipping
511
+ scaler.unscale_(opt)
512
+ grad_norm = clip_grad_norm_(model.parameters(), max_norm=1e9)
513
+
514
+ scaler.step(opt)
515
+ scaler.update()
516
+ opt.zero_grad()
517
+ sched.step()
518
+
519
+ true_loss = loss.item() * train_cfg.grad_accum_steps
520
+ current_lr = sched.get_last_lr()[0]
521
+
522
+ writer.add_scalar('GradNorm/global', grad_norm, global_step)
523
+ writer.add_scalar('LR', current_lr, global_step)
524
+ writer.add_scalar('Loss/train', true_loss, global_step)
525
+
526
+ return true_loss
527
+ else:
528
+ return loss.item() * train_cfg.grad_accum_steps
529
+
530
+
531
+
532
+
533
+ def eval_step(model, val_loader, dia_cfg, dac_model, writer, global_step):
534
+ """
535
+ Run evaluation: compute average loss on validation set and log audio samples.
536
+ """
537
+ import gc
538
+ eval_losses = []
539
+ last_batch = None
540
+ with torch.inference_mode():
541
+ for eb in tqdm(val_loader, desc="eval"):
542
+ last_batch = eb
543
+
544
+ with autocast(dtype=torch.float16):
545
+ logits16 = model(
546
+ src_BxS=eb['src_tokens'],
547
+ tgt_BxTxC=eb['tgt_tokens'],
548
+ src_positions=eb['src_positions'],
549
+ tgt_positions=eb['tgt_positions'],
550
+ enc_self_attn_mask=eb['enc_self_attn_mask'],
551
+ dec_self_attn_mask=eb['dec_self_attn_mask'],
552
+ dec_cross_attn_mask=eb['dec_cross_attn_mask'],
553
+ enable_dropout=False,
554
+ )[:, :-1]
555
+
556
+ logits = logits16.float()
557
+ target = eb['tgt_tokens'][:, 1:]
558
+ B_e, T_e, C_e = target.shape
559
+ V_e = logits.size(-1)
560
+
561
+ loss_e = 0.0
562
+ weights_e = [4.0] + [1.0] * (C_e - 1)
563
+ for c, w in enumerate(weights_e):
564
+ lc = logits[:, :, c, :].reshape(-1, V_e)
565
+ tc = target[:, :, c].reshape(-1)
566
+ loss_e += w * F.cross_entropy(
567
+ lc, tc, ignore_index=dia_cfg.data.audio_pad_value
568
+ )
569
+ loss_e = loss_e / sum(weights_e)
570
+
571
+ eval_losses.append(loss_e)
572
+
573
+ avg_eval_loss = sum(eval_losses) / len(eval_losses)
574
+ writer.add_scalar('Loss/eval', avg_eval_loss.item(), global_step)
575
+
576
+ # --- Inference test sentence ---
577
+ try:
578
+ orig_dtype = next(model.parameters()).dtype
579
+ model = model.float()
580
+ dia_gen = Dia(dia_cfg, device)
581
+ dia_gen.model, dia_gen.dac_model = model, dac_model
582
+
583
+ # ✅ Test câu hội thoại đa giọng
584
+ test_dialogue = "[vtv24] Em vừa đi học về, anh ạ. [duongfg] Ừ, em ăn cơm chưa? [vtv24] Em ăn rồi!"
585
+
586
+ if len(test_dialogue) > 10:
587
+ try:
588
+ audio = dia_gen.generate(text=test_dialogue)
589
+ writer.add_audio("Eval/test_dialogue", audio, global_step, 44100)
590
+ except Exception:
591
+ logger.exception("Eval error during test_dialogue")
592
+ finally:
593
+ if 'audio' in locals():
594
+ del audio
595
+
596
+
597
+ except Exception:
598
+ logger.exception("Eval error")
599
+
600
+ finally:
601
+ if 'audio' in locals():
602
+ del audio
603
+ gc.collect()
604
+ torch.cuda.empty_cache()
605
+ if orig_dtype == torch.float16:
606
+ model = model.half()
607
+
608
+ def train(model, dia_cfg: DiaConfig, dac_model: dac.DAC, dataset, train_cfg: TrainConfig):
609
+ """
610
+ Run the full training loop over epochs.
611
+ """
612
+ # prepare directories
613
+ train_cfg.output_dir.mkdir(parents=True, exist_ok=True)
614
+ (train_cfg.runs_dir / train_cfg.run_name).mkdir(parents=True, exist_ok=True)
615
+ model = model.to(device)
616
+
617
+ train_loader, val_loader = setup_loaders(dataset, dia_cfg, train_cfg, device)
618
+ opt, sched = setup_optimizer_and_scheduler(model, train_loader, train_cfg)
619
+
620
+ writer = SummaryWriter(train_cfg.runs_dir / train_cfg.run_name)
621
+ model.train()
622
+ scaler = GradScaler()
623
+ start_epoch = 0
624
+ global_step = 0
625
+ resume_ckpt = getattr(train_cfg, "resume_from", None)
626
+ if resume_ckpt and resume_ckpt.exists():
627
+ logger.info(f"Resuming from checkpoint: {resume_ckpt}")
628
+ checkpoint = torch.load(resume_ckpt, map_location=device)
629
+ model.load_state_dict(checkpoint["model"])
630
+ opt.load_state_dict(checkpoint["optimizer"])
631
+ sched.load_state_dict(checkpoint["scheduler"])
632
+ scaler.load_state_dict(checkpoint["scaler"])
633
+ start_epoch = checkpoint.get("epoch", 0)
634
+ global_step = checkpoint.get("global_step", 0)
635
+
636
+
637
+ steps_per_epoch = getattr(train_loader, 'steps_per_epoch', None)
638
+ if steps_per_epoch is None:
639
+ try:
640
+ steps_per_epoch = len(train_loader)
641
+ except Exception:
642
+ steps_per_epoch = None
643
+
644
+ for epoch in range(start_epoch, train_cfg.epochs):
645
+ # iterate with progress bar, using total if known
646
+ loader_iter = tqdm(
647
+ train_loader,
648
+ desc=f"E{epoch+1}",
649
+ total=steps_per_epoch
650
+ )
651
+ pbar = tqdm(loader_iter, total=train_cfg.total_steps, initial=global_step, desc=f"E{epoch}")
652
+ for step_in_epoch, batch in enumerate(pbar):
653
+ global_step += 1
654
+ # training step
655
+ loss = train_step(model, batch, dia_cfg, train_cfg, opt, sched, writer, step_in_epoch, global_step, scaler)
656
+
657
+ cur_alloc = torch.cuda.memory_allocated() # bytes currently allocated by tensors
658
+ peak_alloc = torch.cuda.max_memory_allocated() # bytes peak during program
659
+ # optionally convert to GB
660
+ cur_gb = cur_alloc / 1024**3
661
+ peak_gb = peak_alloc / 1024**3
662
+
663
+ # update the tqdm postfix
664
+ loader_iter.set_postfix({
665
+ 'loss': f"{loss:.4f}",
666
+ 'VRAM (GB)': f"{cur_gb:.2f}/{peak_gb:.2f}"
667
+ })
668
+
669
+ # remember to zero the peak if you want rolling peaks per step
670
+ if torch.cuda.is_available():
671
+ torch.cuda.reset_peak_memory_stats()
672
+
673
+
674
+ # evaluation
675
+ if step_in_epoch % train_cfg.eval_step == 0:
676
+ model.eval()
677
+ with torch.no_grad():
678
+ eval_step(model, val_loader, dia_cfg, dac_model, writer, global_step)
679
+ model.train()
680
+ scaler = GradScaler()
681
+
682
+ # checkpoint
683
+ if step_in_epoch and step_in_epoch % train_cfg.save_step == 0:
684
+ ckpt = train_cfg.output_dir / f"ckpt_step{global_step}.pth"
685
+ torch.save({
686
+ "model": model.state_dict(),
687
+ "optimizer": opt.state_dict(),
688
+ "scheduler": sched.state_dict(),
689
+ "scaler": scaler.state_dict(),
690
+ "epoch": epoch,
691
+ "global_step": global_step
692
+ }, ckpt)
693
+ logger.info(f"Saved checkpoint: {ckpt}")
694
+
695
+ # end of epoch checkpoint
696
+ ckpt_e = train_cfg.output_dir / f"ckpt_epoch{epoch+1}.pth"
697
+ torch.save({
698
+ "model": model.state_dict(),
699
+ "optimizer": opt.state_dict(),
700
+ "scheduler": sched.state_dict(),
701
+ "scaler": scaler.state_dict(),
702
+ "epoch": epoch + 1,
703
+ "global_step": global_step
704
+ }, ckpt_e)
705
+ logger.info(f"Saved end-of-epoch checkpoint: {ckpt_e}")
706
+
707
+ from datasets import disable_caching
708
+
709
+ def main():
710
+ args = get_args()
711
+ import os
712
+ os.environ["HF_DATASETS_CACHE"] = "/tmp/force_streaming" # ép cache mới
713
+ disable_caching()
714
+ # tắt toàn bộ cache local HuggingFace
715
+ import json
716
+ with open(args.config, "r", encoding="utf-8") as f:
717
+ config_dict = json.load(f)
718
+
719
+ dia_cfg = DiaConfig(**config_dict)
720
+ dac_model = dac.DAC.load(dac.utils.download()).to(device)
721
+ dataset = None
722
+
723
+ if not dataset:
724
+ if args.csv_path:
725
+ if not args.audio_root:
726
+ raise ValueError("`--audio_root` must be set when using `--csv_path`")
727
+ dataset = LocalDiaDataset(args.csv_path, args.audio_root, dia_cfg, dac_model)
728
+ else:
729
+ # ✅ Check nếu dataset là đường dẫn local
730
+ if Path(args.dataset).exists():
731
+ print(f"Loading dataset from local path: {args.dataset}")
732
+ ds1 = load_from_disk(args.dataset)
733
+ if isinstance(ds1, DatasetDict):
734
+ ds1 = ds1["train"]
735
+ dataset = HFDiaDataset(ds1, dia_cfg, dac_model)
736
+ else:
737
+ print(f"Loading HuggingFace dataset: {args.dataset} (streaming)")
738
+ ds1 = load_dataset(args.dataset, split="train", streaming=True)
739
+
740
+ if args.dataset2:
741
+ ds2 = load_dataset(args.dataset2, split="train", streaming=True)
742
+ hf_ds = interleave_datasets([ds1, ds2])
743
+ dataset = HFDiaIterDataset(hf_ds, dia_cfg, dac_model)
744
+ else:
745
+ hf_ds = ds1
746
+ dataset = HFDiaIterDataset(hf_ds, dia_cfg, dac_model)
747
+
748
+
749
+
750
+ train_cfg = TrainConfig(
751
+ run_name = args.run_name or TrainConfig.run_name,
752
+ output_dir = args.output_dir or TrainConfig.output_dir,
753
+ shuffle_buffer_size = args.shuffle_buffer_size,
754
+ seed = args.seed,
755
+ )
756
+ if args.resume_from:
757
+ train_cfg.resume_from = Path(args.resume_from)
758
+ # load model checkpoint
759
+ if args.local_ckpt:
760
+ ckpt_file = args.local_ckpt
761
+ else:
762
+ ckpt_file = hf_hub_download(args.hub_model, filename="dia-v0_1.pth")
763
+ model = DiaModel(dia_cfg)
764
+ if args.half:
765
+ model=model.half()
766
+ if args.compile:
767
+ model = torch.compile(model, backend="inductor")
768
+ ckpt = torch.load(ckpt_file, map_location="cpu")
769
+ state_dict = ckpt["model"] if "model" in ckpt else ckpt
770
+ new_state_dict = {}
771
+
772
+ for k, v in state_dict.items():
773
+ if "encoder.embedding.weight" in k:
774
+ if v.shape != model.state_dict()[k].shape:
775
+ print(f"⚠️ Bỏ qua {k} do shape không khớp: {v.shape} vs {model.state_dict()[k].shape}")
776
+ continue
777
+ new_state_dict[k] = v
778
+
779
+ model.load_state_dict(new_state_dict, strict=False)
780
+
781
+
782
+ # start training
783
+ train(model, dia_cfg, dac_model, dataset, train_cfg)
784
+
785
+
786
+ if __name__ == "__main__":
787
+ main()
dia/interleaved_datasets.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, get_dataset_config_names, interleave_datasets, load_dataset_builder
2
+ from .dataset import HFDiaIterDataset
3
+ import pandas as pd
4
+ from huggingface_hub import hf_hub_download
5
+
6
+
7
+ LANG_NAME_TO_CODE = {
8
+ "dutch": "nl",
9
+ "french": "fr",
10
+ "german": "de",
11
+ "italian": "it",
12
+ "polish": "pl",
13
+ "portuguese": "pt",
14
+ "spanish": "es",
15
+ # add more if other configs appear...
16
+ }
17
+
18
+
19
+
20
+
21
+
22
+
23
+ def load_cml_tts_streamed(dia_cfg, dac_model):
24
+ """
25
+ Stream all language subsets of the CML-TTS dataset in train split,
26
+ add a `language` field, drop all except `text`, `audio`, `language`,
27
+ and interleave them into one streaming Dataset.
28
+
29
+ Returns:
30
+ datasets.IterableDataset: interleaved streaming dataset
31
+ """
32
+ # 1) Discover all language subsets
33
+ lang_configs = get_dataset_config_names("ylacombe/cml-tts")
34
+
35
+ # 2) Build one streaming subset per language, with only desired columns
36
+ streams = []
37
+ num_ex=0
38
+ for lang in lang_configs:
39
+
40
+ iso_code = LANG_NAME_TO_CODE.get(lang, lang)
41
+ ds_stream = load_dataset(
42
+ "ylacombe/cml-tts",
43
+ name=lang,
44
+ split="train",
45
+ streaming=True
46
+ )
47
+
48
+ num_ex += ds_stream.info.splits['train'].num_examples
49
+ # keep only text, audio, and add language
50
+ def _add_lang(ex, iso=iso_code):
51
+ return {
52
+ "text": ex["text"],
53
+ "audio": ex["audio"],
54
+ "language": iso
55
+ }
56
+ ds_stream = ds_stream.map(
57
+ _add_lang,
58
+ remove_columns=[c for c in ds_stream.column_names if c not in ["text", "audio", "language"]]
59
+ )
60
+ streams.append(ds_stream)
61
+
62
+ # 3) Interleave all streams into one unified stream
63
+ interleaved = interleave_datasets(streams, stopping_strategy="all_exhausted")
64
+ ds = HFDiaIterDataset(interleaved, dia_cfg, dac_model)
65
+ ds.total_examples = num_ex
66
+ return ds
67
+
68
+
69
+
70
+
71
+
72
+
73
+ def count_tsv_rows(
74
+ repo_id: str,
75
+ subset: str,
76
+ split: str = "train",
77
+ revision: str = "main"
78
+ ) -> int:
79
+ """Download the TSV for a given subset/split and return its number of rows."""
80
+ file_path = f"transcript/{subset}/{split}.tsv"
81
+ try:
82
+ local_file = hf_hub_download(
83
+ repo_id=repo_id,
84
+ filename=file_path,
85
+ repo_type="dataset",
86
+ revision=revision
87
+ )
88
+ except:
89
+ print("error fetching tsv metadata")
90
+
91
+ df = pd.read_csv(local_file, sep="\t", low_memory=False)
92
+ return len(df)
93
+
94
+ def load_common_voice17_streamed(dia_cfg, dac_model, revision="main"):
95
+ """
96
+ Stream the train split of Common Voice 17 for the given language codes,
97
+ rename `sentence`→`text`, keep only `text`, `audio`, and `language`,
98
+ then interleave into a single streaming Dataset.
99
+
100
+ Languages loaded: en, de, fr, es, it, nl, pl, pt, tr, hu
101
+ """
102
+ repo_id = "mozilla-foundation/common_voice_17_0"
103
+ langs = ["en", "de", "fr", "es", "it", "nl", "pl", "pt", "tr", "hu"]
104
+
105
+ streams = []
106
+ row_counts = []
107
+
108
+ for lang in langs:
109
+ # 1) figure out how many rows in the TSV
110
+ n_rows = count_tsv_rows(repo_id, lang, split="train", revision=revision)
111
+ row_counts.append(n_rows)
112
+
113
+ # 2) load in streaming mode
114
+ ds_stream = load_dataset(
115
+ repo_id,
116
+ name=lang,
117
+ split="train",
118
+ streaming=True,
119
+ revision=revision
120
+ )
121
+
122
+ # 3) map to desired schema
123
+ def _prep(ex, iso=lang):
124
+ return {
125
+ "text": ex["sentence"],
126
+ "audio": ex["audio"],
127
+ "language": iso
128
+ }
129
+
130
+ ds_stream = ds_stream.map(
131
+ _prep,
132
+ remove_columns=[c for c in ds_stream.column_names if c not in ("sentence", "audio")]
133
+ )
134
+ streams.append(ds_stream)
135
+
136
+ # 4) interleave: all_exhausted ⇒ max_length * num_streams
137
+ interleaved = interleave_datasets(streams, stopping_strategy="all_exhausted")
138
+
139
+ # 5) wrap and attach total_examples
140
+ ds = HFDiaIterDataset(interleaved, dia_cfg, dac_model)
141
+ ds.total_examples = max(row_counts) * len(langs)
142
+
143
+ return ds
144
+
dia/layers.py ADDED
@@ -0,0 +1,909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+ from torch.nn import RMSNorm
8
+
9
+ from .config import DiaConfig
10
+
11
+
12
+ def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
13
+ return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
14
+
15
+
16
+ def _str_to_dtype(dtype_str: str) -> torch.dtype | None:
17
+ # Allow None for default behavior
18
+ if dtype_str is None or dtype_str.lower() == "none":
19
+ return None
20
+ if dtype_str == "float32":
21
+ return torch.float32
22
+ elif dtype_str == "float16":
23
+ return torch.float16
24
+ elif dtype_str == "bfloat16":
25
+ return torch.bfloat16
26
+ else:
27
+ raise ValueError(f"Unsupported dtype string: {dtype_str}")
28
+
29
+
30
+ class DenseGeneral(nn.Module):
31
+ """
32
+ PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
33
+
34
+ Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
35
+ for the generalized matrix multiplication. Weight/bias shapes are calculated
36
+ and parameters created during initialization based on config.
37
+ `load_weights` validates shapes and copies data.
38
+
39
+ Attributes:
40
+ axis (Tuple[int, ...]): Input axis or axes to contract.
41
+ in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
42
+ out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
43
+ use_bias (bool): Whether to add a bias term.
44
+ weight (nn.Parameter): The kernel parameter.
45
+ bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ in_shapes: tuple[int, ...],
51
+ out_features: tuple[int, ...],
52
+ axis: tuple[int, ...] = (-1,),
53
+ dtype: torch.dtype | None = None,
54
+ weight_dtype: torch.dtype | None = None,
55
+ device: torch.device | None = None,
56
+ ):
57
+ super().__init__()
58
+ self.in_shapes = in_shapes
59
+ self.out_features = out_features
60
+ self.axis = axis
61
+ self.dtype = dtype
62
+ self.kernel_shape = self.in_shapes + self.out_features
63
+
64
+ factory_kwargs = {"device": device, "dtype": weight_dtype}
65
+ self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
66
+ self.register_parameter("bias", None)
67
+
68
+ def forward(self, inputs: Tensor) -> Tensor:
69
+ norm_axis = _normalize_axes(self.axis, inputs.ndim)
70
+ kernel_contract_axes = tuple(range(len(norm_axis)))
71
+
72
+ output = torch.tensordot(
73
+ inputs.float(),
74
+ self.weight.float(),
75
+ dims=(norm_axis, kernel_contract_axes),
76
+ ).to(inputs.dtype)
77
+ return output
78
+
79
+
80
+ def get_activation_fn(activation_string: str) -> nn.Module: # Return Module instance
81
+ """Maps activation string to PyTorch activation function module."""
82
+ if activation_string == "gelu":
83
+ return nn.GELU()
84
+ elif activation_string == "relu":
85
+ return nn.ReLU()
86
+ elif activation_string == "silu" or activation_string == "swish":
87
+ return nn.SiLU()
88
+ elif activation_string == "linear":
89
+ return nn.Identity()
90
+ else:
91
+ raise ValueError(f"Unsupported activation function: {activation_string}")
92
+
93
+
94
+ class MlpBlock(nn.Module):
95
+ """MLP block using DenseGeneral."""
96
+
97
+ def __init__(
98
+ self,
99
+ config: DiaConfig,
100
+ embed_dim: int,
101
+ intermediate_dim: int,
102
+ dropout_rate: float,
103
+ activations: list[str] = ["silu", "linear"],
104
+ use_pre_norm: bool = False,
105
+ ):
106
+ super().__init__()
107
+ self.use_pre_norm = use_pre_norm
108
+ num_activations = len(activations)
109
+ compute_dtype = _str_to_dtype(config.training.dtype)
110
+ weight_dtype = _str_to_dtype(config.model.weight_dtype)
111
+ self.dtype = compute_dtype
112
+ # Assume default device for now, could be passed in config
113
+
114
+ if use_pre_norm:
115
+ self.pre_norm = RMSNorm(
116
+ embed_dim,
117
+ eps=config.model.normalization_layer_epsilon,
118
+ dtype=torch.float32,
119
+ )
120
+
121
+ self.wi_fused = DenseGeneral(
122
+ in_shapes=(embed_dim,),
123
+ out_features=(
124
+ num_activations,
125
+ intermediate_dim,
126
+ ),
127
+ axis=(-1,),
128
+ dtype=compute_dtype,
129
+ weight_dtype=weight_dtype,
130
+ )
131
+
132
+ self.activation_fn_0 = get_activation_fn(activations[0]) # silu
133
+ self.activation_fn_1 = get_activation_fn(activations[1]) # linear
134
+
135
+ self.dropout = nn.Dropout(dropout_rate)
136
+
137
+ # Output layer using DenseGeneral
138
+ self.wo = DenseGeneral(
139
+ in_shapes=(intermediate_dim,),
140
+ out_features=(embed_dim,),
141
+ axis=(-1,),
142
+ dtype=compute_dtype,
143
+ weight_dtype=weight_dtype,
144
+ )
145
+
146
+ def forward(self, x: torch.Tensor, deterministic: bool) -> torch.Tensor:
147
+ """Forward pass."""
148
+ if self.use_pre_norm and hasattr(self, "pre_norm"):
149
+ x = self.pre_norm(x)
150
+
151
+ fused_x = self.wi_fused(x)
152
+
153
+ gate_input = fused_x[..., 0, :]
154
+ up_input = fused_x[..., 1, :]
155
+
156
+ gate = self.activation_fn_0(gate_input)
157
+ up = self.activation_fn_1(up_input)
158
+ hidden = torch.mul(gate, up).to(self.dtype)
159
+
160
+ if not deterministic:
161
+ hidden = self.dropout(hidden)
162
+
163
+ output = self.wo(hidden)
164
+ return output
165
+
166
+
167
+ class RotaryEmbedding(nn.Module):
168
+ """Rotary Position Embedding (RoPE) implementation in PyTorch."""
169
+
170
+ def __init__(
171
+ self,
172
+ embedding_dims: int,
173
+ min_timescale: int = 1,
174
+ max_timescale: int = 10000,
175
+ dtype: torch.dtype = torch.float32,
176
+ ):
177
+ super().__init__()
178
+ if embedding_dims % 2 != 0:
179
+ raise ValueError("Embedding dim must be even for RoPE.")
180
+ self.embedding_dims = embedding_dims
181
+ self.min_timescale = min_timescale
182
+ self.max_timescale = max_timescale
183
+ self.dtype = dtype
184
+
185
+ half_embedding_dim = embedding_dims // 2
186
+ fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
187
+ self.register_buffer(
188
+ "timescale",
189
+ self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction,
190
+ persistent=False,
191
+ )
192
+
193
+ def extra_repr(self) -> str:
194
+ s = f"{self.timescale.shape}"
195
+ return s
196
+
197
+ def forward(self, inputs: torch.Tensor, position: torch.Tensor):
198
+ """Applies RoPE."""
199
+ position = position.unsqueeze(-1).unsqueeze(-1)
200
+ timescale = self.timescale.to(inputs.device)
201
+ sinusoid_inp = position / timescale
202
+ sin = torch.sin(sinusoid_inp).to(inputs.dtype)
203
+ cos = torch.cos(sinusoid_inp).to(inputs.dtype)
204
+ first_half, second_half = torch.chunk(inputs, 2, dim=-1)
205
+ first_part = first_half * cos - second_half * sin
206
+ second_part = second_half * cos + first_half * sin
207
+ return torch.cat((first_part, second_part), dim=-1)
208
+
209
+
210
+ class KVCache:
211
+ def __init__(self, num_heads, max_len, head_dim, device, k=None, v=None):
212
+ self.k = torch.zeros((2, num_heads, max_len, head_dim), device=device) if k is None else k
213
+ self.v = torch.zeros((2, num_heads, max_len, head_dim), device=device) if v is None else v
214
+ self.current_idx = 0
215
+ self.max_len = max_len
216
+
217
+ def get_kv_for_attention(self, current_k, current_v):
218
+ if self.current_idx == 0:
219
+ return current_k, current_v
220
+ else:
221
+ past_k = self.k[:, :, : self.current_idx, :]
222
+ past_v = self.v[:, :, : self.current_idx, :]
223
+ attn_k = torch.cat((past_k, current_k), dim=2)
224
+ attn_v = torch.cat((past_v, current_v), dim=2)
225
+ return attn_k, attn_v
226
+
227
+ def update_cache(self, k, v):
228
+ assert self.current_idx < self.max_len
229
+ self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
230
+ self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
231
+ self.current_idx += 1
232
+
233
+ def prefill_kv(self, k, v):
234
+ prefill_len = k.shape[2]
235
+ assert prefill_len <= self.max_len
236
+ self.k[:, :, :prefill_len, :] = k
237
+ self.v[:, :, :prefill_len, :] = v
238
+ self.current_idx = prefill_len
239
+
240
+
241
+ class Attention(nn.Module):
242
+ """Attention using DenseGeneral."""
243
+
244
+ def __init__(
245
+ self,
246
+ config: DiaConfig,
247
+ q_embed_dim: int,
248
+ kv_embed_dim: int,
249
+ num_query_heads: int,
250
+ num_kv_heads: int,
251
+ head_dim: int,
252
+ dropout_rate: float,
253
+ is_cross_attn: bool = False,
254
+ out_embed_dim: int | None = None,
255
+ ):
256
+ super().__init__()
257
+ self.num_query_heads = num_query_heads
258
+ self.num_kv_heads = num_kv_heads
259
+ self.head_dim = head_dim
260
+ self.is_cross_attn = is_cross_attn
261
+ self.dropout_rate = dropout_rate
262
+ compute_dtype = _str_to_dtype(config.training.dtype)
263
+ weight_dtype = _str_to_dtype(config.model.weight_dtype)
264
+ self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
265
+ self.projected_query_dim = num_query_heads * head_dim
266
+ if num_query_heads % num_kv_heads != 0:
267
+ raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})")
268
+ self.num_gqa_groups = num_query_heads // num_kv_heads
269
+
270
+ # --- Projection Layers using DenseGeneral ---
271
+ self.q_proj = DenseGeneral(
272
+ in_shapes=(q_embed_dim,),
273
+ out_features=(num_query_heads, head_dim),
274
+ axis=(-1,),
275
+ dtype=compute_dtype,
276
+ weight_dtype=weight_dtype,
277
+ )
278
+ self.k_proj = DenseGeneral(
279
+ in_shapes=(kv_embed_dim,),
280
+ out_features=(num_kv_heads, head_dim),
281
+ axis=(-1,),
282
+ dtype=compute_dtype,
283
+ weight_dtype=weight_dtype,
284
+ )
285
+ self.v_proj = DenseGeneral(
286
+ in_shapes=(kv_embed_dim,),
287
+ out_features=(num_kv_heads, head_dim),
288
+ axis=(-1,),
289
+ dtype=compute_dtype,
290
+ weight_dtype=weight_dtype,
291
+ )
292
+ self.o_proj = DenseGeneral(
293
+ in_shapes=(num_query_heads, head_dim),
294
+ out_features=(self.output_dim,),
295
+ axis=(-2, -1),
296
+ dtype=compute_dtype,
297
+ weight_dtype=weight_dtype,
298
+ )
299
+
300
+ # --- Rotary Embedding ---
301
+ self.rotary_emb = RotaryEmbedding(
302
+ embedding_dims=self.head_dim,
303
+ min_timescale=config.model.rope_min_timescale,
304
+ max_timescale=config.model.rope_max_timescale,
305
+ dtype=compute_dtype,
306
+ )
307
+
308
+ def forward(
309
+ self,
310
+ Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
311
+ Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
312
+ q_positions: torch.Tensor, # (B, T)
313
+ kv_positions: torch.Tensor | None = None, # (B, S)
314
+ deterministic: bool = True,
315
+ attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others
316
+ cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
317
+ prefill: bool = False, # True only when prefilling KV Cache
318
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
319
+ """
320
+ Performs attention calculation with optional KV caching.
321
+
322
+ Args:
323
+ Xq: Query tensor (B, T, D). T=1 during single-step decoding.
324
+ Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
325
+ q_positions: Positions for queries (B, T).
326
+ kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
327
+ deterministic: If True, disable dropout.
328
+ attn_mask: Attention mask.
329
+ cache: KVCache.
330
+ prefill: If True, use prefill mode.
331
+
332
+ Returns:
333
+ A tuple containing:
334
+ - output: The attention output tensor (B, T, output_dim).
335
+ - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
336
+ """
337
+ if kv_positions is None:
338
+ kv_positions = q_positions
339
+ original_dtype = Xq.dtype
340
+
341
+ Xq_BxTxNxH = self.q_proj(Xq)
342
+ Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
343
+ Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
344
+
345
+ # Input values into attention calculation
346
+ attn_k: torch.Tensor | None = None
347
+ attn_v: torch.Tensor | None = None
348
+ new_kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None
349
+
350
+ # Decoder Cross Attention
351
+ if self.is_cross_attn:
352
+ # Directly use cache (no need to check index)
353
+ attn_k, attn_v = cache.k, cache.v
354
+ if attn_k.shape[1] != self.num_query_heads or attn_v.shape[1] != self.num_query_heads:
355
+ raise ValueError(
356
+ f"Cross-attention cache head dimension ({attn_k.shape[1]}) "
357
+ f"does not match num_query_heads ({self.num_query_heads}). "
358
+ "Cache should be pre-repeated for GQA."
359
+ )
360
+ # Self Attention
361
+ else:
362
+ Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
363
+ Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
364
+ Xk_BxSxKxH = self.rotary_emb(Xk_BxSxKxH, position=kv_positions) # (B, S, K, H)
365
+
366
+ Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
367
+ Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
368
+ # S=1 for Decode Step
369
+
370
+ if self.num_gqa_groups > 1:
371
+ Xk_BxNxSxH = Xk_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
372
+ Xv_BxNxSxH = Xv_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
373
+ else:
374
+ Xk_BxNxSxH = Xk_BxKxSxH
375
+ Xv_BxNxSxH = Xv_BxKxSxH
376
+
377
+ # Encoder Self Attention
378
+ if cache is None:
379
+ attn_k = Xk_BxNxSxH
380
+ attn_v = Xv_BxNxSxH
381
+ # Decoder Self Attention
382
+ else:
383
+ # In prefill mode, we fill in cache until prefill length
384
+ if prefill:
385
+ attn_k, attn_v = Xk_BxNxSxH, Xv_BxNxSxH
386
+ cache.prefill_kv(attn_k, attn_v)
387
+ # In decode step, we add current K/V to cache step by step
388
+ else:
389
+ new_kv_cache = Xk_BxNxSxH, Xv_BxNxSxH
390
+ attn_k, attn_v = cache.get_kv_for_attention(Xk_BxNxSxH, Xv_BxNxSxH)
391
+
392
+ attn_output = F.scaled_dot_product_attention(
393
+ Xq_BxNxTxH,
394
+ attn_k,
395
+ attn_v,
396
+ attn_mask=attn_mask,
397
+ dropout_p=self.dropout_rate if not deterministic else 0.0,
398
+ scale=1.0,
399
+ )
400
+
401
+ attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
402
+ output = self.o_proj(attn_output)
403
+
404
+ return output.to(original_dtype), new_kv_cache
405
+
406
+
407
+ class EncoderLayer(nn.Module):
408
+ """Transformer Encoder Layer using DenseGeneral."""
409
+
410
+ def __init__(self, config: DiaConfig):
411
+ super().__init__()
412
+ self.config = config
413
+ model_config = config.model
414
+ enc_config = config.model.encoder
415
+ embed_dim = enc_config.n_embd
416
+
417
+ self.pre_sa_norm = RMSNorm(
418
+ embed_dim,
419
+ eps=model_config.normalization_layer_epsilon,
420
+ dtype=torch.float32,
421
+ )
422
+ self.self_attention = Attention(
423
+ config=config,
424
+ q_embed_dim=embed_dim,
425
+ kv_embed_dim=embed_dim,
426
+ num_query_heads=enc_config.n_head,
427
+ num_kv_heads=enc_config.n_head,
428
+ head_dim=enc_config.head_dim,
429
+ dropout_rate=model_config.dropout,
430
+ is_cross_attn=False,
431
+ out_embed_dim=embed_dim,
432
+ )
433
+ self.post_sa_norm = RMSNorm(
434
+ embed_dim,
435
+ eps=model_config.normalization_layer_epsilon,
436
+ dtype=torch.float32,
437
+ )
438
+ self.mlp = MlpBlock(
439
+ config=config,
440
+ embed_dim=embed_dim,
441
+ intermediate_dim=enc_config.n_hidden,
442
+ activations=enc_config.mlp_activations,
443
+ dropout_rate=model_config.dropout,
444
+ use_pre_norm=enc_config.use_pre_norm,
445
+ )
446
+ self.dropout = nn.Dropout(model_config.dropout)
447
+
448
+ def forward(
449
+ self,
450
+ x: torch.Tensor,
451
+ src_positions: torch.Tensor | None = None,
452
+ deterministic: bool = True,
453
+ attn_mask: torch.Tensor | None = None,
454
+ ) -> torch.Tensor:
455
+ residual = x
456
+ x_norm = self.pre_sa_norm(x)
457
+
458
+ sa_out, _ = self.self_attention(
459
+ Xq=x_norm,
460
+ Xkv=x_norm,
461
+ q_positions=src_positions,
462
+ kv_positions=src_positions,
463
+ deterministic=deterministic,
464
+ attn_mask=attn_mask,
465
+ )
466
+ x = residual + sa_out
467
+
468
+ residual = x
469
+ x_norm = self.post_sa_norm(x)
470
+ mlp_out = self.mlp(x_norm, deterministic=deterministic)
471
+ x = residual + mlp_out
472
+
473
+ if not deterministic:
474
+ x = self.dropout(x)
475
+ return x
476
+
477
+
478
+ class Encoder(nn.Module):
479
+ """Transformer Encoder Stack using DenseGeneral."""
480
+
481
+ def __init__(self, config: DiaConfig):
482
+ super().__init__()
483
+ self.config = config
484
+ model_config = config.model
485
+ enc_config = config.model.encoder
486
+ compute_dtype = _str_to_dtype(config.training.dtype)
487
+
488
+ self.embedding = nn.Embedding(
489
+ model_config.src_vocab_size,
490
+ enc_config.n_embd,
491
+ dtype=compute_dtype,
492
+ )
493
+ self.dropout = nn.Dropout(model_config.dropout)
494
+ self.layers = nn.ModuleList([EncoderLayer(config=config) for _ in range(enc_config.n_layer)])
495
+ self.norm = RMSNorm(
496
+ enc_config.n_embd,
497
+ eps=model_config.normalization_layer_epsilon,
498
+ dtype=torch.float32,
499
+ )
500
+
501
+ def forward(
502
+ self,
503
+ x_ids: torch.Tensor,
504
+ src_positions: torch.Tensor | None = None,
505
+ deterministic: bool = True,
506
+ attn_mask: torch.Tensor | None = None,
507
+ ) -> torch.Tensor:
508
+ x = self.embedding(x_ids)
509
+
510
+ if not deterministic:
511
+ x = self.dropout(x)
512
+
513
+ for layer in self.layers:
514
+ x = layer(
515
+ x,
516
+ src_positions=src_positions,
517
+ deterministic=deterministic,
518
+ attn_mask=attn_mask,
519
+ )
520
+ x = self.norm(x)
521
+ if not deterministic:
522
+ x = self.dropout(x)
523
+ return x
524
+
525
+
526
+ class DecoderLayer(nn.Module):
527
+ """Transformer Decoder Layer using DenseGeneral."""
528
+
529
+ def __init__(self, config: DiaConfig):
530
+ super().__init__()
531
+ self.config = config
532
+ model_config = config.model
533
+ dec_config = config.model.decoder
534
+ enc_config = config.model.encoder
535
+ dec_embed_dim = dec_config.n_embd
536
+ enc_embed_dim = enc_config.n_embd
537
+
538
+ # Norms
539
+ self.pre_sa_norm = RMSNorm(
540
+ dec_embed_dim,
541
+ eps=model_config.normalization_layer_epsilon,
542
+ dtype=torch.float32,
543
+ )
544
+ self.pre_ca_norm = RMSNorm(
545
+ dec_embed_dim,
546
+ eps=model_config.normalization_layer_epsilon,
547
+ dtype=torch.float32,
548
+ )
549
+ self.pre_mlp_norm = RMSNorm(
550
+ dec_embed_dim,
551
+ eps=model_config.normalization_layer_epsilon,
552
+ dtype=torch.float32,
553
+ )
554
+
555
+ # Self-Attention (GQA) with Causal Masking
556
+ self.self_attention = Attention(
557
+ config=config,
558
+ q_embed_dim=dec_embed_dim,
559
+ kv_embed_dim=dec_embed_dim,
560
+ num_query_heads=dec_config.gqa_query_heads,
561
+ num_kv_heads=dec_config.kv_heads,
562
+ head_dim=dec_config.gqa_head_dim,
563
+ dropout_rate=model_config.dropout,
564
+ is_cross_attn=False,
565
+ out_embed_dim=dec_embed_dim,
566
+ )
567
+ # Cross-Attention (MHA)
568
+ self.cross_attention = Attention(
569
+ config=config,
570
+ q_embed_dim=dec_embed_dim,
571
+ kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
572
+ num_query_heads=dec_config.cross_query_heads,
573
+ num_kv_heads=dec_config.cross_query_heads,
574
+ head_dim=dec_config.cross_head_dim,
575
+ dropout_rate=model_config.dropout,
576
+ is_cross_attn=True,
577
+ out_embed_dim=dec_embed_dim,
578
+ )
579
+ # MLP
580
+ self.mlp = MlpBlock(
581
+ config=config,
582
+ embed_dim=dec_embed_dim,
583
+ intermediate_dim=dec_config.n_hidden,
584
+ activations=dec_config.mlp_activations,
585
+ dropout_rate=model_config.dropout,
586
+ use_pre_norm=dec_config.use_pre_norm,
587
+ )
588
+
589
+ def forward(
590
+ self,
591
+ x: torch.Tensor,
592
+ encoder_out: torch.Tensor,
593
+ tgt_positions: torch.Tensor,
594
+ src_positions: torch.Tensor | None,
595
+ deterministic: bool,
596
+ self_attn_mask: torch.Tensor,
597
+ cross_attn_mask: torch.Tensor,
598
+ self_attn_cache: KVCache,
599
+ cross_attn_cache: KVCache,
600
+ prefill: bool = False,
601
+ ) -> torch.Tensor:
602
+ residual = x
603
+ x_norm = self.pre_sa_norm(x)
604
+
605
+ sa_out, new_kv_cache = self.self_attention(
606
+ Xq=x_norm, # (2, 1, D)
607
+ Xkv=x_norm, # (2, 1, D)
608
+ q_positions=tgt_positions, # (2, 1)
609
+ kv_positions=tgt_positions, # (2, 1)
610
+ deterministic=deterministic,
611
+ attn_mask=self_attn_mask, # (2, 1, 1, S_max)
612
+ cache=self_attn_cache,
613
+ prefill=prefill,
614
+ )
615
+
616
+ x = residual + sa_out
617
+
618
+ # 2. Cross-Attention
619
+ residual = x
620
+ x_norm = self.pre_ca_norm(x)
621
+ ca_out, _ = self.cross_attention(
622
+ Xq=x_norm,
623
+ Xkv=encoder_out,
624
+ q_positions=tgt_positions,
625
+ kv_positions=src_positions,
626
+ deterministic=deterministic,
627
+ attn_mask=cross_attn_mask,
628
+ cache=cross_attn_cache,
629
+ )
630
+ x = residual + ca_out
631
+
632
+ # 3. MLP
633
+ residual = x
634
+ x_norm = self.pre_mlp_norm(x)
635
+ mlp_out = self.mlp(x_norm, deterministic=deterministic)
636
+ x = residual + mlp_out
637
+
638
+ return x, new_kv_cache
639
+
640
+
641
+ class Decoder(nn.Module):
642
+ """Transformer Decoder Stack using DenseGeneral."""
643
+
644
+ def __init__(self, config: DiaConfig):
645
+ super().__init__()
646
+ self.config = config
647
+ model_config = config.model
648
+ dec_config = config.model.decoder
649
+ train_config = config.training
650
+ data_config = config.data
651
+ compute_dtype = _str_to_dtype(config.training.dtype)
652
+ weight_dtype = _str_to_dtype(config.model.weight_dtype)
653
+ self.num_channels = data_config.channels
654
+ self.num_layers = dec_config.n_layer
655
+
656
+ self.embeddings = nn.ModuleList(
657
+ [
658
+ nn.Embedding(model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype)
659
+ for _ in range(self.num_channels)
660
+ ]
661
+ )
662
+ self.dropout = nn.Dropout(model_config.dropout)
663
+ self.layers = nn.ModuleList([DecoderLayer(config=config) for _ in range(self.num_layers)])
664
+ self.norm = RMSNorm(
665
+ dec_config.n_embd,
666
+ eps=model_config.normalization_layer_epsilon,
667
+ dtype=torch.float32,
668
+ )
669
+
670
+ # Final Logits Projection using DenseGeneral
671
+ self.logits_dense = DenseGeneral(
672
+ in_shapes=(dec_config.n_embd,),
673
+ out_features=(self.num_channels, model_config.tgt_vocab_size),
674
+ axis=(-1,),
675
+ dtype=(torch.float32 if train_config.logits_dot_in_fp32 else compute_dtype),
676
+ weight_dtype=weight_dtype,
677
+ )
678
+ self.logits_in_fp32 = train_config.logits_dot_in_fp32
679
+
680
+ def precompute_cross_attention_kv(
681
+ self,
682
+ max_len: int,
683
+ encoder_out: torch.Tensor, # (B, S, E)
684
+ src_positions: torch.Tensor | None, # (B, S)
685
+ ) -> list[KVCache]:
686
+ """
687
+ Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
688
+ """
689
+ per_layer_kv_cache: list[KVCache] = []
690
+
691
+ for layer in self.layers:
692
+ cross_attn_module = layer.cross_attention
693
+ k_proj = cross_attn_module.k_proj(encoder_out)
694
+ v_proj = cross_attn_module.v_proj(encoder_out)
695
+
696
+ k_proj = cross_attn_module.rotary_emb(k_proj, position=src_positions)
697
+ k = k_proj.transpose(1, 2)
698
+ v = v_proj.transpose(1, 2)
699
+
700
+ per_layer_kv_cache.append(
701
+ KVCache(
702
+ cross_attn_module.num_kv_heads,
703
+ max_len,
704
+ cross_attn_module.head_dim,
705
+ k.device,
706
+ k=k,
707
+ v=v,
708
+ )
709
+ )
710
+
711
+ return per_layer_kv_cache
712
+
713
+ def decode_step(
714
+ self,
715
+ tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
716
+ tgt_pos_Bx1: torch.Tensor, # [B, 1]
717
+ encoder_out: torch.Tensor, # [B, S, E]
718
+ self_attn_mask: Any, # None
719
+ cross_attn_mask: torch.Tensor, # [B, 1, 1, S]
720
+ self_attention_cache: list[KVCache],
721
+ cross_attention_cache: list[KVCache],
722
+ ) -> torch.Tensor:
723
+ """
724
+ Performs a single decoding step, managing KV caches layer by layer.
725
+
726
+ Returns:
727
+ A tuple containing:
728
+ - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
729
+ """
730
+ assert self_attn_mask is None, "Self-attention mask should be None, kept for pattern"
731
+
732
+ x = None
733
+ for i in range(self.num_channels):
734
+ channel_tokens = tgt_ids_Bx1xC[..., i]
735
+ channel_embed = self.embeddings[i](channel_tokens)
736
+ x = channel_embed if x is None else x + channel_embed
737
+
738
+ new_cache = []
739
+
740
+ for i, layer in enumerate(self.layers):
741
+ self_cache = self_attention_cache[i]
742
+ cross_cache = cross_attention_cache[i]
743
+ x, new_kv_cache = layer(
744
+ x, # (2, 1, D)
745
+ encoder_out, # (2, S, E)
746
+ src_positions=None, # CA KV is already computed
747
+ tgt_positions=tgt_pos_Bx1, # (2, 1)
748
+ deterministic=True,
749
+ self_attn_mask=None,
750
+ cross_attn_mask=cross_attn_mask,
751
+ self_attn_cache=self_cache,
752
+ cross_attn_cache=cross_cache,
753
+ )
754
+ new_cache.append(new_kv_cache)
755
+
756
+ x = self.norm(x)
757
+ logits_Bx1xCxV = self.logits_dense(x)
758
+
759
+ return logits_Bx1xCxV.to(torch.float32), new_cache
760
+
761
+ def forward(
762
+ self,
763
+ tgt_ids_BxTxC: torch.Tensor,
764
+ encoder_out: torch.Tensor,
765
+ tgt_positions: torch.Tensor,
766
+ src_positions: torch.Tensor,
767
+ deterministic: bool,
768
+ self_attn_mask: torch.Tensor,
769
+ cross_attn_mask: torch.Tensor,
770
+ self_attention_cache: list[KVCache],
771
+ cross_attention_cache: list[KVCache],
772
+ ) -> torch.Tensor:
773
+ """
774
+ Forward pass for the Decoder stack, managing KV caches.
775
+
776
+ Args:
777
+ tgt_ids_BxTxC: Target token IDs (B, T, C).
778
+ encoder_out: Output from the encoder (B, S, E).
779
+ tgt_positions: Positions for target sequence (B, T).
780
+ src_positions: Positions for source sequence (B, S).
781
+ deterministic: Disable dropout if True.
782
+ self_attn_mask: Mask for self-attention.
783
+ cross_attn_mask: Mask for cross-attention.
784
+ past_key_values: List containing the self-attention KV cache for each layer
785
+ from the previous decoding step. `len(past_key_values)` should
786
+ equal `num_layers`.
787
+ precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
788
+ derived from `encoder_out`. This is passed identically
789
+ to all layers.
790
+
791
+ Returns:
792
+ A tuple containing:
793
+ - logits: The final output logits (B, T, C * V), cast to float32.
794
+ - present_key_values: A list containing the updated self-attention KV cache
795
+ for each layer for the *current* decoding step.
796
+ """
797
+ _, _, num_channels_in = tgt_ids_BxTxC.shape
798
+ assert num_channels_in == self.num_channels, "Input channels mismatch"
799
+
800
+ # Embeddings
801
+ x = None
802
+ for i in range(self.num_channels):
803
+ channel_tokens = tgt_ids_BxTxC[..., i]
804
+ channel_embed = self.embeddings[i](channel_tokens)
805
+ x = channel_embed if x is None else x + channel_embed
806
+
807
+ if not deterministic:
808
+ x = self.dropout(x)
809
+
810
+ for i, layer in enumerate(self.layers):
811
+ x, _ = layer(
812
+ x,
813
+ encoder_out,
814
+ tgt_positions=tgt_positions,
815
+ src_positions=src_positions,
816
+ deterministic=deterministic,
817
+ self_attn_mask=self_attn_mask,
818
+ cross_attn_mask=cross_attn_mask,
819
+ self_attn_cache=self_attention_cache[i],
820
+ cross_attn_cache=cross_attention_cache[i],
821
+ prefill=True,
822
+ )
823
+
824
+ # Final Norm
825
+ x = self.norm(x)
826
+ logits_BxTxCxV = self.logits_dense(x)
827
+
828
+ return logits_BxTxCxV.to(torch.float32)
829
+
830
+
831
+ class DiaModel(nn.Module):
832
+ """PyTorch Dia Model using DenseGeneral."""
833
+
834
+ def __init__(self, config: DiaConfig):
835
+ super().__init__()
836
+ self.config = config
837
+ self.encoder = Encoder(config)
838
+ self.decoder = Decoder(config)
839
+ #self._init_weights()
840
+
841
+
842
+ def _init_weights(self):
843
+ for module in self.modules():
844
+ if isinstance(module, (torch.nn.Linear, torch.nn.Conv1d)):
845
+ torch.nn.init.xavier_uniform_(module.weight)
846
+ if module.bias is not None:
847
+ torch.nn.init.zeros_(module.bias)
848
+ elif isinstance(module, torch.nn.Embedding):
849
+ torch.nn.init.xavier_uniform_(module.weight)
850
+ elif isinstance(module, torch.nn.LayerNorm) or isinstance(module, torch.nn.modules.normalization.RMSNorm):
851
+ if hasattr(module, 'weight') and module.weight is not None:
852
+ torch.nn.init.ones_(module.weight)
853
+ if hasattr(module, 'bias') and module.bias is not None:
854
+ torch.nn.init.zeros_(module.bias)
855
+
856
+ def forward(
857
+ self,
858
+ src_BxS: torch.Tensor,
859
+ tgt_BxTxC: torch.Tensor,
860
+ src_positions: torch.Tensor | None = None,
861
+ tgt_positions: torch.Tensor | None = None,
862
+ enc_self_attn_mask: torch.Tensor | None = None,
863
+ dec_self_attn_mask: torch.Tensor | None = None,
864
+ dec_cross_attn_mask: torch.Tensor | None = None,
865
+ enable_dropout: bool = True,
866
+ ):
867
+ deterministic = not enable_dropout
868
+
869
+ # --- Encoder Pass ---
870
+ encoder_out = self.encoder(
871
+ x_ids=src_BxS,
872
+ src_positions=src_positions,
873
+ deterministic=deterministic,
874
+ attn_mask=enc_self_attn_mask,
875
+ )
876
+
877
+ B, T, C = tgt_BxTxC.shape # Batch size, target sequence length, channels
878
+ device = tgt_BxTxC.device
879
+
880
+ self_attention_cache = [
881
+ KVCache(
882
+ num_heads=self.decoder.layers[i].self_attention.num_query_heads, # ✅ FIXED: use query heads!
883
+ max_len=T,
884
+ head_dim=self.decoder.layers[i].self_attention.head_dim,
885
+ device=device,
886
+ )
887
+ for i in range(self.decoder.num_layers)
888
+ ]
889
+
890
+ cross_attention_cache = self.decoder.precompute_cross_attention_kv(
891
+ max_len=encoder_out.shape[1],
892
+ encoder_out=encoder_out,
893
+ src_positions=src_positions,
894
+ )
895
+
896
+ # --- Decoder Pass ---
897
+ logits = self.decoder(
898
+ tgt_ids_BxTxC=tgt_BxTxC,
899
+ encoder_out=encoder_out,
900
+ tgt_positions=tgt_positions,
901
+ src_positions=src_positions,
902
+ deterministic=deterministic,
903
+ self_attn_mask=dec_self_attn_mask,
904
+ cross_attn_mask=dec_cross_attn_mask,
905
+ self_attention_cache=self_attention_cache,
906
+ cross_attention_cache=cross_attention_cache
907
+ )
908
+
909
+ return logits
dia/model.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dac
2
+ import numpy as np
3
+ import torch
4
+ import torchaudio
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ from .audio import audio_to_codebook, codebook_to_audio
8
+ from .config import DiaConfig
9
+ from .layers import DiaModel, KVCache
10
+
11
+
12
+ def get_default_device():
13
+ if torch.cuda.is_available():
14
+ return torch.device("cuda")
15
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
16
+ return torch.device("mps")
17
+ return torch.device("cpu")
18
+
19
+
20
+ def _sample_next_token(
21
+ logits_BCxV: torch.Tensor,
22
+ temperature: float,
23
+ top_p: float,
24
+ use_cfg_filter: bool,
25
+ cfg_filter_top_k: int | None = None,
26
+ ) -> torch.Tensor:
27
+ if temperature == 0.0:
28
+ return torch.argmax(logits_BCxV, dim=-1)
29
+
30
+ logits_BCxV = logits_BCxV / temperature
31
+ if use_cfg_filter and cfg_filter_top_k is not None:
32
+ _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
33
+ mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
34
+ mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
35
+ logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
36
+
37
+ if top_p < 1.0:
38
+ probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
39
+ sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True)
40
+ cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
41
+
42
+ # Calculate indices to remove based on top_p
43
+ sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
44
+ # Shift the mask to the right to keep the first token above the threshold
45
+ sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[..., :-1].clone()
46
+ sorted_indices_to_remove_BCxV[..., 0] = 0 # Always keep the most probable token
47
+
48
+ indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
49
+ indices_to_remove_BCxV.scatter_(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV)
50
+ logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
51
+
52
+ final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
53
+
54
+ sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1)
55
+ sampled_indices_C = sampled_indices_BC.squeeze(-1)
56
+ return sampled_indices_C
57
+
58
+
59
+ class Dia:
60
+ def __init__(self, config: DiaConfig, device: torch.device | None = None):
61
+ """Initializes the Dia model.
62
+
63
+ Args:
64
+ config: The configuration object for the model.
65
+ device: The device to load the model onto. If None, will automatically select the best available device.
66
+
67
+ Raises:
68
+ RuntimeError: If there is an error loading the DAC model.
69
+ """
70
+ super().__init__()
71
+ self.config = config
72
+ self.device = device if device is not None else get_default_device()
73
+ self.model = DiaModel(config)
74
+ self.dac_model = None
75
+
76
+ @classmethod
77
+ def from_local(cls, config_path: str, checkpoint_path: str, device: torch.device | None = None) -> "Dia":
78
+ """Loads the Dia model from local configuration and checkpoint files.
79
+
80
+ Args:
81
+ config_path: Path to the configuration JSON file.
82
+ checkpoint_path: Path to the model checkpoint (.pth) file.
83
+ device: The device to load the model onto. If None, will automatically select the best available device.
84
+
85
+ Returns:
86
+ An instance of the Dia model loaded with weights and set to eval mode.
87
+
88
+ Raises:
89
+ FileNotFoundError: If the config or checkpoint file is not found.
90
+ RuntimeError: If there is an error loading the checkpoint.
91
+ """
92
+ config = DiaConfig.load(config_path)
93
+ if config is None:
94
+ raise FileNotFoundError(f"Config file not found at {config_path}")
95
+
96
+ dia = cls(config, device)
97
+
98
+ try:
99
+ #state_dict = torch.load(checkpoint_path, map_location=dia.device)
100
+ #dia.model.load_state_dict(state_dict)
101
+ checkpoint = torch.load(checkpoint_path, map_location=device)
102
+ if "model" in checkpoint:
103
+ state_dict = checkpoint["model"] # lấy riêng phần model
104
+ else:
105
+ state_dict = checkpoint
106
+ dia.model.load_state_dict(state_dict)
107
+ except FileNotFoundError:
108
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
109
+ except Exception as e:
110
+ raise RuntimeError(f"Error loading checkpoint from {checkpoint_path}") from e
111
+
112
+ dia.model.to(dia.device)
113
+ dia.model.eval()
114
+ dia._load_dac_model()
115
+ return dia
116
+
117
+ @classmethod
118
+ def from_pretrained(
119
+ cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device | None = None
120
+ ) -> "Dia":
121
+ """Loads the Dia model from a Hugging Face Hub repository.
122
+
123
+ Downloads the configuration and checkpoint files from the specified
124
+ repository ID and then loads the model.
125
+
126
+ Args:
127
+ model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
128
+ device: The device to load the model onto. If None, will automatically select the best available device.
129
+
130
+ Returns:
131
+ An instance of the Dia model loaded with weights and set to eval mode.
132
+
133
+ Raises:
134
+ FileNotFoundError: If config or checkpoint download/loading fails.
135
+ RuntimeError: If there is an error loading the checkpoint.
136
+ """
137
+ config_path = hf_hub_download(repo_id=model_name, filename="config.json")
138
+ checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
139
+ return cls.from_local(config_path, checkpoint_path, device)
140
+
141
+ def _load_dac_model(self):
142
+ try:
143
+ dac_model_path = dac.utils.download()
144
+ dac_model = dac.DAC.load(dac_model_path).to(self.device)
145
+ except Exception as e:
146
+ raise RuntimeError("Failed to load DAC model") from e
147
+ self.dac_model = dac_model
148
+
149
+ def _create_attn_mask(
150
+ self,
151
+ q_padding_mask_1d: torch.Tensor,
152
+ k_padding_mask_1d: torch.Tensor,
153
+ is_causal: bool = False,
154
+ ) -> torch.Tensor:
155
+ """
156
+ Creates the attention mask (self or cross) mimicking JAX segment ID logic.
157
+ """
158
+ B1, Tq = q_padding_mask_1d.shape
159
+ B2, Tk = k_padding_mask_1d.shape
160
+ assert B1 == B2, "Query and key batch dimensions must match"
161
+
162
+ p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
163
+ p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
164
+
165
+ # Condition A: Non-padding query attends to non-padding key
166
+ non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
167
+
168
+ # Condition B: Padding query attends to padding key
169
+ pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
170
+
171
+ # Combine: True if padding status is compatible (both non-pad OR both pad)
172
+ # This implementation follows Jax TPU splash attention kernel
173
+ mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
174
+
175
+ if is_causal:
176
+ # Ensure causality for self-attention (Tq == Tk)
177
+ assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal"
178
+ # Standard lower-triangular causal mask (True means allow)
179
+ causal_mask_2d = torch.tril(torch.ones((Tq, Tk), dtype=torch.bool, device=self.device)) # Shape [Tq, Tk]
180
+ causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
181
+ return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
182
+ else:
183
+ # For cross-attention or non-causal self-attention
184
+ return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
185
+
186
+ def _prepare_text_input(self, text: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
187
+ """Encodes text prompt, pads, and creates attention mask and positions."""
188
+ text_pad_value = self.config.data.text_pad_value
189
+ max_len = self.config.data.text_length
190
+
191
+ byte_text = text.encode("utf-8")
192
+ replaced_bytes = byte_text
193
+
194
+ LANG2BYTE = {
195
+ "en": 3,
196
+ "vi": 19,
197
+ }
198
+
199
+ CHANNELS = [
200
+ "5phutcrypto",
201
+ "anhbanthan",
202
+ "anhthamtu",
203
+ "animerewind.official",
204
+ "bibitv8888",
205
+ "btvgo",
206
+ "baclieutv",
207
+ "bachhoaxanhcom",
208
+ "baodientuvov",
209
+ "blvckvines",
210
+ "boringppl",
211
+ "bronub",
212
+ "cdteam-why",
213
+ "cobabinhduong",
214
+ "cosmicwriter",
215
+ "cuthongthai",
216
+ "daiphatthanhtruyenhinhsonla",
217
+ "day-be-thong-minh-tv",
218
+ "danangtv",
219
+ "daihanoi-htv",
220
+ "daiptththainguyentntv",
221
+ "dongmauviet",
222
+ "dongthaptv",
223
+ "fptbongdaofficial",
224
+ "fonosvietnam",
225
+ "hieurotrong5phut-ntkt",
226
+ "htvtintuc",
227
+ "happyhidari",
228
+ "hoabinhtvgo",
229
+ "hocenglishonline",
230
+ "hocvienbovagau",
231
+ "hungyentvvngo",
232
+ "huynhduykhuongofficial",
233
+ "huynhlapofficial",
234
+ "jvevermind",
235
+ "kenhvtc16",
236
+ "kiengiangtv",
237
+ "khanhvyofficial",
238
+ "kienthucquansu",
239
+ "lamdongtv1",
240
+ "lamvlog",
241
+ "longantv-la34",
242
+ "mangovid",
243
+ "mensbay",
244
+ "meovatcuocsonglnv",
245
+ "meuchannel",
246
+ "ntnvlogsnguyenthanhnam",
247
+ "ngamradio",
248
+ "nhanhac555",
249
+ "nhantaidaiviet",
250
+ "ptth-trt",
251
+ "ptvtruyenhinhphutho",
252
+ "phantichgame",
253
+ "phephim",
254
+ "phimhottk-l",
255
+ "riwaylegal",
256
+ "ruangao",
257
+ "suckhoetamsinh",
258
+ "sachbiquyethanhcong",
259
+ "soisangbrightsidevietnamese",
260
+ "spiderum",
261
+ "spiderumbooks",
262
+ "sukieskitchen",
263
+ "tin3phut",
264
+ "tranthanhtown",
265
+ "tulemientay",
266
+ "tayninhtv",
267
+ "thainhitv",
268
+ "thanhpahm",
269
+ "thegioilaptop",
270
+ "thepresentwriter",
271
+ "tiengiangtivi",
272
+ "tieubaobaothom",
273
+ "tintucbitcoin247",
274
+ "truyenhinhbinhphuoc-bptv",
275
+ "truyenhinhyenbaiytv",
276
+ "truyenhinhcaobang",
277
+ "truyenhinhdaklakdrt",
278
+ "truyenhinhdaknong1",
279
+ "truyenhinhdienbien23.9",
280
+ "truyenhinhkhanhhoa",
281
+ "truyenhinhkontumkrt",
282
+ "truyenhinhnaminhntv",
283
+ "truyenhinhninhthuan",
284
+ "truyenhinhquangngai",
285
+ "tuantienti2911",
286
+ "tuyenquangttv",
287
+ "vovlivedoctruyen",
288
+ "vietcetera",
289
+ "vinhlongtv",
290
+ "voizfm",
291
+ "vutrunguyenthuy",
292
+ "vuive",
293
+ "w2wanime",
294
+ "w2wcartoon",
295
+ "w2whorror",
296
+ "w2wmovie",
297
+ "web5ngay",
298
+ "xanh24h",
299
+ "aiphatthanhtruyenhinhquangtri",
300
+ "aiphatthanhvatruyenhinhhai1908",
301
+ "altonghop",
302
+ "antvtruyenhinhcongannhandan",
303
+ "baihoc10phut",
304
+ "battlecry.khampha",
305
+ "betterversionvn",
306
+ "blogkhoinghiep",
307
+ "bumcn",
308
+ "caikinhdi_vn",
309
+ "canthitg",
310
+ "chanthienmybachnien",
311
+ "chauanhchao",
312
+ "cosu",
313
+ "cungmaivaobep-monan-amthuc",
314
+ "daiptthphuyen",
315
+ "daiptthtv",
316
+ "daitruyenhinhangiang",
317
+ "daitruyenhinhbacgiang",
318
+ "dannytran2375",
319
+ "daybehoc5489",
320
+ "daylaphegame",
321
+ "dienmay",
322
+ "ducisreal",
323
+ "duongfg",
324
+ "duyluandethuong",
325
+ "duythanhish",
326
+ "elroydevops",
327
+ "gc.gamelab",
328
+ "hacthaybachthay",
329
+ "hagiangtv475",
330
+ "haiduongtv247",
331
+ "hanamtv8831",
332
+ "hangphimtailieudienanhnd",
333
+ "haugiangtv",
334
+ "haunauday",
335
+ "hieu-tv",
336
+ "hoshiphan",
337
+ "jakinatsumi2915",
338
+ "kechuyentieuhoc1719",
339
+ "kenhcovan",
340
+ "khalid_dinh",
341
+ "kiaralah",
342
+ "laichautv",
343
+ "langsontvtube",
344
+ "megame_official",
345
+ "minvestvn",
346
+ "nguoithanhcong1991",
347
+ "nhatkycuocsong.",
348
+ "ntcanima",
349
+ "ptthbentre",
350
+ "ptthquangbinh",
351
+ "qrt",
352
+ "quangninhtv",
353
+ "snewsvn",
354
+ "soctrangtv",
355
+ "sunhuynpodcast",
356
+ "tamhonanuong",
357
+ "tgddreview",
358
+ "thaibinhtv",
359
+ "thanhnamedu",
360
+ "thanhnientvnews",
361
+ "thbrt",
362
+ "thieunhitv3630",
363
+ "thtpct",
364
+ "tinnhanh3phut868",
365
+ "toansam",
366
+ "toidicodedaoblog",
367
+ "tranquochuywecommit",
368
+ "tranvyvy",
369
+ "truyenhinh4k",
370
+ "truyenhinhbinhthuan",
371
+ "truyenhinhcamau69",
372
+ "truyenhinhdongnai_dnrtv",
373
+ "truyenhinhgialai",
374
+ "truyenhinhlaocai",
375
+ "truyenhinhnghean",
376
+ "truyenhinhvinhphuc",
377
+ "txtofficial8798",
378
+ "vanhkhuyenle",
379
+ "vietnh1009",
380
+ "visaothenhipodcast",
381
+ "vtc14",
382
+ "vtcnow",
383
+ "vtv24",
384
+ "vuive123",
385
+ "zombiev4",
386
+ ]
387
+ LANG2BYTE.update({ch: 30 + i for i, ch in enumerate(CHANNELS)})
388
+ # Thay thế tag thành mã byte
389
+ for tag, byte_val in LANG2BYTE.items():
390
+ pattern = f"[{tag}]".encode("ascii") # ví dụ b"[5phutcrypto]"
391
+ code = bytes([byte_val]) # ví dụ b"\x1e"
392
+ replaced_bytes = replaced_bytes.replace(pattern, code)
393
+ text_tokens = list(replaced_bytes)
394
+
395
+ current_len = len(text_tokens)
396
+ padding_needed = max_len - current_len
397
+ if padding_needed <= 0:
398
+ text_tokens = text_tokens[:max_len]
399
+ padded_text_np = np.array(text_tokens, dtype=np.uint8)
400
+ else:
401
+ padded_text_np = np.pad(
402
+ text_tokens,
403
+ (0, padding_needed),
404
+ mode="constant",
405
+ constant_values=text_pad_value,
406
+ ).astype(np.uint8)
407
+
408
+ src_tokens = torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0) # [1, S]
409
+ src_positions = torch.arange(max_len, device=self.device).to(torch.long).unsqueeze(0) # [1, S]
410
+
411
+ src_padding_mask = (src_tokens != text_pad_value).to(self.device) # [1, S]
412
+
413
+ enc_self_attn_mask = self._create_attn_mask(src_padding_mask, src_padding_mask, is_causal=False) # [1, S, S]
414
+
415
+ return src_tokens, src_positions, src_padding_mask, enc_self_attn_mask
416
+
417
+ @torch.inference_mode()
418
+ def generate(
419
+ self,
420
+ text: str,
421
+ max_tokens: int | None = None,
422
+ cfg_scale: float = 3.0,
423
+ temperature: float = 1.3,
424
+ top_p: float = 0.95,
425
+ use_cfg_filter: bool = True,
426
+ use_torch_compile: bool = False,
427
+ cfg_filter_top_k: int = 35,
428
+ audio_prompt_path: str | None = None,
429
+ ) -> np.ndarray:
430
+ """
431
+ Generates audio from a text prompt (and optional audio prompt) using the Nari model.
432
+
433
+ Returns:
434
+ A tensor of generated audio codes (shape: [max_tokens, num_channels]).
435
+ """
436
+ num_channels = self.config.data.channels
437
+ audio_bos_value = self.config.data.audio_bos_value
438
+ audio_eos_value = self.config.data.audio_eos_value
439
+ audio_pad_value = self.config.data.audio_pad_value
440
+ delay_pattern = self.config.data.delay_pattern
441
+ max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
442
+ delay_tensor = torch.tensor(delay_pattern, dtype=torch.long, device=self.device)
443
+ max_delay_pattern = max(delay_pattern)
444
+ self.model.eval()
445
+
446
+ (
447
+ cond_src_BxS,
448
+ cond_src_positions_BxS,
449
+ cond_src_padding_mask_BxS,
450
+ cond_enc_self_attn_mask_Bx1xSxS,
451
+ ) = self._prepare_text_input(text)
452
+
453
+ unc_src_BxS = torch.zeros_like(cond_src_BxS)
454
+ src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0)
455
+ src_positions_BxS = cond_src_positions_BxS.expand(2, -1)
456
+ src_padding_mask_BxS = cond_src_padding_mask_BxS.expand(2, -1)
457
+ enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(2, -1, -1, -1)
458
+
459
+ # 2. Encoder Pass
460
+ # with torch.autocast(device_type="cuda", dtype=forward_dtype):
461
+ encoder_out = self.model.encoder(
462
+ x_ids=src_BxS,
463
+ src_positions=src_positions_BxS,
464
+ deterministic=True,
465
+ attn_mask=enc_self_attn_mask_Bx1xSxS,
466
+ ) # Shape: (B, S, E)
467
+
468
+ # 3. Prepare Decoder Inputs
469
+ # 3-1. Allocate KV Cache (Static)
470
+ decoder_cross_attention_cache: list[KVCache] = self.model.decoder.precompute_cross_attention_kv(
471
+ max_tokens, encoder_out, src_positions_BxS
472
+ )
473
+
474
+ decoder_self_attention_cache: list[KVCache] = []
475
+ for _ in range(self.model.decoder.num_layers):
476
+ decoder_self_attention_cache.append(
477
+ KVCache(
478
+ self.config.model.decoder.gqa_query_heads,
479
+ max_tokens,
480
+ self.config.model.decoder.gqa_head_dim,
481
+ self.device,
482
+ )
483
+ )
484
+
485
+ # 3-2. Initialize Decoder Inputs
486
+ generated_BxTxC = torch.full(
487
+ (2, 1, num_channels),
488
+ fill_value=audio_bos_value,
489
+ dtype=torch.long,
490
+ device=self.device,
491
+ )
492
+
493
+ current_step = 0
494
+ prompt_len_inc_bos = 1 # Start with BOS length
495
+
496
+ # 3-3. Load Audio Prompt (if provided)
497
+ if audio_prompt_path is not None:
498
+ audio_prompt, sr = torchaudio.load(audio_prompt_path, channels_first=True) # C, T
499
+ if sr != 44100: # Resample to 44.1kHz
500
+ audio_prompt = torchaudio.functional.resample(audio_prompt, sr, 44100)
501
+ audio_prompt = audio_prompt.to(self.device).unsqueeze(0) # 1, C, T
502
+ audio_prompt = audio_to_codebook(self.dac_model, audio_prompt, data_config=self.config.data)
503
+ print("✅ Prompt shape:", audio_prompt.shape)
504
+ generated_BxTxC = torch.cat([generated_BxTxC, audio_prompt.expand(2, -1, -1)], dim=1)
505
+
506
+ prefill_len = generated_BxTxC.shape[1]
507
+ prompt_len_inc_bos = prefill_len
508
+ prefill_tgt_pos = torch.arange(prefill_len, device=self.device).unsqueeze(0).expand(2, -1)
509
+ prefill_tgt_padding_mask = (generated_BxTxC != audio_pad_value).any(dim=2)
510
+
511
+ prefill_self_attn_mask = self._create_attn_mask(
512
+ prefill_tgt_padding_mask,
513
+ prefill_tgt_padding_mask,
514
+ is_causal=True,
515
+ )
516
+ prefill_cross_attn_mask = self._create_attn_mask(
517
+ prefill_tgt_padding_mask,
518
+ src_padding_mask_BxS,
519
+ is_causal=False,
520
+ )
521
+
522
+ _ = self.model.decoder.forward(
523
+ tgt_ids_BxTxC=generated_BxTxC,
524
+ encoder_out=encoder_out,
525
+ tgt_positions=prefill_tgt_pos,
526
+ src_positions=src_positions_BxS,
527
+ deterministic=True,
528
+ self_attn_mask=prefill_self_attn_mask,
529
+ cross_attn_mask=prefill_cross_attn_mask,
530
+ self_attention_cache=decoder_self_attention_cache,
531
+ cross_attention_cache=decoder_cross_attention_cache,
532
+ )
533
+
534
+ current_step = prefill_len - 1
535
+
536
+ # 4. Autoregressive Generation Loop
537
+ eos_detected_channel_0 = False
538
+ eos_countdown = -1
539
+ extra_steps_after_eos = 30
540
+ # Make generated_BxTxC a fixed size tensor
541
+ # Length is either 1 + max tokens or 1 + prompt len + max tokens
542
+ generated_BxTxC = torch.cat(
543
+ [
544
+ generated_BxTxC,
545
+ torch.full(
546
+ (2, max_tokens, num_channels),
547
+ fill_value=-1,
548
+ dtype=torch.long,
549
+ device=self.device,
550
+ ),
551
+ ],
552
+ dim=1,
553
+ )
554
+
555
+ decode_step = self.model.decoder.decode_step
556
+ if use_torch_compile:
557
+ decode_step = torch.compile(
558
+ self.model.decoder.decode_step,
559
+ mode="default",
560
+ )
561
+
562
+ tgt_padding_mask = (
563
+ (generated_BxTxC[:, -1, :].unsqueeze(1) != audio_pad_value).any(dim=2).to(self.device)
564
+ ) # [B, 1]
565
+ # Generated tokens are never PAD, so we use fixed mask
566
+ decoder_cross_attn_mask = self._create_attn_mask(
567
+ tgt_padding_mask, # Query mask [B, 1]
568
+ src_padding_mask_BxS, # Key mask [B, S]
569
+ is_causal=False,
570
+ ) # [B, 1, 1, S]
571
+
572
+ for step in range(current_step, current_step + max_tokens):
573
+ tgt_ids_Bx1xC = generated_BxTxC[:, step, :].unsqueeze(1)
574
+ tgt_pos_Bx1 = torch.full(
575
+ (2, 1),
576
+ fill_value=step,
577
+ dtype=torch.long,
578
+ device=self.device,
579
+ )
580
+
581
+ logits_Bx1xCxV, new_cache = decode_step(
582
+ tgt_ids_Bx1xC=tgt_ids_Bx1xC,
583
+ tgt_pos_Bx1=tgt_pos_Bx1,
584
+ encoder_out=encoder_out,
585
+ self_attn_mask=None,
586
+ cross_attn_mask=decoder_cross_attn_mask,
587
+ self_attention_cache=decoder_self_attention_cache,
588
+ cross_attention_cache=decoder_cross_attention_cache,
589
+ )
590
+
591
+ for i, layer_cache in enumerate(decoder_self_attention_cache):
592
+ layer_cache.update_cache(new_cache[i][0], new_cache[i][1])
593
+
594
+ V = self.config.model.tgt_vocab_size
595
+ logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :] # B, C, V
596
+ uncond_logits_CxV = logits_last_BxCxV[0, :, :]
597
+ cond_logits_CxV = logits_last_BxCxV[1, :, :]
598
+
599
+ cfg_logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
600
+
601
+ logits_CxV = cfg_logits_CxV.reshape((-1, V)) # C, V
602
+ logits_CxV[:, 1025:] = -torch.inf
603
+
604
+ # Sample next token
605
+ pred_C = _sample_next_token(
606
+ logits_CxV.float(),
607
+ temperature=temperature,
608
+ top_p=top_p,
609
+ use_cfg_filter=use_cfg_filter,
610
+ cfg_filter_top_k=cfg_filter_top_k,
611
+ )
612
+
613
+ generation_step_index = step - current_step
614
+ if audio_prompt_path is None:
615
+ pred_C = torch.where(
616
+ generation_step_index >= delay_tensor,
617
+ pred_C,
618
+ audio_bos_value,
619
+ )
620
+
621
+ generated_BxTxC[:, step + 1, :] = pred_C.unsqueeze(0).expand(2, -1)
622
+
623
+ if not eos_detected_channel_0 and pred_C[0] == audio_eos_value:
624
+ eos_detected_channel_0 = True
625
+ eos_countdown = extra_steps_after_eos
626
+
627
+ if eos_countdown > 0:
628
+ step_after_eos = max_delay_pattern - eos_countdown
629
+ for i, d in enumerate(delay_pattern):
630
+ if step_after_eos == d:
631
+ generated_BxTxC[:, step + 1, i] = audio_eos_value
632
+ elif step_after_eos > d:
633
+ generated_BxTxC[:, step + 1, i] = audio_pad_value
634
+ eos_countdown -= 1
635
+ if eos_countdown == 0:
636
+ break
637
+
638
+ generation_step_index = step - current_step + 1
639
+
640
+ output_codes = generated_BxTxC[:, prompt_len_inc_bos : step + 1, :]
641
+
642
+ generated_codes = output_codes[0]
643
+
644
+ audio = codebook_to_audio(
645
+ generated_codes.transpose(1, 0), self.dac_model, delay_pattern, B=1, T=max_tokens, C=num_channels
646
+ )
647
+ print("🟩 Tổng số tokens sinh ra:", generated_codes.shape[0])
648
+ return audio.squeeze().cpu().numpy()