Spaces:
Running
Running
chore: upload folder dia for inference
Browse files- dia/__init__.py +0 -0
- dia/audio.py +286 -0
- dia/config.json +49 -0
- dia/config.py +206 -0
- dia/config_inference.json +49 -0
- dia/convert_ckpt.py +54 -0
- dia/dataset.py +162 -0
- dia/finetune.py +787 -0
- dia/interleaved_datasets.py +144 -0
- dia/layers.py +909 -0
- dia/model.py +648 -0
dia/__init__.py
ADDED
|
File without changes
|
dia/audio.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing as tp
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .config import DataConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_delay_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 9 |
+
"""
|
| 10 |
+
Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
|
| 11 |
+
Negative t_idx => BOS; t_idx >= T => PAD.
|
| 12 |
+
"""
|
| 13 |
+
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
|
| 14 |
+
|
| 15 |
+
t_idx_BxT = torch.broadcast_to(
|
| 16 |
+
torch.arange(T, dtype=torch.int32)[None, :],
|
| 17 |
+
[B, T],
|
| 18 |
+
)
|
| 19 |
+
t_idx_BxTx1 = t_idx_BxT[..., None]
|
| 20 |
+
t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
|
| 21 |
+
|
| 22 |
+
b_idx_BxTxC = torch.broadcast_to(
|
| 23 |
+
torch.arange(B, dtype=torch.int32).view(B, 1, 1),
|
| 24 |
+
[B, T, C],
|
| 25 |
+
)
|
| 26 |
+
c_idx_BxTxC = torch.broadcast_to(
|
| 27 |
+
torch.arange(C, dtype=torch.int32).view(1, 1, C),
|
| 28 |
+
[B, T, C],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
|
| 32 |
+
t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
|
| 33 |
+
|
| 34 |
+
indices_BTCx3 = torch.stack(
|
| 35 |
+
[
|
| 36 |
+
b_idx_BxTxC.reshape(-1),
|
| 37 |
+
t_clamped_BxTxC.reshape(-1),
|
| 38 |
+
c_idx_BxTxC.reshape(-1),
|
| 39 |
+
],
|
| 40 |
+
dim=1,
|
| 41 |
+
).long() # Ensure indices are long type for indexing
|
| 42 |
+
|
| 43 |
+
return t_idx_BxTxC, indices_BTCx3
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def apply_audio_delay(
|
| 47 |
+
audio_BxTxC: torch.Tensor,
|
| 48 |
+
pad_value: int,
|
| 49 |
+
bos_value: int,
|
| 50 |
+
precomp: tp.Tuple[torch.Tensor, torch.Tensor],
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
"""
|
| 53 |
+
Applies the delay pattern to batched audio tokens using precomputed indices,
|
| 54 |
+
inserting BOS where t_idx < 0 and PAD where t_idx >= T.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
|
| 58 |
+
pad_value: the padding token
|
| 59 |
+
bos_value: the BOS token
|
| 60 |
+
precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
result_BxTxC: [B, T, C] delayed audio tokens
|
| 64 |
+
"""
|
| 65 |
+
device = audio_BxTxC.device # Get device from input tensor
|
| 66 |
+
t_idx_BxTxC, indices_BTCx3 = precomp
|
| 67 |
+
t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device
|
| 68 |
+
indices_BTCx3 = indices_BTCx3.to(device)
|
| 69 |
+
|
| 70 |
+
# Equivalent of tf.gather_nd using advanced indexing
|
| 71 |
+
# Ensure indices are long type if not already (build_delay_indices should handle this)
|
| 72 |
+
gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
|
| 73 |
+
gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
|
| 74 |
+
|
| 75 |
+
# Create masks on the correct device
|
| 76 |
+
mask_bos = t_idx_BxTxC < 0 # => place bos_value
|
| 77 |
+
mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
|
| 78 |
+
|
| 79 |
+
# Create scalar tensors on the correct device
|
| 80 |
+
bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
|
| 81 |
+
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
|
| 82 |
+
|
| 83 |
+
# If mask_bos, BOS; else if mask_pad, PAD; else original gather
|
| 84 |
+
# All tensors should now be on the same device
|
| 85 |
+
result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
|
| 86 |
+
|
| 87 |
+
return result_BxTxC
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@torch.no_grad()
|
| 91 |
+
@torch.inference_mode()
|
| 92 |
+
def audio_to_codebook(
|
| 93 |
+
model,
|
| 94 |
+
input_values,
|
| 95 |
+
data_config: DataConfig,
|
| 96 |
+
padding_mask=None,
|
| 97 |
+
sample_rate=44100,
|
| 98 |
+
):
|
| 99 |
+
"""
|
| 100 |
+
Encodes the input audio waveform into discrete codes.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
model: The model to use for encoding.
|
| 104 |
+
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
| 105 |
+
Float values of the input audio waveform.
|
| 106 |
+
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
| 107 |
+
Padding mask used to pad the `input_values`.
|
| 108 |
+
sample_rate (`int`, *optional*) :
|
| 109 |
+
Signal sampling_rate
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
|
| 113 |
+
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
|
| 114 |
+
`codebook` of shape `[batch_size, num_codebooks, frames]`.
|
| 115 |
+
Scale is not used here.
|
| 116 |
+
|
| 117 |
+
"""
|
| 118 |
+
audio_data = model.preprocess(input_values, sample_rate)
|
| 119 |
+
|
| 120 |
+
if padding_mask is None:
|
| 121 |
+
padding_mask = torch.ones_like(input_values).bool()
|
| 122 |
+
|
| 123 |
+
_, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None) # 1, C, T
|
| 124 |
+
seq_length = encoded_frame.shape[2]
|
| 125 |
+
|
| 126 |
+
t_idx_BxTxC, indices_BTCx3 = build_delay_indices(
|
| 127 |
+
B=1,
|
| 128 |
+
T=seq_length,
|
| 129 |
+
C=data_config.channels,
|
| 130 |
+
delay_pattern=data_config.delay_pattern,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
encoded_frame = apply_audio_delay(
|
| 134 |
+
audio_BxTxC=encoded_frame.transpose(1, 2), # 1, T, C
|
| 135 |
+
pad_value=data_config.audio_pad_value,
|
| 136 |
+
bos_value=data_config.audio_bos_value,
|
| 137 |
+
precomp=(t_idx_BxTxC, indices_BTCx3),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
return encoded_frame
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 144 |
+
"""
|
| 145 |
+
Precompute indices for the revert operation using PyTorch.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
A tuple (t_idx_BxTxC, indices_BTCx3) where:
|
| 149 |
+
- t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
|
| 150 |
+
- indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
|
| 151 |
+
batch indices, clamped time indices, and channel indices.
|
| 152 |
+
"""
|
| 153 |
+
# Use default device unless specified otherwise; assumes inputs might define device later
|
| 154 |
+
device = None # Or determine dynamically if needed, e.g., from a model parameter
|
| 155 |
+
|
| 156 |
+
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
|
| 157 |
+
|
| 158 |
+
t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
|
| 159 |
+
t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
|
| 160 |
+
|
| 161 |
+
t_idx_BxTxC = torch.minimum(
|
| 162 |
+
t_idx_BT1 + delay_arr.view(1, 1, C),
|
| 163 |
+
torch.tensor(T - 1, device=device),
|
| 164 |
+
)
|
| 165 |
+
b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
|
| 166 |
+
c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
|
| 167 |
+
|
| 168 |
+
indices_BTCx3 = torch.stack(
|
| 169 |
+
[
|
| 170 |
+
b_idx_BxTxC.reshape(-1),
|
| 171 |
+
t_idx_BxTxC.reshape(-1),
|
| 172 |
+
c_idx_BxTxC.reshape(-1),
|
| 173 |
+
],
|
| 174 |
+
axis=1,
|
| 175 |
+
).long() # Ensure indices are long type
|
| 176 |
+
|
| 177 |
+
return t_idx_BxTxC, indices_BTCx3
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def revert_audio_delay(
|
| 181 |
+
audio_BxTxC: torch.Tensor,
|
| 182 |
+
pad_value: int,
|
| 183 |
+
precomp: tp.Tuple[torch.Tensor, torch.Tensor],
|
| 184 |
+
T: int,
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
"""
|
| 187 |
+
Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version).
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
audio_BxTxC: Input delayed audio tensor
|
| 191 |
+
pad_value: Padding value for out-of-bounds indices
|
| 192 |
+
precomp: Precomputed revert indices tuple containing:
|
| 193 |
+
- t_idx_BxTxC: Time offset indices tensor
|
| 194 |
+
- indices_BTCx3: Gather indices tensor for original audio
|
| 195 |
+
T: Original sequence length before padding
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Reverted audio tensor with same shape as input
|
| 199 |
+
"""
|
| 200 |
+
t_idx_BxTxC, indices_BTCx3 = precomp
|
| 201 |
+
device = audio_BxTxC.device # Get device from input tensor
|
| 202 |
+
|
| 203 |
+
# Move precomputed indices to the same device as audio_BxTxC if they aren't already
|
| 204 |
+
t_idx_BxTxC = t_idx_BxTxC.to(device)
|
| 205 |
+
indices_BTCx3 = indices_BTCx3.to(device)
|
| 206 |
+
|
| 207 |
+
# Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
|
| 208 |
+
gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
|
| 209 |
+
gathered_BxTxC = gathered_flat.view(audio_BxTxC.size()) # Use .size() for robust reshaping
|
| 210 |
+
|
| 211 |
+
# Create pad_tensor on the correct device
|
| 212 |
+
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
|
| 213 |
+
# Create T tensor on the correct device for comparison
|
| 214 |
+
T_tensor = torch.tensor(T, device=device)
|
| 215 |
+
|
| 216 |
+
result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC) # Changed np.where to torch.where
|
| 217 |
+
|
| 218 |
+
return result_BxTxC
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@torch.no_grad()
|
| 222 |
+
@torch.inference_mode()
|
| 223 |
+
def decode(
|
| 224 |
+
model,
|
| 225 |
+
audio_codes,
|
| 226 |
+
):
|
| 227 |
+
"""
|
| 228 |
+
Decodes the given frames into an output audio waveform
|
| 229 |
+
"""
|
| 230 |
+
if len(audio_codes) != 1:
|
| 231 |
+
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
audio_values = model.quantizer.from_codes(audio_codes)
|
| 235 |
+
audio_values = model.decode(audio_values[0])
|
| 236 |
+
|
| 237 |
+
return audio_values
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f"Error in decode method: {str(e)}")
|
| 240 |
+
raise
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def codebook_to_audio(generated_codes: torch.Tensor, model, delay_pattern, B=1, T=2600, C=9):
|
| 244 |
+
"""Process a single codebook file to generate audio"""
|
| 245 |
+
# Remove BOS token
|
| 246 |
+
generated_codes = generated_codes[:, 1:]
|
| 247 |
+
|
| 248 |
+
if generated_codes.shape[1] > T:
|
| 249 |
+
generated_codes = generated_codes[:, :T]
|
| 250 |
+
|
| 251 |
+
seq_length = generated_codes.shape[1]
|
| 252 |
+
|
| 253 |
+
# Build revert indices
|
| 254 |
+
t_idx_BxTxC, indices_BTCx3 = build_revert_indices(B=B, T=seq_length, C=C, delay_pattern=delay_pattern)
|
| 255 |
+
|
| 256 |
+
# Transpose and add batch dimension
|
| 257 |
+
audio_BxTxC = generated_codes.transpose(1, 0).unsqueeze(0)
|
| 258 |
+
reverted_codebook = revert_audio_delay(
|
| 259 |
+
audio_BxTxC=audio_BxTxC,
|
| 260 |
+
pad_value=0,
|
| 261 |
+
precomp=(t_idx_BxTxC, indices_BTCx3),
|
| 262 |
+
T=seq_length,
|
| 263 |
+
)
|
| 264 |
+
# Chỉ cắt bỏ 'delay' frame cuối nếu total frame > delay
|
| 265 |
+
delay = 30
|
| 266 |
+
num_frames = reverted_codebook.size(1) # chiều T ban đầu
|
| 267 |
+
|
| 268 |
+
if num_frames > delay:
|
| 269 |
+
reverted_codebook = reverted_codebook[:, :-delay, :]
|
| 270 |
+
# Nếu num_frames <= delay thì không cắt để tránh tạo chiều T = 0
|
| 271 |
+
|
| 272 |
+
codebook = reverted_codebook.transpose(1, 2) # (B x T x C) -> (B x C x T)
|
| 273 |
+
|
| 274 |
+
min_valid_index = 0
|
| 275 |
+
max_valid_index = 1023
|
| 276 |
+
invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
|
| 277 |
+
|
| 278 |
+
num_invalid = torch.sum(invalid_mask).item()
|
| 279 |
+
if num_invalid > 0:
|
| 280 |
+
print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.")
|
| 281 |
+
|
| 282 |
+
# Set invalid values to 0 (modify the tensor in-place)
|
| 283 |
+
codebook[invalid_mask] = 0
|
| 284 |
+
audio_array = decode(model, codebook)
|
| 285 |
+
|
| 286 |
+
return audio_array
|
dia/config.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "0.1",
|
| 3 |
+
"model": {
|
| 4 |
+
"encoder": {
|
| 5 |
+
"n_layer": 12,
|
| 6 |
+
"n_embd": 1024,
|
| 7 |
+
"n_hidden": 4096,
|
| 8 |
+
"n_head": 16,
|
| 9 |
+
"head_dim": 128
|
| 10 |
+
},
|
| 11 |
+
"decoder": {
|
| 12 |
+
"n_layer": 18,
|
| 13 |
+
"n_embd": 2048,
|
| 14 |
+
"n_hidden": 8192,
|
| 15 |
+
"gqa_query_heads": 16,
|
| 16 |
+
"cross_query_heads": 16,
|
| 17 |
+
"kv_heads": 4,
|
| 18 |
+
"gqa_head_dim": 128,
|
| 19 |
+
"cross_head_dim": 128,
|
| 20 |
+
"d_model" : 256
|
| 21 |
+
},
|
| 22 |
+
"src_vocab_size": 256,
|
| 23 |
+
"tgt_vocab_size": 1028,
|
| 24 |
+
"dropout": 0.0
|
| 25 |
+
},
|
| 26 |
+
"training": {
|
| 27 |
+
"dtype": "bfloat16"
|
| 28 |
+
},
|
| 29 |
+
"data": {
|
| 30 |
+
"text_length": 512,
|
| 31 |
+
"audio_length": 1536,
|
| 32 |
+
"channels": 9,
|
| 33 |
+
"text_pad_value": 0,
|
| 34 |
+
"audio_eos_value": 1024,
|
| 35 |
+
"audio_pad_value": 1025,
|
| 36 |
+
"audio_bos_value": 1026,
|
| 37 |
+
"delay_pattern": [
|
| 38 |
+
0,
|
| 39 |
+
8,
|
| 40 |
+
9,
|
| 41 |
+
10,
|
| 42 |
+
11,
|
| 43 |
+
12,
|
| 44 |
+
13,
|
| 45 |
+
14,
|
| 46 |
+
15
|
| 47 |
+
]
|
| 48 |
+
}
|
| 49 |
+
}
|
dia/config.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration management module for the Dia model.
|
| 2 |
+
|
| 3 |
+
This module provides comprehensive configuration management for the Dia model,
|
| 4 |
+
utilizing Pydantic for validation. It defines configurations for data processing,
|
| 5 |
+
model architecture (encoder and decoder), and training settings.
|
| 6 |
+
|
| 7 |
+
Key components:
|
| 8 |
+
- DataConfig: Parameters for data loading and preprocessing.
|
| 9 |
+
- EncoderConfig: Architecture details for the encoder module.
|
| 10 |
+
- DecoderConfig: Architecture details for the decoder module.
|
| 11 |
+
- ModelConfig: Combined model architecture settings.
|
| 12 |
+
- TrainingConfig: Training hyperparameters and settings.
|
| 13 |
+
- DiaConfig: Master configuration combining all components.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from typing import Annotated
|
| 18 |
+
|
| 19 |
+
from pydantic import BaseModel, BeforeValidator, Field
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DataConfig(BaseModel, frozen=True):
|
| 23 |
+
"""Configuration for data loading and preprocessing.
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
text_length: Maximum length of text sequences (must be multiple of 128).
|
| 27 |
+
audio_length: Maximum length of audio sequences (must be multiple of 128).
|
| 28 |
+
channels: Number of audio channels.
|
| 29 |
+
text_pad_value: Value used for padding text sequences.
|
| 30 |
+
audio_eos_value: Value representing the end of audio sequences.
|
| 31 |
+
audio_bos_value: Value representing the beginning of audio sequences.
|
| 32 |
+
audio_pad_value: Value used for padding audio sequences.
|
| 33 |
+
delay_pattern: List of delay values for each audio channel.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
|
| 37 |
+
audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
|
| 38 |
+
channels: int = Field(default=9, gt=0, multiple_of=1)
|
| 39 |
+
text_pad_value: int = Field(default=0)
|
| 40 |
+
audio_eos_value: int = Field(default=1024)
|
| 41 |
+
audio_pad_value: int = Field(default=1025)
|
| 42 |
+
audio_bos_value: int = Field(default=1026)
|
| 43 |
+
delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15])
|
| 44 |
+
|
| 45 |
+
def __hash__(self) -> int:
|
| 46 |
+
"""Generate a hash based on all fields of the config."""
|
| 47 |
+
return hash(
|
| 48 |
+
(
|
| 49 |
+
self.text_length,
|
| 50 |
+
self.audio_length,
|
| 51 |
+
self.channels,
|
| 52 |
+
self.text_pad_value,
|
| 53 |
+
self.audio_pad_value,
|
| 54 |
+
self.audio_bos_value,
|
| 55 |
+
self.audio_eos_value,
|
| 56 |
+
tuple(self.delay_pattern),
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class EncoderConfig(BaseModel, frozen=True):
|
| 62 |
+
"""Configuration for the encoder component of the Dia model.
|
| 63 |
+
|
| 64 |
+
Attributes:
|
| 65 |
+
n_layer: Number of transformer layers.
|
| 66 |
+
n_embd: Embedding dimension.
|
| 67 |
+
n_hidden: Hidden dimension size in the MLP layers.
|
| 68 |
+
n_head: Number of attention heads.
|
| 69 |
+
head_dim: Dimension per attention head.
|
| 70 |
+
mlp_activations: List of activation functions for the MLP layers.
|
| 71 |
+
use_pre_norm: Whether to use pre-normalization (LayerNorm before attention/MLP).
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
n_layer: int = Field(gt=0)
|
| 75 |
+
n_embd: int = Field(gt=0)
|
| 76 |
+
n_hidden: int = Field(gt=0)
|
| 77 |
+
n_head: int = Field(gt=0)
|
| 78 |
+
head_dim: int = Field(gt=0)
|
| 79 |
+
mlp_activations: list[str] = Field(default=["silu", "linear"])
|
| 80 |
+
use_pre_norm: bool = Field(default=False)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class DecoderConfig(BaseModel, frozen=True):
|
| 84 |
+
"""Configuration for the decoder component of the Dia model.
|
| 85 |
+
|
| 86 |
+
Attributes:
|
| 87 |
+
n_layer: Number of transformer layers.
|
| 88 |
+
n_embd: Embedding dimension.
|
| 89 |
+
n_hidden: Hidden dimension size in the MLP layers.
|
| 90 |
+
gqa_query_heads: Number of query heads for grouped-query self-attention.
|
| 91 |
+
kv_heads: Number of key/value heads for grouped-query self-attention.
|
| 92 |
+
gqa_head_dim: Dimension per query head for grouped-query self-attention.
|
| 93 |
+
cross_query_heads: Number of query heads for cross-attention.
|
| 94 |
+
cross_head_dim: Dimension per cross-attention head.
|
| 95 |
+
mlp_activations: List of activation functions for the MLP layers.
|
| 96 |
+
use_pre_norm: Whether to use pre-normalization.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
n_layer: int = Field(gt=0)
|
| 100 |
+
n_embd: int = Field(gt=0)
|
| 101 |
+
n_hidden: int = Field(gt=0)
|
| 102 |
+
gqa_query_heads: int = Field(gt=0)
|
| 103 |
+
kv_heads: int = Field(gt=0)
|
| 104 |
+
gqa_head_dim: int = Field(gt=0)
|
| 105 |
+
cross_query_heads: int = Field(gt=0)
|
| 106 |
+
cross_head_dim: int = Field(gt=0)
|
| 107 |
+
mlp_activations: list[str] = Field(default=["silu", "linear"])
|
| 108 |
+
use_pre_norm: bool = Field(default=False)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class ModelConfig(BaseModel, frozen=True):
|
| 112 |
+
"""Main configuration container for the Dia model architecture.
|
| 113 |
+
|
| 114 |
+
Attributes:
|
| 115 |
+
encoder: Configuration for the encoder component.
|
| 116 |
+
decoder: Configuration for the decoder component.
|
| 117 |
+
src_vocab_size: Size of the source (text) vocabulary.
|
| 118 |
+
tgt_vocab_size: Size of the target (audio code) vocabulary.
|
| 119 |
+
dropout: Dropout probability applied within the model.
|
| 120 |
+
normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
|
| 121 |
+
weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
|
| 122 |
+
rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
|
| 123 |
+
rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
encoder: EncoderConfig
|
| 127 |
+
decoder: DecoderConfig
|
| 128 |
+
src_vocab_size: int = Field(default=128, gt=0)
|
| 129 |
+
tgt_vocab_size: int = Field(default=1028, gt=0)
|
| 130 |
+
dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
|
| 131 |
+
normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
|
| 132 |
+
weight_dtype: str = Field(default="float32", description="Weight precision")
|
| 133 |
+
rope_min_timescale: int = Field(default=1, description="Timescale For global Attention")
|
| 134 |
+
rope_max_timescale: int = Field(default=10_000, description="Timescale For global Attention")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class TrainingConfig(BaseModel, frozen=True):
|
| 138 |
+
"""Training process configuration and hyperparameters.
|
| 139 |
+
|
| 140 |
+
Note: This configuration currently only includes precision settings.
|
| 141 |
+
Other training parameters (like batch size, learning rate, optimizer settings)
|
| 142 |
+
are assumed to be handled externally.
|
| 143 |
+
|
| 144 |
+
Attributes:
|
| 145 |
+
dtype: Data type for activations during training (e.g., "bfloat16", "float32").
|
| 146 |
+
logits_dot_in_fp32: Whether to compute the final logits dot product in fp32 for stability.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
dtype: str = Field(default="bfloat16", description="Activation precision")
|
| 150 |
+
logits_dot_in_fp32: bool = Field(default=False)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class DiaConfig(BaseModel, frozen=True):
|
| 154 |
+
"""Master configuration for the Dia model.
|
| 155 |
+
|
| 156 |
+
Combines all sub-configurations into a single validated object.
|
| 157 |
+
|
| 158 |
+
Attributes:
|
| 159 |
+
version: Configuration version string.
|
| 160 |
+
model: Model architecture configuration.
|
| 161 |
+
training: Training process configuration (precision settings).
|
| 162 |
+
data: Data loading and processing configuration.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
version: str = Field(default="1.0")
|
| 166 |
+
model: ModelConfig
|
| 167 |
+
training: TrainingConfig
|
| 168 |
+
data: DataConfig
|
| 169 |
+
|
| 170 |
+
def save(self, path: str) -> None:
|
| 171 |
+
"""Save the current configuration instance to a JSON file.
|
| 172 |
+
|
| 173 |
+
Ensures the parent directory exists and the file has a .json extension.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
path: The target file path to save the configuration.
|
| 177 |
+
|
| 178 |
+
Raises:
|
| 179 |
+
ValueError: If the path is not a file with a .json extension.
|
| 180 |
+
"""
|
| 181 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 182 |
+
config_json = self.model_dump_json(indent=2)
|
| 183 |
+
with open(path, "w") as f:
|
| 184 |
+
f.write(config_json)
|
| 185 |
+
|
| 186 |
+
@classmethod
|
| 187 |
+
def load(cls, path: str) -> "DiaConfig | None":
|
| 188 |
+
"""Load and validate a Dia configuration from a JSON file.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
path: The path to the configuration file.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
A validated DiaConfig instance if the file exists and is valid,
|
| 195 |
+
otherwise None if the file is not found.
|
| 196 |
+
|
| 197 |
+
Raises:
|
| 198 |
+
ValueError: If the path does not point to an existing .json file.
|
| 199 |
+
pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
|
| 200 |
+
"""
|
| 201 |
+
try:
|
| 202 |
+
with open(path, "r") as f:
|
| 203 |
+
content = f.read()
|
| 204 |
+
return cls.model_validate_json(content)
|
| 205 |
+
except FileNotFoundError:
|
| 206 |
+
return None
|
dia/config_inference.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "0.1",
|
| 3 |
+
"model": {
|
| 4 |
+
"encoder": {
|
| 5 |
+
"n_layer": 12,
|
| 6 |
+
"n_embd": 1024,
|
| 7 |
+
"n_hidden": 4096,
|
| 8 |
+
"n_head": 16,
|
| 9 |
+
"head_dim": 128
|
| 10 |
+
},
|
| 11 |
+
"decoder": {
|
| 12 |
+
"n_layer": 18,
|
| 13 |
+
"n_embd": 2048,
|
| 14 |
+
"n_hidden": 8192,
|
| 15 |
+
"gqa_query_heads": 16,
|
| 16 |
+
"cross_query_heads": 16,
|
| 17 |
+
"kv_heads": 4,
|
| 18 |
+
"gqa_head_dim": 128,
|
| 19 |
+
"cross_head_dim": 128,
|
| 20 |
+
"d_model" : 256
|
| 21 |
+
},
|
| 22 |
+
"src_vocab_size": 256,
|
| 23 |
+
"tgt_vocab_size": 1028,
|
| 24 |
+
"dropout": 0.0
|
| 25 |
+
},
|
| 26 |
+
"training": {
|
| 27 |
+
"dtype": "float32"
|
| 28 |
+
},
|
| 29 |
+
"data": {
|
| 30 |
+
"text_length": 512,
|
| 31 |
+
"audio_length": 1536,
|
| 32 |
+
"channels": 9,
|
| 33 |
+
"text_pad_value": 0,
|
| 34 |
+
"audio_eos_value": 1024,
|
| 35 |
+
"audio_pad_value": 1025,
|
| 36 |
+
"audio_bos_value": 1026,
|
| 37 |
+
"delay_pattern": [
|
| 38 |
+
0,
|
| 39 |
+
8,
|
| 40 |
+
9,
|
| 41 |
+
10,
|
| 42 |
+
11,
|
| 43 |
+
12,
|
| 44 |
+
13,
|
| 45 |
+
14,
|
| 46 |
+
15
|
| 47 |
+
]
|
| 48 |
+
}
|
| 49 |
+
}
|
dia/convert_ckpt.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
from dia.layers import DiaModel # adjust your import if needed
|
| 4 |
+
from dia.config import DiaConfig
|
| 5 |
+
|
| 6 |
+
def convert_checkpoint(input_ckpt: str, output_ckpt: str, config_path: str):
|
| 7 |
+
# select device
|
| 8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
+
|
| 10 |
+
# 1) Reconstruct exactly the same compiled model you saved
|
| 11 |
+
dia_cfg = DiaConfig.load(config_path)
|
| 12 |
+
model = DiaModel(dia_cfg).to(device)
|
| 13 |
+
model = model.half()
|
| 14 |
+
model = torch.compile(model, backend="inductor")
|
| 15 |
+
|
| 16 |
+
# 2) Load your compiled/half checkpoint
|
| 17 |
+
state = torch.load(input_ckpt, map_location=device)
|
| 18 |
+
model.load_state_dict(state)
|
| 19 |
+
|
| 20 |
+
# 3) Un-wrap to the original nn.Module
|
| 21 |
+
orig = getattr(model, "_orig_mod", None) or getattr(model, "__wrapped__", None) or model
|
| 22 |
+
|
| 23 |
+
# 4) Cast all params & buffers back to float32
|
| 24 |
+
orig.float()
|
| 25 |
+
|
| 26 |
+
# 5) Save its clean, float32 state_dict
|
| 27 |
+
torch.save(orig.state_dict(), output_ckpt)
|
| 28 |
+
print(f"Saved normal FP32 checkpoint to {output_ckpt}")
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
parser = argparse.ArgumentParser(
|
| 32 |
+
description="Convert a compiled/half-precision checkpoint back to a standard FP32 state_dict."
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--input-ckpt", "-i",
|
| 36 |
+
required=True,
|
| 37 |
+
help="Path to the half-precision compiled checkpoint (.pth) to load"
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--output-ckpt", "-o",
|
| 41 |
+
required=True,
|
| 42 |
+
help="Path where the FP32 state_dict will be saved"
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--config", "-c",
|
| 46 |
+
required=True,
|
| 47 |
+
help="Path to your DiaConfig JSON file"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
args = parser.parse_args()
|
| 51 |
+
convert_checkpoint(args.input_ckpt, args.output_ckpt, args.config)
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
main()
|
dia/dataset.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
|
| 8 |
+
import dac
|
| 9 |
+
from .config import DiaConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LocalDiaDataset(Dataset):
|
| 15 |
+
"""Dataset loader from a local CSV (pipe-separated) and audio folder."""
|
| 16 |
+
def __init__(self, csv_path: Path, audio_root: Path, config: DiaConfig, dac_model: dac.DAC):
|
| 17 |
+
self.df = pd.read_csv(csv_path, sep=r"\s*\|\s*", engine="python", names=["audio", "text", "channel"])
|
| 18 |
+
self.audio_root = audio_root
|
| 19 |
+
self.config = config
|
| 20 |
+
self.dac_model = dac_model
|
| 21 |
+
|
| 22 |
+
def __len__(self) -> int:
|
| 23 |
+
return len(self.df)
|
| 24 |
+
|
| 25 |
+
def __getitem__(self, idx: int):
|
| 26 |
+
row = self.df.iloc[idx]
|
| 27 |
+
text = row["text"]
|
| 28 |
+
channel = row.get("channel", None)
|
| 29 |
+
if channel and pd.notna(channel):
|
| 30 |
+
text = f"[{channel}]{text}"
|
| 31 |
+
|
| 32 |
+
audio_path = self.audio_root / row["audio"]
|
| 33 |
+
waveform, sr = torchaudio.load(audio_path)
|
| 34 |
+
|
| 35 |
+
if sr != 44100:
|
| 36 |
+
waveform = torchaudio.functional.resample(waveform, sr, 44100)
|
| 37 |
+
|
| 38 |
+
if waveform.ndim == 1:
|
| 39 |
+
waveform = waveform.unsqueeze(0)
|
| 40 |
+
elif waveform.ndim == 2:
|
| 41 |
+
waveform = waveform[:1] # only take 1 channel if stereo
|
| 42 |
+
|
| 43 |
+
waveform = waveform.unsqueeze(0) # (1, 1, T)
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
audio_tensor = self.dac_model.preprocess(waveform, 44100).to(
|
| 46 |
+
next(self.dac_model.parameters()).device
|
| 47 |
+
)
|
| 48 |
+
_, encoded, *_ = self.dac_model.encode(audio_tensor, n_quantizers=None)
|
| 49 |
+
encoded = encoded.squeeze(0).transpose(0, 1) # (T, C)
|
| 50 |
+
|
| 51 |
+
return text, encoded, waveform
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class HFDiaDataset(Dataset):
|
| 55 |
+
"""Dataset loader từ Hugging Face Datasets object."""
|
| 56 |
+
|
| 57 |
+
def __init__(self, hf_dataset, config: DiaConfig, dac_model: dac.DAC):
|
| 58 |
+
self.dataset = hf_dataset
|
| 59 |
+
self.config = config
|
| 60 |
+
self.dac_model = dac_model
|
| 61 |
+
|
| 62 |
+
def __len__(self) -> int:
|
| 63 |
+
return len(self.dataset)
|
| 64 |
+
|
| 65 |
+
def __getitem__(self, idx: int):
|
| 66 |
+
sample = self.dataset[idx]
|
| 67 |
+
|
| 68 |
+
# Xử lý text tag
|
| 69 |
+
text = sample["text"]
|
| 70 |
+
channel = sample.get("channel", None)
|
| 71 |
+
lang = sample.get("language", None)
|
| 72 |
+
|
| 73 |
+
if channel and isinstance(channel, str) and channel.strip():
|
| 74 |
+
text = f"[{channel}]{text}"
|
| 75 |
+
elif lang and isinstance(lang, str):
|
| 76 |
+
text = f"[{lang}]{text}"
|
| 77 |
+
|
| 78 |
+
# Xử lý audio
|
| 79 |
+
audio_info = sample["audio"]
|
| 80 |
+
waveform = torch.tensor(audio_info["array"], dtype=torch.float32)
|
| 81 |
+
|
| 82 |
+
# Đảm bảo waveform shape (1, 1, T)
|
| 83 |
+
if waveform.ndim == 1:
|
| 84 |
+
waveform = waveform.unsqueeze(0).unsqueeze(0)
|
| 85 |
+
elif waveform.ndim == 2:
|
| 86 |
+
waveform = waveform[:1].unsqueeze(0) # lấy 1 channel đầu
|
| 87 |
+
|
| 88 |
+
# Resample nếu không phải 44100 Hz
|
| 89 |
+
sr = audio_info.get("sampling_rate", 44100)
|
| 90 |
+
if sr != 44100:
|
| 91 |
+
waveform = torchaudio.functional.resample(waveform, sr, 44100)
|
| 92 |
+
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
audio_tensor = self.dac_model.preprocess(waveform, 44100).to(next(self.dac_model.parameters()).device)
|
| 95 |
+
_, encoded, *_ = self.dac_model.encode(audio_tensor, n_quantizers=None)
|
| 96 |
+
encoded = encoded.squeeze(0).transpose(0, 1) # (T, C)
|
| 97 |
+
|
| 98 |
+
return text, encoded, waveform
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class HFDiaIterDataset(torch.utils.data.IterableDataset):
|
| 103 |
+
"""Iterable wrapper for a HF streaming Dataset that has `audio.array` & `text`."""
|
| 104 |
+
def __init__(self, hf_iterable, config: DiaConfig, dac_model: dac.DAC):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.dataset = hf_iterable
|
| 107 |
+
self.config = config
|
| 108 |
+
self.dac_model = dac_model
|
| 109 |
+
|
| 110 |
+
def __iter__(self):
|
| 111 |
+
for sample in self.dataset:
|
| 112 |
+
lang = sample.get("language", None)
|
| 113 |
+
# Lấy thông tin channel và chuẩn hóa
|
| 114 |
+
channel = sample.get("channel", "").replace("@", "").lower()
|
| 115 |
+
speaker_tag = f"[{channel}]" if channel else "[unk]"
|
| 116 |
+
# Ghép tag speaker + text
|
| 117 |
+
text = speaker_tag + sample["text"]
|
| 118 |
+
audio_info = sample['audio']
|
| 119 |
+
waveform = torch.tensor(audio_info['array'], dtype=torch.float32)
|
| 120 |
+
if waveform.ndim == 1:
|
| 121 |
+
waveform = waveform.unsqueeze(0).unsqueeze(0)
|
| 122 |
+
elif waveform.ndim == 2:
|
| 123 |
+
waveform = waveform.unsqueeze(0)
|
| 124 |
+
sr = audio_info.get('sampling_rate', 44100)
|
| 125 |
+
if sr != 44100:
|
| 126 |
+
waveform = torchaudio.functional.resample(waveform, sr, 44100)
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
audio_tensor = (
|
| 129 |
+
self.dac_model.preprocess(waveform, 44100)
|
| 130 |
+
.to(next(self.dac_model.parameters()).device)
|
| 131 |
+
)
|
| 132 |
+
_, encoded, *_ = self.dac_model.encode(audio_tensor, n_quantizers=None)
|
| 133 |
+
encoded = encoded.squeeze(0).transpose(0, 1)
|
| 134 |
+
yield text, encoded, waveform
|
| 135 |
+
|
| 136 |
+
from .dataset import HFDiaIterDataset
|
| 137 |
+
|
| 138 |
+
class VietnameseDiaDataset(HFDiaIterDataset):
|
| 139 |
+
def __init__(self, dataset, dia_cfg, dac_model):
|
| 140 |
+
super().__init__(dataset, dia_cfg, dac_model)
|
| 141 |
+
|
| 142 |
+
def __getitem__(self, idx):
|
| 143 |
+
item = self.dataset[idx]
|
| 144 |
+
# 1) Thêm tag ngôn ngữ [vi]
|
| 145 |
+
text = item["text"]
|
| 146 |
+
if not text.startswith("[vi]"):
|
| 147 |
+
text = f"[vi]{text}"
|
| 148 |
+
|
| 149 |
+
# 2) Xử lý audio về 44.1 kHz
|
| 150 |
+
audio_array = item["audio"]["array"]
|
| 151 |
+
sr = item["audio"]["sampling_rate"]
|
| 152 |
+
if sr != 44100:
|
| 153 |
+
audio_array = torchaudio.functional.resample(
|
| 154 |
+
torch.tensor(audio_array),
|
| 155 |
+
orig_freq=sr,
|
| 156 |
+
new_freq=44100
|
| 157 |
+
).numpy()
|
| 158 |
+
|
| 159 |
+
# 3) Mã hoá DAC (tần số codec) từ đoạn audio
|
| 160 |
+
encoding = self.get_dac_encoding(audio_array)
|
| 161 |
+
|
| 162 |
+
return text, encoding, audio_array
|
dia/finetune.py
ADDED
|
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import tempfile
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
| 13 |
+
from torch.cuda.amp import autocast
|
| 14 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 15 |
+
from torch.nn.utils import clip_grad_norm_
|
| 16 |
+
from transformers import get_scheduler
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import bitsandbytes as bnb
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from datasets import load_dataset, interleave_datasets, get_dataset_config_names, DatasetDict
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
+
import math
|
| 23 |
+
import gc
|
| 24 |
+
from torch.cuda.amp import GradScaler
|
| 25 |
+
|
| 26 |
+
import dac
|
| 27 |
+
from .config import DiaConfig
|
| 28 |
+
from .layers import DiaModel
|
| 29 |
+
from .model import Dia
|
| 30 |
+
from .audio import build_delay_indices, apply_audio_delay
|
| 31 |
+
from .dataset import *
|
| 32 |
+
from .interleaved_datasets import load_cml_tts_streamed, load_common_voice17_streamed
|
| 33 |
+
from datasets import load_from_disk
|
| 34 |
+
from .dataset import HFDiaDataset
|
| 35 |
+
from tqdm import tqdm
|
| 36 |
+
|
| 37 |
+
# Configure logging
|
| 38 |
+
logging.basicConfig(
|
| 39 |
+
level=logging.INFO,
|
| 40 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 41 |
+
)
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 45 |
+
torch.backends.cudnn.benchmark = True
|
| 46 |
+
|
| 47 |
+
#bytes for language tag replacement
|
| 48 |
+
LANG2BYTE = {
|
| 49 |
+
"en": 3,
|
| 50 |
+
"vi": 19,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
CHANNELS = [
|
| 54 |
+
"5phutcrypto",
|
| 55 |
+
"anhbanthan",
|
| 56 |
+
"anhthamtu",
|
| 57 |
+
"animerewind.official",
|
| 58 |
+
"bibitv8888",
|
| 59 |
+
"btvgo",
|
| 60 |
+
"baclieutv",
|
| 61 |
+
"bachhoaxanhcom",
|
| 62 |
+
"baodientuvov",
|
| 63 |
+
"blvckvines",
|
| 64 |
+
"boringppl",
|
| 65 |
+
"bronub",
|
| 66 |
+
"cdteam-why",
|
| 67 |
+
"cobabinhduong",
|
| 68 |
+
"cosmicwriter",
|
| 69 |
+
"cuthongthai",
|
| 70 |
+
"daiphatthanhtruyenhinhsonla",
|
| 71 |
+
"day-be-thong-minh-tv",
|
| 72 |
+
"danangtv",
|
| 73 |
+
"daihanoi-htv",
|
| 74 |
+
"daiptththainguyentntv",
|
| 75 |
+
"dongmauviet",
|
| 76 |
+
"dongthaptv",
|
| 77 |
+
"fptbongdaofficial",
|
| 78 |
+
"fonosvietnam",
|
| 79 |
+
"hieurotrong5phut-ntkt",
|
| 80 |
+
"htvtintuc",
|
| 81 |
+
"happyhidari",
|
| 82 |
+
"hoabinhtvgo",
|
| 83 |
+
"hocenglishonline",
|
| 84 |
+
"hocvienbovagau",
|
| 85 |
+
"hungyentvvngo",
|
| 86 |
+
"huynhduykhuongofficial",
|
| 87 |
+
"huynhlapofficial",
|
| 88 |
+
"jvevermind",
|
| 89 |
+
"kenhvtc16",
|
| 90 |
+
"kiengiangtv",
|
| 91 |
+
"khanhvyofficial",
|
| 92 |
+
"kienthucquansu",
|
| 93 |
+
"lamdongtv1",
|
| 94 |
+
"lamvlog",
|
| 95 |
+
"longantv-la34",
|
| 96 |
+
"mangovid",
|
| 97 |
+
"mensbay",
|
| 98 |
+
"meovatcuocsonglnv",
|
| 99 |
+
"meuchannel",
|
| 100 |
+
"ntnvlogsnguyenthanhnam",
|
| 101 |
+
"ngamradio",
|
| 102 |
+
"nhanhac555",
|
| 103 |
+
"nhantaidaiviet",
|
| 104 |
+
"ptth-trt",
|
| 105 |
+
"ptvtruyenhinhphutho",
|
| 106 |
+
"phantichgame",
|
| 107 |
+
"phephim",
|
| 108 |
+
"phimhottk-l",
|
| 109 |
+
"riwaylegal",
|
| 110 |
+
"ruangao",
|
| 111 |
+
"suckhoetamsinh",
|
| 112 |
+
"sachbiquyethanhcong",
|
| 113 |
+
"soisangbrightsidevietnamese",
|
| 114 |
+
"spiderum",
|
| 115 |
+
"spiderumbooks",
|
| 116 |
+
"sukieskitchen",
|
| 117 |
+
"tin3phut",
|
| 118 |
+
"tranthanhtown",
|
| 119 |
+
"tulemientay",
|
| 120 |
+
"tayninhtv",
|
| 121 |
+
"thainhitv",
|
| 122 |
+
"thanhpahm",
|
| 123 |
+
"thegioilaptop",
|
| 124 |
+
"thepresentwriter",
|
| 125 |
+
"tiengiangtivi",
|
| 126 |
+
"tieubaobaothom",
|
| 127 |
+
"tintucbitcoin247",
|
| 128 |
+
"truyenhinhbinhphuoc-bptv",
|
| 129 |
+
"truyenhinhyenbaiytv",
|
| 130 |
+
"truyenhinhcaobang",
|
| 131 |
+
"truyenhinhdaklakdrt",
|
| 132 |
+
"truyenhinhdaknong1",
|
| 133 |
+
"truyenhinhdienbien23.9",
|
| 134 |
+
"truyenhinhkhanhhoa",
|
| 135 |
+
"truyenhinhkontumkrt",
|
| 136 |
+
"truyenhinhnaminhntv",
|
| 137 |
+
"truyenhinhninhthuan",
|
| 138 |
+
"truyenhinhquangngai",
|
| 139 |
+
"tuantienti2911",
|
| 140 |
+
"tuyenquangttv",
|
| 141 |
+
"vovlivedoctruyen",
|
| 142 |
+
"vietcetera",
|
| 143 |
+
"vinhlongtv",
|
| 144 |
+
"voizfm",
|
| 145 |
+
"vutrunguyenthuy",
|
| 146 |
+
"vuive",
|
| 147 |
+
"w2wanime",
|
| 148 |
+
"w2wcartoon",
|
| 149 |
+
"w2whorror",
|
| 150 |
+
"w2wmovie",
|
| 151 |
+
"web5ngay",
|
| 152 |
+
"xanh24h",
|
| 153 |
+
"aiphatthanhtruyenhinhquangtri",
|
| 154 |
+
"aiphatthanhvatruyenhinhhai1908",
|
| 155 |
+
"altonghop",
|
| 156 |
+
"antvtruyenhinhcongannhandan",
|
| 157 |
+
"baihoc10phut",
|
| 158 |
+
"battlecry.khampha",
|
| 159 |
+
"betterversionvn",
|
| 160 |
+
"blogkhoinghiep",
|
| 161 |
+
"bumcn",
|
| 162 |
+
"caikinhdi_vn",
|
| 163 |
+
"canthitg",
|
| 164 |
+
"chanthienmybachnien",
|
| 165 |
+
"chauanhchao",
|
| 166 |
+
"cosu",
|
| 167 |
+
"cungmaivaobep-monan-amthuc",
|
| 168 |
+
"daiptthphuyen",
|
| 169 |
+
"daiptthtv",
|
| 170 |
+
"daitruyenhinhangiang",
|
| 171 |
+
"daitruyenhinhbacgiang",
|
| 172 |
+
"dannytran2375",
|
| 173 |
+
"daybehoc5489",
|
| 174 |
+
"daylaphegame",
|
| 175 |
+
"dienmay",
|
| 176 |
+
"ducisreal",
|
| 177 |
+
"duongfg",
|
| 178 |
+
"duyluandethuong",
|
| 179 |
+
"duythanhish",
|
| 180 |
+
"elroydevops",
|
| 181 |
+
"gc.gamelab",
|
| 182 |
+
"hacthaybachthay",
|
| 183 |
+
"hagiangtv475",
|
| 184 |
+
"haiduongtv247",
|
| 185 |
+
"hanamtv8831",
|
| 186 |
+
"hangphimtailieudienanhnd",
|
| 187 |
+
"haugiangtv",
|
| 188 |
+
"haunauday",
|
| 189 |
+
"hieu-tv",
|
| 190 |
+
"hoshiphan",
|
| 191 |
+
"jakinatsumi2915",
|
| 192 |
+
"kechuyentieuhoc1719",
|
| 193 |
+
"kenhcovan",
|
| 194 |
+
"khalid_dinh",
|
| 195 |
+
"kiaralah",
|
| 196 |
+
"laichautv",
|
| 197 |
+
"langsontvtube",
|
| 198 |
+
"megame_official",
|
| 199 |
+
"minvestvn",
|
| 200 |
+
"nguoithanhcong1991",
|
| 201 |
+
"nhatkycuocsong.",
|
| 202 |
+
"ntcanima",
|
| 203 |
+
"ptthbentre",
|
| 204 |
+
"ptthquangbinh",
|
| 205 |
+
"qrt",
|
| 206 |
+
"quangninhtv",
|
| 207 |
+
"snewsvn",
|
| 208 |
+
"soctrangtv",
|
| 209 |
+
"sunhuynpodcast",
|
| 210 |
+
"tamhonanuong",
|
| 211 |
+
"tgddreview",
|
| 212 |
+
"thaibinhtv",
|
| 213 |
+
"thanhnamedu",
|
| 214 |
+
"thanhnientvnews",
|
| 215 |
+
"thbrt",
|
| 216 |
+
"thieunhitv3630",
|
| 217 |
+
"thtpct",
|
| 218 |
+
"tinnhanh3phut868",
|
| 219 |
+
"toansam",
|
| 220 |
+
"toidicodedaoblog",
|
| 221 |
+
"tranquochuywecommit",
|
| 222 |
+
"tranvyvy",
|
| 223 |
+
"truyenhinh4k",
|
| 224 |
+
"truyenhinhbinhthuan",
|
| 225 |
+
"truyenhinhcamau69",
|
| 226 |
+
"truyenhinhdongnai_dnrtv",
|
| 227 |
+
"truyenhinhgialai",
|
| 228 |
+
"truyenhinhlaocai",
|
| 229 |
+
"truyenhinhnghean",
|
| 230 |
+
"truyenhinhvinhphuc",
|
| 231 |
+
"txtofficial8798",
|
| 232 |
+
"vanhkhuyenle",
|
| 233 |
+
"vietnh1009",
|
| 234 |
+
"visaothenhipodcast",
|
| 235 |
+
"vtc14",
|
| 236 |
+
"vtcnow",
|
| 237 |
+
"vtv24",
|
| 238 |
+
"vuive123",
|
| 239 |
+
"zombiev4",
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
# Tự động ánh xạ channel → token (bắt đầu từ 30)
|
| 243 |
+
for i, ch in enumerate(CHANNELS):
|
| 244 |
+
LANG2BYTE[ch] = 30 + i
|
| 245 |
+
|
| 246 |
+
test_sentences = {
|
| 247 |
+
"en": "In order to fully assess performance and the accuracy of language tags, this test sentence contains multiple subordinate clauses, varied punctuation, and a sufficient word count.",
|
| 248 |
+
"vi": "Để đánh giá toàn diện hiệu suất và độ chính xác của các thẻ ngôn ngữ, câu kiểm tra này chứa nhiều mệnh đề phụ, dấu câu đa dạng và số lượng từ đầy đủ."
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
@dataclass
|
| 252 |
+
class TrainConfig:
|
| 253 |
+
epochs: int = 1
|
| 254 |
+
batch_size: int = 2
|
| 255 |
+
grad_accum_steps: int = 2
|
| 256 |
+
learning_rate: float = 1e-5
|
| 257 |
+
warmup_steps: int = 500
|
| 258 |
+
unconditional_frac: float = 0.15
|
| 259 |
+
eval_step: int = 200
|
| 260 |
+
save_step: int = 2000
|
| 261 |
+
split_ratio: float = 0.997
|
| 262 |
+
shuffle_buffer_size: int = None # for streaming shuffle
|
| 263 |
+
seed: int = 42 # seed for reproducibility
|
| 264 |
+
runs_dir: Path = Path("runs")
|
| 265 |
+
run_name: str = "dia_finetune_cv"
|
| 266 |
+
output_dir: Path = Path(".cpkts/dia_finetune_cv ")
|
| 267 |
+
resume_from: Path = None
|
| 268 |
+
total_steps: int = 290007
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def get_args() -> argparse.Namespace:
|
| 273 |
+
parser = argparse.ArgumentParser(description="Train the Dia audio model")
|
| 274 |
+
parser.add_argument("--config", type=Path, default=Path("dia/config.json"))
|
| 275 |
+
parser.add_argument("--dataset", type=str, default="Paradoxia/opendata-iisys-hui",
|
| 276 |
+
help="HuggingFace dataset name (if not using --csv_path).")
|
| 277 |
+
parser.add_argument("--dataset2", type=str, default=None,
|
| 278 |
+
help="(Optional) second HF dataset to interleave (streaming)")
|
| 279 |
+
parser.add_argument("--streaming",action="store_true",
|
| 280 |
+
help="Enable HuggingFace dataset streaming")
|
| 281 |
+
parser.add_argument("--hub_model", type=str, default="nari-labs/Dia-1.6B")
|
| 282 |
+
parser.add_argument("--local_ckpt", type=str, default=None)
|
| 283 |
+
parser.add_argument("--csv_path", type=Path, default=None,
|
| 284 |
+
help="Path to local CSV/TSV file with `audio|text` (if you want to train locally).")
|
| 285 |
+
parser.add_argument("--audio_root",type=Path, default=None,
|
| 286 |
+
help="Root directory for local audio files (required if --csv_path is set).")
|
| 287 |
+
parser.add_argument("--run_name", type=str, default=None)
|
| 288 |
+
parser.add_argument("--output_dir",type=Path, default=None)
|
| 289 |
+
parser.add_argument("--shuffle_buffer_size", type=int, default=None,
|
| 290 |
+
help="Buffer size for streaming dataset shuffle.")
|
| 291 |
+
parser.add_argument("--seed", type=int, default=42,
|
| 292 |
+
help="Random seed for reproducibility.")
|
| 293 |
+
parser.add_argument("--half", action="store_true", help="load model in fp16")
|
| 294 |
+
parser.add_argument("--compile", action="store_true", help="torch compile model")
|
| 295 |
+
parser.add_argument('--use_amp', action='store_true', help='Enable mixed precision')
|
| 296 |
+
parser.add_argument("--resume_from", type=str, default=None)
|
| 297 |
+
return parser.parse_args()
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def collate_fn(batch, config: DiaConfig, device: torch.device):
|
| 302 |
+
from torch.nn.functional import pad
|
| 303 |
+
|
| 304 |
+
texts, encodings, waveforms = zip(*batch)
|
| 305 |
+
|
| 306 |
+
# -- Text inputs ---------------------------------------------------------
|
| 307 |
+
|
| 308 |
+
max_text = config.data.text_length
|
| 309 |
+
pad_tok = config.data.text_pad_value
|
| 310 |
+
text_ids = []
|
| 311 |
+
for txt in texts:
|
| 312 |
+
b_full = txt.encode('utf-8')
|
| 313 |
+
# replace leading "[lang]" prefix
|
| 314 |
+
for code, val in LANG2BYTE.items():
|
| 315 |
+
prefix = f"[{code}]".encode('utf-8')
|
| 316 |
+
if b_full.startswith(prefix):
|
| 317 |
+
b_full = bytes([val]) + b_full[len(prefix):]
|
| 318 |
+
break
|
| 319 |
+
bts = b_full[:max_text]
|
| 320 |
+
arr = list(bts) + [pad_tok] * (max_text - len(bts))
|
| 321 |
+
text_ids.append(torch.tensor(arr, dtype=torch.long))
|
| 322 |
+
src = torch.stack(text_ids).to(device)
|
| 323 |
+
src_pos = torch.arange(max_text, device=device).unsqueeze(0).expand(src.size(0), -1)
|
| 324 |
+
src_pad = src.ne(pad_tok)
|
| 325 |
+
enc_self_attn_mask = (src_pad.unsqueeze(2) & src_pad.unsqueeze(1)).unsqueeze(1)
|
| 326 |
+
|
| 327 |
+
# -- Audio codes --------------------------------------------------------
|
| 328 |
+
|
| 329 |
+
max_audio = config.data.audio_length
|
| 330 |
+
# per-sample lengths (clipped to max_audio)
|
| 331 |
+
seq_lens = [min(e.size(0), max_audio) for e in encodings]
|
| 332 |
+
batch_max = max(seq_lens)
|
| 333 |
+
|
| 334 |
+
# pad or trim each encoding to the batch max length
|
| 335 |
+
padded = [pad(e, (0, 0, 0, batch_max - e.size(0))) if e.size(0) < batch_max else e[:batch_max]
|
| 336 |
+
for e in encodings]
|
| 337 |
+
codes = torch.stack(padded).to(device) # (B, T=batch_max, C)
|
| 338 |
+
|
| 339 |
+
B, T, C = codes.shape
|
| 340 |
+
t_idx, idxs = build_delay_indices(B, T, C, config.data.delay_pattern)
|
| 341 |
+
delayed = apply_audio_delay(
|
| 342 |
+
codes,
|
| 343 |
+
config.data.audio_pad_value,
|
| 344 |
+
config.data.audio_bos_value,
|
| 345 |
+
(t_idx, idxs)
|
| 346 |
+
)
|
| 347 |
+
# ensure no longer than max_audio
|
| 348 |
+
delayed = delayed[:, :max_audio, :]
|
| 349 |
+
|
| 350 |
+
# -- Targets with per-sample EOS ----------------------------------------
|
| 351 |
+
|
| 352 |
+
max_tgt_len = max_audio + 2
|
| 353 |
+
pad_val = config.data.audio_pad_value
|
| 354 |
+
bos_val = config.data.audio_bos_value
|
| 355 |
+
eos_val = config.data.audio_eos_value
|
| 356 |
+
|
| 357 |
+
tgt = torch.full((B, max_tgt_len, C), pad_val, dtype=torch.long, device=device)
|
| 358 |
+
tgt[:, 0, :] = bos_val
|
| 359 |
+
tgt_lens = []
|
| 360 |
+
for i, L in enumerate(seq_lens):
|
| 361 |
+
tgt[i, 1:1 + L, :] = delayed[i, :L, :]
|
| 362 |
+
tgt[i, 1 + L, :] = eos_val
|
| 363 |
+
tgt_lens.append(1 + L + 1)
|
| 364 |
+
|
| 365 |
+
tgt_pos = torch.arange(max_tgt_len, device=device).unsqueeze(0).expand(B, -1)
|
| 366 |
+
tgt_pad = tgt.ne(pad_val).any(-1)
|
| 367 |
+
|
| 368 |
+
causal = torch.tril(torch.ones((max_tgt_len, max_tgt_len),
|
| 369 |
+
dtype=torch.bool,
|
| 370 |
+
device=device))
|
| 371 |
+
dec_self_attn_mask = (tgt_pad.unsqueeze(2) & tgt_pad.unsqueeze(1) & causal).unsqueeze(1)
|
| 372 |
+
dec_cross_attn_mask = (tgt_pad.unsqueeze(2) & src_pad.unsqueeze(1)).unsqueeze(1)
|
| 373 |
+
|
| 374 |
+
return {
|
| 375 |
+
'src_tokens': src,
|
| 376 |
+
'src_positions': src_pos,
|
| 377 |
+
'enc_self_attn_mask': enc_self_attn_mask,
|
| 378 |
+
'tgt_tokens': tgt,
|
| 379 |
+
'tgt_positions': tgt_pos,
|
| 380 |
+
'dec_self_attn_mask': dec_self_attn_mask,
|
| 381 |
+
'dec_cross_attn_mask': dec_cross_attn_mask,
|
| 382 |
+
'waveforms': waveforms,
|
| 383 |
+
'raw_text': texts[0],
|
| 384 |
+
'tgt_lens': torch.tensor(tgt_lens, dtype=torch.long, device=device),
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
def setup_loaders(dataset, dia_cfg: DiaConfig, train_cfg: TrainConfig, device):
|
| 388 |
+
collate = lambda b: collate_fn(b, dia_cfg, device)
|
| 389 |
+
if isinstance(dataset, HFDiaIterDataset):
|
| 390 |
+
total = getattr(dataset, "total_examples", None)
|
| 391 |
+
if total is None:
|
| 392 |
+
total = dataset.dataset.info.splits["train"].num_examples
|
| 393 |
+
n_train = int(train_cfg.split_ratio * total)
|
| 394 |
+
n_val = total - n_train
|
| 395 |
+
if n_val <= 0:
|
| 396 |
+
raise RuntimeError(f"No validation samples: total={total}, split_ratio={train_cfg.split_ratio}")
|
| 397 |
+
base = dataset.dataset.shuffle(buffer_size=train_cfg.shuffle_buffer_size, seed=train_cfg.seed) if train_cfg.shuffle_buffer_size else dataset.dataset
|
| 398 |
+
val_stream = base.take(n_val)
|
| 399 |
+
train_stream = base.skip(n_val)
|
| 400 |
+
train_ds = HFDiaIterDataset(train_stream, dia_cfg, dataset.dac_model)
|
| 401 |
+
val_ds = HFDiaIterDataset(val_stream, dia_cfg, dataset.dac_model)
|
| 402 |
+
train_loader = DataLoader(train_ds, batch_size=train_cfg.batch_size, shuffle=False, collate_fn=collate)
|
| 403 |
+
train_loader.steps_per_epoch = math.ceil(n_train / train_cfg.batch_size)
|
| 404 |
+
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate)
|
| 405 |
+
return train_loader, val_loader
|
| 406 |
+
ds_len = len(dataset)
|
| 407 |
+
n_train = int(train_cfg.split_ratio * ds_len)
|
| 408 |
+
train_ds, val_ds = random_split(dataset, [n_train, ds_len - n_train])
|
| 409 |
+
train_loader = DataLoader(train_ds, batch_size=train_cfg.batch_size, shuffle=True, collate_fn=collate)
|
| 410 |
+
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate)
|
| 411 |
+
return train_loader, val_loader
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def setup_optimizer_and_scheduler(model, train_loader, train_cfg):
|
| 416 |
+
opt = bnb.optim.AdamW8bit(model.parameters(), lr=train_cfg.learning_rate)
|
| 417 |
+
# Determine steps per epoch: prefer len(), else use attached attribute
|
| 418 |
+
try:
|
| 419 |
+
steps_per_epoch = len(train_loader)
|
| 420 |
+
except TypeError:
|
| 421 |
+
if hasattr(train_loader, 'steps_per_epoch'):
|
| 422 |
+
steps_per_epoch = train_loader.steps_per_epoch
|
| 423 |
+
else:
|
| 424 |
+
raise RuntimeError("Cannot determine steps_per_epoch for streaming loader")
|
| 425 |
+
total_training_steps = steps_per_epoch * train_cfg.epochs
|
| 426 |
+
sched = get_scheduler(
|
| 427 |
+
'cosine', opt,
|
| 428 |
+
num_warmup_steps=train_cfg.warmup_steps / train_cfg.grad_accum_steps,
|
| 429 |
+
num_training_steps=total_training_steps / train_cfg.grad_accum_steps
|
| 430 |
+
)
|
| 431 |
+
return opt, sched
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def train_step(model, batch, dia_cfg, train_cfg, opt, sched, writer, step_in_epoch, global_step,scaler):
|
| 436 |
+
"""
|
| 437 |
+
Perform a single training step: forward, loss, backward, update, log.
|
| 438 |
+
Now uses per‑sample tgt_lens to mask out padding after each EOS,
|
| 439 |
+
and applies 4× loss weight on the first channel.
|
| 440 |
+
"""
|
| 441 |
+
# (optional) unconditional conditioning
|
| 442 |
+
if random.random() < train_cfg.unconditional_frac:
|
| 443 |
+
pad_tok = dia_cfg.data.text_pad_value
|
| 444 |
+
batch['src_tokens'] = torch.zeros_like(batch['src_tokens'])
|
| 445 |
+
batch['enc_self_attn_mask'] = torch.zeros_like(batch['enc_self_attn_mask'])
|
| 446 |
+
batch['dec_cross_attn_mask'] = torch.zeros_like(batch['dec_cross_attn_mask'])
|
| 447 |
+
|
| 448 |
+
with autocast(dtype=torch.float16):
|
| 449 |
+
# forward pass
|
| 450 |
+
logits = model(
|
| 451 |
+
src_BxS=batch['src_tokens'],
|
| 452 |
+
tgt_BxTxC=batch['tgt_tokens'],
|
| 453 |
+
src_positions=batch['src_positions'],
|
| 454 |
+
tgt_positions=batch['tgt_positions'],
|
| 455 |
+
enc_self_attn_mask=batch['enc_self_attn_mask'],
|
| 456 |
+
dec_self_attn_mask=batch['dec_self_attn_mask'],
|
| 457 |
+
dec_cross_attn_mask=batch['dec_cross_attn_mask'],
|
| 458 |
+
enable_dropout=True,
|
| 459 |
+
)
|
| 460 |
+
# fetch per-sample target‑lengths (including BOS+frames+EOS)
|
| 461 |
+
lens = batch['tgt_lens'] # shape: (B,)
|
| 462 |
+
max_L = int(lens.max().item()) # maximum over batch
|
| 463 |
+
|
| 464 |
+
# keep only up through the last possible EOS slot
|
| 465 |
+
# logits: (B, T, C, V) -> (B, max_L-1, C, V)
|
| 466 |
+
logits = logits[:, : max_L - 1]
|
| 467 |
+
|
| 468 |
+
# targets: shift off the BOS so 0..<max_L-1> align with logits
|
| 469 |
+
# target: (B, T, C) -> (B, max_L-1, C)
|
| 470 |
+
target = batch['tgt_tokens'][:, 1:max_L, :]
|
| 471 |
+
|
| 472 |
+
B, Tm1, C = target.shape
|
| 473 |
+
pad_val = dia_cfg.data.audio_pad_value
|
| 474 |
+
|
| 475 |
+
# build a mask [B x (max_L-1)] that is True for t < (lens[i]-1)
|
| 476 |
+
time_idx = torch.arange(Tm1, device=lens.device).unsqueeze(0) # (1, Tm1)
|
| 477 |
+
valid_time = time_idx < (lens.unsqueeze(1) - 1) # (B, Tm1)
|
| 478 |
+
mask = valid_time.unsqueeze(-1).expand(-1, -1, C) # (B, Tm1, C)
|
| 479 |
+
|
| 480 |
+
# apply 4× weight on first channel, 1× on others
|
| 481 |
+
channel_weights = [4.0] + [1.0] * (C - 1)
|
| 482 |
+
loss_c = 0.0
|
| 483 |
+
_, _, _, V = logits.size()
|
| 484 |
+
|
| 485 |
+
for c, w in enumerate(channel_weights):
|
| 486 |
+
# flatten this channel
|
| 487 |
+
lc = logits[:, :, c, :].reshape(-1, V) # (B*Tm1, V)
|
| 488 |
+
tc = target[:, :, c].reshape(-1) # (B*Tm1,)
|
| 489 |
+
mc = mask[:, :, c].reshape(-1) # (B*Tm1,)
|
| 490 |
+
|
| 491 |
+
# mask out padding and compute cross-entropy
|
| 492 |
+
lc_valid = lc[mc]
|
| 493 |
+
tc_valid = tc[mc]
|
| 494 |
+
loss_c += w * F.cross_entropy(
|
| 495 |
+
lc_valid, tc_valid,
|
| 496 |
+
ignore_index=pad_val
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# normalize by sum of weights
|
| 500 |
+
loss = loss_c / sum(channel_weights)
|
| 501 |
+
|
| 502 |
+
# scale + backward
|
| 503 |
+
loss = loss / train_cfg.grad_accum_steps
|
| 504 |
+
scaler.scale(loss).backward()
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
# step & log
|
| 508 |
+
|
| 509 |
+
if (step_in_epoch + 1) % train_cfg.grad_accum_steps == 0:
|
| 510 |
+
# Unscale before clipping
|
| 511 |
+
scaler.unscale_(opt)
|
| 512 |
+
grad_norm = clip_grad_norm_(model.parameters(), max_norm=1e9)
|
| 513 |
+
|
| 514 |
+
scaler.step(opt)
|
| 515 |
+
scaler.update()
|
| 516 |
+
opt.zero_grad()
|
| 517 |
+
sched.step()
|
| 518 |
+
|
| 519 |
+
true_loss = loss.item() * train_cfg.grad_accum_steps
|
| 520 |
+
current_lr = sched.get_last_lr()[0]
|
| 521 |
+
|
| 522 |
+
writer.add_scalar('GradNorm/global', grad_norm, global_step)
|
| 523 |
+
writer.add_scalar('LR', current_lr, global_step)
|
| 524 |
+
writer.add_scalar('Loss/train', true_loss, global_step)
|
| 525 |
+
|
| 526 |
+
return true_loss
|
| 527 |
+
else:
|
| 528 |
+
return loss.item() * train_cfg.grad_accum_steps
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def eval_step(model, val_loader, dia_cfg, dac_model, writer, global_step):
|
| 534 |
+
"""
|
| 535 |
+
Run evaluation: compute average loss on validation set and log audio samples.
|
| 536 |
+
"""
|
| 537 |
+
import gc
|
| 538 |
+
eval_losses = []
|
| 539 |
+
last_batch = None
|
| 540 |
+
with torch.inference_mode():
|
| 541 |
+
for eb in tqdm(val_loader, desc="eval"):
|
| 542 |
+
last_batch = eb
|
| 543 |
+
|
| 544 |
+
with autocast(dtype=torch.float16):
|
| 545 |
+
logits16 = model(
|
| 546 |
+
src_BxS=eb['src_tokens'],
|
| 547 |
+
tgt_BxTxC=eb['tgt_tokens'],
|
| 548 |
+
src_positions=eb['src_positions'],
|
| 549 |
+
tgt_positions=eb['tgt_positions'],
|
| 550 |
+
enc_self_attn_mask=eb['enc_self_attn_mask'],
|
| 551 |
+
dec_self_attn_mask=eb['dec_self_attn_mask'],
|
| 552 |
+
dec_cross_attn_mask=eb['dec_cross_attn_mask'],
|
| 553 |
+
enable_dropout=False,
|
| 554 |
+
)[:, :-1]
|
| 555 |
+
|
| 556 |
+
logits = logits16.float()
|
| 557 |
+
target = eb['tgt_tokens'][:, 1:]
|
| 558 |
+
B_e, T_e, C_e = target.shape
|
| 559 |
+
V_e = logits.size(-1)
|
| 560 |
+
|
| 561 |
+
loss_e = 0.0
|
| 562 |
+
weights_e = [4.0] + [1.0] * (C_e - 1)
|
| 563 |
+
for c, w in enumerate(weights_e):
|
| 564 |
+
lc = logits[:, :, c, :].reshape(-1, V_e)
|
| 565 |
+
tc = target[:, :, c].reshape(-1)
|
| 566 |
+
loss_e += w * F.cross_entropy(
|
| 567 |
+
lc, tc, ignore_index=dia_cfg.data.audio_pad_value
|
| 568 |
+
)
|
| 569 |
+
loss_e = loss_e / sum(weights_e)
|
| 570 |
+
|
| 571 |
+
eval_losses.append(loss_e)
|
| 572 |
+
|
| 573 |
+
avg_eval_loss = sum(eval_losses) / len(eval_losses)
|
| 574 |
+
writer.add_scalar('Loss/eval', avg_eval_loss.item(), global_step)
|
| 575 |
+
|
| 576 |
+
# --- Inference test sentence ---
|
| 577 |
+
try:
|
| 578 |
+
orig_dtype = next(model.parameters()).dtype
|
| 579 |
+
model = model.float()
|
| 580 |
+
dia_gen = Dia(dia_cfg, device)
|
| 581 |
+
dia_gen.model, dia_gen.dac_model = model, dac_model
|
| 582 |
+
|
| 583 |
+
# ✅ Test câu hội thoại đa giọng
|
| 584 |
+
test_dialogue = "[vtv24] Em vừa đi học về, anh ạ. [duongfg] Ừ, em ăn cơm chưa? [vtv24] Em ăn rồi!"
|
| 585 |
+
|
| 586 |
+
if len(test_dialogue) > 10:
|
| 587 |
+
try:
|
| 588 |
+
audio = dia_gen.generate(text=test_dialogue)
|
| 589 |
+
writer.add_audio("Eval/test_dialogue", audio, global_step, 44100)
|
| 590 |
+
except Exception:
|
| 591 |
+
logger.exception("Eval error during test_dialogue")
|
| 592 |
+
finally:
|
| 593 |
+
if 'audio' in locals():
|
| 594 |
+
del audio
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
except Exception:
|
| 598 |
+
logger.exception("Eval error")
|
| 599 |
+
|
| 600 |
+
finally:
|
| 601 |
+
if 'audio' in locals():
|
| 602 |
+
del audio
|
| 603 |
+
gc.collect()
|
| 604 |
+
torch.cuda.empty_cache()
|
| 605 |
+
if orig_dtype == torch.float16:
|
| 606 |
+
model = model.half()
|
| 607 |
+
|
| 608 |
+
def train(model, dia_cfg: DiaConfig, dac_model: dac.DAC, dataset, train_cfg: TrainConfig):
|
| 609 |
+
"""
|
| 610 |
+
Run the full training loop over epochs.
|
| 611 |
+
"""
|
| 612 |
+
# prepare directories
|
| 613 |
+
train_cfg.output_dir.mkdir(parents=True, exist_ok=True)
|
| 614 |
+
(train_cfg.runs_dir / train_cfg.run_name).mkdir(parents=True, exist_ok=True)
|
| 615 |
+
model = model.to(device)
|
| 616 |
+
|
| 617 |
+
train_loader, val_loader = setup_loaders(dataset, dia_cfg, train_cfg, device)
|
| 618 |
+
opt, sched = setup_optimizer_and_scheduler(model, train_loader, train_cfg)
|
| 619 |
+
|
| 620 |
+
writer = SummaryWriter(train_cfg.runs_dir / train_cfg.run_name)
|
| 621 |
+
model.train()
|
| 622 |
+
scaler = GradScaler()
|
| 623 |
+
start_epoch = 0
|
| 624 |
+
global_step = 0
|
| 625 |
+
resume_ckpt = getattr(train_cfg, "resume_from", None)
|
| 626 |
+
if resume_ckpt and resume_ckpt.exists():
|
| 627 |
+
logger.info(f"Resuming from checkpoint: {resume_ckpt}")
|
| 628 |
+
checkpoint = torch.load(resume_ckpt, map_location=device)
|
| 629 |
+
model.load_state_dict(checkpoint["model"])
|
| 630 |
+
opt.load_state_dict(checkpoint["optimizer"])
|
| 631 |
+
sched.load_state_dict(checkpoint["scheduler"])
|
| 632 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
| 633 |
+
start_epoch = checkpoint.get("epoch", 0)
|
| 634 |
+
global_step = checkpoint.get("global_step", 0)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
steps_per_epoch = getattr(train_loader, 'steps_per_epoch', None)
|
| 638 |
+
if steps_per_epoch is None:
|
| 639 |
+
try:
|
| 640 |
+
steps_per_epoch = len(train_loader)
|
| 641 |
+
except Exception:
|
| 642 |
+
steps_per_epoch = None
|
| 643 |
+
|
| 644 |
+
for epoch in range(start_epoch, train_cfg.epochs):
|
| 645 |
+
# iterate with progress bar, using total if known
|
| 646 |
+
loader_iter = tqdm(
|
| 647 |
+
train_loader,
|
| 648 |
+
desc=f"E{epoch+1}",
|
| 649 |
+
total=steps_per_epoch
|
| 650 |
+
)
|
| 651 |
+
pbar = tqdm(loader_iter, total=train_cfg.total_steps, initial=global_step, desc=f"E{epoch}")
|
| 652 |
+
for step_in_epoch, batch in enumerate(pbar):
|
| 653 |
+
global_step += 1
|
| 654 |
+
# training step
|
| 655 |
+
loss = train_step(model, batch, dia_cfg, train_cfg, opt, sched, writer, step_in_epoch, global_step, scaler)
|
| 656 |
+
|
| 657 |
+
cur_alloc = torch.cuda.memory_allocated() # bytes currently allocated by tensors
|
| 658 |
+
peak_alloc = torch.cuda.max_memory_allocated() # bytes peak during program
|
| 659 |
+
# optionally convert to GB
|
| 660 |
+
cur_gb = cur_alloc / 1024**3
|
| 661 |
+
peak_gb = peak_alloc / 1024**3
|
| 662 |
+
|
| 663 |
+
# update the tqdm postfix
|
| 664 |
+
loader_iter.set_postfix({
|
| 665 |
+
'loss': f"{loss:.4f}",
|
| 666 |
+
'VRAM (GB)': f"{cur_gb:.2f}/{peak_gb:.2f}"
|
| 667 |
+
})
|
| 668 |
+
|
| 669 |
+
# remember to zero the peak if you want rolling peaks per step
|
| 670 |
+
if torch.cuda.is_available():
|
| 671 |
+
torch.cuda.reset_peak_memory_stats()
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
# evaluation
|
| 675 |
+
if step_in_epoch % train_cfg.eval_step == 0:
|
| 676 |
+
model.eval()
|
| 677 |
+
with torch.no_grad():
|
| 678 |
+
eval_step(model, val_loader, dia_cfg, dac_model, writer, global_step)
|
| 679 |
+
model.train()
|
| 680 |
+
scaler = GradScaler()
|
| 681 |
+
|
| 682 |
+
# checkpoint
|
| 683 |
+
if step_in_epoch and step_in_epoch % train_cfg.save_step == 0:
|
| 684 |
+
ckpt = train_cfg.output_dir / f"ckpt_step{global_step}.pth"
|
| 685 |
+
torch.save({
|
| 686 |
+
"model": model.state_dict(),
|
| 687 |
+
"optimizer": opt.state_dict(),
|
| 688 |
+
"scheduler": sched.state_dict(),
|
| 689 |
+
"scaler": scaler.state_dict(),
|
| 690 |
+
"epoch": epoch,
|
| 691 |
+
"global_step": global_step
|
| 692 |
+
}, ckpt)
|
| 693 |
+
logger.info(f"Saved checkpoint: {ckpt}")
|
| 694 |
+
|
| 695 |
+
# end of epoch checkpoint
|
| 696 |
+
ckpt_e = train_cfg.output_dir / f"ckpt_epoch{epoch+1}.pth"
|
| 697 |
+
torch.save({
|
| 698 |
+
"model": model.state_dict(),
|
| 699 |
+
"optimizer": opt.state_dict(),
|
| 700 |
+
"scheduler": sched.state_dict(),
|
| 701 |
+
"scaler": scaler.state_dict(),
|
| 702 |
+
"epoch": epoch + 1,
|
| 703 |
+
"global_step": global_step
|
| 704 |
+
}, ckpt_e)
|
| 705 |
+
logger.info(f"Saved end-of-epoch checkpoint: {ckpt_e}")
|
| 706 |
+
|
| 707 |
+
from datasets import disable_caching
|
| 708 |
+
|
| 709 |
+
def main():
|
| 710 |
+
args = get_args()
|
| 711 |
+
import os
|
| 712 |
+
os.environ["HF_DATASETS_CACHE"] = "/tmp/force_streaming" # ép cache mới
|
| 713 |
+
disable_caching()
|
| 714 |
+
# tắt toàn bộ cache local HuggingFace
|
| 715 |
+
import json
|
| 716 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
| 717 |
+
config_dict = json.load(f)
|
| 718 |
+
|
| 719 |
+
dia_cfg = DiaConfig(**config_dict)
|
| 720 |
+
dac_model = dac.DAC.load(dac.utils.download()).to(device)
|
| 721 |
+
dataset = None
|
| 722 |
+
|
| 723 |
+
if not dataset:
|
| 724 |
+
if args.csv_path:
|
| 725 |
+
if not args.audio_root:
|
| 726 |
+
raise ValueError("`--audio_root` must be set when using `--csv_path`")
|
| 727 |
+
dataset = LocalDiaDataset(args.csv_path, args.audio_root, dia_cfg, dac_model)
|
| 728 |
+
else:
|
| 729 |
+
# ✅ Check nếu dataset là đường dẫn local
|
| 730 |
+
if Path(args.dataset).exists():
|
| 731 |
+
print(f"Loading dataset from local path: {args.dataset}")
|
| 732 |
+
ds1 = load_from_disk(args.dataset)
|
| 733 |
+
if isinstance(ds1, DatasetDict):
|
| 734 |
+
ds1 = ds1["train"]
|
| 735 |
+
dataset = HFDiaDataset(ds1, dia_cfg, dac_model)
|
| 736 |
+
else:
|
| 737 |
+
print(f"Loading HuggingFace dataset: {args.dataset} (streaming)")
|
| 738 |
+
ds1 = load_dataset(args.dataset, split="train", streaming=True)
|
| 739 |
+
|
| 740 |
+
if args.dataset2:
|
| 741 |
+
ds2 = load_dataset(args.dataset2, split="train", streaming=True)
|
| 742 |
+
hf_ds = interleave_datasets([ds1, ds2])
|
| 743 |
+
dataset = HFDiaIterDataset(hf_ds, dia_cfg, dac_model)
|
| 744 |
+
else:
|
| 745 |
+
hf_ds = ds1
|
| 746 |
+
dataset = HFDiaIterDataset(hf_ds, dia_cfg, dac_model)
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
train_cfg = TrainConfig(
|
| 751 |
+
run_name = args.run_name or TrainConfig.run_name,
|
| 752 |
+
output_dir = args.output_dir or TrainConfig.output_dir,
|
| 753 |
+
shuffle_buffer_size = args.shuffle_buffer_size,
|
| 754 |
+
seed = args.seed,
|
| 755 |
+
)
|
| 756 |
+
if args.resume_from:
|
| 757 |
+
train_cfg.resume_from = Path(args.resume_from)
|
| 758 |
+
# load model checkpoint
|
| 759 |
+
if args.local_ckpt:
|
| 760 |
+
ckpt_file = args.local_ckpt
|
| 761 |
+
else:
|
| 762 |
+
ckpt_file = hf_hub_download(args.hub_model, filename="dia-v0_1.pth")
|
| 763 |
+
model = DiaModel(dia_cfg)
|
| 764 |
+
if args.half:
|
| 765 |
+
model=model.half()
|
| 766 |
+
if args.compile:
|
| 767 |
+
model = torch.compile(model, backend="inductor")
|
| 768 |
+
ckpt = torch.load(ckpt_file, map_location="cpu")
|
| 769 |
+
state_dict = ckpt["model"] if "model" in ckpt else ckpt
|
| 770 |
+
new_state_dict = {}
|
| 771 |
+
|
| 772 |
+
for k, v in state_dict.items():
|
| 773 |
+
if "encoder.embedding.weight" in k:
|
| 774 |
+
if v.shape != model.state_dict()[k].shape:
|
| 775 |
+
print(f"⚠️ Bỏ qua {k} do shape không khớp: {v.shape} vs {model.state_dict()[k].shape}")
|
| 776 |
+
continue
|
| 777 |
+
new_state_dict[k] = v
|
| 778 |
+
|
| 779 |
+
model.load_state_dict(new_state_dict, strict=False)
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
# start training
|
| 783 |
+
train(model, dia_cfg, dac_model, dataset, train_cfg)
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
if __name__ == "__main__":
|
| 787 |
+
main()
|
dia/interleaved_datasets.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset, get_dataset_config_names, interleave_datasets, load_dataset_builder
|
| 2 |
+
from .dataset import HFDiaIterDataset
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
LANG_NAME_TO_CODE = {
|
| 8 |
+
"dutch": "nl",
|
| 9 |
+
"french": "fr",
|
| 10 |
+
"german": "de",
|
| 11 |
+
"italian": "it",
|
| 12 |
+
"polish": "pl",
|
| 13 |
+
"portuguese": "pt",
|
| 14 |
+
"spanish": "es",
|
| 15 |
+
# add more if other configs appear...
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_cml_tts_streamed(dia_cfg, dac_model):
|
| 24 |
+
"""
|
| 25 |
+
Stream all language subsets of the CML-TTS dataset in train split,
|
| 26 |
+
add a `language` field, drop all except `text`, `audio`, `language`,
|
| 27 |
+
and interleave them into one streaming Dataset.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
datasets.IterableDataset: interleaved streaming dataset
|
| 31 |
+
"""
|
| 32 |
+
# 1) Discover all language subsets
|
| 33 |
+
lang_configs = get_dataset_config_names("ylacombe/cml-tts")
|
| 34 |
+
|
| 35 |
+
# 2) Build one streaming subset per language, with only desired columns
|
| 36 |
+
streams = []
|
| 37 |
+
num_ex=0
|
| 38 |
+
for lang in lang_configs:
|
| 39 |
+
|
| 40 |
+
iso_code = LANG_NAME_TO_CODE.get(lang, lang)
|
| 41 |
+
ds_stream = load_dataset(
|
| 42 |
+
"ylacombe/cml-tts",
|
| 43 |
+
name=lang,
|
| 44 |
+
split="train",
|
| 45 |
+
streaming=True
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
num_ex += ds_stream.info.splits['train'].num_examples
|
| 49 |
+
# keep only text, audio, and add language
|
| 50 |
+
def _add_lang(ex, iso=iso_code):
|
| 51 |
+
return {
|
| 52 |
+
"text": ex["text"],
|
| 53 |
+
"audio": ex["audio"],
|
| 54 |
+
"language": iso
|
| 55 |
+
}
|
| 56 |
+
ds_stream = ds_stream.map(
|
| 57 |
+
_add_lang,
|
| 58 |
+
remove_columns=[c for c in ds_stream.column_names if c not in ["text", "audio", "language"]]
|
| 59 |
+
)
|
| 60 |
+
streams.append(ds_stream)
|
| 61 |
+
|
| 62 |
+
# 3) Interleave all streams into one unified stream
|
| 63 |
+
interleaved = interleave_datasets(streams, stopping_strategy="all_exhausted")
|
| 64 |
+
ds = HFDiaIterDataset(interleaved, dia_cfg, dac_model)
|
| 65 |
+
ds.total_examples = num_ex
|
| 66 |
+
return ds
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def count_tsv_rows(
|
| 74 |
+
repo_id: str,
|
| 75 |
+
subset: str,
|
| 76 |
+
split: str = "train",
|
| 77 |
+
revision: str = "main"
|
| 78 |
+
) -> int:
|
| 79 |
+
"""Download the TSV for a given subset/split and return its number of rows."""
|
| 80 |
+
file_path = f"transcript/{subset}/{split}.tsv"
|
| 81 |
+
try:
|
| 82 |
+
local_file = hf_hub_download(
|
| 83 |
+
repo_id=repo_id,
|
| 84 |
+
filename=file_path,
|
| 85 |
+
repo_type="dataset",
|
| 86 |
+
revision=revision
|
| 87 |
+
)
|
| 88 |
+
except:
|
| 89 |
+
print("error fetching tsv metadata")
|
| 90 |
+
|
| 91 |
+
df = pd.read_csv(local_file, sep="\t", low_memory=False)
|
| 92 |
+
return len(df)
|
| 93 |
+
|
| 94 |
+
def load_common_voice17_streamed(dia_cfg, dac_model, revision="main"):
|
| 95 |
+
"""
|
| 96 |
+
Stream the train split of Common Voice 17 for the given language codes,
|
| 97 |
+
rename `sentence`→`text`, keep only `text`, `audio`, and `language`,
|
| 98 |
+
then interleave into a single streaming Dataset.
|
| 99 |
+
|
| 100 |
+
Languages loaded: en, de, fr, es, it, nl, pl, pt, tr, hu
|
| 101 |
+
"""
|
| 102 |
+
repo_id = "mozilla-foundation/common_voice_17_0"
|
| 103 |
+
langs = ["en", "de", "fr", "es", "it", "nl", "pl", "pt", "tr", "hu"]
|
| 104 |
+
|
| 105 |
+
streams = []
|
| 106 |
+
row_counts = []
|
| 107 |
+
|
| 108 |
+
for lang in langs:
|
| 109 |
+
# 1) figure out how many rows in the TSV
|
| 110 |
+
n_rows = count_tsv_rows(repo_id, lang, split="train", revision=revision)
|
| 111 |
+
row_counts.append(n_rows)
|
| 112 |
+
|
| 113 |
+
# 2) load in streaming mode
|
| 114 |
+
ds_stream = load_dataset(
|
| 115 |
+
repo_id,
|
| 116 |
+
name=lang,
|
| 117 |
+
split="train",
|
| 118 |
+
streaming=True,
|
| 119 |
+
revision=revision
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# 3) map to desired schema
|
| 123 |
+
def _prep(ex, iso=lang):
|
| 124 |
+
return {
|
| 125 |
+
"text": ex["sentence"],
|
| 126 |
+
"audio": ex["audio"],
|
| 127 |
+
"language": iso
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
ds_stream = ds_stream.map(
|
| 131 |
+
_prep,
|
| 132 |
+
remove_columns=[c for c in ds_stream.column_names if c not in ("sentence", "audio")]
|
| 133 |
+
)
|
| 134 |
+
streams.append(ds_stream)
|
| 135 |
+
|
| 136 |
+
# 4) interleave: all_exhausted ⇒ max_length * num_streams
|
| 137 |
+
interleaved = interleave_datasets(streams, stopping_strategy="all_exhausted")
|
| 138 |
+
|
| 139 |
+
# 5) wrap and attach total_examples
|
| 140 |
+
ds = HFDiaIterDataset(interleaved, dia_cfg, dac_model)
|
| 141 |
+
ds.total_examples = max(row_counts) * len(langs)
|
| 142 |
+
|
| 143 |
+
return ds
|
| 144 |
+
|
dia/layers.py
ADDED
|
@@ -0,0 +1,909 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from torch.nn import RMSNorm
|
| 8 |
+
|
| 9 |
+
from .config import DiaConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
|
| 13 |
+
return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _str_to_dtype(dtype_str: str) -> torch.dtype | None:
|
| 17 |
+
# Allow None for default behavior
|
| 18 |
+
if dtype_str is None or dtype_str.lower() == "none":
|
| 19 |
+
return None
|
| 20 |
+
if dtype_str == "float32":
|
| 21 |
+
return torch.float32
|
| 22 |
+
elif dtype_str == "float16":
|
| 23 |
+
return torch.float16
|
| 24 |
+
elif dtype_str == "bfloat16":
|
| 25 |
+
return torch.bfloat16
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError(f"Unsupported dtype string: {dtype_str}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class DenseGeneral(nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
|
| 33 |
+
|
| 34 |
+
Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
|
| 35 |
+
for the generalized matrix multiplication. Weight/bias shapes are calculated
|
| 36 |
+
and parameters created during initialization based on config.
|
| 37 |
+
`load_weights` validates shapes and copies data.
|
| 38 |
+
|
| 39 |
+
Attributes:
|
| 40 |
+
axis (Tuple[int, ...]): Input axis or axes to contract.
|
| 41 |
+
in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
|
| 42 |
+
out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
|
| 43 |
+
use_bias (bool): Whether to add a bias term.
|
| 44 |
+
weight (nn.Parameter): The kernel parameter.
|
| 45 |
+
bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
in_shapes: tuple[int, ...],
|
| 51 |
+
out_features: tuple[int, ...],
|
| 52 |
+
axis: tuple[int, ...] = (-1,),
|
| 53 |
+
dtype: torch.dtype | None = None,
|
| 54 |
+
weight_dtype: torch.dtype | None = None,
|
| 55 |
+
device: torch.device | None = None,
|
| 56 |
+
):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.in_shapes = in_shapes
|
| 59 |
+
self.out_features = out_features
|
| 60 |
+
self.axis = axis
|
| 61 |
+
self.dtype = dtype
|
| 62 |
+
self.kernel_shape = self.in_shapes + self.out_features
|
| 63 |
+
|
| 64 |
+
factory_kwargs = {"device": device, "dtype": weight_dtype}
|
| 65 |
+
self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
|
| 66 |
+
self.register_parameter("bias", None)
|
| 67 |
+
|
| 68 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 69 |
+
norm_axis = _normalize_axes(self.axis, inputs.ndim)
|
| 70 |
+
kernel_contract_axes = tuple(range(len(norm_axis)))
|
| 71 |
+
|
| 72 |
+
output = torch.tensordot(
|
| 73 |
+
inputs.float(),
|
| 74 |
+
self.weight.float(),
|
| 75 |
+
dims=(norm_axis, kernel_contract_axes),
|
| 76 |
+
).to(inputs.dtype)
|
| 77 |
+
return output
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_activation_fn(activation_string: str) -> nn.Module: # Return Module instance
|
| 81 |
+
"""Maps activation string to PyTorch activation function module."""
|
| 82 |
+
if activation_string == "gelu":
|
| 83 |
+
return nn.GELU()
|
| 84 |
+
elif activation_string == "relu":
|
| 85 |
+
return nn.ReLU()
|
| 86 |
+
elif activation_string == "silu" or activation_string == "swish":
|
| 87 |
+
return nn.SiLU()
|
| 88 |
+
elif activation_string == "linear":
|
| 89 |
+
return nn.Identity()
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError(f"Unsupported activation function: {activation_string}")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class MlpBlock(nn.Module):
|
| 95 |
+
"""MLP block using DenseGeneral."""
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
config: DiaConfig,
|
| 100 |
+
embed_dim: int,
|
| 101 |
+
intermediate_dim: int,
|
| 102 |
+
dropout_rate: float,
|
| 103 |
+
activations: list[str] = ["silu", "linear"],
|
| 104 |
+
use_pre_norm: bool = False,
|
| 105 |
+
):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.use_pre_norm = use_pre_norm
|
| 108 |
+
num_activations = len(activations)
|
| 109 |
+
compute_dtype = _str_to_dtype(config.training.dtype)
|
| 110 |
+
weight_dtype = _str_to_dtype(config.model.weight_dtype)
|
| 111 |
+
self.dtype = compute_dtype
|
| 112 |
+
# Assume default device for now, could be passed in config
|
| 113 |
+
|
| 114 |
+
if use_pre_norm:
|
| 115 |
+
self.pre_norm = RMSNorm(
|
| 116 |
+
embed_dim,
|
| 117 |
+
eps=config.model.normalization_layer_epsilon,
|
| 118 |
+
dtype=torch.float32,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.wi_fused = DenseGeneral(
|
| 122 |
+
in_shapes=(embed_dim,),
|
| 123 |
+
out_features=(
|
| 124 |
+
num_activations,
|
| 125 |
+
intermediate_dim,
|
| 126 |
+
),
|
| 127 |
+
axis=(-1,),
|
| 128 |
+
dtype=compute_dtype,
|
| 129 |
+
weight_dtype=weight_dtype,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.activation_fn_0 = get_activation_fn(activations[0]) # silu
|
| 133 |
+
self.activation_fn_1 = get_activation_fn(activations[1]) # linear
|
| 134 |
+
|
| 135 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 136 |
+
|
| 137 |
+
# Output layer using DenseGeneral
|
| 138 |
+
self.wo = DenseGeneral(
|
| 139 |
+
in_shapes=(intermediate_dim,),
|
| 140 |
+
out_features=(embed_dim,),
|
| 141 |
+
axis=(-1,),
|
| 142 |
+
dtype=compute_dtype,
|
| 143 |
+
weight_dtype=weight_dtype,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def forward(self, x: torch.Tensor, deterministic: bool) -> torch.Tensor:
|
| 147 |
+
"""Forward pass."""
|
| 148 |
+
if self.use_pre_norm and hasattr(self, "pre_norm"):
|
| 149 |
+
x = self.pre_norm(x)
|
| 150 |
+
|
| 151 |
+
fused_x = self.wi_fused(x)
|
| 152 |
+
|
| 153 |
+
gate_input = fused_x[..., 0, :]
|
| 154 |
+
up_input = fused_x[..., 1, :]
|
| 155 |
+
|
| 156 |
+
gate = self.activation_fn_0(gate_input)
|
| 157 |
+
up = self.activation_fn_1(up_input)
|
| 158 |
+
hidden = torch.mul(gate, up).to(self.dtype)
|
| 159 |
+
|
| 160 |
+
if not deterministic:
|
| 161 |
+
hidden = self.dropout(hidden)
|
| 162 |
+
|
| 163 |
+
output = self.wo(hidden)
|
| 164 |
+
return output
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class RotaryEmbedding(nn.Module):
|
| 168 |
+
"""Rotary Position Embedding (RoPE) implementation in PyTorch."""
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
embedding_dims: int,
|
| 173 |
+
min_timescale: int = 1,
|
| 174 |
+
max_timescale: int = 10000,
|
| 175 |
+
dtype: torch.dtype = torch.float32,
|
| 176 |
+
):
|
| 177 |
+
super().__init__()
|
| 178 |
+
if embedding_dims % 2 != 0:
|
| 179 |
+
raise ValueError("Embedding dim must be even for RoPE.")
|
| 180 |
+
self.embedding_dims = embedding_dims
|
| 181 |
+
self.min_timescale = min_timescale
|
| 182 |
+
self.max_timescale = max_timescale
|
| 183 |
+
self.dtype = dtype
|
| 184 |
+
|
| 185 |
+
half_embedding_dim = embedding_dims // 2
|
| 186 |
+
fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
|
| 187 |
+
self.register_buffer(
|
| 188 |
+
"timescale",
|
| 189 |
+
self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction,
|
| 190 |
+
persistent=False,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def extra_repr(self) -> str:
|
| 194 |
+
s = f"{self.timescale.shape}"
|
| 195 |
+
return s
|
| 196 |
+
|
| 197 |
+
def forward(self, inputs: torch.Tensor, position: torch.Tensor):
|
| 198 |
+
"""Applies RoPE."""
|
| 199 |
+
position = position.unsqueeze(-1).unsqueeze(-1)
|
| 200 |
+
timescale = self.timescale.to(inputs.device)
|
| 201 |
+
sinusoid_inp = position / timescale
|
| 202 |
+
sin = torch.sin(sinusoid_inp).to(inputs.dtype)
|
| 203 |
+
cos = torch.cos(sinusoid_inp).to(inputs.dtype)
|
| 204 |
+
first_half, second_half = torch.chunk(inputs, 2, dim=-1)
|
| 205 |
+
first_part = first_half * cos - second_half * sin
|
| 206 |
+
second_part = second_half * cos + first_half * sin
|
| 207 |
+
return torch.cat((first_part, second_part), dim=-1)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class KVCache:
|
| 211 |
+
def __init__(self, num_heads, max_len, head_dim, device, k=None, v=None):
|
| 212 |
+
self.k = torch.zeros((2, num_heads, max_len, head_dim), device=device) if k is None else k
|
| 213 |
+
self.v = torch.zeros((2, num_heads, max_len, head_dim), device=device) if v is None else v
|
| 214 |
+
self.current_idx = 0
|
| 215 |
+
self.max_len = max_len
|
| 216 |
+
|
| 217 |
+
def get_kv_for_attention(self, current_k, current_v):
|
| 218 |
+
if self.current_idx == 0:
|
| 219 |
+
return current_k, current_v
|
| 220 |
+
else:
|
| 221 |
+
past_k = self.k[:, :, : self.current_idx, :]
|
| 222 |
+
past_v = self.v[:, :, : self.current_idx, :]
|
| 223 |
+
attn_k = torch.cat((past_k, current_k), dim=2)
|
| 224 |
+
attn_v = torch.cat((past_v, current_v), dim=2)
|
| 225 |
+
return attn_k, attn_v
|
| 226 |
+
|
| 227 |
+
def update_cache(self, k, v):
|
| 228 |
+
assert self.current_idx < self.max_len
|
| 229 |
+
self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
|
| 230 |
+
self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
|
| 231 |
+
self.current_idx += 1
|
| 232 |
+
|
| 233 |
+
def prefill_kv(self, k, v):
|
| 234 |
+
prefill_len = k.shape[2]
|
| 235 |
+
assert prefill_len <= self.max_len
|
| 236 |
+
self.k[:, :, :prefill_len, :] = k
|
| 237 |
+
self.v[:, :, :prefill_len, :] = v
|
| 238 |
+
self.current_idx = prefill_len
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class Attention(nn.Module):
|
| 242 |
+
"""Attention using DenseGeneral."""
|
| 243 |
+
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
config: DiaConfig,
|
| 247 |
+
q_embed_dim: int,
|
| 248 |
+
kv_embed_dim: int,
|
| 249 |
+
num_query_heads: int,
|
| 250 |
+
num_kv_heads: int,
|
| 251 |
+
head_dim: int,
|
| 252 |
+
dropout_rate: float,
|
| 253 |
+
is_cross_attn: bool = False,
|
| 254 |
+
out_embed_dim: int | None = None,
|
| 255 |
+
):
|
| 256 |
+
super().__init__()
|
| 257 |
+
self.num_query_heads = num_query_heads
|
| 258 |
+
self.num_kv_heads = num_kv_heads
|
| 259 |
+
self.head_dim = head_dim
|
| 260 |
+
self.is_cross_attn = is_cross_attn
|
| 261 |
+
self.dropout_rate = dropout_rate
|
| 262 |
+
compute_dtype = _str_to_dtype(config.training.dtype)
|
| 263 |
+
weight_dtype = _str_to_dtype(config.model.weight_dtype)
|
| 264 |
+
self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
|
| 265 |
+
self.projected_query_dim = num_query_heads * head_dim
|
| 266 |
+
if num_query_heads % num_kv_heads != 0:
|
| 267 |
+
raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})")
|
| 268 |
+
self.num_gqa_groups = num_query_heads // num_kv_heads
|
| 269 |
+
|
| 270 |
+
# --- Projection Layers using DenseGeneral ---
|
| 271 |
+
self.q_proj = DenseGeneral(
|
| 272 |
+
in_shapes=(q_embed_dim,),
|
| 273 |
+
out_features=(num_query_heads, head_dim),
|
| 274 |
+
axis=(-1,),
|
| 275 |
+
dtype=compute_dtype,
|
| 276 |
+
weight_dtype=weight_dtype,
|
| 277 |
+
)
|
| 278 |
+
self.k_proj = DenseGeneral(
|
| 279 |
+
in_shapes=(kv_embed_dim,),
|
| 280 |
+
out_features=(num_kv_heads, head_dim),
|
| 281 |
+
axis=(-1,),
|
| 282 |
+
dtype=compute_dtype,
|
| 283 |
+
weight_dtype=weight_dtype,
|
| 284 |
+
)
|
| 285 |
+
self.v_proj = DenseGeneral(
|
| 286 |
+
in_shapes=(kv_embed_dim,),
|
| 287 |
+
out_features=(num_kv_heads, head_dim),
|
| 288 |
+
axis=(-1,),
|
| 289 |
+
dtype=compute_dtype,
|
| 290 |
+
weight_dtype=weight_dtype,
|
| 291 |
+
)
|
| 292 |
+
self.o_proj = DenseGeneral(
|
| 293 |
+
in_shapes=(num_query_heads, head_dim),
|
| 294 |
+
out_features=(self.output_dim,),
|
| 295 |
+
axis=(-2, -1),
|
| 296 |
+
dtype=compute_dtype,
|
| 297 |
+
weight_dtype=weight_dtype,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# --- Rotary Embedding ---
|
| 301 |
+
self.rotary_emb = RotaryEmbedding(
|
| 302 |
+
embedding_dims=self.head_dim,
|
| 303 |
+
min_timescale=config.model.rope_min_timescale,
|
| 304 |
+
max_timescale=config.model.rope_max_timescale,
|
| 305 |
+
dtype=compute_dtype,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
def forward(
|
| 309 |
+
self,
|
| 310 |
+
Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
|
| 311 |
+
Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
|
| 312 |
+
q_positions: torch.Tensor, # (B, T)
|
| 313 |
+
kv_positions: torch.Tensor | None = None, # (B, S)
|
| 314 |
+
deterministic: bool = True,
|
| 315 |
+
attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others
|
| 316 |
+
cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
|
| 317 |
+
prefill: bool = False, # True only when prefilling KV Cache
|
| 318 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
|
| 319 |
+
"""
|
| 320 |
+
Performs attention calculation with optional KV caching.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
Xq: Query tensor (B, T, D). T=1 during single-step decoding.
|
| 324 |
+
Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
|
| 325 |
+
q_positions: Positions for queries (B, T).
|
| 326 |
+
kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
|
| 327 |
+
deterministic: If True, disable dropout.
|
| 328 |
+
attn_mask: Attention mask.
|
| 329 |
+
cache: KVCache.
|
| 330 |
+
prefill: If True, use prefill mode.
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
A tuple containing:
|
| 334 |
+
- output: The attention output tensor (B, T, output_dim).
|
| 335 |
+
- present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
|
| 336 |
+
"""
|
| 337 |
+
if kv_positions is None:
|
| 338 |
+
kv_positions = q_positions
|
| 339 |
+
original_dtype = Xq.dtype
|
| 340 |
+
|
| 341 |
+
Xq_BxTxNxH = self.q_proj(Xq)
|
| 342 |
+
Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
|
| 343 |
+
Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
|
| 344 |
+
|
| 345 |
+
# Input values into attention calculation
|
| 346 |
+
attn_k: torch.Tensor | None = None
|
| 347 |
+
attn_v: torch.Tensor | None = None
|
| 348 |
+
new_kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None
|
| 349 |
+
|
| 350 |
+
# Decoder Cross Attention
|
| 351 |
+
if self.is_cross_attn:
|
| 352 |
+
# Directly use cache (no need to check index)
|
| 353 |
+
attn_k, attn_v = cache.k, cache.v
|
| 354 |
+
if attn_k.shape[1] != self.num_query_heads or attn_v.shape[1] != self.num_query_heads:
|
| 355 |
+
raise ValueError(
|
| 356 |
+
f"Cross-attention cache head dimension ({attn_k.shape[1]}) "
|
| 357 |
+
f"does not match num_query_heads ({self.num_query_heads}). "
|
| 358 |
+
"Cache should be pre-repeated for GQA."
|
| 359 |
+
)
|
| 360 |
+
# Self Attention
|
| 361 |
+
else:
|
| 362 |
+
Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
|
| 363 |
+
Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
|
| 364 |
+
Xk_BxSxKxH = self.rotary_emb(Xk_BxSxKxH, position=kv_positions) # (B, S, K, H)
|
| 365 |
+
|
| 366 |
+
Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
|
| 367 |
+
Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
|
| 368 |
+
# S=1 for Decode Step
|
| 369 |
+
|
| 370 |
+
if self.num_gqa_groups > 1:
|
| 371 |
+
Xk_BxNxSxH = Xk_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
|
| 372 |
+
Xv_BxNxSxH = Xv_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
|
| 373 |
+
else:
|
| 374 |
+
Xk_BxNxSxH = Xk_BxKxSxH
|
| 375 |
+
Xv_BxNxSxH = Xv_BxKxSxH
|
| 376 |
+
|
| 377 |
+
# Encoder Self Attention
|
| 378 |
+
if cache is None:
|
| 379 |
+
attn_k = Xk_BxNxSxH
|
| 380 |
+
attn_v = Xv_BxNxSxH
|
| 381 |
+
# Decoder Self Attention
|
| 382 |
+
else:
|
| 383 |
+
# In prefill mode, we fill in cache until prefill length
|
| 384 |
+
if prefill:
|
| 385 |
+
attn_k, attn_v = Xk_BxNxSxH, Xv_BxNxSxH
|
| 386 |
+
cache.prefill_kv(attn_k, attn_v)
|
| 387 |
+
# In decode step, we add current K/V to cache step by step
|
| 388 |
+
else:
|
| 389 |
+
new_kv_cache = Xk_BxNxSxH, Xv_BxNxSxH
|
| 390 |
+
attn_k, attn_v = cache.get_kv_for_attention(Xk_BxNxSxH, Xv_BxNxSxH)
|
| 391 |
+
|
| 392 |
+
attn_output = F.scaled_dot_product_attention(
|
| 393 |
+
Xq_BxNxTxH,
|
| 394 |
+
attn_k,
|
| 395 |
+
attn_v,
|
| 396 |
+
attn_mask=attn_mask,
|
| 397 |
+
dropout_p=self.dropout_rate if not deterministic else 0.0,
|
| 398 |
+
scale=1.0,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
|
| 402 |
+
output = self.o_proj(attn_output)
|
| 403 |
+
|
| 404 |
+
return output.to(original_dtype), new_kv_cache
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class EncoderLayer(nn.Module):
|
| 408 |
+
"""Transformer Encoder Layer using DenseGeneral."""
|
| 409 |
+
|
| 410 |
+
def __init__(self, config: DiaConfig):
|
| 411 |
+
super().__init__()
|
| 412 |
+
self.config = config
|
| 413 |
+
model_config = config.model
|
| 414 |
+
enc_config = config.model.encoder
|
| 415 |
+
embed_dim = enc_config.n_embd
|
| 416 |
+
|
| 417 |
+
self.pre_sa_norm = RMSNorm(
|
| 418 |
+
embed_dim,
|
| 419 |
+
eps=model_config.normalization_layer_epsilon,
|
| 420 |
+
dtype=torch.float32,
|
| 421 |
+
)
|
| 422 |
+
self.self_attention = Attention(
|
| 423 |
+
config=config,
|
| 424 |
+
q_embed_dim=embed_dim,
|
| 425 |
+
kv_embed_dim=embed_dim,
|
| 426 |
+
num_query_heads=enc_config.n_head,
|
| 427 |
+
num_kv_heads=enc_config.n_head,
|
| 428 |
+
head_dim=enc_config.head_dim,
|
| 429 |
+
dropout_rate=model_config.dropout,
|
| 430 |
+
is_cross_attn=False,
|
| 431 |
+
out_embed_dim=embed_dim,
|
| 432 |
+
)
|
| 433 |
+
self.post_sa_norm = RMSNorm(
|
| 434 |
+
embed_dim,
|
| 435 |
+
eps=model_config.normalization_layer_epsilon,
|
| 436 |
+
dtype=torch.float32,
|
| 437 |
+
)
|
| 438 |
+
self.mlp = MlpBlock(
|
| 439 |
+
config=config,
|
| 440 |
+
embed_dim=embed_dim,
|
| 441 |
+
intermediate_dim=enc_config.n_hidden,
|
| 442 |
+
activations=enc_config.mlp_activations,
|
| 443 |
+
dropout_rate=model_config.dropout,
|
| 444 |
+
use_pre_norm=enc_config.use_pre_norm,
|
| 445 |
+
)
|
| 446 |
+
self.dropout = nn.Dropout(model_config.dropout)
|
| 447 |
+
|
| 448 |
+
def forward(
|
| 449 |
+
self,
|
| 450 |
+
x: torch.Tensor,
|
| 451 |
+
src_positions: torch.Tensor | None = None,
|
| 452 |
+
deterministic: bool = True,
|
| 453 |
+
attn_mask: torch.Tensor | None = None,
|
| 454 |
+
) -> torch.Tensor:
|
| 455 |
+
residual = x
|
| 456 |
+
x_norm = self.pre_sa_norm(x)
|
| 457 |
+
|
| 458 |
+
sa_out, _ = self.self_attention(
|
| 459 |
+
Xq=x_norm,
|
| 460 |
+
Xkv=x_norm,
|
| 461 |
+
q_positions=src_positions,
|
| 462 |
+
kv_positions=src_positions,
|
| 463 |
+
deterministic=deterministic,
|
| 464 |
+
attn_mask=attn_mask,
|
| 465 |
+
)
|
| 466 |
+
x = residual + sa_out
|
| 467 |
+
|
| 468 |
+
residual = x
|
| 469 |
+
x_norm = self.post_sa_norm(x)
|
| 470 |
+
mlp_out = self.mlp(x_norm, deterministic=deterministic)
|
| 471 |
+
x = residual + mlp_out
|
| 472 |
+
|
| 473 |
+
if not deterministic:
|
| 474 |
+
x = self.dropout(x)
|
| 475 |
+
return x
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class Encoder(nn.Module):
|
| 479 |
+
"""Transformer Encoder Stack using DenseGeneral."""
|
| 480 |
+
|
| 481 |
+
def __init__(self, config: DiaConfig):
|
| 482 |
+
super().__init__()
|
| 483 |
+
self.config = config
|
| 484 |
+
model_config = config.model
|
| 485 |
+
enc_config = config.model.encoder
|
| 486 |
+
compute_dtype = _str_to_dtype(config.training.dtype)
|
| 487 |
+
|
| 488 |
+
self.embedding = nn.Embedding(
|
| 489 |
+
model_config.src_vocab_size,
|
| 490 |
+
enc_config.n_embd,
|
| 491 |
+
dtype=compute_dtype,
|
| 492 |
+
)
|
| 493 |
+
self.dropout = nn.Dropout(model_config.dropout)
|
| 494 |
+
self.layers = nn.ModuleList([EncoderLayer(config=config) for _ in range(enc_config.n_layer)])
|
| 495 |
+
self.norm = RMSNorm(
|
| 496 |
+
enc_config.n_embd,
|
| 497 |
+
eps=model_config.normalization_layer_epsilon,
|
| 498 |
+
dtype=torch.float32,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
def forward(
|
| 502 |
+
self,
|
| 503 |
+
x_ids: torch.Tensor,
|
| 504 |
+
src_positions: torch.Tensor | None = None,
|
| 505 |
+
deterministic: bool = True,
|
| 506 |
+
attn_mask: torch.Tensor | None = None,
|
| 507 |
+
) -> torch.Tensor:
|
| 508 |
+
x = self.embedding(x_ids)
|
| 509 |
+
|
| 510 |
+
if not deterministic:
|
| 511 |
+
x = self.dropout(x)
|
| 512 |
+
|
| 513 |
+
for layer in self.layers:
|
| 514 |
+
x = layer(
|
| 515 |
+
x,
|
| 516 |
+
src_positions=src_positions,
|
| 517 |
+
deterministic=deterministic,
|
| 518 |
+
attn_mask=attn_mask,
|
| 519 |
+
)
|
| 520 |
+
x = self.norm(x)
|
| 521 |
+
if not deterministic:
|
| 522 |
+
x = self.dropout(x)
|
| 523 |
+
return x
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class DecoderLayer(nn.Module):
|
| 527 |
+
"""Transformer Decoder Layer using DenseGeneral."""
|
| 528 |
+
|
| 529 |
+
def __init__(self, config: DiaConfig):
|
| 530 |
+
super().__init__()
|
| 531 |
+
self.config = config
|
| 532 |
+
model_config = config.model
|
| 533 |
+
dec_config = config.model.decoder
|
| 534 |
+
enc_config = config.model.encoder
|
| 535 |
+
dec_embed_dim = dec_config.n_embd
|
| 536 |
+
enc_embed_dim = enc_config.n_embd
|
| 537 |
+
|
| 538 |
+
# Norms
|
| 539 |
+
self.pre_sa_norm = RMSNorm(
|
| 540 |
+
dec_embed_dim,
|
| 541 |
+
eps=model_config.normalization_layer_epsilon,
|
| 542 |
+
dtype=torch.float32,
|
| 543 |
+
)
|
| 544 |
+
self.pre_ca_norm = RMSNorm(
|
| 545 |
+
dec_embed_dim,
|
| 546 |
+
eps=model_config.normalization_layer_epsilon,
|
| 547 |
+
dtype=torch.float32,
|
| 548 |
+
)
|
| 549 |
+
self.pre_mlp_norm = RMSNorm(
|
| 550 |
+
dec_embed_dim,
|
| 551 |
+
eps=model_config.normalization_layer_epsilon,
|
| 552 |
+
dtype=torch.float32,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
# Self-Attention (GQA) with Causal Masking
|
| 556 |
+
self.self_attention = Attention(
|
| 557 |
+
config=config,
|
| 558 |
+
q_embed_dim=dec_embed_dim,
|
| 559 |
+
kv_embed_dim=dec_embed_dim,
|
| 560 |
+
num_query_heads=dec_config.gqa_query_heads,
|
| 561 |
+
num_kv_heads=dec_config.kv_heads,
|
| 562 |
+
head_dim=dec_config.gqa_head_dim,
|
| 563 |
+
dropout_rate=model_config.dropout,
|
| 564 |
+
is_cross_attn=False,
|
| 565 |
+
out_embed_dim=dec_embed_dim,
|
| 566 |
+
)
|
| 567 |
+
# Cross-Attention (MHA)
|
| 568 |
+
self.cross_attention = Attention(
|
| 569 |
+
config=config,
|
| 570 |
+
q_embed_dim=dec_embed_dim,
|
| 571 |
+
kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
|
| 572 |
+
num_query_heads=dec_config.cross_query_heads,
|
| 573 |
+
num_kv_heads=dec_config.cross_query_heads,
|
| 574 |
+
head_dim=dec_config.cross_head_dim,
|
| 575 |
+
dropout_rate=model_config.dropout,
|
| 576 |
+
is_cross_attn=True,
|
| 577 |
+
out_embed_dim=dec_embed_dim,
|
| 578 |
+
)
|
| 579 |
+
# MLP
|
| 580 |
+
self.mlp = MlpBlock(
|
| 581 |
+
config=config,
|
| 582 |
+
embed_dim=dec_embed_dim,
|
| 583 |
+
intermediate_dim=dec_config.n_hidden,
|
| 584 |
+
activations=dec_config.mlp_activations,
|
| 585 |
+
dropout_rate=model_config.dropout,
|
| 586 |
+
use_pre_norm=dec_config.use_pre_norm,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
def forward(
|
| 590 |
+
self,
|
| 591 |
+
x: torch.Tensor,
|
| 592 |
+
encoder_out: torch.Tensor,
|
| 593 |
+
tgt_positions: torch.Tensor,
|
| 594 |
+
src_positions: torch.Tensor | None,
|
| 595 |
+
deterministic: bool,
|
| 596 |
+
self_attn_mask: torch.Tensor,
|
| 597 |
+
cross_attn_mask: torch.Tensor,
|
| 598 |
+
self_attn_cache: KVCache,
|
| 599 |
+
cross_attn_cache: KVCache,
|
| 600 |
+
prefill: bool = False,
|
| 601 |
+
) -> torch.Tensor:
|
| 602 |
+
residual = x
|
| 603 |
+
x_norm = self.pre_sa_norm(x)
|
| 604 |
+
|
| 605 |
+
sa_out, new_kv_cache = self.self_attention(
|
| 606 |
+
Xq=x_norm, # (2, 1, D)
|
| 607 |
+
Xkv=x_norm, # (2, 1, D)
|
| 608 |
+
q_positions=tgt_positions, # (2, 1)
|
| 609 |
+
kv_positions=tgt_positions, # (2, 1)
|
| 610 |
+
deterministic=deterministic,
|
| 611 |
+
attn_mask=self_attn_mask, # (2, 1, 1, S_max)
|
| 612 |
+
cache=self_attn_cache,
|
| 613 |
+
prefill=prefill,
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
x = residual + sa_out
|
| 617 |
+
|
| 618 |
+
# 2. Cross-Attention
|
| 619 |
+
residual = x
|
| 620 |
+
x_norm = self.pre_ca_norm(x)
|
| 621 |
+
ca_out, _ = self.cross_attention(
|
| 622 |
+
Xq=x_norm,
|
| 623 |
+
Xkv=encoder_out,
|
| 624 |
+
q_positions=tgt_positions,
|
| 625 |
+
kv_positions=src_positions,
|
| 626 |
+
deterministic=deterministic,
|
| 627 |
+
attn_mask=cross_attn_mask,
|
| 628 |
+
cache=cross_attn_cache,
|
| 629 |
+
)
|
| 630 |
+
x = residual + ca_out
|
| 631 |
+
|
| 632 |
+
# 3. MLP
|
| 633 |
+
residual = x
|
| 634 |
+
x_norm = self.pre_mlp_norm(x)
|
| 635 |
+
mlp_out = self.mlp(x_norm, deterministic=deterministic)
|
| 636 |
+
x = residual + mlp_out
|
| 637 |
+
|
| 638 |
+
return x, new_kv_cache
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
class Decoder(nn.Module):
|
| 642 |
+
"""Transformer Decoder Stack using DenseGeneral."""
|
| 643 |
+
|
| 644 |
+
def __init__(self, config: DiaConfig):
|
| 645 |
+
super().__init__()
|
| 646 |
+
self.config = config
|
| 647 |
+
model_config = config.model
|
| 648 |
+
dec_config = config.model.decoder
|
| 649 |
+
train_config = config.training
|
| 650 |
+
data_config = config.data
|
| 651 |
+
compute_dtype = _str_to_dtype(config.training.dtype)
|
| 652 |
+
weight_dtype = _str_to_dtype(config.model.weight_dtype)
|
| 653 |
+
self.num_channels = data_config.channels
|
| 654 |
+
self.num_layers = dec_config.n_layer
|
| 655 |
+
|
| 656 |
+
self.embeddings = nn.ModuleList(
|
| 657 |
+
[
|
| 658 |
+
nn.Embedding(model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype)
|
| 659 |
+
for _ in range(self.num_channels)
|
| 660 |
+
]
|
| 661 |
+
)
|
| 662 |
+
self.dropout = nn.Dropout(model_config.dropout)
|
| 663 |
+
self.layers = nn.ModuleList([DecoderLayer(config=config) for _ in range(self.num_layers)])
|
| 664 |
+
self.norm = RMSNorm(
|
| 665 |
+
dec_config.n_embd,
|
| 666 |
+
eps=model_config.normalization_layer_epsilon,
|
| 667 |
+
dtype=torch.float32,
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
# Final Logits Projection using DenseGeneral
|
| 671 |
+
self.logits_dense = DenseGeneral(
|
| 672 |
+
in_shapes=(dec_config.n_embd,),
|
| 673 |
+
out_features=(self.num_channels, model_config.tgt_vocab_size),
|
| 674 |
+
axis=(-1,),
|
| 675 |
+
dtype=(torch.float32 if train_config.logits_dot_in_fp32 else compute_dtype),
|
| 676 |
+
weight_dtype=weight_dtype,
|
| 677 |
+
)
|
| 678 |
+
self.logits_in_fp32 = train_config.logits_dot_in_fp32
|
| 679 |
+
|
| 680 |
+
def precompute_cross_attention_kv(
|
| 681 |
+
self,
|
| 682 |
+
max_len: int,
|
| 683 |
+
encoder_out: torch.Tensor, # (B, S, E)
|
| 684 |
+
src_positions: torch.Tensor | None, # (B, S)
|
| 685 |
+
) -> list[KVCache]:
|
| 686 |
+
"""
|
| 687 |
+
Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
|
| 688 |
+
"""
|
| 689 |
+
per_layer_kv_cache: list[KVCache] = []
|
| 690 |
+
|
| 691 |
+
for layer in self.layers:
|
| 692 |
+
cross_attn_module = layer.cross_attention
|
| 693 |
+
k_proj = cross_attn_module.k_proj(encoder_out)
|
| 694 |
+
v_proj = cross_attn_module.v_proj(encoder_out)
|
| 695 |
+
|
| 696 |
+
k_proj = cross_attn_module.rotary_emb(k_proj, position=src_positions)
|
| 697 |
+
k = k_proj.transpose(1, 2)
|
| 698 |
+
v = v_proj.transpose(1, 2)
|
| 699 |
+
|
| 700 |
+
per_layer_kv_cache.append(
|
| 701 |
+
KVCache(
|
| 702 |
+
cross_attn_module.num_kv_heads,
|
| 703 |
+
max_len,
|
| 704 |
+
cross_attn_module.head_dim,
|
| 705 |
+
k.device,
|
| 706 |
+
k=k,
|
| 707 |
+
v=v,
|
| 708 |
+
)
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
return per_layer_kv_cache
|
| 712 |
+
|
| 713 |
+
def decode_step(
|
| 714 |
+
self,
|
| 715 |
+
tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
|
| 716 |
+
tgt_pos_Bx1: torch.Tensor, # [B, 1]
|
| 717 |
+
encoder_out: torch.Tensor, # [B, S, E]
|
| 718 |
+
self_attn_mask: Any, # None
|
| 719 |
+
cross_attn_mask: torch.Tensor, # [B, 1, 1, S]
|
| 720 |
+
self_attention_cache: list[KVCache],
|
| 721 |
+
cross_attention_cache: list[KVCache],
|
| 722 |
+
) -> torch.Tensor:
|
| 723 |
+
"""
|
| 724 |
+
Performs a single decoding step, managing KV caches layer by layer.
|
| 725 |
+
|
| 726 |
+
Returns:
|
| 727 |
+
A tuple containing:
|
| 728 |
+
- logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
|
| 729 |
+
"""
|
| 730 |
+
assert self_attn_mask is None, "Self-attention mask should be None, kept for pattern"
|
| 731 |
+
|
| 732 |
+
x = None
|
| 733 |
+
for i in range(self.num_channels):
|
| 734 |
+
channel_tokens = tgt_ids_Bx1xC[..., i]
|
| 735 |
+
channel_embed = self.embeddings[i](channel_tokens)
|
| 736 |
+
x = channel_embed if x is None else x + channel_embed
|
| 737 |
+
|
| 738 |
+
new_cache = []
|
| 739 |
+
|
| 740 |
+
for i, layer in enumerate(self.layers):
|
| 741 |
+
self_cache = self_attention_cache[i]
|
| 742 |
+
cross_cache = cross_attention_cache[i]
|
| 743 |
+
x, new_kv_cache = layer(
|
| 744 |
+
x, # (2, 1, D)
|
| 745 |
+
encoder_out, # (2, S, E)
|
| 746 |
+
src_positions=None, # CA KV is already computed
|
| 747 |
+
tgt_positions=tgt_pos_Bx1, # (2, 1)
|
| 748 |
+
deterministic=True,
|
| 749 |
+
self_attn_mask=None,
|
| 750 |
+
cross_attn_mask=cross_attn_mask,
|
| 751 |
+
self_attn_cache=self_cache,
|
| 752 |
+
cross_attn_cache=cross_cache,
|
| 753 |
+
)
|
| 754 |
+
new_cache.append(new_kv_cache)
|
| 755 |
+
|
| 756 |
+
x = self.norm(x)
|
| 757 |
+
logits_Bx1xCxV = self.logits_dense(x)
|
| 758 |
+
|
| 759 |
+
return logits_Bx1xCxV.to(torch.float32), new_cache
|
| 760 |
+
|
| 761 |
+
def forward(
|
| 762 |
+
self,
|
| 763 |
+
tgt_ids_BxTxC: torch.Tensor,
|
| 764 |
+
encoder_out: torch.Tensor,
|
| 765 |
+
tgt_positions: torch.Tensor,
|
| 766 |
+
src_positions: torch.Tensor,
|
| 767 |
+
deterministic: bool,
|
| 768 |
+
self_attn_mask: torch.Tensor,
|
| 769 |
+
cross_attn_mask: torch.Tensor,
|
| 770 |
+
self_attention_cache: list[KVCache],
|
| 771 |
+
cross_attention_cache: list[KVCache],
|
| 772 |
+
) -> torch.Tensor:
|
| 773 |
+
"""
|
| 774 |
+
Forward pass for the Decoder stack, managing KV caches.
|
| 775 |
+
|
| 776 |
+
Args:
|
| 777 |
+
tgt_ids_BxTxC: Target token IDs (B, T, C).
|
| 778 |
+
encoder_out: Output from the encoder (B, S, E).
|
| 779 |
+
tgt_positions: Positions for target sequence (B, T).
|
| 780 |
+
src_positions: Positions for source sequence (B, S).
|
| 781 |
+
deterministic: Disable dropout if True.
|
| 782 |
+
self_attn_mask: Mask for self-attention.
|
| 783 |
+
cross_attn_mask: Mask for cross-attention.
|
| 784 |
+
past_key_values: List containing the self-attention KV cache for each layer
|
| 785 |
+
from the previous decoding step. `len(past_key_values)` should
|
| 786 |
+
equal `num_layers`.
|
| 787 |
+
precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
|
| 788 |
+
derived from `encoder_out`. This is passed identically
|
| 789 |
+
to all layers.
|
| 790 |
+
|
| 791 |
+
Returns:
|
| 792 |
+
A tuple containing:
|
| 793 |
+
- logits: The final output logits (B, T, C * V), cast to float32.
|
| 794 |
+
- present_key_values: A list containing the updated self-attention KV cache
|
| 795 |
+
for each layer for the *current* decoding step.
|
| 796 |
+
"""
|
| 797 |
+
_, _, num_channels_in = tgt_ids_BxTxC.shape
|
| 798 |
+
assert num_channels_in == self.num_channels, "Input channels mismatch"
|
| 799 |
+
|
| 800 |
+
# Embeddings
|
| 801 |
+
x = None
|
| 802 |
+
for i in range(self.num_channels):
|
| 803 |
+
channel_tokens = tgt_ids_BxTxC[..., i]
|
| 804 |
+
channel_embed = self.embeddings[i](channel_tokens)
|
| 805 |
+
x = channel_embed if x is None else x + channel_embed
|
| 806 |
+
|
| 807 |
+
if not deterministic:
|
| 808 |
+
x = self.dropout(x)
|
| 809 |
+
|
| 810 |
+
for i, layer in enumerate(self.layers):
|
| 811 |
+
x, _ = layer(
|
| 812 |
+
x,
|
| 813 |
+
encoder_out,
|
| 814 |
+
tgt_positions=tgt_positions,
|
| 815 |
+
src_positions=src_positions,
|
| 816 |
+
deterministic=deterministic,
|
| 817 |
+
self_attn_mask=self_attn_mask,
|
| 818 |
+
cross_attn_mask=cross_attn_mask,
|
| 819 |
+
self_attn_cache=self_attention_cache[i],
|
| 820 |
+
cross_attn_cache=cross_attention_cache[i],
|
| 821 |
+
prefill=True,
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
# Final Norm
|
| 825 |
+
x = self.norm(x)
|
| 826 |
+
logits_BxTxCxV = self.logits_dense(x)
|
| 827 |
+
|
| 828 |
+
return logits_BxTxCxV.to(torch.float32)
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
class DiaModel(nn.Module):
|
| 832 |
+
"""PyTorch Dia Model using DenseGeneral."""
|
| 833 |
+
|
| 834 |
+
def __init__(self, config: DiaConfig):
|
| 835 |
+
super().__init__()
|
| 836 |
+
self.config = config
|
| 837 |
+
self.encoder = Encoder(config)
|
| 838 |
+
self.decoder = Decoder(config)
|
| 839 |
+
#self._init_weights()
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
def _init_weights(self):
|
| 843 |
+
for module in self.modules():
|
| 844 |
+
if isinstance(module, (torch.nn.Linear, torch.nn.Conv1d)):
|
| 845 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 846 |
+
if module.bias is not None:
|
| 847 |
+
torch.nn.init.zeros_(module.bias)
|
| 848 |
+
elif isinstance(module, torch.nn.Embedding):
|
| 849 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 850 |
+
elif isinstance(module, torch.nn.LayerNorm) or isinstance(module, torch.nn.modules.normalization.RMSNorm):
|
| 851 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
| 852 |
+
torch.nn.init.ones_(module.weight)
|
| 853 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 854 |
+
torch.nn.init.zeros_(module.bias)
|
| 855 |
+
|
| 856 |
+
def forward(
|
| 857 |
+
self,
|
| 858 |
+
src_BxS: torch.Tensor,
|
| 859 |
+
tgt_BxTxC: torch.Tensor,
|
| 860 |
+
src_positions: torch.Tensor | None = None,
|
| 861 |
+
tgt_positions: torch.Tensor | None = None,
|
| 862 |
+
enc_self_attn_mask: torch.Tensor | None = None,
|
| 863 |
+
dec_self_attn_mask: torch.Tensor | None = None,
|
| 864 |
+
dec_cross_attn_mask: torch.Tensor | None = None,
|
| 865 |
+
enable_dropout: bool = True,
|
| 866 |
+
):
|
| 867 |
+
deterministic = not enable_dropout
|
| 868 |
+
|
| 869 |
+
# --- Encoder Pass ---
|
| 870 |
+
encoder_out = self.encoder(
|
| 871 |
+
x_ids=src_BxS,
|
| 872 |
+
src_positions=src_positions,
|
| 873 |
+
deterministic=deterministic,
|
| 874 |
+
attn_mask=enc_self_attn_mask,
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
B, T, C = tgt_BxTxC.shape # Batch size, target sequence length, channels
|
| 878 |
+
device = tgt_BxTxC.device
|
| 879 |
+
|
| 880 |
+
self_attention_cache = [
|
| 881 |
+
KVCache(
|
| 882 |
+
num_heads=self.decoder.layers[i].self_attention.num_query_heads, # ✅ FIXED: use query heads!
|
| 883 |
+
max_len=T,
|
| 884 |
+
head_dim=self.decoder.layers[i].self_attention.head_dim,
|
| 885 |
+
device=device,
|
| 886 |
+
)
|
| 887 |
+
for i in range(self.decoder.num_layers)
|
| 888 |
+
]
|
| 889 |
+
|
| 890 |
+
cross_attention_cache = self.decoder.precompute_cross_attention_kv(
|
| 891 |
+
max_len=encoder_out.shape[1],
|
| 892 |
+
encoder_out=encoder_out,
|
| 893 |
+
src_positions=src_positions,
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
# --- Decoder Pass ---
|
| 897 |
+
logits = self.decoder(
|
| 898 |
+
tgt_ids_BxTxC=tgt_BxTxC,
|
| 899 |
+
encoder_out=encoder_out,
|
| 900 |
+
tgt_positions=tgt_positions,
|
| 901 |
+
src_positions=src_positions,
|
| 902 |
+
deterministic=deterministic,
|
| 903 |
+
self_attn_mask=dec_self_attn_mask,
|
| 904 |
+
cross_attn_mask=dec_cross_attn_mask,
|
| 905 |
+
self_attention_cache=self_attention_cache,
|
| 906 |
+
cross_attention_cache=cross_attention_cache
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
return logits
|
dia/model.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dac
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
|
| 7 |
+
from .audio import audio_to_codebook, codebook_to_audio
|
| 8 |
+
from .config import DiaConfig
|
| 9 |
+
from .layers import DiaModel, KVCache
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_default_device():
|
| 13 |
+
if torch.cuda.is_available():
|
| 14 |
+
return torch.device("cuda")
|
| 15 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 16 |
+
return torch.device("mps")
|
| 17 |
+
return torch.device("cpu")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _sample_next_token(
|
| 21 |
+
logits_BCxV: torch.Tensor,
|
| 22 |
+
temperature: float,
|
| 23 |
+
top_p: float,
|
| 24 |
+
use_cfg_filter: bool,
|
| 25 |
+
cfg_filter_top_k: int | None = None,
|
| 26 |
+
) -> torch.Tensor:
|
| 27 |
+
if temperature == 0.0:
|
| 28 |
+
return torch.argmax(logits_BCxV, dim=-1)
|
| 29 |
+
|
| 30 |
+
logits_BCxV = logits_BCxV / temperature
|
| 31 |
+
if use_cfg_filter and cfg_filter_top_k is not None:
|
| 32 |
+
_, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
|
| 33 |
+
mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
|
| 34 |
+
mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
|
| 35 |
+
logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
|
| 36 |
+
|
| 37 |
+
if top_p < 1.0:
|
| 38 |
+
probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
| 39 |
+
sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True)
|
| 40 |
+
cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
|
| 41 |
+
|
| 42 |
+
# Calculate indices to remove based on top_p
|
| 43 |
+
sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
|
| 44 |
+
# Shift the mask to the right to keep the first token above the threshold
|
| 45 |
+
sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[..., :-1].clone()
|
| 46 |
+
sorted_indices_to_remove_BCxV[..., 0] = 0 # Always keep the most probable token
|
| 47 |
+
|
| 48 |
+
indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
|
| 49 |
+
indices_to_remove_BCxV.scatter_(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV)
|
| 50 |
+
logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
|
| 51 |
+
|
| 52 |
+
final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
| 53 |
+
|
| 54 |
+
sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1)
|
| 55 |
+
sampled_indices_C = sampled_indices_BC.squeeze(-1)
|
| 56 |
+
return sampled_indices_C
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Dia:
|
| 60 |
+
def __init__(self, config: DiaConfig, device: torch.device | None = None):
|
| 61 |
+
"""Initializes the Dia model.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
config: The configuration object for the model.
|
| 65 |
+
device: The device to load the model onto. If None, will automatically select the best available device.
|
| 66 |
+
|
| 67 |
+
Raises:
|
| 68 |
+
RuntimeError: If there is an error loading the DAC model.
|
| 69 |
+
"""
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.config = config
|
| 72 |
+
self.device = device if device is not None else get_default_device()
|
| 73 |
+
self.model = DiaModel(config)
|
| 74 |
+
self.dac_model = None
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def from_local(cls, config_path: str, checkpoint_path: str, device: torch.device | None = None) -> "Dia":
|
| 78 |
+
"""Loads the Dia model from local configuration and checkpoint files.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
config_path: Path to the configuration JSON file.
|
| 82 |
+
checkpoint_path: Path to the model checkpoint (.pth) file.
|
| 83 |
+
device: The device to load the model onto. If None, will automatically select the best available device.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
An instance of the Dia model loaded with weights and set to eval mode.
|
| 87 |
+
|
| 88 |
+
Raises:
|
| 89 |
+
FileNotFoundError: If the config or checkpoint file is not found.
|
| 90 |
+
RuntimeError: If there is an error loading the checkpoint.
|
| 91 |
+
"""
|
| 92 |
+
config = DiaConfig.load(config_path)
|
| 93 |
+
if config is None:
|
| 94 |
+
raise FileNotFoundError(f"Config file not found at {config_path}")
|
| 95 |
+
|
| 96 |
+
dia = cls(config, device)
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
#state_dict = torch.load(checkpoint_path, map_location=dia.device)
|
| 100 |
+
#dia.model.load_state_dict(state_dict)
|
| 101 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 102 |
+
if "model" in checkpoint:
|
| 103 |
+
state_dict = checkpoint["model"] # lấy riêng phần model
|
| 104 |
+
else:
|
| 105 |
+
state_dict = checkpoint
|
| 106 |
+
dia.model.load_state_dict(state_dict)
|
| 107 |
+
except FileNotFoundError:
|
| 108 |
+
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
| 109 |
+
except Exception as e:
|
| 110 |
+
raise RuntimeError(f"Error loading checkpoint from {checkpoint_path}") from e
|
| 111 |
+
|
| 112 |
+
dia.model.to(dia.device)
|
| 113 |
+
dia.model.eval()
|
| 114 |
+
dia._load_dac_model()
|
| 115 |
+
return dia
|
| 116 |
+
|
| 117 |
+
@classmethod
|
| 118 |
+
def from_pretrained(
|
| 119 |
+
cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device | None = None
|
| 120 |
+
) -> "Dia":
|
| 121 |
+
"""Loads the Dia model from a Hugging Face Hub repository.
|
| 122 |
+
|
| 123 |
+
Downloads the configuration and checkpoint files from the specified
|
| 124 |
+
repository ID and then loads the model.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
|
| 128 |
+
device: The device to load the model onto. If None, will automatically select the best available device.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
An instance of the Dia model loaded with weights and set to eval mode.
|
| 132 |
+
|
| 133 |
+
Raises:
|
| 134 |
+
FileNotFoundError: If config or checkpoint download/loading fails.
|
| 135 |
+
RuntimeError: If there is an error loading the checkpoint.
|
| 136 |
+
"""
|
| 137 |
+
config_path = hf_hub_download(repo_id=model_name, filename="config.json")
|
| 138 |
+
checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
|
| 139 |
+
return cls.from_local(config_path, checkpoint_path, device)
|
| 140 |
+
|
| 141 |
+
def _load_dac_model(self):
|
| 142 |
+
try:
|
| 143 |
+
dac_model_path = dac.utils.download()
|
| 144 |
+
dac_model = dac.DAC.load(dac_model_path).to(self.device)
|
| 145 |
+
except Exception as e:
|
| 146 |
+
raise RuntimeError("Failed to load DAC model") from e
|
| 147 |
+
self.dac_model = dac_model
|
| 148 |
+
|
| 149 |
+
def _create_attn_mask(
|
| 150 |
+
self,
|
| 151 |
+
q_padding_mask_1d: torch.Tensor,
|
| 152 |
+
k_padding_mask_1d: torch.Tensor,
|
| 153 |
+
is_causal: bool = False,
|
| 154 |
+
) -> torch.Tensor:
|
| 155 |
+
"""
|
| 156 |
+
Creates the attention mask (self or cross) mimicking JAX segment ID logic.
|
| 157 |
+
"""
|
| 158 |
+
B1, Tq = q_padding_mask_1d.shape
|
| 159 |
+
B2, Tk = k_padding_mask_1d.shape
|
| 160 |
+
assert B1 == B2, "Query and key batch dimensions must match"
|
| 161 |
+
|
| 162 |
+
p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
|
| 163 |
+
p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
|
| 164 |
+
|
| 165 |
+
# Condition A: Non-padding query attends to non-padding key
|
| 166 |
+
non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
|
| 167 |
+
|
| 168 |
+
# Condition B: Padding query attends to padding key
|
| 169 |
+
pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
|
| 170 |
+
|
| 171 |
+
# Combine: True if padding status is compatible (both non-pad OR both pad)
|
| 172 |
+
# This implementation follows Jax TPU splash attention kernel
|
| 173 |
+
mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
|
| 174 |
+
|
| 175 |
+
if is_causal:
|
| 176 |
+
# Ensure causality for self-attention (Tq == Tk)
|
| 177 |
+
assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal"
|
| 178 |
+
# Standard lower-triangular causal mask (True means allow)
|
| 179 |
+
causal_mask_2d = torch.tril(torch.ones((Tq, Tk), dtype=torch.bool, device=self.device)) # Shape [Tq, Tk]
|
| 180 |
+
causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
|
| 181 |
+
return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
|
| 182 |
+
else:
|
| 183 |
+
# For cross-attention or non-causal self-attention
|
| 184 |
+
return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
|
| 185 |
+
|
| 186 |
+
def _prepare_text_input(self, text: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 187 |
+
"""Encodes text prompt, pads, and creates attention mask and positions."""
|
| 188 |
+
text_pad_value = self.config.data.text_pad_value
|
| 189 |
+
max_len = self.config.data.text_length
|
| 190 |
+
|
| 191 |
+
byte_text = text.encode("utf-8")
|
| 192 |
+
replaced_bytes = byte_text
|
| 193 |
+
|
| 194 |
+
LANG2BYTE = {
|
| 195 |
+
"en": 3,
|
| 196 |
+
"vi": 19,
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
CHANNELS = [
|
| 200 |
+
"5phutcrypto",
|
| 201 |
+
"anhbanthan",
|
| 202 |
+
"anhthamtu",
|
| 203 |
+
"animerewind.official",
|
| 204 |
+
"bibitv8888",
|
| 205 |
+
"btvgo",
|
| 206 |
+
"baclieutv",
|
| 207 |
+
"bachhoaxanhcom",
|
| 208 |
+
"baodientuvov",
|
| 209 |
+
"blvckvines",
|
| 210 |
+
"boringppl",
|
| 211 |
+
"bronub",
|
| 212 |
+
"cdteam-why",
|
| 213 |
+
"cobabinhduong",
|
| 214 |
+
"cosmicwriter",
|
| 215 |
+
"cuthongthai",
|
| 216 |
+
"daiphatthanhtruyenhinhsonla",
|
| 217 |
+
"day-be-thong-minh-tv",
|
| 218 |
+
"danangtv",
|
| 219 |
+
"daihanoi-htv",
|
| 220 |
+
"daiptththainguyentntv",
|
| 221 |
+
"dongmauviet",
|
| 222 |
+
"dongthaptv",
|
| 223 |
+
"fptbongdaofficial",
|
| 224 |
+
"fonosvietnam",
|
| 225 |
+
"hieurotrong5phut-ntkt",
|
| 226 |
+
"htvtintuc",
|
| 227 |
+
"happyhidari",
|
| 228 |
+
"hoabinhtvgo",
|
| 229 |
+
"hocenglishonline",
|
| 230 |
+
"hocvienbovagau",
|
| 231 |
+
"hungyentvvngo",
|
| 232 |
+
"huynhduykhuongofficial",
|
| 233 |
+
"huynhlapofficial",
|
| 234 |
+
"jvevermind",
|
| 235 |
+
"kenhvtc16",
|
| 236 |
+
"kiengiangtv",
|
| 237 |
+
"khanhvyofficial",
|
| 238 |
+
"kienthucquansu",
|
| 239 |
+
"lamdongtv1",
|
| 240 |
+
"lamvlog",
|
| 241 |
+
"longantv-la34",
|
| 242 |
+
"mangovid",
|
| 243 |
+
"mensbay",
|
| 244 |
+
"meovatcuocsonglnv",
|
| 245 |
+
"meuchannel",
|
| 246 |
+
"ntnvlogsnguyenthanhnam",
|
| 247 |
+
"ngamradio",
|
| 248 |
+
"nhanhac555",
|
| 249 |
+
"nhantaidaiviet",
|
| 250 |
+
"ptth-trt",
|
| 251 |
+
"ptvtruyenhinhphutho",
|
| 252 |
+
"phantichgame",
|
| 253 |
+
"phephim",
|
| 254 |
+
"phimhottk-l",
|
| 255 |
+
"riwaylegal",
|
| 256 |
+
"ruangao",
|
| 257 |
+
"suckhoetamsinh",
|
| 258 |
+
"sachbiquyethanhcong",
|
| 259 |
+
"soisangbrightsidevietnamese",
|
| 260 |
+
"spiderum",
|
| 261 |
+
"spiderumbooks",
|
| 262 |
+
"sukieskitchen",
|
| 263 |
+
"tin3phut",
|
| 264 |
+
"tranthanhtown",
|
| 265 |
+
"tulemientay",
|
| 266 |
+
"tayninhtv",
|
| 267 |
+
"thainhitv",
|
| 268 |
+
"thanhpahm",
|
| 269 |
+
"thegioilaptop",
|
| 270 |
+
"thepresentwriter",
|
| 271 |
+
"tiengiangtivi",
|
| 272 |
+
"tieubaobaothom",
|
| 273 |
+
"tintucbitcoin247",
|
| 274 |
+
"truyenhinhbinhphuoc-bptv",
|
| 275 |
+
"truyenhinhyenbaiytv",
|
| 276 |
+
"truyenhinhcaobang",
|
| 277 |
+
"truyenhinhdaklakdrt",
|
| 278 |
+
"truyenhinhdaknong1",
|
| 279 |
+
"truyenhinhdienbien23.9",
|
| 280 |
+
"truyenhinhkhanhhoa",
|
| 281 |
+
"truyenhinhkontumkrt",
|
| 282 |
+
"truyenhinhnaminhntv",
|
| 283 |
+
"truyenhinhninhthuan",
|
| 284 |
+
"truyenhinhquangngai",
|
| 285 |
+
"tuantienti2911",
|
| 286 |
+
"tuyenquangttv",
|
| 287 |
+
"vovlivedoctruyen",
|
| 288 |
+
"vietcetera",
|
| 289 |
+
"vinhlongtv",
|
| 290 |
+
"voizfm",
|
| 291 |
+
"vutrunguyenthuy",
|
| 292 |
+
"vuive",
|
| 293 |
+
"w2wanime",
|
| 294 |
+
"w2wcartoon",
|
| 295 |
+
"w2whorror",
|
| 296 |
+
"w2wmovie",
|
| 297 |
+
"web5ngay",
|
| 298 |
+
"xanh24h",
|
| 299 |
+
"aiphatthanhtruyenhinhquangtri",
|
| 300 |
+
"aiphatthanhvatruyenhinhhai1908",
|
| 301 |
+
"altonghop",
|
| 302 |
+
"antvtruyenhinhcongannhandan",
|
| 303 |
+
"baihoc10phut",
|
| 304 |
+
"battlecry.khampha",
|
| 305 |
+
"betterversionvn",
|
| 306 |
+
"blogkhoinghiep",
|
| 307 |
+
"bumcn",
|
| 308 |
+
"caikinhdi_vn",
|
| 309 |
+
"canthitg",
|
| 310 |
+
"chanthienmybachnien",
|
| 311 |
+
"chauanhchao",
|
| 312 |
+
"cosu",
|
| 313 |
+
"cungmaivaobep-monan-amthuc",
|
| 314 |
+
"daiptthphuyen",
|
| 315 |
+
"daiptthtv",
|
| 316 |
+
"daitruyenhinhangiang",
|
| 317 |
+
"daitruyenhinhbacgiang",
|
| 318 |
+
"dannytran2375",
|
| 319 |
+
"daybehoc5489",
|
| 320 |
+
"daylaphegame",
|
| 321 |
+
"dienmay",
|
| 322 |
+
"ducisreal",
|
| 323 |
+
"duongfg",
|
| 324 |
+
"duyluandethuong",
|
| 325 |
+
"duythanhish",
|
| 326 |
+
"elroydevops",
|
| 327 |
+
"gc.gamelab",
|
| 328 |
+
"hacthaybachthay",
|
| 329 |
+
"hagiangtv475",
|
| 330 |
+
"haiduongtv247",
|
| 331 |
+
"hanamtv8831",
|
| 332 |
+
"hangphimtailieudienanhnd",
|
| 333 |
+
"haugiangtv",
|
| 334 |
+
"haunauday",
|
| 335 |
+
"hieu-tv",
|
| 336 |
+
"hoshiphan",
|
| 337 |
+
"jakinatsumi2915",
|
| 338 |
+
"kechuyentieuhoc1719",
|
| 339 |
+
"kenhcovan",
|
| 340 |
+
"khalid_dinh",
|
| 341 |
+
"kiaralah",
|
| 342 |
+
"laichautv",
|
| 343 |
+
"langsontvtube",
|
| 344 |
+
"megame_official",
|
| 345 |
+
"minvestvn",
|
| 346 |
+
"nguoithanhcong1991",
|
| 347 |
+
"nhatkycuocsong.",
|
| 348 |
+
"ntcanima",
|
| 349 |
+
"ptthbentre",
|
| 350 |
+
"ptthquangbinh",
|
| 351 |
+
"qrt",
|
| 352 |
+
"quangninhtv",
|
| 353 |
+
"snewsvn",
|
| 354 |
+
"soctrangtv",
|
| 355 |
+
"sunhuynpodcast",
|
| 356 |
+
"tamhonanuong",
|
| 357 |
+
"tgddreview",
|
| 358 |
+
"thaibinhtv",
|
| 359 |
+
"thanhnamedu",
|
| 360 |
+
"thanhnientvnews",
|
| 361 |
+
"thbrt",
|
| 362 |
+
"thieunhitv3630",
|
| 363 |
+
"thtpct",
|
| 364 |
+
"tinnhanh3phut868",
|
| 365 |
+
"toansam",
|
| 366 |
+
"toidicodedaoblog",
|
| 367 |
+
"tranquochuywecommit",
|
| 368 |
+
"tranvyvy",
|
| 369 |
+
"truyenhinh4k",
|
| 370 |
+
"truyenhinhbinhthuan",
|
| 371 |
+
"truyenhinhcamau69",
|
| 372 |
+
"truyenhinhdongnai_dnrtv",
|
| 373 |
+
"truyenhinhgialai",
|
| 374 |
+
"truyenhinhlaocai",
|
| 375 |
+
"truyenhinhnghean",
|
| 376 |
+
"truyenhinhvinhphuc",
|
| 377 |
+
"txtofficial8798",
|
| 378 |
+
"vanhkhuyenle",
|
| 379 |
+
"vietnh1009",
|
| 380 |
+
"visaothenhipodcast",
|
| 381 |
+
"vtc14",
|
| 382 |
+
"vtcnow",
|
| 383 |
+
"vtv24",
|
| 384 |
+
"vuive123",
|
| 385 |
+
"zombiev4",
|
| 386 |
+
]
|
| 387 |
+
LANG2BYTE.update({ch: 30 + i for i, ch in enumerate(CHANNELS)})
|
| 388 |
+
# Thay thế tag thành mã byte
|
| 389 |
+
for tag, byte_val in LANG2BYTE.items():
|
| 390 |
+
pattern = f"[{tag}]".encode("ascii") # ví dụ b"[5phutcrypto]"
|
| 391 |
+
code = bytes([byte_val]) # ví dụ b"\x1e"
|
| 392 |
+
replaced_bytes = replaced_bytes.replace(pattern, code)
|
| 393 |
+
text_tokens = list(replaced_bytes)
|
| 394 |
+
|
| 395 |
+
current_len = len(text_tokens)
|
| 396 |
+
padding_needed = max_len - current_len
|
| 397 |
+
if padding_needed <= 0:
|
| 398 |
+
text_tokens = text_tokens[:max_len]
|
| 399 |
+
padded_text_np = np.array(text_tokens, dtype=np.uint8)
|
| 400 |
+
else:
|
| 401 |
+
padded_text_np = np.pad(
|
| 402 |
+
text_tokens,
|
| 403 |
+
(0, padding_needed),
|
| 404 |
+
mode="constant",
|
| 405 |
+
constant_values=text_pad_value,
|
| 406 |
+
).astype(np.uint8)
|
| 407 |
+
|
| 408 |
+
src_tokens = torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0) # [1, S]
|
| 409 |
+
src_positions = torch.arange(max_len, device=self.device).to(torch.long).unsqueeze(0) # [1, S]
|
| 410 |
+
|
| 411 |
+
src_padding_mask = (src_tokens != text_pad_value).to(self.device) # [1, S]
|
| 412 |
+
|
| 413 |
+
enc_self_attn_mask = self._create_attn_mask(src_padding_mask, src_padding_mask, is_causal=False) # [1, S, S]
|
| 414 |
+
|
| 415 |
+
return src_tokens, src_positions, src_padding_mask, enc_self_attn_mask
|
| 416 |
+
|
| 417 |
+
@torch.inference_mode()
|
| 418 |
+
def generate(
|
| 419 |
+
self,
|
| 420 |
+
text: str,
|
| 421 |
+
max_tokens: int | None = None,
|
| 422 |
+
cfg_scale: float = 3.0,
|
| 423 |
+
temperature: float = 1.3,
|
| 424 |
+
top_p: float = 0.95,
|
| 425 |
+
use_cfg_filter: bool = True,
|
| 426 |
+
use_torch_compile: bool = False,
|
| 427 |
+
cfg_filter_top_k: int = 35,
|
| 428 |
+
audio_prompt_path: str | None = None,
|
| 429 |
+
) -> np.ndarray:
|
| 430 |
+
"""
|
| 431 |
+
Generates audio from a text prompt (and optional audio prompt) using the Nari model.
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
A tensor of generated audio codes (shape: [max_tokens, num_channels]).
|
| 435 |
+
"""
|
| 436 |
+
num_channels = self.config.data.channels
|
| 437 |
+
audio_bos_value = self.config.data.audio_bos_value
|
| 438 |
+
audio_eos_value = self.config.data.audio_eos_value
|
| 439 |
+
audio_pad_value = self.config.data.audio_pad_value
|
| 440 |
+
delay_pattern = self.config.data.delay_pattern
|
| 441 |
+
max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
|
| 442 |
+
delay_tensor = torch.tensor(delay_pattern, dtype=torch.long, device=self.device)
|
| 443 |
+
max_delay_pattern = max(delay_pattern)
|
| 444 |
+
self.model.eval()
|
| 445 |
+
|
| 446 |
+
(
|
| 447 |
+
cond_src_BxS,
|
| 448 |
+
cond_src_positions_BxS,
|
| 449 |
+
cond_src_padding_mask_BxS,
|
| 450 |
+
cond_enc_self_attn_mask_Bx1xSxS,
|
| 451 |
+
) = self._prepare_text_input(text)
|
| 452 |
+
|
| 453 |
+
unc_src_BxS = torch.zeros_like(cond_src_BxS)
|
| 454 |
+
src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0)
|
| 455 |
+
src_positions_BxS = cond_src_positions_BxS.expand(2, -1)
|
| 456 |
+
src_padding_mask_BxS = cond_src_padding_mask_BxS.expand(2, -1)
|
| 457 |
+
enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(2, -1, -1, -1)
|
| 458 |
+
|
| 459 |
+
# 2. Encoder Pass
|
| 460 |
+
# with torch.autocast(device_type="cuda", dtype=forward_dtype):
|
| 461 |
+
encoder_out = self.model.encoder(
|
| 462 |
+
x_ids=src_BxS,
|
| 463 |
+
src_positions=src_positions_BxS,
|
| 464 |
+
deterministic=True,
|
| 465 |
+
attn_mask=enc_self_attn_mask_Bx1xSxS,
|
| 466 |
+
) # Shape: (B, S, E)
|
| 467 |
+
|
| 468 |
+
# 3. Prepare Decoder Inputs
|
| 469 |
+
# 3-1. Allocate KV Cache (Static)
|
| 470 |
+
decoder_cross_attention_cache: list[KVCache] = self.model.decoder.precompute_cross_attention_kv(
|
| 471 |
+
max_tokens, encoder_out, src_positions_BxS
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
decoder_self_attention_cache: list[KVCache] = []
|
| 475 |
+
for _ in range(self.model.decoder.num_layers):
|
| 476 |
+
decoder_self_attention_cache.append(
|
| 477 |
+
KVCache(
|
| 478 |
+
self.config.model.decoder.gqa_query_heads,
|
| 479 |
+
max_tokens,
|
| 480 |
+
self.config.model.decoder.gqa_head_dim,
|
| 481 |
+
self.device,
|
| 482 |
+
)
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# 3-2. Initialize Decoder Inputs
|
| 486 |
+
generated_BxTxC = torch.full(
|
| 487 |
+
(2, 1, num_channels),
|
| 488 |
+
fill_value=audio_bos_value,
|
| 489 |
+
dtype=torch.long,
|
| 490 |
+
device=self.device,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
current_step = 0
|
| 494 |
+
prompt_len_inc_bos = 1 # Start with BOS length
|
| 495 |
+
|
| 496 |
+
# 3-3. Load Audio Prompt (if provided)
|
| 497 |
+
if audio_prompt_path is not None:
|
| 498 |
+
audio_prompt, sr = torchaudio.load(audio_prompt_path, channels_first=True) # C, T
|
| 499 |
+
if sr != 44100: # Resample to 44.1kHz
|
| 500 |
+
audio_prompt = torchaudio.functional.resample(audio_prompt, sr, 44100)
|
| 501 |
+
audio_prompt = audio_prompt.to(self.device).unsqueeze(0) # 1, C, T
|
| 502 |
+
audio_prompt = audio_to_codebook(self.dac_model, audio_prompt, data_config=self.config.data)
|
| 503 |
+
print("✅ Prompt shape:", audio_prompt.shape)
|
| 504 |
+
generated_BxTxC = torch.cat([generated_BxTxC, audio_prompt.expand(2, -1, -1)], dim=1)
|
| 505 |
+
|
| 506 |
+
prefill_len = generated_BxTxC.shape[1]
|
| 507 |
+
prompt_len_inc_bos = prefill_len
|
| 508 |
+
prefill_tgt_pos = torch.arange(prefill_len, device=self.device).unsqueeze(0).expand(2, -1)
|
| 509 |
+
prefill_tgt_padding_mask = (generated_BxTxC != audio_pad_value).any(dim=2)
|
| 510 |
+
|
| 511 |
+
prefill_self_attn_mask = self._create_attn_mask(
|
| 512 |
+
prefill_tgt_padding_mask,
|
| 513 |
+
prefill_tgt_padding_mask,
|
| 514 |
+
is_causal=True,
|
| 515 |
+
)
|
| 516 |
+
prefill_cross_attn_mask = self._create_attn_mask(
|
| 517 |
+
prefill_tgt_padding_mask,
|
| 518 |
+
src_padding_mask_BxS,
|
| 519 |
+
is_causal=False,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
_ = self.model.decoder.forward(
|
| 523 |
+
tgt_ids_BxTxC=generated_BxTxC,
|
| 524 |
+
encoder_out=encoder_out,
|
| 525 |
+
tgt_positions=prefill_tgt_pos,
|
| 526 |
+
src_positions=src_positions_BxS,
|
| 527 |
+
deterministic=True,
|
| 528 |
+
self_attn_mask=prefill_self_attn_mask,
|
| 529 |
+
cross_attn_mask=prefill_cross_attn_mask,
|
| 530 |
+
self_attention_cache=decoder_self_attention_cache,
|
| 531 |
+
cross_attention_cache=decoder_cross_attention_cache,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
current_step = prefill_len - 1
|
| 535 |
+
|
| 536 |
+
# 4. Autoregressive Generation Loop
|
| 537 |
+
eos_detected_channel_0 = False
|
| 538 |
+
eos_countdown = -1
|
| 539 |
+
extra_steps_after_eos = 30
|
| 540 |
+
# Make generated_BxTxC a fixed size tensor
|
| 541 |
+
# Length is either 1 + max tokens or 1 + prompt len + max tokens
|
| 542 |
+
generated_BxTxC = torch.cat(
|
| 543 |
+
[
|
| 544 |
+
generated_BxTxC,
|
| 545 |
+
torch.full(
|
| 546 |
+
(2, max_tokens, num_channels),
|
| 547 |
+
fill_value=-1,
|
| 548 |
+
dtype=torch.long,
|
| 549 |
+
device=self.device,
|
| 550 |
+
),
|
| 551 |
+
],
|
| 552 |
+
dim=1,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
decode_step = self.model.decoder.decode_step
|
| 556 |
+
if use_torch_compile:
|
| 557 |
+
decode_step = torch.compile(
|
| 558 |
+
self.model.decoder.decode_step,
|
| 559 |
+
mode="default",
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
tgt_padding_mask = (
|
| 563 |
+
(generated_BxTxC[:, -1, :].unsqueeze(1) != audio_pad_value).any(dim=2).to(self.device)
|
| 564 |
+
) # [B, 1]
|
| 565 |
+
# Generated tokens are never PAD, so we use fixed mask
|
| 566 |
+
decoder_cross_attn_mask = self._create_attn_mask(
|
| 567 |
+
tgt_padding_mask, # Query mask [B, 1]
|
| 568 |
+
src_padding_mask_BxS, # Key mask [B, S]
|
| 569 |
+
is_causal=False,
|
| 570 |
+
) # [B, 1, 1, S]
|
| 571 |
+
|
| 572 |
+
for step in range(current_step, current_step + max_tokens):
|
| 573 |
+
tgt_ids_Bx1xC = generated_BxTxC[:, step, :].unsqueeze(1)
|
| 574 |
+
tgt_pos_Bx1 = torch.full(
|
| 575 |
+
(2, 1),
|
| 576 |
+
fill_value=step,
|
| 577 |
+
dtype=torch.long,
|
| 578 |
+
device=self.device,
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
logits_Bx1xCxV, new_cache = decode_step(
|
| 582 |
+
tgt_ids_Bx1xC=tgt_ids_Bx1xC,
|
| 583 |
+
tgt_pos_Bx1=tgt_pos_Bx1,
|
| 584 |
+
encoder_out=encoder_out,
|
| 585 |
+
self_attn_mask=None,
|
| 586 |
+
cross_attn_mask=decoder_cross_attn_mask,
|
| 587 |
+
self_attention_cache=decoder_self_attention_cache,
|
| 588 |
+
cross_attention_cache=decoder_cross_attention_cache,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
for i, layer_cache in enumerate(decoder_self_attention_cache):
|
| 592 |
+
layer_cache.update_cache(new_cache[i][0], new_cache[i][1])
|
| 593 |
+
|
| 594 |
+
V = self.config.model.tgt_vocab_size
|
| 595 |
+
logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :] # B, C, V
|
| 596 |
+
uncond_logits_CxV = logits_last_BxCxV[0, :, :]
|
| 597 |
+
cond_logits_CxV = logits_last_BxCxV[1, :, :]
|
| 598 |
+
|
| 599 |
+
cfg_logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
|
| 600 |
+
|
| 601 |
+
logits_CxV = cfg_logits_CxV.reshape((-1, V)) # C, V
|
| 602 |
+
logits_CxV[:, 1025:] = -torch.inf
|
| 603 |
+
|
| 604 |
+
# Sample next token
|
| 605 |
+
pred_C = _sample_next_token(
|
| 606 |
+
logits_CxV.float(),
|
| 607 |
+
temperature=temperature,
|
| 608 |
+
top_p=top_p,
|
| 609 |
+
use_cfg_filter=use_cfg_filter,
|
| 610 |
+
cfg_filter_top_k=cfg_filter_top_k,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
generation_step_index = step - current_step
|
| 614 |
+
if audio_prompt_path is None:
|
| 615 |
+
pred_C = torch.where(
|
| 616 |
+
generation_step_index >= delay_tensor,
|
| 617 |
+
pred_C,
|
| 618 |
+
audio_bos_value,
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
generated_BxTxC[:, step + 1, :] = pred_C.unsqueeze(0).expand(2, -1)
|
| 622 |
+
|
| 623 |
+
if not eos_detected_channel_0 and pred_C[0] == audio_eos_value:
|
| 624 |
+
eos_detected_channel_0 = True
|
| 625 |
+
eos_countdown = extra_steps_after_eos
|
| 626 |
+
|
| 627 |
+
if eos_countdown > 0:
|
| 628 |
+
step_after_eos = max_delay_pattern - eos_countdown
|
| 629 |
+
for i, d in enumerate(delay_pattern):
|
| 630 |
+
if step_after_eos == d:
|
| 631 |
+
generated_BxTxC[:, step + 1, i] = audio_eos_value
|
| 632 |
+
elif step_after_eos > d:
|
| 633 |
+
generated_BxTxC[:, step + 1, i] = audio_pad_value
|
| 634 |
+
eos_countdown -= 1
|
| 635 |
+
if eos_countdown == 0:
|
| 636 |
+
break
|
| 637 |
+
|
| 638 |
+
generation_step_index = step - current_step + 1
|
| 639 |
+
|
| 640 |
+
output_codes = generated_BxTxC[:, prompt_len_inc_bos : step + 1, :]
|
| 641 |
+
|
| 642 |
+
generated_codes = output_codes[0]
|
| 643 |
+
|
| 644 |
+
audio = codebook_to_audio(
|
| 645 |
+
generated_codes.transpose(1, 0), self.dac_model, delay_pattern, B=1, T=max_tokens, C=num_channels
|
| 646 |
+
)
|
| 647 |
+
print("🟩 Tổng số tokens sinh ra:", generated_codes.shape[0])
|
| 648 |
+
return audio.squeeze().cpu().numpy()
|