UDface11jkj commited on
Commit
0107cd3
·
verified ·
1 Parent(s): 3fc4f90

post at 11.29

Browse files
Files changed (12) hide show
  1. .dockerignore +17 -0
  2. README.md +10 -10
  3. app.py +67 -0
  4. dia/__init__.py +0 -0
  5. dia/audio.py +203 -0
  6. dia/config.py +197 -0
  7. dia/layers.py +642 -0
  8. dia/model.py +488 -0
  9. dia/state.py +234 -0
  10. dockerfile +30 -0
  11. requirements.txt +9 -0
  12. templates/index.html +104 -0
.dockerignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ venv/
3
+
4
+
5
+
6
+ __pycache__/
7
+ *.pyc
8
+ *.pyo
9
+ *.pyd
10
+
11
+ # Ignore git
12
+ .git/
13
+ .gitignore
14
+
15
+ # Ignore IDE settings
16
+ .vscode/
17
+ .idea/
README.md CHANGED
@@ -1,10 +1,10 @@
1
- ---
2
- title: 'Text '
3
- emoji: 🏆
4
- colorFrom: yellow
5
- colorTo: indigo
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: 'Text '
3
+ emoji: 🏆
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Form, Request, HTTPException
2
+ from fastapi.responses import HTMLResponse, FileResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.templating import Jinja2Templates
5
+ import soundfile as sf
6
+ from dia.model import Dia
7
+ import os
8
+ import uuid
9
+ import torch
10
+
11
+ app = FastAPI()
12
+ app.mount("/static", StaticFiles(directory="static"), name="static")
13
+ templates = Jinja2Templates(directory="templates")
14
+
15
+
16
+ try:
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ print(f"Using device: {device}")
19
+
20
+
21
+ os.makedirs("static/audio", exist_ok=True)
22
+
23
+ model = Dia.from_pretrained(
24
+ "nari-labs/Dia-1.6B",
25
+ compute_dtype="float16",
26
+ device=device,
27
+ use_torch_compile=True,
28
+ low_cpu_mem_usage=True,
29
+ )
30
+
31
+ if device == "cpu":
32
+ model = model.eval()
33
+ torch.set_num_threads(4)
34
+
35
+ print("Model loaded successfully with optimizations")
36
+ except Exception as e:
37
+ print(f"Error loading Dia model: {str(e)}")
38
+ raise
39
+
40
+ @app.get('/')
41
+ async def index(request: Request):
42
+ return templates.TemplateResponse("index.html", {'request': request})
43
+
44
+ @app.post("/convertor")
45
+ async def process(request: Request, paragraph: str = Form(...), action: str = Form(...)):
46
+ try:
47
+ if not paragraph:
48
+ raise HTTPException(status_code=400, detail="Text is required")
49
+
50
+ if action == "audio":
51
+
52
+ output = model.generate(paragraph)
53
+ filename = f"audio_{uuid.uuid4()}.wav"
54
+ filepath = f"static/audio/{filename}"
55
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
56
+
57
+ sf.write(filepath, output, 44100)
58
+ return FileResponse(filepath, media_type="audio/wav", filename=filename)
59
+
60
+ elif action == "summarize":
61
+
62
+ raise HTTPException(status_code=400, detail="Summarization not implemented")
63
+
64
+ return HTTPException(status_code=400, detail="Invalid action")
65
+
66
+ except Exception as e:
67
+ raise HTTPException(status_code=500, detail=str(e))
dia/__init__.py ADDED
File without changes
dia/audio.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+
6
+ def build_delay_indices(
7
+ B: int, T: int, C: int, delay_pattern: tp.List[int]
8
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
9
+ """
10
+ Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
11
+ Negative t_idx => BOS; t_idx >= T => PAD.
12
+ """
13
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
14
+
15
+ t_idx_BxT = torch.broadcast_to(
16
+ torch.arange(T, dtype=torch.int32)[None, :],
17
+ [B, T],
18
+ )
19
+ t_idx_BxTx1 = t_idx_BxT[..., None]
20
+ t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
21
+
22
+ b_idx_BxTxC = torch.broadcast_to(
23
+ torch.arange(B, dtype=torch.int32).view(B, 1, 1),
24
+ [B, T, C],
25
+ )
26
+ c_idx_BxTxC = torch.broadcast_to(
27
+ torch.arange(C, dtype=torch.int32).view(1, 1, C),
28
+ [B, T, C],
29
+ )
30
+
31
+ # We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
32
+ t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
33
+
34
+ indices_BTCx3 = torch.stack(
35
+ [
36
+ b_idx_BxTxC.reshape(-1),
37
+ t_clamped_BxTxC.reshape(-1),
38
+ c_idx_BxTxC.reshape(-1),
39
+ ],
40
+ dim=1,
41
+ ).long() # Ensure indices are long type for indexing
42
+
43
+ return t_idx_BxTxC, indices_BTCx3
44
+
45
+
46
+ def apply_audio_delay(
47
+ audio_BxTxC: torch.Tensor,
48
+ pad_value: int,
49
+ bos_value: int,
50
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
51
+ ) -> torch.Tensor:
52
+ """
53
+ Applies the delay pattern to batched audio tokens using precomputed indices,
54
+ inserting BOS where t_idx < 0 and PAD where t_idx >= T.
55
+
56
+ Args:
57
+ audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
58
+ pad_value: the padding token
59
+ bos_value: the BOS token
60
+ precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
61
+
62
+ Returns:
63
+ result_BxTxC: [B, T, C] delayed audio tokens
64
+ """
65
+ device = audio_BxTxC.device # Get device from input tensor
66
+ t_idx_BxTxC, indices_BTCx3 = precomp
67
+ t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device
68
+ indices_BTCx3 = indices_BTCx3.to(device)
69
+
70
+ # Equivalent of tf.gather_nd using advanced indexing
71
+ # Ensure indices are long type if not already (build_delay_indices should handle this)
72
+ gathered_flat = audio_BxTxC[
73
+ indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]
74
+ ]
75
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
76
+
77
+ # Create masks on the correct device
78
+ mask_bos = t_idx_BxTxC < 0 # => place bos_value
79
+ mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
80
+
81
+ # Create scalar tensors on the correct device
82
+ bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
83
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
84
+
85
+ # If mask_bos, BOS; else if mask_pad, PAD; else original gather
86
+ # All tensors should now be on the same device
87
+ result_BxTxC = torch.where(
88
+ mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC)
89
+ )
90
+
91
+ return result_BxTxC
92
+
93
+
94
+ def build_revert_indices(
95
+ B: int, T: int, C: int, delay_pattern: tp.List[int]
96
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
97
+ """
98
+ Precompute indices for the revert operation using PyTorch.
99
+
100
+ Returns:
101
+ A tuple (t_idx_BxTxC, indices_BTCx3) where:
102
+ - t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
103
+ - indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
104
+ batch indices, clamped time indices, and channel indices.
105
+ """
106
+ # Use default device unless specified otherwise; assumes inputs might define device later
107
+ device = None # Or determine dynamically if needed, e.g., from a model parameter
108
+
109
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
110
+
111
+ t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
112
+ t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
113
+
114
+ t_idx_BxTxC = torch.minimum(
115
+ t_idx_BT1 + delay_arr.view(1, 1, C),
116
+ torch.tensor(T - 1, device=device),
117
+ )
118
+ b_idx_BxTxC = torch.broadcast_to(
119
+ torch.arange(B, device=device).view(B, 1, 1), [B, T, C]
120
+ )
121
+ c_idx_BxTxC = torch.broadcast_to(
122
+ torch.arange(C, device=device).view(1, 1, C), [B, T, C]
123
+ )
124
+
125
+ indices_BTCx3 = torch.stack(
126
+ [
127
+ b_idx_BxTxC.reshape(-1),
128
+ t_idx_BxTxC.reshape(-1),
129
+ c_idx_BxTxC.reshape(-1),
130
+ ],
131
+ axis=1,
132
+ ).long() # Ensure indices are long type
133
+
134
+ return t_idx_BxTxC, indices_BTCx3
135
+
136
+
137
+ def revert_audio_delay(
138
+ audio_BxTxC: torch.Tensor,
139
+ pad_value: int,
140
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
141
+ T: int,
142
+ ) -> torch.Tensor:
143
+ """
144
+ Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version).
145
+
146
+ Args:
147
+ audio_BxTxC: Input delayed audio tensor
148
+ pad_value: Padding value for out-of-bounds indices
149
+ precomp: Precomputed revert indices tuple containing:
150
+ - t_idx_BxTxC: Time offset indices tensor
151
+ - indices_BTCx3: Gather indices tensor for original audio
152
+ T: Original sequence length before padding
153
+
154
+ Returns:
155
+ Reverted audio tensor with same shape as input
156
+ """
157
+ t_idx_BxTxC, indices_BTCx3 = precomp
158
+ device = audio_BxTxC.device # Get device from input tensor
159
+
160
+ # Move precomputed indices to the same device as audio_BxTxC if they aren't already
161
+ t_idx_BxTxC = t_idx_BxTxC.to(device)
162
+ indices_BTCx3 = indices_BTCx3.to(device)
163
+
164
+ # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
165
+ gathered_flat = audio_BxTxC[
166
+ indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]
167
+ ]
168
+ gathered_BxTxC = gathered_flat.view(
169
+ audio_BxTxC.size()
170
+ ) # Use .size() for robust reshaping
171
+
172
+ # Create pad_tensor on the correct device
173
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
174
+ # Create T tensor on the correct device for comparison
175
+ T_tensor = torch.tensor(T, device=device)
176
+
177
+ result_BxTxC = torch.where(
178
+ t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC
179
+ ) # Changed np.where to torch.where
180
+
181
+ return result_BxTxC
182
+
183
+
184
+ @torch.no_grad()
185
+ @torch.inference_mode()
186
+ def decode(
187
+ model,
188
+ audio_codes,
189
+ ):
190
+ """
191
+ Decodes the given frames into an output audio waveform
192
+ """
193
+ if len(audio_codes) != 1:
194
+ raise ValueError(f"Expected one frame, got {len(audio_codes)}")
195
+
196
+ try:
197
+ audio_values = model.quantizer.from_codes(audio_codes)
198
+ audio_values = model.decode(audio_values[0])
199
+
200
+ return audio_values
201
+ except Exception as e:
202
+ print(f"Error in decode method: {str(e)}")
203
+ raise
dia/config.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management module for the Dia model.
2
+
3
+ This module provides comprehensive configuration management for the Dia model,
4
+ utilizing Pydantic for validation. It defines configurations for data processing,
5
+ model architecture (encoder and decoder), and training settings.
6
+
7
+ Key components:
8
+ - DataConfig: Parameters for data loading and preprocessing.
9
+ - EncoderConfig: Architecture details for the encoder module.
10
+ - DecoderConfig: Architecture details for the decoder module.
11
+ - ModelConfig: Combined model architecture settings.
12
+ - TrainingConfig: Training hyperparameters and settings.
13
+ - DiaConfig: Master configuration combining all components.
14
+ """
15
+
16
+ import os
17
+ from typing import Annotated
18
+
19
+ from pydantic import BaseModel, BeforeValidator, Field
20
+
21
+
22
+ class DataConfig(BaseModel, frozen=True):
23
+ """Configuration for data loading and preprocessing.
24
+
25
+ Attributes:
26
+ text_length: Maximum length of text sequences (must be multiple of 128).
27
+ audio_length: Maximum length of audio sequences (must be multiple of 128).
28
+ channels: Number of audio channels.
29
+ text_pad_value: Value used for padding text sequences.
30
+ audio_eos_value: Value representing the end of audio sequences.
31
+ audio_bos_value: Value representing the beginning of audio sequences.
32
+ audio_pad_value: Value used for padding audio sequences.
33
+ delay_pattern: List of delay values for each audio channel.
34
+ """
35
+
36
+ text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
37
+ Field(gt=0, multiple_of=128)
38
+ )
39
+ audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
40
+ Field(gt=0, multiple_of=128)
41
+ )
42
+ channels: int = Field(default=9, gt=0, multiple_of=1)
43
+ text_pad_value: int = Field(default=0)
44
+ audio_eos_value: int = Field(default=1024)
45
+ audio_pad_value: int = Field(default=1025)
46
+ audio_bos_value: int = Field(default=1026)
47
+ delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(
48
+ default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]
49
+ )
50
+
51
+ def __hash__(self) -> int:
52
+ """Generate a hash based on all fields of the config."""
53
+ return hash(
54
+ (
55
+ self.text_length,
56
+ self.audio_length,
57
+ self.channels,
58
+ self.text_pad_value,
59
+ self.audio_pad_value,
60
+ self.audio_bos_value,
61
+ self.audio_eos_value,
62
+ tuple(self.delay_pattern),
63
+ )
64
+ )
65
+
66
+
67
+ class EncoderConfig(BaseModel, frozen=True):
68
+ """Configuration for the encoder component of the Dia model.
69
+
70
+ Attributes:
71
+ n_layer: Number of transformer layers.
72
+ n_embd: Embedding dimension.
73
+ n_hidden: Hidden dimension size in the MLP layers.
74
+ n_head: Number of attention heads.
75
+ head_dim: Dimension per attention head.
76
+ """
77
+
78
+ n_layer: int = Field(gt=0)
79
+ n_embd: int = Field(gt=0)
80
+ n_hidden: int = Field(gt=0)
81
+ n_head: int = Field(gt=0)
82
+ head_dim: int = Field(gt=0)
83
+
84
+
85
+ class DecoderConfig(BaseModel, frozen=True):
86
+ """Configuration for the decoder component of the Dia model.
87
+
88
+ Attributes:
89
+ n_layer: Number of transformer layers.
90
+ n_embd: Embedding dimension.
91
+ n_hidden: Hidden dimension size in the MLP layers.
92
+ gqa_query_heads: Number of query heads for grouped-query self-attention.
93
+ kv_heads: Number of key/value heads for grouped-query self-attention.
94
+ gqa_head_dim: Dimension per query head for grouped-query self-attention.
95
+ cross_query_heads: Number of query heads for cross-attention.
96
+ cross_head_dim: Dimension per cross-attention head.
97
+ """
98
+
99
+ n_layer: int = Field(gt=0)
100
+ n_embd: int = Field(gt=0)
101
+ n_hidden: int = Field(gt=0)
102
+ gqa_query_heads: int = Field(gt=0)
103
+ kv_heads: int = Field(gt=0)
104
+ gqa_head_dim: int = Field(gt=0)
105
+ cross_query_heads: int = Field(gt=0)
106
+ cross_head_dim: int = Field(gt=0)
107
+
108
+
109
+ class ModelConfig(BaseModel, frozen=True):
110
+ """Main configuration container for the Dia model architecture.
111
+
112
+ Attributes:
113
+ encoder: Configuration for the encoder component.
114
+ decoder: Configuration for the decoder component.
115
+ src_vocab_size: Size of the source (text) vocabulary.
116
+ tgt_vocab_size: Size of the target (audio code) vocabulary.
117
+ dropout: Dropout probability applied within the model.
118
+ normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
119
+ weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
120
+ rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
121
+ rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
122
+ """
123
+
124
+ encoder: EncoderConfig
125
+ decoder: DecoderConfig
126
+ src_vocab_size: int = Field(default=128, gt=0)
127
+ tgt_vocab_size: int = Field(default=1028, gt=0)
128
+ dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
129
+ normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
130
+ weight_dtype: str = Field(default="float32", description="Weight precision")
131
+ rope_min_timescale: int = Field(
132
+ default=1, description="Timescale For global Attention"
133
+ )
134
+ rope_max_timescale: int = Field(
135
+ default=10_000, description="Timescale For global Attention"
136
+ )
137
+
138
+
139
+ class TrainingConfig(BaseModel, frozen=True):
140
+ pass
141
+
142
+
143
+ class DiaConfig(BaseModel, frozen=True):
144
+ """Master configuration for the Dia model.
145
+
146
+ Combines all sub-configurations into a single validated object.
147
+
148
+ Attributes:
149
+ version: Configuration version string.
150
+ model: Model architecture configuration.
151
+ training: Training process configuration (precision settings).
152
+ data: Data loading and processing configuration.
153
+ """
154
+
155
+ version: str = Field(default="1.0")
156
+ model: ModelConfig
157
+ # TODO: remove training. this is just for backwards-compatability
158
+ training: TrainingConfig
159
+ data: DataConfig
160
+
161
+ def save(self, path: str) -> None:
162
+ """Save the current configuration instance to a JSON file.
163
+
164
+ Ensures the parent directory exists and the file has a .json extension.
165
+
166
+ Args:
167
+ path: The target file path to save the configuration.
168
+
169
+ Raises:
170
+ ValueError: If the path is not a file with a .json extension.
171
+ """
172
+ os.makedirs(os.path.dirname(path), exist_ok=True)
173
+ config_json = self.model_dump_json(indent=2)
174
+ with open(path, "w") as f:
175
+ f.write(config_json)
176
+
177
+ @classmethod
178
+ def load(cls, path: str) -> "DiaConfig | None":
179
+ """Load and validate a Dia configuration from a JSON file.
180
+
181
+ Args:
182
+ path: The path to the configuration file.
183
+
184
+ Returns:
185
+ A validated DiaConfig instance if the file exists and is valid,
186
+ otherwise None if the file is not found.
187
+
188
+ Raises:
189
+ ValueError: If the path does not point to an existing .json file.
190
+ pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
191
+ """
192
+ try:
193
+ with open(path, "r") as f:
194
+ content = f.read()
195
+ return cls.model_validate_json(content)
196
+ except FileNotFoundError:
197
+ return None
dia/layers.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+ from torch.nn import RMSNorm
6
+
7
+ from .config import DiaConfig
8
+ from .state import DecoderInferenceState, EncoderInferenceState, KVCache
9
+
10
+
11
+ def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
12
+ return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
13
+
14
+
15
+ class DenseGeneral(nn.Module):
16
+ """
17
+ PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
18
+
19
+ Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
20
+ for the generalized matrix multiplication. Weight/bias shapes are calculated
21
+ and parameters created during initialization based on config.
22
+ `load_weights` validates shapes and copies data.
23
+
24
+ Attributes:
25
+ axis (Tuple[int, ...]): Input axis or axes to contract.
26
+ in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
27
+ out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
28
+ use_bias (bool): Whether to add a bias term.
29
+ weight (nn.Parameter): The kernel parameter.
30
+ bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ in_shapes: tuple[int, ...],
36
+ out_features: tuple[int, ...],
37
+ axis: tuple[int, ...] = (-1,),
38
+ weight_dtype: torch.dtype | None = None,
39
+ device: torch.device | None = None,
40
+ ):
41
+ super().__init__()
42
+ self.in_shapes = in_shapes
43
+ self.out_features = out_features
44
+ self.axis = axis
45
+ self.kernel_shape = self.in_shapes + self.out_features
46
+
47
+ factory_kwargs = {"device": device, "dtype": weight_dtype}
48
+ self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
49
+ self.register_parameter("bias", None)
50
+
51
+ def forward(self, inputs: Tensor) -> Tensor:
52
+ norm_axis = _normalize_axes(self.axis, inputs.ndim)
53
+ kernel_contract_axes = tuple(range(len(norm_axis)))
54
+
55
+ output = torch.tensordot(
56
+ inputs.to(self.weight.dtype),
57
+ self.weight,
58
+ dims=(norm_axis, kernel_contract_axes),
59
+ ).to(inputs.dtype)
60
+ return output
61
+
62
+
63
+ class MlpBlock(nn.Module):
64
+ """MLP block using DenseGeneral."""
65
+
66
+ def __init__(
67
+ self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype
68
+ ):
69
+ super().__init__()
70
+ self.dtype = compute_dtype
71
+
72
+ self.wi_fused = DenseGeneral(
73
+ in_shapes=(embed_dim,),
74
+ out_features=(2, intermediate_dim),
75
+ axis=(-1,),
76
+ weight_dtype=compute_dtype,
77
+ )
78
+
79
+ self.wo = DenseGeneral(
80
+ in_shapes=(intermediate_dim,),
81
+ out_features=(embed_dim,),
82
+ axis=(-1,),
83
+ weight_dtype=compute_dtype,
84
+ )
85
+
86
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
87
+ """Forward pass."""
88
+ fused_x = self.wi_fused(x)
89
+
90
+ gate = fused_x[..., 0, :]
91
+ up = fused_x[..., 1, :]
92
+
93
+ hidden = torch.mul(F.silu(gate), up).to(self.dtype)
94
+
95
+ output = self.wo(hidden)
96
+ return output
97
+
98
+
99
+ class RotaryEmbedding(nn.Module):
100
+ """Rotary Position Embedding (RoPE) implementation in PyTorch."""
101
+
102
+ def __init__(
103
+ self,
104
+ embedding_dims: int,
105
+ min_timescale: int = 1,
106
+ max_timescale: int = 10000,
107
+ dtype: torch.dtype = torch.float32,
108
+ ):
109
+ super().__init__()
110
+ if embedding_dims % 2 != 0:
111
+ raise ValueError("Embedding dim must be even for RoPE.")
112
+ self.embedding_dims = embedding_dims
113
+ self.min_timescale = min_timescale
114
+ self.max_timescale = max_timescale
115
+ self.dtype = dtype
116
+
117
+ half_embedding_dim = embedding_dims // 2
118
+ fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
119
+ self.register_buffer(
120
+ "timescale",
121
+ self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction,
122
+ persistent=False,
123
+ )
124
+
125
+ def extra_repr(self) -> str:
126
+ s = f"{self.timescale.shape}"
127
+ return s
128
+
129
+ def forward(self, inputs: torch.Tensor, position: torch.Tensor):
130
+ """Applies RoPE."""
131
+ position = position.unsqueeze(-1).unsqueeze(-1)
132
+ timescale = self.timescale.to(inputs.device)
133
+ sinusoid_inp = position / timescale
134
+ sin = torch.sin(sinusoid_inp).to(inputs.dtype)
135
+ cos = torch.cos(sinusoid_inp).to(inputs.dtype)
136
+ first_half, second_half = torch.chunk(inputs, 2, dim=-1)
137
+ first_part = first_half * cos - second_half * sin
138
+ second_part = second_half * cos + first_half * sin
139
+ return torch.cat((first_part, second_part), dim=-1)
140
+
141
+
142
+ class Attention(nn.Module):
143
+ """Attention using DenseGeneral."""
144
+
145
+ def __init__(
146
+ self,
147
+ config: DiaConfig,
148
+ q_embed_dim: int,
149
+ kv_embed_dim: int,
150
+ num_query_heads: int,
151
+ num_kv_heads: int,
152
+ head_dim: int,
153
+ compute_dtype: torch.dtype,
154
+ is_cross_attn: bool = False,
155
+ out_embed_dim: int | None = None,
156
+ ):
157
+ super().__init__()
158
+ self.num_query_heads = num_query_heads
159
+ self.num_kv_heads = num_kv_heads
160
+ self.head_dim = head_dim
161
+ self.is_cross_attn = is_cross_attn
162
+ self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
163
+ self.projected_query_dim = num_query_heads * head_dim
164
+ if num_query_heads % num_kv_heads != 0:
165
+ raise ValueError(
166
+ f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
167
+ )
168
+ self.num_gqa_groups = num_query_heads // num_kv_heads
169
+
170
+ # --- Projection Layers using DenseGeneral ---
171
+ self.q_proj = DenseGeneral(
172
+ in_shapes=(q_embed_dim,),
173
+ out_features=(num_query_heads, head_dim),
174
+ axis=(-1,),
175
+ weight_dtype=compute_dtype,
176
+ )
177
+ self.k_proj = DenseGeneral(
178
+ in_shapes=(kv_embed_dim,),
179
+ out_features=(num_kv_heads, head_dim),
180
+ axis=(-1,),
181
+ weight_dtype=compute_dtype,
182
+ )
183
+ self.v_proj = DenseGeneral(
184
+ in_shapes=(kv_embed_dim,),
185
+ out_features=(num_kv_heads, head_dim),
186
+ axis=(-1,),
187
+ weight_dtype=compute_dtype,
188
+ )
189
+ self.o_proj = DenseGeneral(
190
+ in_shapes=(num_query_heads, head_dim),
191
+ out_features=(self.output_dim,),
192
+ axis=(-2, -1),
193
+ weight_dtype=compute_dtype,
194
+ )
195
+
196
+ # --- Rotary Embedding ---
197
+ self.rotary_emb = RotaryEmbedding(
198
+ embedding_dims=self.head_dim,
199
+ min_timescale=config.model.rope_min_timescale,
200
+ max_timescale=config.model.rope_max_timescale,
201
+ dtype=compute_dtype,
202
+ )
203
+
204
+ def forward(
205
+ self,
206
+ Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
207
+ Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
208
+ q_positions: torch.Tensor, # (B, T)
209
+ kv_positions: torch.Tensor | None = None, # (B, S)
210
+ attn_mask: torch.Tensor
211
+ | None = None, # None in Decoder Self Attention, Valid mask in Others
212
+ cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
213
+ prefill: bool = False,
214
+ is_causal: bool = False,
215
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
216
+ """
217
+ Performs attention calculation with optional KV caching.
218
+
219
+ Args:
220
+ Xq: Query tensor (B, T, D). T=1 during single-step decoding.
221
+ Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
222
+ q_positions: Positions for queries (B, T).
223
+ kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
224
+ attn_mask: Attention mask.
225
+ cache: KVCache.
226
+ prefill: If True, use prefill mode.
227
+
228
+ Returns:
229
+ A tuple containing:
230
+ - output: The attention output tensor (B, T, output_dim).
231
+ - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
232
+ """
233
+ if kv_positions is None:
234
+ kv_positions = q_positions
235
+ original_dtype = Xq.dtype
236
+
237
+ Xq_BxTxNxH = self.q_proj(Xq)
238
+ Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
239
+ Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
240
+
241
+ attn_k: torch.Tensor | None = None
242
+ attn_v: torch.Tensor | None = None
243
+
244
+ if self.is_cross_attn:
245
+ attn_k, attn_v = cache.k, cache.v
246
+ else:
247
+ Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
248
+ Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
249
+ Xk_BxSxKxH = self.rotary_emb(
250
+ Xk_BxSxKxH, position=kv_positions
251
+ ) # (B, S, K, H)
252
+
253
+ Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
254
+ Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
255
+
256
+ if cache is None:
257
+ attn_k = Xk_BxKxSxH
258
+ attn_v = Xv_BxKxSxH
259
+ else:
260
+ if prefill:
261
+ attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
262
+ cache.prefill(attn_k, attn_v)
263
+ else:
264
+ attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH)
265
+
266
+ attn_output = F.scaled_dot_product_attention(
267
+ Xq_BxNxTxH,
268
+ attn_k,
269
+ attn_v,
270
+ attn_mask=attn_mask,
271
+ scale=1.0,
272
+ enable_gqa=self.num_gqa_groups > 1,
273
+ is_causal=is_causal,
274
+ )
275
+
276
+ attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
277
+ output = self.o_proj(attn_output)
278
+
279
+ return output.to(original_dtype)
280
+
281
+
282
+ class EncoderLayer(nn.Module):
283
+ """Transformer Encoder Layer using DenseGeneral."""
284
+
285
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
286
+ super().__init__()
287
+ self.config = config
288
+ model_config = config.model
289
+ enc_config = config.model.encoder
290
+ embed_dim = enc_config.n_embd
291
+
292
+ self.pre_sa_norm = RMSNorm(
293
+ embed_dim,
294
+ eps=model_config.normalization_layer_epsilon,
295
+ dtype=torch.float32,
296
+ )
297
+ self.self_attention = Attention(
298
+ config,
299
+ q_embed_dim=embed_dim,
300
+ kv_embed_dim=embed_dim,
301
+ num_query_heads=enc_config.n_head,
302
+ num_kv_heads=enc_config.n_head,
303
+ head_dim=enc_config.head_dim,
304
+ compute_dtype=compute_dtype,
305
+ is_cross_attn=False,
306
+ out_embed_dim=embed_dim,
307
+ )
308
+ self.post_sa_norm = RMSNorm(
309
+ embed_dim,
310
+ eps=model_config.normalization_layer_epsilon,
311
+ dtype=torch.float32,
312
+ )
313
+ self.mlp = MlpBlock(
314
+ embed_dim=embed_dim,
315
+ intermediate_dim=enc_config.n_hidden,
316
+ compute_dtype=compute_dtype,
317
+ )
318
+
319
+ def forward(
320
+ self,
321
+ x: torch.Tensor,
322
+ state: EncoderInferenceState,
323
+ ) -> torch.Tensor:
324
+ residual = x
325
+ x_norm = self.pre_sa_norm(x)
326
+ sa_out = self.self_attention(
327
+ Xq=x_norm,
328
+ Xkv=x_norm,
329
+ q_positions=state.positions,
330
+ kv_positions=state.positions,
331
+ attn_mask=state.attn_mask,
332
+ )
333
+ x = residual + sa_out
334
+
335
+ residual = x
336
+ x_norm = self.post_sa_norm(x)
337
+ mlp_out = self.mlp(x_norm)
338
+ x = residual + mlp_out
339
+
340
+ return x
341
+
342
+
343
+ class Encoder(nn.Module):
344
+ """Transformer Encoder Stack using DenseGeneral."""
345
+
346
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
347
+ super().__init__()
348
+ self.config = config
349
+ model_config = config.model
350
+ enc_config = config.model.encoder
351
+
352
+ self.embedding = nn.Embedding(
353
+ model_config.src_vocab_size,
354
+ enc_config.n_embd,
355
+ dtype=compute_dtype,
356
+ )
357
+ self.layers = nn.ModuleList(
358
+ [EncoderLayer(config, compute_dtype) for _ in range(enc_config.n_layer)]
359
+ )
360
+ self.norm = RMSNorm(
361
+ enc_config.n_embd,
362
+ eps=model_config.normalization_layer_epsilon,
363
+ dtype=torch.float32,
364
+ )
365
+
366
+ def forward(
367
+ self,
368
+ x_ids: torch.Tensor,
369
+ state: EncoderInferenceState,
370
+ ) -> torch.Tensor:
371
+ x = self.embedding(x_ids)
372
+
373
+ for layer in self.layers:
374
+ x = layer(x, state)
375
+
376
+ x = self.norm(x)
377
+ return x
378
+
379
+
380
+ class DecoderLayer(nn.Module):
381
+ """Transformer Decoder Layer using DenseGeneral."""
382
+
383
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
384
+ super().__init__()
385
+ self.config = config
386
+ model_config = config.model
387
+ dec_config = config.model.decoder
388
+ enc_config = config.model.encoder
389
+ dec_embed_dim = dec_config.n_embd
390
+ enc_embed_dim = enc_config.n_embd
391
+
392
+ # Norms
393
+ self.pre_sa_norm = RMSNorm(
394
+ dec_embed_dim,
395
+ eps=model_config.normalization_layer_epsilon,
396
+ dtype=torch.float32,
397
+ )
398
+ self.pre_ca_norm = RMSNorm(
399
+ dec_embed_dim,
400
+ eps=model_config.normalization_layer_epsilon,
401
+ dtype=torch.float32,
402
+ )
403
+ self.pre_mlp_norm = RMSNorm(
404
+ dec_embed_dim,
405
+ eps=model_config.normalization_layer_epsilon,
406
+ dtype=torch.float32,
407
+ )
408
+
409
+ # Self-Attention (GQA) with Causal Masking
410
+ self.self_attention = Attention(
411
+ config,
412
+ q_embed_dim=dec_embed_dim,
413
+ kv_embed_dim=dec_embed_dim,
414
+ num_query_heads=dec_config.gqa_query_heads,
415
+ num_kv_heads=dec_config.kv_heads,
416
+ head_dim=dec_config.gqa_head_dim,
417
+ compute_dtype=compute_dtype,
418
+ is_cross_attn=False,
419
+ out_embed_dim=dec_embed_dim,
420
+ )
421
+ # Cross-Attention (MHA)
422
+ self.cross_attention = Attention(
423
+ config=config,
424
+ q_embed_dim=dec_embed_dim,
425
+ kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
426
+ num_query_heads=dec_config.cross_query_heads,
427
+ num_kv_heads=dec_config.cross_query_heads,
428
+ head_dim=dec_config.cross_head_dim,
429
+ compute_dtype=compute_dtype,
430
+ is_cross_attn=True,
431
+ out_embed_dim=dec_embed_dim,
432
+ )
433
+ # MLP
434
+ self.mlp = MlpBlock(
435
+ embed_dim=dec_embed_dim,
436
+ intermediate_dim=dec_config.n_hidden,
437
+ compute_dtype=compute_dtype,
438
+ )
439
+
440
+ def forward(
441
+ self,
442
+ x: torch.Tensor,
443
+ state: DecoderInferenceState,
444
+ self_attn_cache: KVCache | None = None,
445
+ cross_attn_cache: KVCache | None = None,
446
+ prefill: bool = False,
447
+ ) -> torch.Tensor:
448
+ residual = x
449
+ x_norm = self.pre_sa_norm(x)
450
+
451
+ sa_out = self.self_attention(
452
+ Xq=x_norm, # (2, 1, D)
453
+ Xkv=x_norm, # (2, 1, D)
454
+ q_positions=state.dec_positions, # (2, 1)
455
+ kv_positions=state.dec_positions, # (2, 1)
456
+ attn_mask=None,
457
+ cache=self_attn_cache,
458
+ prefill=prefill,
459
+ is_causal=prefill,
460
+ )
461
+
462
+ x = residual + sa_out
463
+
464
+ residual = x
465
+ x_norm = self.pre_ca_norm(x)
466
+ ca_out = self.cross_attention(
467
+ Xq=x_norm,
468
+ Xkv=state.enc_out,
469
+ q_positions=state.dec_positions,
470
+ kv_positions=state.enc_positions,
471
+ attn_mask=state.dec_cross_attn_mask,
472
+ cache=cross_attn_cache,
473
+ )
474
+ x = residual + ca_out
475
+
476
+ residual = x
477
+ x_norm = self.pre_mlp_norm(x)
478
+ mlp_out = self.mlp(x_norm)
479
+ x = residual + mlp_out
480
+
481
+ return x
482
+
483
+
484
+ class Decoder(nn.Module):
485
+ """Transformer Decoder Stack using DenseGeneral."""
486
+
487
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
488
+ super().__init__()
489
+ self.config = config
490
+ model_config = config.model
491
+ dec_config = config.model.decoder
492
+ data_config = config.data
493
+ self.num_channels = data_config.channels
494
+ self.num_layers = dec_config.n_layer
495
+
496
+ self.embeddings = nn.ModuleList(
497
+ [
498
+ nn.Embedding(
499
+ model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype
500
+ )
501
+ for _ in range(self.num_channels)
502
+ ]
503
+ )
504
+ self.layers = nn.ModuleList(
505
+ [
506
+ DecoderLayer(config=config, compute_dtype=compute_dtype)
507
+ for _ in range(self.num_layers)
508
+ ]
509
+ )
510
+
511
+ self.norm = RMSNorm(
512
+ dec_config.n_embd,
513
+ eps=model_config.normalization_layer_epsilon,
514
+ dtype=torch.float32,
515
+ )
516
+
517
+ self.logits_dense = DenseGeneral(
518
+ in_shapes=(dec_config.n_embd,),
519
+ out_features=(self.num_channels, model_config.tgt_vocab_size),
520
+ axis=(-1,),
521
+ weight_dtype=compute_dtype,
522
+ )
523
+
524
+ def precompute_cross_attn_cache(
525
+ self,
526
+ enc_out: torch.Tensor, # (B, S, E)
527
+ enc_positions: torch.Tensor, # (B, S)
528
+ ) -> list[KVCache]:
529
+ """
530
+ Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
531
+ """
532
+ per_layer_kv_cache: list[KVCache] = []
533
+
534
+ for layer in self.layers:
535
+ cross_attn_module = layer.cross_attention
536
+ k_proj = cross_attn_module.k_proj(enc_out)
537
+ v_proj = cross_attn_module.v_proj(enc_out)
538
+
539
+ k_proj = cross_attn_module.rotary_emb(k_proj, position=enc_positions)
540
+ k = k_proj.transpose(1, 2)
541
+ v = v_proj.transpose(1, 2)
542
+
543
+ per_layer_kv_cache.append(KVCache.from_kv(k, v))
544
+
545
+ return per_layer_kv_cache
546
+
547
+ def decode_step(
548
+ self,
549
+ tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
550
+ state: DecoderInferenceState,
551
+ ) -> torch.Tensor:
552
+ """
553
+ Performs a single decoding step, managing KV caches layer by layer.
554
+
555
+ Returns:
556
+ A tuple containing:
557
+ - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
558
+ """
559
+
560
+ x = None
561
+ for i in range(self.num_channels):
562
+ channel_tokens = tgt_ids_Bx1xC[..., i]
563
+ channel_embed = self.embeddings[i](channel_tokens)
564
+ x = channel_embed if x is None else x + channel_embed
565
+
566
+ for i, layer in enumerate(self.layers):
567
+ self_cache = state.self_attn_cache[i]
568
+ cross_cache = state.cross_attn_cache[i]
569
+ x = layer(
570
+ x, # (2, 1, D)
571
+ state,
572
+ self_attn_cache=self_cache,
573
+ cross_attn_cache=cross_cache,
574
+ )
575
+
576
+ x = self.norm(x)
577
+ logits_Bx1xCxV = self.logits_dense(x)
578
+
579
+ return logits_Bx1xCxV.to(torch.float32)
580
+
581
+ def forward(
582
+ self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState
583
+ ) -> torch.Tensor:
584
+ """
585
+ Forward pass for the Decoder stack, managing KV caches.
586
+
587
+ Args:
588
+ tgt_ids_BxTxC: Target token IDs (B, T, C).
589
+ encoder_out: Output from the encoder (B, S, E).
590
+ tgt_positions: Positions for target sequence (B, T).
591
+ src_positions: Positions for source sequence (B, S).
592
+ self_attn_mask: Mask for self-attention.
593
+ cross_attn_mask: Mask for cross-attention.
594
+ past_key_values: List containing the self-attention KV cache for each layer
595
+ from the previous decoding step. `len(past_key_values)` should
596
+ equal `num_layers`.
597
+ precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
598
+ derived from `encoder_out`. This is passed identically
599
+ to all layers.
600
+
601
+ Returns:
602
+ A tuple containing:
603
+ - logits: The final output logits (B, T, C * V), cast to float32.
604
+ - present_key_values: A list containing the updated self-attention KV cache
605
+ for each layer for the *current* decoding step.
606
+ """
607
+ _, _, num_channels_in = tgt_ids_BxTxC.shape
608
+ assert num_channels_in == self.num_channels, "Input channels mismatch"
609
+
610
+ # Embeddings
611
+ x = None
612
+ for i in range(self.num_channels):
613
+ channel_tokens = tgt_ids_BxTxC[..., i]
614
+ channel_embed = self.embeddings[i](channel_tokens)
615
+ x = channel_embed if x is None else x + channel_embed
616
+
617
+ for i, layer in enumerate(self.layers):
618
+ self_cache = state.self_attn_cache[i]
619
+ cross_cache = state.cross_attn_cache[i]
620
+ x = layer(
621
+ x,
622
+ state,
623
+ self_attn_cache=self_cache,
624
+ cross_attn_cache=cross_cache,
625
+ prefill=True,
626
+ )
627
+
628
+ # Final Norm
629
+ x = self.norm(x)
630
+ logits_BxTxCxV = self.logits_dense(x)
631
+
632
+ return logits_BxTxCxV.to(torch.float32)
633
+
634
+
635
+ class DiaModel(nn.Module):
636
+ """PyTorch Dia Model using DenseGeneral."""
637
+
638
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
639
+ super().__init__()
640
+ self.config = config
641
+ self.encoder = Encoder(config, compute_dtype)
642
+ self.decoder = Decoder(config, compute_dtype)
dia/model.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from enum import Enum
3
+
4
+ import dac
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ from .audio import (
11
+ apply_audio_delay,
12
+ build_delay_indices,
13
+ build_revert_indices,
14
+ decode,
15
+ revert_audio_delay,
16
+ )
17
+ from .config import DiaConfig
18
+ from .layers import DiaModel
19
+ from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState
20
+
21
+
22
+ DEFAULT_SAMPLE_RATE = 44100
23
+
24
+
25
+ def _get_default_device():
26
+ if torch.cuda.is_available():
27
+ return torch.device("cuda")
28
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
29
+ return torch.device("mps")
30
+ return torch.device("cpu")
31
+
32
+
33
+ def _sample_next_token(
34
+ logits_BCxV: torch.Tensor,
35
+ temperature: float,
36
+ top_p: float,
37
+ cfg_filter_top_k: int | None = None,
38
+ ) -> torch.Tensor:
39
+ if temperature == 0.0:
40
+ return torch.argmax(logits_BCxV, dim=-1)
41
+
42
+ logits_BCxV = logits_BCxV / temperature
43
+ if cfg_filter_top_k is not None:
44
+ _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
45
+ mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
46
+ mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
47
+ logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
48
+
49
+ if top_p < 1.0:
50
+ probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
51
+ sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(
52
+ probs_BCxV, dim=-1, descending=True
53
+ )
54
+ cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
55
+
56
+ sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
57
+ sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[
58
+ ..., :-1
59
+ ].clone()
60
+ sorted_indices_to_remove_BCxV[..., 0] = 0
61
+
62
+ indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
63
+ indices_to_remove_BCxV.scatter_(
64
+ dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
65
+ )
66
+ logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
67
+
68
+ final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
69
+
70
+ sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1)
71
+ sampled_indices_C = sampled_indices_BC.squeeze(-1)
72
+ return sampled_indices_C
73
+
74
+
75
+ class ComputeDtype(str, Enum):
76
+ FLOAT32 = "float32"
77
+ FLOAT16 = "float16"
78
+ BFLOAT16 = "bfloat16"
79
+
80
+ def to_dtype(self) -> torch.dtype:
81
+ if self == ComputeDtype.FLOAT32:
82
+ return torch.float32
83
+ elif self == ComputeDtype.FLOAT16:
84
+ return torch.float16
85
+ elif self == ComputeDtype.BFLOAT16:
86
+ return torch.bfloat16
87
+ else:
88
+ raise ValueError(f"Unsupported compute dtype: {self}")
89
+
90
+
91
+ class Dia:
92
+ def __init__(
93
+ self,
94
+ config: DiaConfig,
95
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
96
+ device: torch.device | None = None,
97
+ ):
98
+ """Initializes the Dia model.
99
+
100
+ Args:
101
+ config: The configuration object for the model.
102
+ device: The device to load the model onto. If None, will automatically select the best available device.
103
+
104
+ Raises:
105
+ RuntimeError: If there is an error loading the DAC model.
106
+ """
107
+ super().__init__()
108
+ self.config = config
109
+ self.device = device if device is not None else _get_default_device()
110
+ if isinstance(compute_dtype, str):
111
+ compute_dtype = ComputeDtype(compute_dtype)
112
+ self.compute_dtype = compute_dtype.to_dtype()
113
+ self.model = DiaModel(config, self.compute_dtype)
114
+ self.dac_model = None
115
+
116
+ @classmethod
117
+ def from_local(
118
+ cls,
119
+ config_path: str,
120
+ checkpoint_path: str,
121
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
122
+ device: torch.device | None = None,
123
+ ) -> "Dia":
124
+ """Loads the Dia model from local configuration and checkpoint files.
125
+
126
+ Args:
127
+ config_path: Path to the configuration JSON file.
128
+ checkpoint_path: Path to the model checkpoint (.pth) file.
129
+ device: The device to load the model onto. If None, will automatically select the best available device.
130
+
131
+ Returns:
132
+ An instance of the Dia model loaded with weights and set to eval mode.
133
+
134
+ Raises:
135
+ FileNotFoundError: If the config or checkpoint file is not found.
136
+ RuntimeError: If there is an error loading the checkpoint.
137
+ """
138
+ config = DiaConfig.load(config_path)
139
+ if config is None:
140
+ raise FileNotFoundError(f"Config file not found at {config_path}")
141
+
142
+ dia = cls(config, compute_dtype, device)
143
+
144
+ try:
145
+ state_dict = torch.load(checkpoint_path, map_location=dia.device)
146
+ dia.model.load_state_dict(state_dict)
147
+ except FileNotFoundError:
148
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
149
+ except Exception as e:
150
+ raise RuntimeError(
151
+ f"Error loading checkpoint from {checkpoint_path}"
152
+ ) from e
153
+
154
+ dia.model.to(dia.device)
155
+ dia.model.eval()
156
+ dia._load_dac_model()
157
+ return dia
158
+
159
+ @classmethod
160
+ def from_pretrained(
161
+ cls,
162
+ model_name: str = "nari-labs/Dia-1.6B",
163
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
164
+ device: torch.device | None = None,
165
+ ) -> "Dia":
166
+ """Loads the Dia model from a Hugging Face Hub repository.
167
+
168
+ Downloads the configuration and checkpoint files from the specified
169
+ repository ID and then loads the model.
170
+
171
+ Args:
172
+ model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
173
+ device: The device to load the model onto. If None, will automatically select the best available device.
174
+
175
+ Returns:
176
+ An instance of the Dia model loaded with weights and set to eval mode.
177
+
178
+ Raises:
179
+ FileNotFoundError: If config or checkpoint download/loading fails.
180
+ RuntimeError: If there is an error loading the checkpoint.
181
+ """
182
+ config_path = hf_hub_download(repo_id=model_name, filename="config.json")
183
+ checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
184
+ return cls.from_local(config_path, checkpoint_path, compute_dtype, device)
185
+
186
+ def _load_dac_model(self):
187
+ try:
188
+ dac_model_path = dac.utils.download()
189
+ dac_model = dac.DAC.load(dac_model_path).to(self.device)
190
+ except Exception as e:
191
+ raise RuntimeError("Failed to load DAC model") from e
192
+ self.dac_model = dac_model
193
+
194
+ def _prepare_text_input(self, text: str) -> torch.Tensor:
195
+ """Encodes text prompt, pads, and creates attention mask and positions."""
196
+ text_pad_value = self.config.data.text_pad_value
197
+ max_len = self.config.data.text_length
198
+
199
+ byte_text = text.encode("utf-8")
200
+ replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02")
201
+ text_tokens = list(replaced_bytes)
202
+
203
+ current_len = len(text_tokens)
204
+ padding_needed = max_len - current_len
205
+ if padding_needed <= 0:
206
+ text_tokens = text_tokens[:max_len]
207
+ padded_text_np = np.array(text_tokens, dtype=np.uint8)
208
+ else:
209
+ padded_text_np = np.pad(
210
+ text_tokens,
211
+ (0, padding_needed),
212
+ mode="constant",
213
+ constant_values=text_pad_value,
214
+ ).astype(np.uint8)
215
+
216
+ src_tokens = (
217
+ torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0)
218
+ ) # [1, S]
219
+ return src_tokens
220
+
221
+ def _prepare_audio_prompt(
222
+ self, audio_prompt: torch.Tensor | None
223
+ ) -> tuple[torch.Tensor, int]:
224
+ num_channels = self.config.data.channels
225
+ audio_bos_value = self.config.data.audio_bos_value
226
+ audio_pad_value = self.config.data.audio_pad_value
227
+ delay_pattern = self.config.data.delay_pattern
228
+ max_delay_pattern = max(delay_pattern)
229
+
230
+ prefill = torch.full(
231
+ (1, num_channels),
232
+ fill_value=audio_bos_value,
233
+ dtype=torch.int,
234
+ device=self.device,
235
+ )
236
+
237
+ prefill_step = 1
238
+
239
+ if audio_prompt is not None:
240
+ prefill_step += audio_prompt.shape[0]
241
+ prefill = torch.cat([prefill, audio_prompt], dim=0)
242
+
243
+ delay_pad_tensor = torch.full(
244
+ (max_delay_pattern, num_channels),
245
+ fill_value=-1,
246
+ dtype=torch.int,
247
+ device=self.device,
248
+ )
249
+ prefill = torch.cat([prefill, delay_pad_tensor], dim=0)
250
+
251
+ delay_precomp = build_delay_indices(
252
+ B=1,
253
+ T=prefill.shape[0],
254
+ C=num_channels,
255
+ delay_pattern=delay_pattern,
256
+ )
257
+
258
+ prefill = apply_audio_delay(
259
+ audio_BxTxC=prefill.unsqueeze(0),
260
+ pad_value=audio_pad_value,
261
+ bos_value=audio_bos_value,
262
+ precomp=delay_precomp,
263
+ ).squeeze(0)
264
+
265
+ return prefill, prefill_step
266
+
267
+ def _prepare_generation(
268
+ self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool
269
+ ):
270
+ enc_input_cond = self._prepare_text_input(text)
271
+ enc_input_uncond = torch.zeros_like(enc_input_cond)
272
+ enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0)
273
+
274
+ if isinstance(audio_prompt, str):
275
+ audio_prompt = self.load_audio(audio_prompt)
276
+ prefill, prefill_step = self._prepare_audio_prompt(audio_prompt)
277
+
278
+ if verbose:
279
+ print("generate: data loaded")
280
+
281
+ enc_state = EncoderInferenceState.new(self.config, enc_input_cond)
282
+ encoder_out = self.model.encoder(enc_input, enc_state)
283
+
284
+ dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(
285
+ encoder_out, enc_state.positions
286
+ )
287
+ dec_state = DecoderInferenceState.new(
288
+ self.config,
289
+ enc_state,
290
+ encoder_out,
291
+ dec_cross_attn_cache,
292
+ self.compute_dtype,
293
+ )
294
+ dec_output = DecoderOutput.new(self.config, self.device)
295
+ dec_output.prefill(prefill, prefill_step)
296
+
297
+ dec_step = prefill_step - 1
298
+ if dec_step > 0:
299
+ dec_state.prepare_step(0, dec_step)
300
+ tokens_BxTxC = (
301
+ dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1)
302
+ )
303
+ self.model.decoder.forward(tokens_BxTxC, dec_state)
304
+
305
+ return dec_state, dec_output
306
+
307
+ def _decoder_step(
308
+ self,
309
+ tokens_Bx1xC: torch.Tensor,
310
+ dec_state: DecoderInferenceState,
311
+ cfg_scale: float,
312
+ temperature: float,
313
+ top_p: float,
314
+ cfg_filter_top_k: int,
315
+ ) -> torch.Tensor:
316
+ audio_eos_value = self.config.data.audio_eos_value
317
+ logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state)
318
+
319
+ logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :]
320
+ uncond_logits_CxV = logits_last_BxCxV[0, :, :]
321
+ cond_logits_CxV = logits_last_BxCxV[1, :, :]
322
+
323
+ logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
324
+ logits_CxV[:, audio_eos_value + 1 :] = -torch.inf
325
+ logits_CxV[1:, audio_eos_value:] = -torch.inf
326
+
327
+ pred_C = _sample_next_token(
328
+ logits_CxV.float(),
329
+ temperature=temperature,
330
+ top_p=top_p,
331
+ cfg_filter_top_k=cfg_filter_top_k,
332
+ )
333
+ return pred_C
334
+
335
+ def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray:
336
+ num_channels = self.config.data.channels
337
+ seq_length = generated_codes.shape[0]
338
+ delay_pattern = self.config.data.delay_pattern
339
+ audio_pad_value = self.config.data.audio_pad_value
340
+ max_delay_pattern = max(delay_pattern)
341
+
342
+ revert_precomp = build_revert_indices(
343
+ B=1,
344
+ T=seq_length,
345
+ C=num_channels,
346
+ delay_pattern=delay_pattern,
347
+ )
348
+
349
+ codebook = revert_audio_delay(
350
+ audio_BxTxC=generated_codes.unsqueeze(0),
351
+ pad_value=audio_pad_value,
352
+ precomp=revert_precomp,
353
+ T=seq_length,
354
+ )[:, :-max_delay_pattern, :]
355
+
356
+ min_valid_index = 0
357
+ max_valid_index = 1023
358
+ invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
359
+ codebook[invalid_mask] = 0
360
+
361
+ audio = decode(self.dac_model, codebook.transpose(1, 2))
362
+
363
+ return audio.squeeze().cpu().numpy()
364
+
365
+ def load_audio(self, audio_path: str) -> torch.Tensor:
366
+ audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T
367
+ if sr != DEFAULT_SAMPLE_RATE:
368
+ audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE)
369
+ audio = audio.to(self.device).unsqueeze(0) # 1, C, T
370
+ audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE)
371
+ _, encoded_frame, _, _, _ = self.dac_model.encode(audio_data) # 1, C, T
372
+ return encoded_frame.squeeze(0).transpose(0, 1)
373
+
374
+ def save_audio(self, path: str, audio: np.ndarray):
375
+ import soundfile as sf
376
+
377
+ sf.write(path, audio, DEFAULT_SAMPLE_RATE)
378
+
379
+ @torch.inference_mode()
380
+ def generate(
381
+ self,
382
+ text: str,
383
+ max_tokens: int | None = None,
384
+ cfg_scale: float = 3.0,
385
+ temperature: float = 1.3,
386
+ top_p: float = 0.95,
387
+ use_torch_compile: bool = False,
388
+ cfg_filter_top_k: int = 35,
389
+ audio_prompt: str | torch.Tensor | None = None,
390
+ audio_prompt_path: str | None = None,
391
+ use_cfg_filter: bool | None = None,
392
+ verbose: bool = False,
393
+ ) -> np.ndarray:
394
+ audio_eos_value = self.config.data.audio_eos_value
395
+ audio_pad_value = self.config.data.audio_pad_value
396
+ delay_pattern = self.config.data.delay_pattern
397
+ max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
398
+ max_delay_pattern = max(delay_pattern)
399
+ self.model.eval()
400
+
401
+ if audio_prompt_path:
402
+ print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.")
403
+ audio_prompt = audio_prompt_path
404
+ if use_cfg_filter is not None:
405
+ print("Warning: use_cfg_filter is deprecated.")
406
+
407
+ if verbose:
408
+ total_start_time = time.time()
409
+
410
+ dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose)
411
+ dec_step = dec_output.prefill_step - 1
412
+
413
+ bos_countdown = max_delay_pattern
414
+ eos_detected = False
415
+ eos_countdown = -1
416
+
417
+ if use_torch_compile:
418
+ step_fn = torch.compile(self._decoder_step, mode="default")
419
+ else:
420
+ step_fn = self._decoder_step
421
+
422
+ if verbose:
423
+ print("generate: starting generation loop")
424
+ if use_torch_compile:
425
+ print(
426
+ "generate: by using use_torch_compile=True, the first step would take long"
427
+ )
428
+ start_time = time.time()
429
+
430
+ while dec_step < max_tokens:
431
+ dec_state.prepare_step(dec_step)
432
+ tokens_Bx1xC = (
433
+ dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1)
434
+ )
435
+ pred_C = step_fn(
436
+ tokens_Bx1xC,
437
+ dec_state,
438
+ cfg_scale,
439
+ temperature,
440
+ top_p,
441
+ cfg_filter_top_k,
442
+ )
443
+
444
+ if (
445
+ not eos_detected and pred_C[0] == audio_eos_value
446
+ ) or dec_step == max_tokens - max_delay_pattern - 1:
447
+ eos_detected = True
448
+ eos_countdown = max_delay_pattern
449
+
450
+ if eos_countdown > 0:
451
+ step_after_eos = max_delay_pattern - eos_countdown
452
+ for i, d in enumerate(delay_pattern):
453
+ if step_after_eos == d:
454
+ pred_C[i] = audio_eos_value
455
+ elif step_after_eos > d:
456
+ pred_C[i] = audio_pad_value
457
+ eos_countdown -= 1
458
+
459
+ bos_countdown = max(0, bos_countdown - 1)
460
+ dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0)
461
+
462
+ if eos_countdown == 0:
463
+ break
464
+
465
+ dec_step += 1
466
+ if verbose and dec_step % 86 == 0:
467
+ duration = time.time() - start_time
468
+ print(
469
+ f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x"
470
+ )
471
+ start_time = time.time()
472
+
473
+ if dec_output.prefill_step >= dec_step + 1:
474
+ print("Warning: Nothing generated")
475
+ return None
476
+
477
+ generated_codes = dec_output.generated_tokens[
478
+ dec_output.prefill_step : dec_step + 1, :
479
+ ]
480
+
481
+ if verbose:
482
+ total_step = dec_step + 1 - dec_output.prefill_step
483
+ total_duration = time.time() - total_start_time
484
+ print(
485
+ f"generate: total step={total_step}, total duration={total_duration:.3f}s"
486
+ )
487
+
488
+ return self._generate_output(generated_codes)
dia/state.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from .config import DiaConfig
6
+
7
+
8
+ def create_attn_mask(
9
+ q_padding_mask_1d: torch.Tensor,
10
+ k_padding_mask_1d: torch.Tensor,
11
+ device: torch.device,
12
+ is_causal: bool = False,
13
+ ) -> torch.Tensor:
14
+ """
15
+ Creates the attention mask (self or cross) mimicking JAX segment ID logic.
16
+ """
17
+ B1, Tq = q_padding_mask_1d.shape
18
+ B2, Tk = k_padding_mask_1d.shape
19
+ assert B1 == B2, "Query and key batch dimensions must match"
20
+
21
+ p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
22
+ p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
23
+
24
+ # Condition A: Non-padding query attends to non-padding key
25
+ non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
26
+
27
+ # Condition B: Padding query attends to padding key
28
+ pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
29
+
30
+ # Combine: True if padding status is compatible (both non-pad OR both pad)
31
+ mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
32
+
33
+ if is_causal:
34
+ assert Tq == Tk, (
35
+ "Causal mask requires query and key sequence lengths to be equal"
36
+ )
37
+ causal_mask_2d = torch.tril(
38
+ torch.ones((Tq, Tk), dtype=torch.bool, device=device)
39
+ ) # Shape [Tq, Tk]
40
+ causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
41
+ return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
42
+ else:
43
+ return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
44
+
45
+
46
+ @dataclass
47
+ class EncoderInferenceState:
48
+ """Parameters specifically for encoder inference."""
49
+
50
+ max_seq_len: int
51
+ device: torch.device
52
+ positions: torch.Tensor
53
+ padding_mask: torch.Tensor
54
+ attn_mask: torch.Tensor
55
+
56
+ @classmethod
57
+ def new(cls, config: DiaConfig, cond_src: torch.Tensor) -> "EncoderInferenceState":
58
+ """Creates EtorchrInferenceParams from DiaConfig and a device."""
59
+ device = cond_src.device
60
+
61
+ positions = (
62
+ torch.arange(config.data.text_length, device=device)
63
+ .to(torch.long)
64
+ .unsqueeze(0)
65
+ .expand(2, -1)
66
+ )
67
+ padding_mask = (cond_src != config.data.text_pad_value).to(device).expand(2, -1)
68
+ attn_mask = create_attn_mask(
69
+ padding_mask, padding_mask, device, is_causal=False
70
+ )
71
+
72
+ return cls(
73
+ max_seq_len=config.data.text_length,
74
+ device=device,
75
+ positions=positions,
76
+ padding_mask=padding_mask,
77
+ attn_mask=attn_mask,
78
+ )
79
+
80
+
81
+ class KVCache:
82
+ def __init__(
83
+ self,
84
+ num_heads: int,
85
+ max_len: int,
86
+ head_dim: int,
87
+ dtype: torch.dtype,
88
+ device: torch.device,
89
+ k: torch.Tensor | None = None,
90
+ v: torch.Tensor | None = None,
91
+ ):
92
+ self.k = (
93
+ torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
94
+ if k is None
95
+ else k
96
+ )
97
+ self.v = (
98
+ torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
99
+ if v is None
100
+ else v
101
+ )
102
+ self.current_idx = torch.tensor(0)
103
+
104
+ @classmethod
105
+ def from_kv(cls, k: torch.Tensor, v: torch.Tensor) -> "KVCache":
106
+ return cls(
107
+ num_heads=k.shape[1],
108
+ max_len=k.shape[2],
109
+ head_dim=k.shape[3],
110
+ dtype=k.dtype,
111
+ device=k.device,
112
+ k=k,
113
+ v=v,
114
+ )
115
+
116
+ def update(
117
+ self, k: torch.Tensor, v: torch.Tensor
118
+ ) -> tuple[torch.Tensor, torch.Tensor]:
119
+ self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
120
+ self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
121
+ self.current_idx += 1
122
+ return self.k[:, :, : self.current_idx, :], self.v[:, :, : self.current_idx, :]
123
+
124
+ def prefill(
125
+ self, k: torch.Tensor, v: torch.Tensor
126
+ ) -> tuple[torch.Tensor, torch.Tensor]:
127
+ prefill_len = k.shape[2]
128
+ self.k[:, :, :prefill_len, :] = k
129
+ self.v[:, :, :prefill_len, :] = v
130
+ self.current_idx = prefill_len - 1
131
+
132
+
133
+ @dataclass
134
+ class DecoderInferenceState:
135
+ """Parameters specifically for decoder inference."""
136
+
137
+ device: torch.device
138
+ dtype: torch.dtype
139
+ enc_out: torch.Tensor
140
+ enc_positions: torch.Tensor
141
+ dec_positions: torch.Tensor
142
+ dec_cross_attn_mask: torch.Tensor
143
+ self_attn_cache: list[KVCache]
144
+ cross_attn_cache: list[KVCache]
145
+
146
+ @classmethod
147
+ def new(
148
+ cls,
149
+ config: DiaConfig,
150
+ enc_state: EncoderInferenceState,
151
+ enc_out: torch.Tensor,
152
+ dec_cross_attn_cache: list[KVCache],
153
+ compute_dtype: torch.dtype,
154
+ ) -> "DecoderInferenceState":
155
+ """Creates DecoderInferenceParams from DiaConfig and a device."""
156
+ device = enc_out.device
157
+ max_audio_len = config.data.audio_length
158
+
159
+ dec_positions = torch.full(
160
+ (2, 1), fill_value=0, dtype=torch.long, device=device
161
+ )
162
+ tgt_padding_mask = torch.ones((2, 1), dtype=torch.bool, device=device)
163
+ dec_cross_attn_mask = create_attn_mask(
164
+ tgt_padding_mask, enc_state.padding_mask, device, is_causal=False
165
+ )
166
+
167
+ self_attn_cache = [
168
+ KVCache(
169
+ config.model.decoder.kv_heads,
170
+ max_audio_len,
171
+ config.model.decoder.gqa_head_dim,
172
+ compute_dtype,
173
+ device,
174
+ )
175
+ for _ in range(config.model.decoder.n_layer)
176
+ ]
177
+
178
+ return cls(
179
+ device=device,
180
+ dtype=compute_dtype,
181
+ enc_out=enc_out,
182
+ enc_positions=enc_state.positions,
183
+ dec_positions=dec_positions,
184
+ dec_cross_attn_mask=dec_cross_attn_mask,
185
+ self_attn_cache=self_attn_cache,
186
+ cross_attn_cache=dec_cross_attn_cache,
187
+ )
188
+
189
+ def prepare_step(self, step_from: int, step_to: int | None = None) -> None:
190
+ if step_to is None:
191
+ step_to = step_from + 1
192
+ self.dec_positions = (
193
+ torch.arange(step_from, step_to, device=self.device)
194
+ .unsqueeze(0)
195
+ .expand(2, -1)
196
+ )
197
+
198
+
199
+ @dataclass
200
+ class DecoderOutput:
201
+ generated_tokens: torch.Tensor
202
+ prefill_step: int
203
+
204
+ @classmethod
205
+ def new(cls, config: DiaConfig, device: torch.device) -> "DecoderOutput":
206
+ max_audio_len = config.data.audio_length
207
+ return cls(
208
+ generated_tokens=torch.full(
209
+ (max_audio_len, config.data.channels),
210
+ fill_value=-1,
211
+ dtype=torch.int,
212
+ device=device,
213
+ ),
214
+ prefill_step=0,
215
+ )
216
+
217
+ def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor:
218
+ if step_to is None:
219
+ step_to = step_from + 1
220
+ return self.generated_tokens[step_from:step_to, :]
221
+
222
+ def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
223
+ if apply_mask:
224
+ mask = self.generated_tokens[step : step + 1, :] == -1
225
+ self.generated_tokens[step : step + 1, :] = torch.where(
226
+ mask, dec_out, self.generated_tokens[step : step + 1, :]
227
+ )
228
+ else:
229
+ self.generated_tokens[step : step + 1, :] = dec_out
230
+
231
+ def prefill(self, dec_out: torch.Tensor, prefill_step: int):
232
+ length = dec_out.shape[0]
233
+ self.generated_tokens[0:length, :] = dec_out
234
+ self.prefill_step = prefill_step
dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use official lightweight Python image
2
+ FROM python:3.9-slim
3
+
4
+ # Create a non-root user
5
+ RUN useradd -m -u 1000 user
6
+
7
+ # Switch to the new user
8
+ USER user
9
+
10
+ # Set environment variables
11
+ ENV PATH="/home/user/.local/bin:$PATH"
12
+
13
+ # Set working directory
14
+ WORKDIR /app
15
+
16
+ # Copy requirements first for caching
17
+ COPY --chown=user requirements.txt .
18
+
19
+ # Install dependencies
20
+ RUN pip install --no-cache-dir --upgrade pip \
21
+ && pip install --no-cache-dir -r requirements.txt
22
+
23
+ # Now copy the rest of the app code
24
+ COPY --chown=user . .
25
+
26
+ # Expose the port (optional, for documentation)
27
+ EXPOSE 7860
28
+
29
+ # Command to run the app
30
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ dac==0.4.3
2
+ fastapi==0.115.12
3
+ huggingface_hub==0.30.2
4
+ numpy==2.2.5
5
+ pydantic==2.11.3
6
+ soundfile==0.13.1
7
+ torch==2.6.0
8
+ torchaudio==2.6.0
9
+
templates/index.html ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Dia Text-to-Speech Converter</title>
7
+ <link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.16/dist/tailwind.min.css" rel="stylesheet">
8
+ </head>
9
+ <body class="bg-gray-100 min-h-screen flex items-center justify-center p-4">
10
+ <div class="container max-w-3xl mx-auto p-8 bg-white rounded-xl shadow-lg">
11
+ <h1 class="text-3xl font-bold text-center text-indigo-700 mb-8">Dia Text-to-Speech Converter</h1>
12
+
13
+ <form action="/convertor" method="post" enctype="multipart/form-data" class="space-y-6">
14
+ <div class="mb-6">
15
+ <label for="paragraph" class="block text-gray-700 font-semibold mb-2">Enter Text to Convert</label>
16
+ <textarea
17
+ id="paragraph"
18
+ name="paragraph"
19
+ rows="6"
20
+ class="w-full p-3 border border-gray-300 rounded-lg focus:ring-2 focus:ring-indigo-500 focus:border-indigo-500 transition"
21
+ placeholder="Type or paste your text here..."
22
+ required
23
+ ></textarea>
24
+ </div>
25
+
26
+ <div class="mb-6">
27
+ <label for="action" class="block text-gray-700 font-semibold mb-2">Choose Action</label>
28
+ <select
29
+ id="action"
30
+ name="action"
31
+ class="w-full p-3 border border-gray-300 rounded-lg focus:ring-2 focus:ring-indigo-500 focus:border-indigo-500 transition"
32
+ >
33
+ <option value="audio">Convert to Audio</option>
34
+ <option value="summarize" disabled>Summarize (Coming Soon)</option>
35
+ </select>
36
+ </div>
37
+
38
+ <div class="flex justify-center">
39
+ <button
40
+ type="submit"
41
+ class="bg-indigo-600 hover:bg-indigo-700 text-white font-bold py-3 px-8 rounded-lg shadow-md transition duration-300 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-opacity-50"
42
+ >
43
+ Generate Audio
44
+ </button>
45
+ </div>
46
+ </form>
47
+
48
+ <div id="result" class="mt-8 text-center">
49
+ <!-- Audio player will appear here after conversion -->
50
+ </div>
51
+
52
+ <div class="mt-8 text-center text-gray-600 text-sm">
53
+ <p>Powered by Dia-1.6B AI Text-to-Speech Model</p>
54
+ </div>
55
+ </div>
56
+
57
+ <script>
58
+ document.addEventListener('DOMContentLoaded', function() {
59
+ const form = document.querySelector('form');
60
+ const resultDiv = document.getElementById('result');
61
+
62
+ form.addEventListener('submit', async function(e) {
63
+ e.preventDefault();
64
+
65
+ const submitButton = form.querySelector('button[type="submit"]');
66
+ submitButton.disabled = true;
67
+ submitButton.innerHTML = 'Processing...';
68
+
69
+ try {
70
+ const formData = new FormData(form);
71
+ const response = await fetch('/convertor', {
72
+ method: 'POST',
73
+ body: formData
74
+ });
75
+
76
+ if (response.ok) {
77
+ const blob = await response.blob();
78
+ const audioUrl = URL.createObjectURL(blob);
79
+
80
+ resultDiv.innerHTML = `
81
+ <h2 class="text-xl font-semibold text-gray-800 mb-4">Your Audio is Ready!</h2>
82
+ <audio controls class="mx-auto mb-4">
83
+ <source src="${audioUrl}" type="audio/wav">
84
+ Your browser does not support the audio element.
85
+ </audio>
86
+ <a href="${audioUrl}" download="generated_audio.wav" class="inline-block bg-green-600 hover:bg-green-700 text-white font-bold py-2 px-4 rounded-lg shadow-md transition duration-300">
87
+ Download Audio
88
+ </a>
89
+ `;
90
+ } else {
91
+ const errorData = await response.json();
92
+ resultDiv.innerHTML = `<p class="text-red-600">Error: ${errorData.detail}</p>`;
93
+ }
94
+ } catch (error) {
95
+ resultDiv.innerHTML = `<p class="text-red-600">Error: ${error.message}</p>`;
96
+ } finally {
97
+ submitButton.disabled = false;
98
+ submitButton.innerHTML = 'Generate Audio';
99
+ }
100
+ });
101
+ });
102
+ </script>
103
+ </body>
104
+ </html>