UDface11jkj commited on
Commit
b13f37a
·
verified ·
1 Parent(s): 619d84d

Update layers.py

Browse files
Files changed (1) hide show
  1. layers.py +642 -642
layers.py CHANGED
@@ -1,642 +1,642 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch import Tensor
5
- from torch.nn import RMSNorm
6
-
7
- from .config import DiaConfig
8
- from .state import DecoderInferenceState, EncoderInferenceState, KVCache
9
-
10
-
11
- def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
12
- return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
13
-
14
-
15
- class DenseGeneral(nn.Module):
16
- """
17
- PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
18
-
19
- Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
20
- for the generalized matrix multiplication. Weight/bias shapes are calculated
21
- and parameters created during initialization based on config.
22
- `load_weights` validates shapes and copies data.
23
-
24
- Attributes:
25
- axis (Tuple[int, ...]): Input axis or axes to contract.
26
- in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
27
- out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
28
- use_bias (bool): Whether to add a bias term.
29
- weight (nn.Parameter): The kernel parameter.
30
- bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
31
- """
32
-
33
- def __init__(
34
- self,
35
- in_shapes: tuple[int, ...],
36
- out_features: tuple[int, ...],
37
- axis: tuple[int, ...] = (-1,),
38
- weight_dtype: torch.dtype | None = None,
39
- device: torch.device | None = None,
40
- ):
41
- super().__init__()
42
- self.in_shapes = in_shapes
43
- self.out_features = out_features
44
- self.axis = axis
45
- self.kernel_shape = self.in_shapes + self.out_features
46
-
47
- factory_kwargs = {"device": device, "dtype": weight_dtype}
48
- self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
49
- self.register_parameter("bias", None)
50
-
51
- def forward(self, inputs: Tensor) -> Tensor:
52
- norm_axis = _normalize_axes(self.axis, inputs.ndim)
53
- kernel_contract_axes = tuple(range(len(norm_axis)))
54
-
55
- output = torch.tensordot(
56
- inputs.to(self.weight.dtype),
57
- self.weight,
58
- dims=(norm_axis, kernel_contract_axes),
59
- ).to(inputs.dtype)
60
- return output
61
-
62
-
63
- class MlpBlock(nn.Module):
64
- """MLP block using DenseGeneral."""
65
-
66
- def __init__(
67
- self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype
68
- ):
69
- super().__init__()
70
- self.dtype = compute_dtype
71
-
72
- self.wi_fused = DenseGeneral(
73
- in_shapes=(embed_dim,),
74
- out_features=(2, intermediate_dim),
75
- axis=(-1,),
76
- weight_dtype=compute_dtype,
77
- )
78
-
79
- self.wo = DenseGeneral(
80
- in_shapes=(intermediate_dim,),
81
- out_features=(embed_dim,),
82
- axis=(-1,),
83
- weight_dtype=compute_dtype,
84
- )
85
-
86
- def forward(self, x: torch.Tensor) -> torch.Tensor:
87
- """Forward pass."""
88
- fused_x = self.wi_fused(x)
89
-
90
- gate = fused_x[..., 0, :]
91
- up = fused_x[..., 1, :]
92
-
93
- hidden = torch.mul(F.silu(gate), up).to(self.dtype)
94
-
95
- output = self.wo(hidden)
96
- return output
97
-
98
-
99
- class RotaryEmbedding(nn.Module):
100
- """Rotary Position Embedding (RoPE) implementation in PyTorch."""
101
-
102
- def __init__(
103
- self,
104
- embedding_dims: int,
105
- min_timescale: int = 1,
106
- max_timescale: int = 10000,
107
- dtype: torch.dtype = torch.float32,
108
- ):
109
- super().__init__()
110
- if embedding_dims % 2 != 0:
111
- raise ValueError("Embedding dim must be even for RoPE.")
112
- self.embedding_dims = embedding_dims
113
- self.min_timescale = min_timescale
114
- self.max_timescale = max_timescale
115
- self.dtype = dtype
116
-
117
- half_embedding_dim = embedding_dims // 2
118
- fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
119
- self.register_buffer(
120
- "timescale",
121
- self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction,
122
- persistent=False,
123
- )
124
-
125
- def extra_repr(self) -> str:
126
- s = f"{self.timescale.shape}"
127
- return s
128
-
129
- def forward(self, inputs: torch.Tensor, position: torch.Tensor):
130
- """Applies RoPE."""
131
- position = position.unsqueeze(-1).unsqueeze(-1)
132
- timescale = self.timescale.to(inputs.device)
133
- sinusoid_inp = position / timescale
134
- sin = torch.sin(sinusoid_inp).to(inputs.dtype)
135
- cos = torch.cos(sinusoid_inp).to(inputs.dtype)
136
- first_half, second_half = torch.chunk(inputs, 2, dim=-1)
137
- first_part = first_half * cos - second_half * sin
138
- second_part = second_half * cos + first_half * sin
139
- return torch.cat((first_part, second_part), dim=-1)
140
-
141
-
142
- class Attention(nn.Module):
143
- """Attention using DenseGeneral."""
144
-
145
- def __init__(
146
- self,
147
- config: DiaConfig,
148
- q_embed_dim: int,
149
- kv_embed_dim: int,
150
- num_query_heads: int,
151
- num_kv_heads: int,
152
- head_dim: int,
153
- compute_dtype: torch.dtype,
154
- is_cross_attn: bool = False,
155
- out_embed_dim: int | None = None,
156
- ):
157
- super().__init__()
158
- self.num_query_heads = num_query_heads
159
- self.num_kv_heads = num_kv_heads
160
- self.head_dim = head_dim
161
- self.is_cross_attn = is_cross_attn
162
- self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
163
- self.projected_query_dim = num_query_heads * head_dim
164
- if num_query_heads % num_kv_heads != 0:
165
- raise ValueError(
166
- f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
167
- )
168
- self.num_gqa_groups = num_query_heads // num_kv_heads
169
-
170
- # --- Projection Layers using DenseGeneral ---
171
- self.q_proj = DenseGeneral(
172
- in_shapes=(q_embed_dim,),
173
- out_features=(num_query_heads, head_dim),
174
- axis=(-1,),
175
- weight_dtype=compute_dtype,
176
- )
177
- self.k_proj = DenseGeneral(
178
- in_shapes=(kv_embed_dim,),
179
- out_features=(num_kv_heads, head_dim),
180
- axis=(-1,),
181
- weight_dtype=compute_dtype,
182
- )
183
- self.v_proj = DenseGeneral(
184
- in_shapes=(kv_embed_dim,),
185
- out_features=(num_kv_heads, head_dim),
186
- axis=(-1,),
187
- weight_dtype=compute_dtype,
188
- )
189
- self.o_proj = DenseGeneral(
190
- in_shapes=(num_query_heads, head_dim),
191
- out_features=(self.output_dim,),
192
- axis=(-2, -1),
193
- weight_dtype=compute_dtype,
194
- )
195
-
196
- # --- Rotary Embedding ---
197
- self.rotary_emb = RotaryEmbedding(
198
- embedding_dims=self.head_dim,
199
- min_timescale=config.model.rope_min_timescale,
200
- max_timescale=config.model.rope_max_timescale,
201
- dtype=compute_dtype,
202
- )
203
-
204
- def forward(
205
- self,
206
- Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
207
- Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
208
- q_positions: torch.Tensor, # (B, T)
209
- kv_positions: torch.Tensor | None = None, # (B, S)
210
- attn_mask: torch.Tensor
211
- | None = None, # None in Decoder Self Attention, Valid mask in Others
212
- cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
213
- prefill: bool = False,
214
- is_causal: bool = False,
215
- ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
216
- """
217
- Performs attention calculation with optional KV caching.
218
-
219
- Args:
220
- Xq: Query tensor (B, T, D). T=1 during single-step decoding.
221
- Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
222
- q_positions: Positions for queries (B, T).
223
- kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
224
- attn_mask: Attention mask.
225
- cache: KVCache.
226
- prefill: If True, use prefill mode.
227
-
228
- Returns:
229
- A tuple containing:
230
- - output: The attention output tensor (B, T, output_dim).
231
- - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
232
- """
233
- if kv_positions is None:
234
- kv_positions = q_positions
235
- original_dtype = Xq.dtype
236
-
237
- Xq_BxTxNxH = self.q_proj(Xq)
238
- Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
239
- Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
240
-
241
- attn_k: torch.Tensor | None = None
242
- attn_v: torch.Tensor | None = None
243
-
244
- if self.is_cross_attn:
245
- attn_k, attn_v = cache.k, cache.v
246
- else:
247
- Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
248
- Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
249
- Xk_BxSxKxH = self.rotary_emb(
250
- Xk_BxSxKxH, position=kv_positions
251
- ) # (B, S, K, H)
252
-
253
- Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
254
- Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
255
-
256
- if cache is None:
257
- attn_k = Xk_BxKxSxH
258
- attn_v = Xv_BxKxSxH
259
- else:
260
- if prefill:
261
- attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
262
- cache.prefill(attn_k, attn_v)
263
- else:
264
- attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH)
265
-
266
- attn_output = F.scaled_dot_product_attention(
267
- Xq_BxNxTxH,
268
- attn_k,
269
- attn_v,
270
- attn_mask=attn_mask,
271
- scale=1.0,
272
- enable_gqa=self.num_gqa_groups > 1,
273
- is_causal=is_causal,
274
- )
275
-
276
- attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
277
- output = self.o_proj(attn_output)
278
-
279
- return output.to(original_dtype)
280
-
281
-
282
- class EncoderLayer(nn.Module):
283
- """Transformer Encoder Layer using DenseGeneral."""
284
-
285
- def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
286
- super().__init__()
287
- self.config = config
288
- model_config = config.model
289
- enc_config = config.model.encoder
290
- embed_dim = enc_config.n_embd
291
-
292
- self.pre_sa_norm = RMSNorm(
293
- embed_dim,
294
- eps=model_config.normalization_layer_epsilon,
295
- dtype=torch.float32,
296
- )
297
- self.self_attention = Attention(
298
- config,
299
- q_embed_dim=embed_dim,
300
- kv_embed_dim=embed_dim,
301
- num_query_heads=enc_config.n_head,
302
- num_kv_heads=enc_config.n_head,
303
- head_dim=enc_config.head_dim,
304
- compute_dtype=compute_dtype,
305
- is_cross_attn=False,
306
- out_embed_dim=embed_dim,
307
- )
308
- self.post_sa_norm = RMSNorm(
309
- embed_dim,
310
- eps=model_config.normalization_layer_epsilon,
311
- dtype=torch.float32,
312
- )
313
- self.mlp = MlpBlock(
314
- embed_dim=embed_dim,
315
- intermediate_dim=enc_config.n_hidden,
316
- compute_dtype=compute_dtype,
317
- )
318
-
319
- def forward(
320
- self,
321
- x: torch.Tensor,
322
- state: EncoderInferenceState,
323
- ) -> torch.Tensor:
324
- residual = x
325
- x_norm = self.pre_sa_norm(x)
326
- sa_out = self.self_attention(
327
- Xq=x_norm,
328
- Xkv=x_norm,
329
- q_positions=state.positions,
330
- kv_positions=state.positions,
331
- attn_mask=state.attn_mask,
332
- )
333
- x = residual + sa_out
334
-
335
- residual = x
336
- x_norm = self.post_sa_norm(x)
337
- mlp_out = self.mlp(x_norm)
338
- x = residual + mlp_out
339
-
340
- return x
341
-
342
-
343
- class Encoder(nn.Module):
344
- """Transformer Encoder Stack using DenseGeneral."""
345
-
346
- def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
347
- super().__init__()
348
- self.config = config
349
- model_config = config.model
350
- enc_config = config.model.encoder
351
-
352
- self.embedding = nn.Embedding(
353
- model_config.src_vocab_size,
354
- enc_config.n_embd,
355
- dtype=compute_dtype,
356
- )
357
- self.layers = nn.ModuleList(
358
- [EncoderLayer(config, compute_dtype) for _ in range(enc_config.n_layer)]
359
- )
360
- self.norm = RMSNorm(
361
- enc_config.n_embd,
362
- eps=model_config.normalization_layer_epsilon,
363
- dtype=torch.float32,
364
- )
365
-
366
- def forward(
367
- self,
368
- x_ids: torch.Tensor,
369
- state: EncoderInferenceState,
370
- ) -> torch.Tensor:
371
- x = self.embedding(x_ids)
372
-
373
- for layer in self.layers:
374
- x = layer(x, state)
375
-
376
- x = self.norm(x)
377
- return x
378
-
379
-
380
- class DecoderLayer(nn.Module):
381
- """Transformer Decoder Layer using DenseGeneral."""
382
-
383
- def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
384
- super().__init__()
385
- self.config = config
386
- model_config = config.model
387
- dec_config = config.model.decoder
388
- enc_config = config.model.encoder
389
- dec_embed_dim = dec_config.n_embd
390
- enc_embed_dim = enc_config.n_embd
391
-
392
- # Norms
393
- self.pre_sa_norm = RMSNorm(
394
- dec_embed_dim,
395
- eps=model_config.normalization_layer_epsilon,
396
- dtype=torch.float32,
397
- )
398
- self.pre_ca_norm = RMSNorm(
399
- dec_embed_dim,
400
- eps=model_config.normalization_layer_epsilon,
401
- dtype=torch.float32,
402
- )
403
- self.pre_mlp_norm = RMSNorm(
404
- dec_embed_dim,
405
- eps=model_config.normalization_layer_epsilon,
406
- dtype=torch.float32,
407
- )
408
-
409
- # Self-Attention (GQA) with Causal Masking
410
- self.self_attention = Attention(
411
- config,
412
- q_embed_dim=dec_embed_dim,
413
- kv_embed_dim=dec_embed_dim,
414
- num_query_heads=dec_config.gqa_query_heads,
415
- num_kv_heads=dec_config.kv_heads,
416
- head_dim=dec_config.gqa_head_dim,
417
- compute_dtype=compute_dtype,
418
- is_cross_attn=False,
419
- out_embed_dim=dec_embed_dim,
420
- )
421
- # Cross-Attention (MHA)
422
- self.cross_attention = Attention(
423
- config=config,
424
- q_embed_dim=dec_embed_dim,
425
- kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
426
- num_query_heads=dec_config.cross_query_heads,
427
- num_kv_heads=dec_config.cross_query_heads,
428
- head_dim=dec_config.cross_head_dim,
429
- compute_dtype=compute_dtype,
430
- is_cross_attn=True,
431
- out_embed_dim=dec_embed_dim,
432
- )
433
- # MLP
434
- self.mlp = MlpBlock(
435
- embed_dim=dec_embed_dim,
436
- intermediate_dim=dec_config.n_hidden,
437
- compute_dtype=compute_dtype,
438
- )
439
-
440
- def forward(
441
- self,
442
- x: torch.Tensor,
443
- state: DecoderInferenceState,
444
- self_attn_cache: KVCache | None = None,
445
- cross_attn_cache: KVCache | None = None,
446
- prefill: bool = False,
447
- ) -> torch.Tensor:
448
- residual = x
449
- x_norm = self.pre_sa_norm(x)
450
-
451
- sa_out = self.self_attention(
452
- Xq=x_norm, # (2, 1, D)
453
- Xkv=x_norm, # (2, 1, D)
454
- q_positions=state.dec_positions, # (2, 1)
455
- kv_positions=state.dec_positions, # (2, 1)
456
- attn_mask=None,
457
- cache=self_attn_cache,
458
- prefill=prefill,
459
- is_causal=prefill,
460
- )
461
-
462
- x = residual + sa_out
463
-
464
- residual = x
465
- x_norm = self.pre_ca_norm(x)
466
- ca_out = self.cross_attention(
467
- Xq=x_norm,
468
- Xkv=state.enc_out,
469
- q_positions=state.dec_positions,
470
- kv_positions=state.enc_positions,
471
- attn_mask=state.dec_cross_attn_mask,
472
- cache=cross_attn_cache,
473
- )
474
- x = residual + ca_out
475
-
476
- residual = x
477
- x_norm = self.pre_mlp_norm(x)
478
- mlp_out = self.mlp(x_norm)
479
- x = residual + mlp_out
480
-
481
- return x
482
-
483
-
484
- class Decoder(nn.Module):
485
- """Transformer Decoder Stack using DenseGeneral."""
486
-
487
- def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
488
- super().__init__()
489
- self.config = config
490
- model_config = config.model
491
- dec_config = config.model.decoder
492
- data_config = config.data
493
- self.num_channels = data_config.channels
494
- self.num_layers = dec_config.n_layer
495
-
496
- self.embeddings = nn.ModuleList(
497
- [
498
- nn.Embedding(
499
- model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype
500
- )
501
- for _ in range(self.num_channels)
502
- ]
503
- )
504
- self.layers = nn.ModuleList(
505
- [
506
- DecoderLayer(config=config, compute_dtype=compute_dtype)
507
- for _ in range(self.num_layers)
508
- ]
509
- )
510
-
511
- self.norm = RMSNorm(
512
- dec_config.n_embd,
513
- eps=model_config.normalization_layer_epsilon,
514
- dtype=torch.float32,
515
- )
516
-
517
- self.logits_dense = DenseGeneral(
518
- in_shapes=(dec_config.n_embd,),
519
- out_features=(self.num_channels, model_config.tgt_vocab_size),
520
- axis=(-1,),
521
- weight_dtype=compute_dtype,
522
- )
523
-
524
- def precompute_cross_attn_cache(
525
- self,
526
- enc_out: torch.Tensor, # (B, S, E)
527
- enc_positions: torch.Tensor, # (B, S)
528
- ) -> list[KVCache]:
529
- """
530
- Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
531
- """
532
- per_layer_kv_cache: list[KVCache] = []
533
-
534
- for layer in self.layers:
535
- cross_attn_module = layer.cross_attention
536
- k_proj = cross_attn_module.k_proj(enc_out)
537
- v_proj = cross_attn_module.v_proj(enc_out)
538
-
539
- k_proj = cross_attn_module.rotary_emb(k_proj, position=enc_positions)
540
- k = k_proj.transpose(1, 2)
541
- v = v_proj.transpose(1, 2)
542
-
543
- per_layer_kv_cache.append(KVCache.from_kv(k, v))
544
-
545
- return per_layer_kv_cache
546
-
547
- def decode_step(
548
- self,
549
- tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
550
- state: DecoderInferenceState,
551
- ) -> torch.Tensor:
552
- """
553
- Performs a single decoding step, managing KV caches layer by layer.
554
-
555
- Returns:
556
- A tuple containing:
557
- - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
558
- """
559
-
560
- x = None
561
- for i in range(self.num_channels):
562
- channel_tokens = tgt_ids_Bx1xC[..., i]
563
- channel_embed = self.embeddings[i](channel_tokens)
564
- x = channel_embed if x is None else x + channel_embed
565
-
566
- for i, layer in enumerate(self.layers):
567
- self_cache = state.self_attn_cache[i]
568
- cross_cache = state.cross_attn_cache[i]
569
- x = layer(
570
- x, # (2, 1, D)
571
- state,
572
- self_attn_cache=self_cache,
573
- cross_attn_cache=cross_cache,
574
- )
575
-
576
- x = self.norm(x)
577
- logits_Bx1xCxV = self.logits_dense(x)
578
-
579
- return logits_Bx1xCxV.to(torch.float32)
580
-
581
- def forward(
582
- self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState
583
- ) -> torch.Tensor:
584
- """
585
- Forward pass for the Decoder stack, managing KV caches.
586
-
587
- Args:
588
- tgt_ids_BxTxC: Target token IDs (B, T, C).
589
- encoder_out: Output from the encoder (B, S, E).
590
- tgt_positions: Positions for target sequence (B, T).
591
- src_positions: Positions for source sequence (B, S).
592
- self_attn_mask: Mask for self-attention.
593
- cross_attn_mask: Mask for cross-attention.
594
- past_key_values: List containing the self-attention KV cache for each layer
595
- from the previous decoding step. `len(past_key_values)` should
596
- equal `num_layers`.
597
- precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
598
- derived from `encoder_out`. This is passed identically
599
- to all layers.
600
-
601
- Returns:
602
- A tuple containing:
603
- - logits: The final output logits (B, T, C * V), cast to float32.
604
- - present_key_values: A list containing the updated self-attention KV cache
605
- for each layer for the *current* decoding step.
606
- """
607
- _, _, num_channels_in = tgt_ids_BxTxC.shape
608
- assert num_channels_in == self.num_channels, "Input channels mismatch"
609
-
610
- # Embeddings
611
- x = None
612
- for i in range(self.num_channels):
613
- channel_tokens = tgt_ids_BxTxC[..., i]
614
- channel_embed = self.embeddings[i](channel_tokens)
615
- x = channel_embed if x is None else x + channel_embed
616
-
617
- for i, layer in enumerate(self.layers):
618
- self_cache = state.self_attn_cache[i]
619
- cross_cache = state.cross_attn_cache[i]
620
- x = layer(
621
- x,
622
- state,
623
- self_attn_cache=self_cache,
624
- cross_attn_cache=cross_cache,
625
- prefill=True,
626
- )
627
-
628
- # Final Norm
629
- x = self.norm(x)
630
- logits_BxTxCxV = self.logits_dense(x)
631
-
632
- return logits_BxTxCxV.to(torch.float32)
633
-
634
-
635
- class DiaModel(nn.Module):
636
- """PyTorch Dia Model using DenseGeneral."""
637
-
638
- def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
639
- super().__init__()
640
- self.config = config
641
- self.encoder = Encoder(config, compute_dtype)
642
- self.decoder = Decoder(config, compute_dtype)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+ from torch.nn import RMSNorm
6
+
7
+ from config import DiaConfig
8
+ from state import DecoderInferenceState, EncoderInferenceState, KVCache
9
+
10
+
11
+ def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
12
+ return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
13
+
14
+
15
+ class DenseGeneral(nn.Module):
16
+ """
17
+ PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
18
+
19
+ Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
20
+ for the generalized matrix multiplication. Weight/bias shapes are calculated
21
+ and parameters created during initialization based on config.
22
+ `load_weights` validates shapes and copies data.
23
+
24
+ Attributes:
25
+ axis (Tuple[int, ...]): Input axis or axes to contract.
26
+ in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
27
+ out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
28
+ use_bias (bool): Whether to add a bias term.
29
+ weight (nn.Parameter): The kernel parameter.
30
+ bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ in_shapes: tuple[int, ...],
36
+ out_features: tuple[int, ...],
37
+ axis: tuple[int, ...] = (-1,),
38
+ weight_dtype: torch.dtype | None = None,
39
+ device: torch.device | None = None,
40
+ ):
41
+ super().__init__()
42
+ self.in_shapes = in_shapes
43
+ self.out_features = out_features
44
+ self.axis = axis
45
+ self.kernel_shape = self.in_shapes + self.out_features
46
+
47
+ factory_kwargs = {"device": device, "dtype": weight_dtype}
48
+ self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
49
+ self.register_parameter("bias", None)
50
+
51
+ def forward(self, inputs: Tensor) -> Tensor:
52
+ norm_axis = _normalize_axes(self.axis, inputs.ndim)
53
+ kernel_contract_axes = tuple(range(len(norm_axis)))
54
+
55
+ output = torch.tensordot(
56
+ inputs.to(self.weight.dtype),
57
+ self.weight,
58
+ dims=(norm_axis, kernel_contract_axes),
59
+ ).to(inputs.dtype)
60
+ return output
61
+
62
+
63
+ class MlpBlock(nn.Module):
64
+ """MLP block using DenseGeneral."""
65
+
66
+ def __init__(
67
+ self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype
68
+ ):
69
+ super().__init__()
70
+ self.dtype = compute_dtype
71
+
72
+ self.wi_fused = DenseGeneral(
73
+ in_shapes=(embed_dim,),
74
+ out_features=(2, intermediate_dim),
75
+ axis=(-1,),
76
+ weight_dtype=compute_dtype,
77
+ )
78
+
79
+ self.wo = DenseGeneral(
80
+ in_shapes=(intermediate_dim,),
81
+ out_features=(embed_dim,),
82
+ axis=(-1,),
83
+ weight_dtype=compute_dtype,
84
+ )
85
+
86
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
87
+ """Forward pass."""
88
+ fused_x = self.wi_fused(x)
89
+
90
+ gate = fused_x[..., 0, :]
91
+ up = fused_x[..., 1, :]
92
+
93
+ hidden = torch.mul(F.silu(gate), up).to(self.dtype)
94
+
95
+ output = self.wo(hidden)
96
+ return output
97
+
98
+
99
+ class RotaryEmbedding(nn.Module):
100
+ """Rotary Position Embedding (RoPE) implementation in PyTorch."""
101
+
102
+ def __init__(
103
+ self,
104
+ embedding_dims: int,
105
+ min_timescale: int = 1,
106
+ max_timescale: int = 10000,
107
+ dtype: torch.dtype = torch.float32,
108
+ ):
109
+ super().__init__()
110
+ if embedding_dims % 2 != 0:
111
+ raise ValueError("Embedding dim must be even for RoPE.")
112
+ self.embedding_dims = embedding_dims
113
+ self.min_timescale = min_timescale
114
+ self.max_timescale = max_timescale
115
+ self.dtype = dtype
116
+
117
+ half_embedding_dim = embedding_dims // 2
118
+ fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
119
+ self.register_buffer(
120
+ "timescale",
121
+ self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction,
122
+ persistent=False,
123
+ )
124
+
125
+ def extra_repr(self) -> str:
126
+ s = f"{self.timescale.shape}"
127
+ return s
128
+
129
+ def forward(self, inputs: torch.Tensor, position: torch.Tensor):
130
+ """Applies RoPE."""
131
+ position = position.unsqueeze(-1).unsqueeze(-1)
132
+ timescale = self.timescale.to(inputs.device)
133
+ sinusoid_inp = position / timescale
134
+ sin = torch.sin(sinusoid_inp).to(inputs.dtype)
135
+ cos = torch.cos(sinusoid_inp).to(inputs.dtype)
136
+ first_half, second_half = torch.chunk(inputs, 2, dim=-1)
137
+ first_part = first_half * cos - second_half * sin
138
+ second_part = second_half * cos + first_half * sin
139
+ return torch.cat((first_part, second_part), dim=-1)
140
+
141
+
142
+ class Attention(nn.Module):
143
+ """Attention using DenseGeneral."""
144
+
145
+ def __init__(
146
+ self,
147
+ config: DiaConfig,
148
+ q_embed_dim: int,
149
+ kv_embed_dim: int,
150
+ num_query_heads: int,
151
+ num_kv_heads: int,
152
+ head_dim: int,
153
+ compute_dtype: torch.dtype,
154
+ is_cross_attn: bool = False,
155
+ out_embed_dim: int | None = None,
156
+ ):
157
+ super().__init__()
158
+ self.num_query_heads = num_query_heads
159
+ self.num_kv_heads = num_kv_heads
160
+ self.head_dim = head_dim
161
+ self.is_cross_attn = is_cross_attn
162
+ self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
163
+ self.projected_query_dim = num_query_heads * head_dim
164
+ if num_query_heads % num_kv_heads != 0:
165
+ raise ValueError(
166
+ f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
167
+ )
168
+ self.num_gqa_groups = num_query_heads // num_kv_heads
169
+
170
+ # --- Projection Layers using DenseGeneral ---
171
+ self.q_proj = DenseGeneral(
172
+ in_shapes=(q_embed_dim,),
173
+ out_features=(num_query_heads, head_dim),
174
+ axis=(-1,),
175
+ weight_dtype=compute_dtype,
176
+ )
177
+ self.k_proj = DenseGeneral(
178
+ in_shapes=(kv_embed_dim,),
179
+ out_features=(num_kv_heads, head_dim),
180
+ axis=(-1,),
181
+ weight_dtype=compute_dtype,
182
+ )
183
+ self.v_proj = DenseGeneral(
184
+ in_shapes=(kv_embed_dim,),
185
+ out_features=(num_kv_heads, head_dim),
186
+ axis=(-1,),
187
+ weight_dtype=compute_dtype,
188
+ )
189
+ self.o_proj = DenseGeneral(
190
+ in_shapes=(num_query_heads, head_dim),
191
+ out_features=(self.output_dim,),
192
+ axis=(-2, -1),
193
+ weight_dtype=compute_dtype,
194
+ )
195
+
196
+ # --- Rotary Embedding ---
197
+ self.rotary_emb = RotaryEmbedding(
198
+ embedding_dims=self.head_dim,
199
+ min_timescale=config.model.rope_min_timescale,
200
+ max_timescale=config.model.rope_max_timescale,
201
+ dtype=compute_dtype,
202
+ )
203
+
204
+ def forward(
205
+ self,
206
+ Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
207
+ Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
208
+ q_positions: torch.Tensor, # (B, T)
209
+ kv_positions: torch.Tensor | None = None, # (B, S)
210
+ attn_mask: torch.Tensor
211
+ | None = None, # None in Decoder Self Attention, Valid mask in Others
212
+ cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
213
+ prefill: bool = False,
214
+ is_causal: bool = False,
215
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
216
+ """
217
+ Performs attention calculation with optional KV caching.
218
+
219
+ Args:
220
+ Xq: Query tensor (B, T, D). T=1 during single-step decoding.
221
+ Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
222
+ q_positions: Positions for queries (B, T).
223
+ kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
224
+ attn_mask: Attention mask.
225
+ cache: KVCache.
226
+ prefill: If True, use prefill mode.
227
+
228
+ Returns:
229
+ A tuple containing:
230
+ - output: The attention output tensor (B, T, output_dim).
231
+ - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
232
+ """
233
+ if kv_positions is None:
234
+ kv_positions = q_positions
235
+ original_dtype = Xq.dtype
236
+
237
+ Xq_BxTxNxH = self.q_proj(Xq)
238
+ Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
239
+ Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
240
+
241
+ attn_k: torch.Tensor | None = None
242
+ attn_v: torch.Tensor | None = None
243
+
244
+ if self.is_cross_attn:
245
+ attn_k, attn_v = cache.k, cache.v
246
+ else:
247
+ Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
248
+ Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
249
+ Xk_BxSxKxH = self.rotary_emb(
250
+ Xk_BxSxKxH, position=kv_positions
251
+ ) # (B, S, K, H)
252
+
253
+ Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
254
+ Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
255
+
256
+ if cache is None:
257
+ attn_k = Xk_BxKxSxH
258
+ attn_v = Xv_BxKxSxH
259
+ else:
260
+ if prefill:
261
+ attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
262
+ cache.prefill(attn_k, attn_v)
263
+ else:
264
+ attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH)
265
+
266
+ attn_output = F.scaled_dot_product_attention(
267
+ Xq_BxNxTxH,
268
+ attn_k,
269
+ attn_v,
270
+ attn_mask=attn_mask,
271
+ scale=1.0,
272
+ enable_gqa=self.num_gqa_groups > 1,
273
+ is_causal=is_causal,
274
+ )
275
+
276
+ attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
277
+ output = self.o_proj(attn_output)
278
+
279
+ return output.to(original_dtype)
280
+
281
+
282
+ class EncoderLayer(nn.Module):
283
+ """Transformer Encoder Layer using DenseGeneral."""
284
+
285
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
286
+ super().__init__()
287
+ self.config = config
288
+ model_config = config.model
289
+ enc_config = config.model.encoder
290
+ embed_dim = enc_config.n_embd
291
+
292
+ self.pre_sa_norm = RMSNorm(
293
+ embed_dim,
294
+ eps=model_config.normalization_layer_epsilon,
295
+ dtype=torch.float32,
296
+ )
297
+ self.self_attention = Attention(
298
+ config,
299
+ q_embed_dim=embed_dim,
300
+ kv_embed_dim=embed_dim,
301
+ num_query_heads=enc_config.n_head,
302
+ num_kv_heads=enc_config.n_head,
303
+ head_dim=enc_config.head_dim,
304
+ compute_dtype=compute_dtype,
305
+ is_cross_attn=False,
306
+ out_embed_dim=embed_dim,
307
+ )
308
+ self.post_sa_norm = RMSNorm(
309
+ embed_dim,
310
+ eps=model_config.normalization_layer_epsilon,
311
+ dtype=torch.float32,
312
+ )
313
+ self.mlp = MlpBlock(
314
+ embed_dim=embed_dim,
315
+ intermediate_dim=enc_config.n_hidden,
316
+ compute_dtype=compute_dtype,
317
+ )
318
+
319
+ def forward(
320
+ self,
321
+ x: torch.Tensor,
322
+ state: EncoderInferenceState,
323
+ ) -> torch.Tensor:
324
+ residual = x
325
+ x_norm = self.pre_sa_norm(x)
326
+ sa_out = self.self_attention(
327
+ Xq=x_norm,
328
+ Xkv=x_norm,
329
+ q_positions=state.positions,
330
+ kv_positions=state.positions,
331
+ attn_mask=state.attn_mask,
332
+ )
333
+ x = residual + sa_out
334
+
335
+ residual = x
336
+ x_norm = self.post_sa_norm(x)
337
+ mlp_out = self.mlp(x_norm)
338
+ x = residual + mlp_out
339
+
340
+ return x
341
+
342
+
343
+ class Encoder(nn.Module):
344
+ """Transformer Encoder Stack using DenseGeneral."""
345
+
346
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
347
+ super().__init__()
348
+ self.config = config
349
+ model_config = config.model
350
+ enc_config = config.model.encoder
351
+
352
+ self.embedding = nn.Embedding(
353
+ model_config.src_vocab_size,
354
+ enc_config.n_embd,
355
+ dtype=compute_dtype,
356
+ )
357
+ self.layers = nn.ModuleList(
358
+ [EncoderLayer(config, compute_dtype) for _ in range(enc_config.n_layer)]
359
+ )
360
+ self.norm = RMSNorm(
361
+ enc_config.n_embd,
362
+ eps=model_config.normalization_layer_epsilon,
363
+ dtype=torch.float32,
364
+ )
365
+
366
+ def forward(
367
+ self,
368
+ x_ids: torch.Tensor,
369
+ state: EncoderInferenceState,
370
+ ) -> torch.Tensor:
371
+ x = self.embedding(x_ids)
372
+
373
+ for layer in self.layers:
374
+ x = layer(x, state)
375
+
376
+ x = self.norm(x)
377
+ return x
378
+
379
+
380
+ class DecoderLayer(nn.Module):
381
+ """Transformer Decoder Layer using DenseGeneral."""
382
+
383
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
384
+ super().__init__()
385
+ self.config = config
386
+ model_config = config.model
387
+ dec_config = config.model.decoder
388
+ enc_config = config.model.encoder
389
+ dec_embed_dim = dec_config.n_embd
390
+ enc_embed_dim = enc_config.n_embd
391
+
392
+ # Norms
393
+ self.pre_sa_norm = RMSNorm(
394
+ dec_embed_dim,
395
+ eps=model_config.normalization_layer_epsilon,
396
+ dtype=torch.float32,
397
+ )
398
+ self.pre_ca_norm = RMSNorm(
399
+ dec_embed_dim,
400
+ eps=model_config.normalization_layer_epsilon,
401
+ dtype=torch.float32,
402
+ )
403
+ self.pre_mlp_norm = RMSNorm(
404
+ dec_embed_dim,
405
+ eps=model_config.normalization_layer_epsilon,
406
+ dtype=torch.float32,
407
+ )
408
+
409
+ # Self-Attention (GQA) with Causal Masking
410
+ self.self_attention = Attention(
411
+ config,
412
+ q_embed_dim=dec_embed_dim,
413
+ kv_embed_dim=dec_embed_dim,
414
+ num_query_heads=dec_config.gqa_query_heads,
415
+ num_kv_heads=dec_config.kv_heads,
416
+ head_dim=dec_config.gqa_head_dim,
417
+ compute_dtype=compute_dtype,
418
+ is_cross_attn=False,
419
+ out_embed_dim=dec_embed_dim,
420
+ )
421
+ # Cross-Attention (MHA)
422
+ self.cross_attention = Attention(
423
+ config=config,
424
+ q_embed_dim=dec_embed_dim,
425
+ kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
426
+ num_query_heads=dec_config.cross_query_heads,
427
+ num_kv_heads=dec_config.cross_query_heads,
428
+ head_dim=dec_config.cross_head_dim,
429
+ compute_dtype=compute_dtype,
430
+ is_cross_attn=True,
431
+ out_embed_dim=dec_embed_dim,
432
+ )
433
+ # MLP
434
+ self.mlp = MlpBlock(
435
+ embed_dim=dec_embed_dim,
436
+ intermediate_dim=dec_config.n_hidden,
437
+ compute_dtype=compute_dtype,
438
+ )
439
+
440
+ def forward(
441
+ self,
442
+ x: torch.Tensor,
443
+ state: DecoderInferenceState,
444
+ self_attn_cache: KVCache | None = None,
445
+ cross_attn_cache: KVCache | None = None,
446
+ prefill: bool = False,
447
+ ) -> torch.Tensor:
448
+ residual = x
449
+ x_norm = self.pre_sa_norm(x)
450
+
451
+ sa_out = self.self_attention(
452
+ Xq=x_norm, # (2, 1, D)
453
+ Xkv=x_norm, # (2, 1, D)
454
+ q_positions=state.dec_positions, # (2, 1)
455
+ kv_positions=state.dec_positions, # (2, 1)
456
+ attn_mask=None,
457
+ cache=self_attn_cache,
458
+ prefill=prefill,
459
+ is_causal=prefill,
460
+ )
461
+
462
+ x = residual + sa_out
463
+
464
+ residual = x
465
+ x_norm = self.pre_ca_norm(x)
466
+ ca_out = self.cross_attention(
467
+ Xq=x_norm,
468
+ Xkv=state.enc_out,
469
+ q_positions=state.dec_positions,
470
+ kv_positions=state.enc_positions,
471
+ attn_mask=state.dec_cross_attn_mask,
472
+ cache=cross_attn_cache,
473
+ )
474
+ x = residual + ca_out
475
+
476
+ residual = x
477
+ x_norm = self.pre_mlp_norm(x)
478
+ mlp_out = self.mlp(x_norm)
479
+ x = residual + mlp_out
480
+
481
+ return x
482
+
483
+
484
+ class Decoder(nn.Module):
485
+ """Transformer Decoder Stack using DenseGeneral."""
486
+
487
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
488
+ super().__init__()
489
+ self.config = config
490
+ model_config = config.model
491
+ dec_config = config.model.decoder
492
+ data_config = config.data
493
+ self.num_channels = data_config.channels
494
+ self.num_layers = dec_config.n_layer
495
+
496
+ self.embeddings = nn.ModuleList(
497
+ [
498
+ nn.Embedding(
499
+ model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype
500
+ )
501
+ for _ in range(self.num_channels)
502
+ ]
503
+ )
504
+ self.layers = nn.ModuleList(
505
+ [
506
+ DecoderLayer(config=config, compute_dtype=compute_dtype)
507
+ for _ in range(self.num_layers)
508
+ ]
509
+ )
510
+
511
+ self.norm = RMSNorm(
512
+ dec_config.n_embd,
513
+ eps=model_config.normalization_layer_epsilon,
514
+ dtype=torch.float32,
515
+ )
516
+
517
+ self.logits_dense = DenseGeneral(
518
+ in_shapes=(dec_config.n_embd,),
519
+ out_features=(self.num_channels, model_config.tgt_vocab_size),
520
+ axis=(-1,),
521
+ weight_dtype=compute_dtype,
522
+ )
523
+
524
+ def precompute_cross_attn_cache(
525
+ self,
526
+ enc_out: torch.Tensor, # (B, S, E)
527
+ enc_positions: torch.Tensor, # (B, S)
528
+ ) -> list[KVCache]:
529
+ """
530
+ Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
531
+ """
532
+ per_layer_kv_cache: list[KVCache] = []
533
+
534
+ for layer in self.layers:
535
+ cross_attn_module = layer.cross_attention
536
+ k_proj = cross_attn_module.k_proj(enc_out)
537
+ v_proj = cross_attn_module.v_proj(enc_out)
538
+
539
+ k_proj = cross_attn_module.rotary_emb(k_proj, position=enc_positions)
540
+ k = k_proj.transpose(1, 2)
541
+ v = v_proj.transpose(1, 2)
542
+
543
+ per_layer_kv_cache.append(KVCache.from_kv(k, v))
544
+
545
+ return per_layer_kv_cache
546
+
547
+ def decode_step(
548
+ self,
549
+ tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
550
+ state: DecoderInferenceState,
551
+ ) -> torch.Tensor:
552
+ """
553
+ Performs a single decoding step, managing KV caches layer by layer.
554
+
555
+ Returns:
556
+ A tuple containing:
557
+ - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
558
+ """
559
+
560
+ x = None
561
+ for i in range(self.num_channels):
562
+ channel_tokens = tgt_ids_Bx1xC[..., i]
563
+ channel_embed = self.embeddings[i](channel_tokens)
564
+ x = channel_embed if x is None else x + channel_embed
565
+
566
+ for i, layer in enumerate(self.layers):
567
+ self_cache = state.self_attn_cache[i]
568
+ cross_cache = state.cross_attn_cache[i]
569
+ x = layer(
570
+ x, # (2, 1, D)
571
+ state,
572
+ self_attn_cache=self_cache,
573
+ cross_attn_cache=cross_cache,
574
+ )
575
+
576
+ x = self.norm(x)
577
+ logits_Bx1xCxV = self.logits_dense(x)
578
+
579
+ return logits_Bx1xCxV.to(torch.float32)
580
+
581
+ def forward(
582
+ self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState
583
+ ) -> torch.Tensor:
584
+ """
585
+ Forward pass for the Decoder stack, managing KV caches.
586
+
587
+ Args:
588
+ tgt_ids_BxTxC: Target token IDs (B, T, C).
589
+ encoder_out: Output from the encoder (B, S, E).
590
+ tgt_positions: Positions for target sequence (B, T).
591
+ src_positions: Positions for source sequence (B, S).
592
+ self_attn_mask: Mask for self-attention.
593
+ cross_attn_mask: Mask for cross-attention.
594
+ past_key_values: List containing the self-attention KV cache for each layer
595
+ from the previous decoding step. `len(past_key_values)` should
596
+ equal `num_layers`.
597
+ precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
598
+ derived from `encoder_out`. This is passed identically
599
+ to all layers.
600
+
601
+ Returns:
602
+ A tuple containing:
603
+ - logits: The final output logits (B, T, C * V), cast to float32.
604
+ - present_key_values: A list containing the updated self-attention KV cache
605
+ for each layer for the *current* decoding step.
606
+ """
607
+ _, _, num_channels_in = tgt_ids_BxTxC.shape
608
+ assert num_channels_in == self.num_channels, "Input channels mismatch"
609
+
610
+ # Embeddings
611
+ x = None
612
+ for i in range(self.num_channels):
613
+ channel_tokens = tgt_ids_BxTxC[..., i]
614
+ channel_embed = self.embeddings[i](channel_tokens)
615
+ x = channel_embed if x is None else x + channel_embed
616
+
617
+ for i, layer in enumerate(self.layers):
618
+ self_cache = state.self_attn_cache[i]
619
+ cross_cache = state.cross_attn_cache[i]
620
+ x = layer(
621
+ x,
622
+ state,
623
+ self_attn_cache=self_cache,
624
+ cross_attn_cache=cross_cache,
625
+ prefill=True,
626
+ )
627
+
628
+ # Final Norm
629
+ x = self.norm(x)
630
+ logits_BxTxCxV = self.logits_dense(x)
631
+
632
+ return logits_BxTxCxV.to(torch.float32)
633
+
634
+
635
+ class DiaModel(nn.Module):
636
+ """PyTorch Dia Model using DenseGeneral."""
637
+
638
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
639
+ super().__init__()
640
+ self.config = config
641
+ self.encoder = Encoder(config, compute_dtype)
642
+ self.decoder = Decoder(config, compute_dtype)