akrao9 commited on
Commit
659ddad
·
verified ·
1 Parent(s): 1ef675a

Upload modeling_boomer_fla.py

Browse files
Files changed (1) hide show
  1. transformer/modeling_boomer_fla.py +1267 -0
transformer/modeling_boomer_fla.py ADDED
@@ -0,0 +1,1267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BoomerFLADiT model — self-contained for HuggingFace trust_remote_code distribution.
2
+
3
+ All dependencies inlined: no boomer package import needed.
4
+ External pip requirements: torch, flash-linear-attention (fla).
5
+ """
6
+ # ── inlined from boomer/models/latent_dit.py ──────────────────────────────────
7
+ from __future__ import annotations
8
+ import math
9
+ import sys
10
+ import types
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.checkpoint import checkpoint as _ckpt
18
+
19
+
20
+ class AttentionRMSNorm(nn.Module):
21
+ def __init__(self, dim: int, scale_factor: float = 0.01, eps: float = 1e-6) -> None:
22
+ super().__init__()
23
+ self.eps = eps
24
+ self.weight = nn.Parameter(torch.ones(dim) * scale_factor)
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ normed = x.float() * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
28
+ weight = self.weight.view(*([1] * (x.ndim - 2)), -1)
29
+ return (weight * normed).type_as(x)
30
+
31
+
32
+ class CaptionEmbedder(nn.Module):
33
+ def __init__(self, in_channels: int, hidden_size: int, token_num: int) -> None:
34
+ super().__init__()
35
+ self.y_proj = nn.Sequential(
36
+ nn.Linear(in_channels, hidden_size),
37
+ nn.GELU(approximate="tanh"),
38
+ nn.Linear(hidden_size, hidden_size),
39
+ )
40
+ null_init = torch.randn(token_num, in_channels) / math.sqrt(in_channels)
41
+ self.null_text_embedding = nn.Parameter(null_init.unsqueeze(0))
42
+
43
+ def forward(self, caption: torch.Tensor) -> torch.Tensor:
44
+ return self.y_proj(caption)
45
+
46
+ def null_condition(self, batch_size, *, device, dtype, mask_dtype=None, token_num=None):
47
+ text = self.null_text_embedding
48
+ if token_num is not None and token_num != text.shape[1]:
49
+ if token_num < text.shape[1]:
50
+ text = text[:, :token_num]
51
+ else:
52
+ pad = text.new_zeros(text.shape[0], token_num - text.shape[1], text.shape[2])
53
+ text = torch.cat([text, pad], dim=1)
54
+ text = text.expand(batch_size, -1, -1).to(device=device, dtype=dtype)
55
+ mask = torch.ones(batch_size, text.shape[1], device=device, dtype=mask_dtype or torch.long)
56
+ if token_num is not None and token_num > self.null_text_embedding.shape[1]:
57
+ mask[:, self.null_text_embedding.shape[1]:] = 0
58
+ return text, mask
59
+
60
+
61
+ class TimestepEmbedder(nn.Module):
62
+ def __init__(self, hidden_dim: int) -> None:
63
+ super().__init__()
64
+ self.net = nn.Sequential(nn.Linear(1, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim))
65
+
66
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
67
+ dtype = self.net[0].weight.dtype
68
+ return self.net(timesteps.to(dtype=dtype).view(-1, 1))
69
+
70
+
71
+ # ── rest of boomer_fla_dit.py below (unchanged except no boomer imports) ──────
72
+
73
+
74
+
75
+ @dataclass(frozen=True)
76
+ class BoomerFLADiTConfig:
77
+ model_type: str = "boomer_fla"
78
+ latent_channels: int = 32
79
+ latent_size: int = 16
80
+ text_dim: int = 1536
81
+ text_seq_len: int = 300
82
+ hidden_dim: int = 1152
83
+ depth: int = 28
84
+ num_heads: int = 16
85
+ mlp_ratio: float = 2.5
86
+ y_norm: bool = True
87
+ y_norm_scale_factor: float = 0.01
88
+ mixer_type: str = "fla_linear"
89
+ fla_mode: str = "chunk"
90
+ fla_feature_map: str = "relu"
91
+ fla_bidirectional: bool = False
92
+ use_short_conv: bool = False
93
+ conv_size: int = 4
94
+ image_attention_every: int = 0
95
+ image_attention_backend: str = "sdpa"
96
+ image_attention_rope: bool = False
97
+ image_rope_theta: float = 10000.0
98
+ cross_attention_backend: str = "sdpa"
99
+ cross_attention_qk_norm: bool = True
100
+ parallel_block: bool = False
101
+ dual_stream_depth: int = 0
102
+ multimodal_coord_ids: bool = False
103
+ use_abs_pos_embed: bool = True
104
+ patch_size: int = 1
105
+ gradient_checkpointing: bool = False
106
+
107
+
108
+ def maybe_add_sibling_fla_repo() -> None:
109
+ candidates = [
110
+ Path(__file__).resolve().parents[3] / "flash-linear-attention",
111
+ Path("/content/flash-linear-attention"),
112
+ Path("/content/flame"),
113
+ ]
114
+ for path in candidates:
115
+ if (path / "fla").is_dir() and str(path) not in sys.path:
116
+ sys.path.insert(0, str(path))
117
+
118
+
119
+ def maybe_add_sibling_flash_attention_repo() -> None:
120
+ candidates = [
121
+ Path(__file__).resolve().parents[3] / "flash-attention" / "hopper",
122
+ Path(__file__).resolve().parents[3] / "flash-attention",
123
+ Path("/work/flash-attention/hopper"),
124
+ Path("/work/flash-attention"),
125
+ Path("/home/jovyan/work/flash-attention"),
126
+ Path("/content/flash-attention/hopper"),
127
+ Path("/content/flash-attention"),
128
+ ]
129
+ for path in candidates:
130
+ if path.exists() and str(path) not in sys.path:
131
+ sys.path.insert(0, str(path))
132
+
133
+
134
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
135
+ return x * (1.0 + scale) + shift
136
+
137
+
138
+ class ConvLayer(nn.Module):
139
+ def __init__(
140
+ self,
141
+ in_dim: int,
142
+ out_dim: int,
143
+ kernel_size: int,
144
+ *,
145
+ groups: int = 1,
146
+ bias: bool = False,
147
+ act: str | None = None,
148
+ ) -> None:
149
+ super().__init__()
150
+ self.conv = nn.Conv2d(
151
+ in_dim,
152
+ out_dim,
153
+ kernel_size=kernel_size,
154
+ padding=kernel_size // 2,
155
+ groups=groups,
156
+ bias=bias,
157
+ )
158
+ self.act = nn.SiLU() if act == "silu" else nn.Identity()
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ return self.act(self.conv(x))
162
+
163
+
164
+ class GLUMBConv(nn.Module):
165
+ """Sana GLUMBConv FFN: 1x1 expand, depthwise spatial conv, GLU, 1x1 project."""
166
+
167
+ def __init__(self, hidden_dim: int, mlp_ratio: float) -> None:
168
+ super().__init__()
169
+ inner_dim = int(hidden_dim * mlp_ratio)
170
+ self.inner_dim = inner_dim
171
+ self.inverted_conv = ConvLayer(hidden_dim, inner_dim * 2, 1, bias=True, act="silu")
172
+ self.depth_conv = ConvLayer(inner_dim * 2, inner_dim * 2, 3, groups=inner_dim * 2, bias=True)
173
+ self.point_conv = ConvLayer(inner_dim, hidden_dim, 1, bias=False)
174
+ nn.init.zeros_(self.point_conv.conv.weight)
175
+ self.glu_act = nn.SiLU()
176
+
177
+ def forward(self, x: torch.Tensor, *, height: int, width: int) -> torch.Tensor:
178
+ batch, tokens, channels = x.shape
179
+ if tokens != height * width:
180
+ raise ValueError(f"Expected {height * width} image tokens, got {tokens}")
181
+ x = x.reshape(batch, height, width, channels).permute(0, 3, 1, 2).contiguous()
182
+ x = self.inverted_conv(x)
183
+ x = self.depth_conv(x)
184
+ x, gate = x.chunk(2, dim=1)
185
+ x = x * self.glu_act(gate)
186
+ x = self.point_conv(x)
187
+ return x.reshape(batch, channels, tokens).transpose(1, 2).contiguous()
188
+
189
+
190
+ class TorchSelfAttention(nn.Module):
191
+ def __init__(self, hidden_dim: int, num_heads: int) -> None:
192
+ super().__init__()
193
+ self.attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
194
+
195
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
196
+ return self.attn(x, x, x, need_weights=False)[0]
197
+
198
+
199
+ class TokenMLP(nn.Module):
200
+ def __init__(self, hidden_dim: int, mlp_ratio: float) -> None:
201
+ super().__init__()
202
+ inner_dim = int(hidden_dim * mlp_ratio)
203
+ self.net = nn.Sequential(
204
+ nn.Linear(hidden_dim, inner_dim),
205
+ nn.GELU(approximate="tanh"),
206
+ nn.Linear(inner_dim, hidden_dim),
207
+ )
208
+
209
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
210
+ return self.net(x)
211
+
212
+
213
+ class MultimodalCoordinateRoPE(nn.Module):
214
+ """FLUX-style coordinate-ID RoPE for joint text/image attention."""
215
+
216
+ def __init__(self, head_dim: int, *, image_size: int, text_seq_len: int, theta: float = 10000.0) -> None:
217
+ super().__init__()
218
+ if head_dim < 6 or head_dim % 2 != 0:
219
+ raise ValueError(f"head_dim={head_dim} must be even and at least 6 for multimodal RoPE")
220
+ if theta <= 0.0:
221
+ raise ValueError(f"theta must be positive, got {theta}")
222
+ type_dim = max(2, (head_dim // 4) // 2 * 2)
223
+ while type_dim > 2 and (head_dim - type_dim) % 4 != 0:
224
+ type_dim -= 2
225
+ remaining = head_dim - type_dim
226
+ row_dim = max(2, (remaining // 2) // 2 * 2)
227
+ col_dim = remaining - row_dim
228
+ if col_dim < 2 or col_dim % 2 != 0:
229
+ raise ValueError(f"could not split head_dim={head_dim} into even multimodal RoPE axes")
230
+ self.axes_dim = (type_dim, row_dim, col_dim)
231
+ self.head_dim = head_dim
232
+ self.image_size = image_size
233
+ self.text_seq_len = text_seq_len
234
+
235
+ for index, dim in enumerate(self.axes_dim):
236
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
237
+ self.register_buffer(f"inv_freq_{index}", inv_freq, persistent=False)
238
+
239
+ @staticmethod
240
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
241
+ x1, x2 = x.chunk(2, dim=-1)
242
+ return torch.cat((-x2, x1), dim=-1)
243
+
244
+ def image_ids(self, batch_size: int, *, height: int, width: int, device: torch.device | str) -> torch.Tensor:
245
+ token_idx = torch.arange(height * width, device=device)
246
+ rows = token_idx // width
247
+ cols = token_idx % width
248
+ token_type = torch.ones_like(rows)
249
+ ids = torch.stack([token_type, rows, cols], dim=-1)
250
+ return ids.unsqueeze(0).expand(batch_size, -1, -1)
251
+
252
+ def text_ids(self, batch_size: int, token_count: int, *, device: torch.device | str) -> torch.Tensor:
253
+ positions = torch.arange(token_count, device=device)
254
+ token_type = torch.zeros_like(positions)
255
+ zeros = torch.zeros_like(positions)
256
+ ids = torch.stack([token_type, positions, zeros], dim=-1)
257
+ return ids.unsqueeze(0).expand(batch_size, -1, -1)
258
+
259
+ def _axis_apply(self, x: torch.Tensor, axis_ids: torch.Tensor, axis_index: int) -> torch.Tensor:
260
+ inv_freq = getattr(self, f"inv_freq_{axis_index}")
261
+ angles = axis_ids.float().unsqueeze(-1) * inv_freq.to(device=x.device).view(1, 1, -1)
262
+ cos = torch.cat([angles.cos(), angles.cos()], dim=-1).unsqueeze(2).to(dtype=x.dtype)
263
+ sin = torch.cat([angles.sin(), angles.sin()], dim=-1).unsqueeze(2).to(dtype=x.dtype)
264
+ return x * cos + self._rotate_half(x) * sin
265
+
266
+ def apply(
267
+ self,
268
+ q: torch.Tensor,
269
+ k: torch.Tensor,
270
+ ids: torch.Tensor,
271
+ ) -> tuple[torch.Tensor, torch.Tensor]:
272
+ if q.shape[-1] != self.head_dim or k.shape[-1] != self.head_dim:
273
+ raise ValueError(f"expected head_dim={self.head_dim}, got q={q.shape[-1]} k={k.shape[-1]}")
274
+ if ids.shape[:2] != q.shape[:2] or ids.shape[-1] != len(self.axes_dim):
275
+ raise ValueError(f"expected ids shape (B, T, {len(self.axes_dim)}), got {tuple(ids.shape)}")
276
+ q_chunks = q.split(self.axes_dim, dim=-1)
277
+ k_chunks = k.split(self.axes_dim, dim=-1)
278
+ q_out = []
279
+ k_out = []
280
+ for index, (q_axis, k_axis) in enumerate(zip(q_chunks, k_chunks, strict=True)):
281
+ q_out.append(self._axis_apply(q_axis, ids[..., index], index))
282
+ k_out.append(self._axis_apply(k_axis, ids[..., index], index))
283
+ return torch.cat(q_out, dim=-1), torch.cat(k_out, dim=-1)
284
+
285
+
286
+ class RoPE2D(nn.Module):
287
+ """2D RoPE for image tokens on a fixed H×W grid (row-major flattening).
288
+
289
+ Splits head_dim in half: the first half encodes height, the second width.
290
+ Each half uses standard 1D RoPE with shared cos/sin tables per axis.
291
+ """
292
+
293
+ def __init__(self, head_dim: int, grid_size: int, *, theta: float = 10000.0) -> None:
294
+ super().__init__()
295
+ if head_dim % 4 != 0:
296
+ raise ValueError(
297
+ f"head_dim={head_dim} must be divisible by 4 for 2D RoPE "
298
+ f"(half for H, half for W, each needing pairs)"
299
+ )
300
+ if grid_size <= 0:
301
+ raise ValueError(f"grid_size must be positive, got {grid_size}")
302
+ if theta <= 0.0:
303
+ raise ValueError(f"theta must be positive, got {theta}")
304
+ self.head_dim = head_dim
305
+ self.grid_size = grid_size
306
+ self.half_dim = head_dim // 2
307
+
308
+ freqs = 1.0 / (theta ** (torch.arange(0, self.half_dim, 2).float() / self.half_dim))
309
+ token_idx = torch.arange(grid_size * grid_size)
310
+ h_idx = token_idx // grid_size
311
+ w_idx = token_idx % grid_size
312
+
313
+ def axis_tables(pos_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
314
+ angles = torch.outer(pos_idx.float(), freqs)
315
+ cos = torch.cat([angles.cos(), angles.cos()], dim=-1)[None, :, None, :]
316
+ sin = torch.cat([angles.sin(), angles.sin()], dim=-1)[None, :, None, :]
317
+ return cos, sin
318
+
319
+ cos_h, sin_h = axis_tables(h_idx)
320
+ cos_w, sin_w = axis_tables(w_idx)
321
+ self.register_buffer("cos_h", cos_h, persistent=False)
322
+ self.register_buffer("sin_h", sin_h, persistent=False)
323
+ self.register_buffer("cos_w", cos_w, persistent=False)
324
+ self.register_buffer("sin_w", sin_w, persistent=False)
325
+
326
+ @staticmethod
327
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
328
+ x1, x2 = x.chunk(2, dim=-1)
329
+ return torch.cat((-x2, x1), dim=-1)
330
+
331
+ def _apply_axis_rope(
332
+ self,
333
+ x: torch.Tensor,
334
+ cos: torch.Tensor,
335
+ sin: torch.Tensor,
336
+ ) -> torch.Tensor:
337
+ return x * cos.to(dtype=x.dtype) + self._rotate_half(x) * sin.to(dtype=x.dtype)
338
+
339
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
340
+ batch, tokens, num_heads, head_dim = q.shape
341
+ if head_dim != self.head_dim:
342
+ raise ValueError(f"expected head_dim={self.head_dim}, got {head_dim}")
343
+ expected_tokens = self.grid_size * self.grid_size
344
+ if tokens != expected_tokens:
345
+ raise ValueError(f"expected {expected_tokens} image tokens, got {tokens}")
346
+
347
+ q_h, q_w = q.chunk(2, dim=-1)
348
+ k_h, k_w = k.chunk(2, dim=-1)
349
+ q_h = self._apply_axis_rope(q_h, self.cos_h, self.sin_h)
350
+ q_w = self._apply_axis_rope(q_w, self.cos_w, self.sin_w)
351
+ k_h = self._apply_axis_rope(k_h, self.cos_h, self.sin_h)
352
+ k_w = self._apply_axis_rope(k_w, self.cos_w, self.sin_w)
353
+ return torch.cat([q_h, q_w], dim=-1), torch.cat([k_h, k_w], dim=-1)
354
+
355
+
356
+ class FullImageSelfAttention(nn.Module):
357
+ """Full image-token attention for the small DC-AE latent grid."""
358
+
359
+ def __init__(
360
+ self,
361
+ hidden_dim: int,
362
+ num_heads: int,
363
+ *,
364
+ backend: str = "sdpa",
365
+ grid_size: int | None = None,
366
+ rope: bool = False,
367
+ rope_theta: float = 10000.0,
368
+ ) -> None:
369
+ super().__init__()
370
+ if hidden_dim % num_heads != 0:
371
+ raise ValueError(f"hidden_dim={hidden_dim} must be divisible by num_heads={num_heads}")
372
+ if backend not in {"sdpa", "flash3", "flash4", "auto"}:
373
+ raise ValueError(f"Unsupported image_attention_backend: {backend}")
374
+ if rope and grid_size is None:
375
+ raise ValueError("grid_size is required when rope=True")
376
+ self.hidden_dim = hidden_dim
377
+ self.num_heads = num_heads
378
+ self.head_dim = hidden_dim // num_heads
379
+ self.backend = backend
380
+ self.qkv = nn.Linear(hidden_dim, hidden_dim * 3)
381
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim)
382
+ nn.init.zeros_(self.out_proj.weight)
383
+ nn.init.zeros_(self.out_proj.bias)
384
+ self.rope = (
385
+ RoPE2D(self.head_dim, grid_size, theta=rope_theta)
386
+ if rope and grid_size is not None
387
+ else None
388
+ )
389
+ self._flash3_attn_func = None
390
+ self._flash3_import_attempted = False
391
+ self._flash4_attn_func = None
392
+ self._flash4_import_attempted = False
393
+
394
+ def _get_flash3_attn_func(self):
395
+ if self._flash3_import_attempted:
396
+ return self._flash3_attn_func
397
+ self._flash3_import_attempted = True
398
+ maybe_add_sibling_flash_attention_repo()
399
+ try:
400
+ from flash_attn_interface import flash_attn_func
401
+ except Exception:
402
+ try:
403
+ from flash_attn.flash_attn_interface import flash_attn_func
404
+ except Exception:
405
+ flash_attn_func = None
406
+ self._flash3_attn_func = flash_attn_func
407
+ return self._flash3_attn_func
408
+
409
+ def _get_flash4_attn_func(self):
410
+ if self._flash4_import_attempted:
411
+ return self._flash4_attn_func
412
+ self._flash4_import_attempted = True
413
+ maybe_add_sibling_flash_attention_repo()
414
+ try:
415
+ from flash_attn.cute.interface import flash_attn_func
416
+ except Exception:
417
+ flash4_paths = [
418
+ Path(__file__).resolve().parents[3] / "flash-attention" / "flash_attn",
419
+ Path("/work/flash-attention/flash_attn"),
420
+ Path("/home/jovyan/work/flash-attention/flash_attn"),
421
+ Path("/content/flash-attention/flash_attn"),
422
+ ]
423
+ existing_paths = [str(path) for path in flash4_paths if (path / "cute").is_dir()]
424
+ if existing_paths:
425
+ for name in list(sys.modules):
426
+ if name == "flash_attn" or name.startswith("flash_attn."):
427
+ del sys.modules[name]
428
+ flash_attn_pkg = types.ModuleType("flash_attn")
429
+ flash_attn_pkg.__path__ = existing_paths
430
+ sys.modules["flash_attn"] = flash_attn_pkg
431
+ try:
432
+ from flash_attn.cute.interface import flash_attn_func
433
+ except Exception:
434
+ flash_attn_func = None
435
+ else:
436
+ flash_attn_func = None
437
+ self._flash4_attn_func = flash_attn_func
438
+ return self._flash4_attn_func
439
+
440
+ def _flash3_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
441
+ flash_attn_func = self._get_flash3_attn_func()
442
+ if flash_attn_func is None:
443
+ raise ImportError(
444
+ "image_attention_backend='flash3' requires FlashAttention-3. "
445
+ "Install it or use --image-attn-backend sdpa."
446
+ )
447
+ out = flash_attn_func(q, k, v, causal=False)
448
+ if isinstance(out, tuple):
449
+ out = out[0]
450
+ return out
451
+
452
+ def _flash4_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
453
+ flash_attn_func = self._get_flash4_attn_func()
454
+ if flash_attn_func is None:
455
+ raise ImportError(
456
+ "image_attention_backend='flash4' requires FlashAttention-4/CuTe. "
457
+ "Install flash-attn-4 or use --image-attn-backend sdpa."
458
+ )
459
+ out = flash_attn_func(q, k, v, causal=False)
460
+ if isinstance(out, tuple):
461
+ out = out[0]
462
+ return out
463
+
464
+ @staticmethod
465
+ def _flash_compute_dtype(x: torch.Tensor) -> torch.dtype | None:
466
+ """FA kernels need fp16/bf16; fp32 master weights + compile may still pass fp32 activations."""
467
+ if not x.is_cuda:
468
+ return None
469
+ if x.dtype in {torch.float16, torch.bfloat16}:
470
+ return x.dtype
471
+ if torch.is_autocast_enabled():
472
+ return torch.get_autocast_dtype("cuda")
473
+ return torch.bfloat16
474
+
475
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
476
+ batch, tokens, channels = x.shape
477
+ qkv = self.qkv(x).reshape(batch, tokens, 3, self.num_heads, self.head_dim)
478
+ q, k, v = qkv.unbind(dim=2)
479
+ if self.rope is not None:
480
+ q, k = self.rope(q, k)
481
+
482
+ flash_dtype = self._flash_compute_dtype(x)
483
+ use_flash = self.backend in {"flash3", "flash4", "auto"} and flash_dtype is not None
484
+ if use_flash and (q.dtype != flash_dtype or k.dtype != flash_dtype or v.dtype != flash_dtype):
485
+ q, k, v = q.to(flash_dtype), k.to(flash_dtype), v.to(flash_dtype)
486
+
487
+ if self.backend == "flash4" and use_flash:
488
+ out = self._flash4_attention(q, k, v)
489
+ elif self.backend == "flash3" and use_flash:
490
+ out = self._flash3_attention(q, k, v)
491
+ elif self.backend == "auto" and use_flash:
492
+ try:
493
+ out = self._flash4_attention(q, k, v)
494
+ except Exception:
495
+ try:
496
+ out = self._flash3_attention(q, k, v)
497
+ except Exception:
498
+ use_flash = False
499
+ if self.backend in {"flash3", "flash4"} and not use_flash:
500
+ raise RuntimeError(
501
+ f"image_attention_backend='{self.backend}' requires CUDA fp16/bf16 compute; got {x.device} {x.dtype}"
502
+ )
503
+ if use_flash and out.dtype != x.dtype:
504
+ out = out.to(dtype=x.dtype)
505
+ if not use_flash:
506
+ q = q.transpose(1, 2)
507
+ k = k.transpose(1, 2)
508
+ v = v.transpose(1, 2)
509
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
510
+ out = out.transpose(1, 2)
511
+
512
+ out = out.reshape(batch, tokens, channels)
513
+ return self.out_proj(out)
514
+
515
+
516
+ class SanaMultiHeadCrossAttention(nn.Module):
517
+ """Sana-style cross-attention with optional q/k norm and SDPA/xformers kernels."""
518
+
519
+ def __init__(
520
+ self,
521
+ hidden_dim: int,
522
+ num_heads: int,
523
+ *,
524
+ backend: str = "sdpa",
525
+ qk_norm: bool = True,
526
+ ) -> None:
527
+ super().__init__()
528
+ if hidden_dim % num_heads != 0:
529
+ raise ValueError(f"hidden_dim={hidden_dim} must be divisible by num_heads={num_heads}")
530
+ if backend not in {"sdpa", "xformers", "auto"}:
531
+ raise ValueError(f"Unsupported cross_attention_backend: {backend}")
532
+ self.hidden_dim = hidden_dim
533
+ self.num_heads = num_heads
534
+ self.head_dim = hidden_dim // num_heads
535
+ self.backend = backend
536
+ self.q_linear = nn.Linear(hidden_dim, hidden_dim)
537
+ self.kv_linear = nn.Linear(hidden_dim, hidden_dim * 2)
538
+ self.q_norm = AttentionRMSNorm(hidden_dim, scale_factor=1.0, eps=1e-6) if qk_norm else nn.Identity()
539
+ self.k_norm = AttentionRMSNorm(hidden_dim, scale_factor=1.0, eps=1e-6) if qk_norm else nn.Identity()
540
+ self.proj = nn.Linear(hidden_dim, hidden_dim)
541
+ # adaLN-Zero style: cross-attn starts as a no-op so Gemma text cannot spike GDN states early.
542
+ nn.init.zeros_(self.proj.weight)
543
+ nn.init.zeros_(self.proj.bias)
544
+ self._xformers_ops = None
545
+ self._xformers_import_attempted = False
546
+
547
+ def _get_xformers_ops(self):
548
+ if self._xformers_import_attempted:
549
+ return self._xformers_ops
550
+ self._xformers_import_attempted = True
551
+ try:
552
+ import xformers.ops as xops
553
+ except Exception:
554
+ xops = None
555
+ self._xformers_ops = xops
556
+ return self._xformers_ops
557
+
558
+ def _xformers_attention(
559
+ self,
560
+ q: torch.Tensor,
561
+ k: torch.Tensor,
562
+ v: torch.Tensor,
563
+ key_padding_mask: torch.Tensor | None,
564
+ ) -> torch.Tensor:
565
+ xops = self._get_xformers_ops()
566
+ if xops is None:
567
+ raise ImportError(
568
+ "cross_attention_backend='xformers' requires xformers. "
569
+ "Install it or use --cross-attn-backend sdpa."
570
+ )
571
+
572
+ batch, image_tokens = q.shape[:2]
573
+ text_tokens = k.shape[1]
574
+ q_lens = [image_tokens] * batch
575
+ q_compact = q.reshape(1, batch * image_tokens, self.num_heads, self.head_dim)
576
+ if key_padding_mask is None:
577
+ kv_lens = [text_tokens] * batch
578
+ k_compact = k.reshape(1, batch * text_tokens, self.num_heads, self.head_dim)
579
+ v_compact = v.reshape(1, batch * text_tokens, self.num_heads, self.head_dim)
580
+ else:
581
+ valid_mask = ~key_padding_mask.bool()
582
+ kv_lens = valid_mask.sum(dim=1).tolist()
583
+ if any(length <= 0 for length in kv_lens):
584
+ raise ValueError("xformers cross-attention received a sample with zero valid text tokens")
585
+ k_compact = torch.cat([k[index, valid_mask[index]] for index in range(batch)], dim=0).unsqueeze(0)
586
+ v_compact = torch.cat([v[index, valid_mask[index]] for index in range(batch)], dim=0).unsqueeze(0)
587
+
588
+ attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(q_lens, kv_lens)
589
+ out = xops.memory_efficient_attention(q_compact, k_compact, v_compact, attn_bias=attn_bias, p=0.0)
590
+ return out.reshape(batch, image_tokens, self.num_heads, self.head_dim)
591
+
592
+ def _sdpa_attention(
593
+ self,
594
+ q: torch.Tensor,
595
+ k: torch.Tensor,
596
+ v: torch.Tensor,
597
+ key_padding_mask: torch.Tensor | None,
598
+ attn_bias: torch.Tensor | None = None,
599
+ ) -> torch.Tensor:
600
+ q = q.transpose(1, 2)
601
+ k = k.transpose(1, 2)
602
+ v = v.transpose(1, 2)
603
+ attn_mask = attn_bias
604
+ if attn_mask is None and key_padding_mask is not None:
605
+ attn_mask = key_padding_mask[:, None, None, :].to(dtype=q.dtype)
606
+ attn_mask = attn_mask.masked_fill(attn_mask > 0, -10000.0)
607
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
608
+ return out.transpose(1, 2)
609
+
610
+ def forward(
611
+ self,
612
+ x: torch.Tensor,
613
+ cond: torch.Tensor,
614
+ key_padding_mask: torch.Tensor | None = None,
615
+ attn_bias: torch.Tensor | None = None,
616
+ ) -> torch.Tensor:
617
+ batch, image_tokens, channels = x.shape
618
+ # Sana order: linear projection first, then per-token q/k RMSNorm before head split.
619
+ # This caps dot-product growth when cond carries high-magnitude Gemma caption states.
620
+ q = self.q_linear(x)
621
+ q = self.q_norm(q).reshape(batch, image_tokens, self.num_heads, self.head_dim)
622
+ k, v = self.kv_linear(cond).chunk(2, dim=-1)
623
+ k = self.k_norm(k).reshape(batch, cond.shape[1], self.num_heads, self.head_dim)
624
+ v = v.reshape(batch, cond.shape[1], self.num_heads, self.head_dim)
625
+
626
+ use_xformers = self.backend in {"xformers", "auto"} and x.is_cuda and x.dtype in {
627
+ torch.float16,
628
+ torch.bfloat16,
629
+ }
630
+ if use_xformers:
631
+ try:
632
+ out = self._xformers_attention(q, k, v, key_padding_mask)
633
+ except Exception:
634
+ if self.backend == "xformers":
635
+ raise
636
+ use_xformers = False
637
+ if self.backend == "xformers" and not use_xformers:
638
+ raise RuntimeError(
639
+ f"cross_attention_backend='xformers' requires CUDA fp16/bf16 tensors; got {x.device} {x.dtype}"
640
+ )
641
+ if not use_xformers:
642
+ out = self._sdpa_attention(q, k, v, key_padding_mask, attn_bias)
643
+
644
+ return self.proj(out.reshape(batch, image_tokens, channels))
645
+
646
+
647
+ class FLASelfMixer(nn.Module):
648
+ def __init__(self, config: BoomerFLADiTConfig, *, layer_idx: int) -> None:
649
+ super().__init__()
650
+ try:
651
+ import fla.layers as fla_layers
652
+ except Exception:
653
+ maybe_add_sibling_fla_repo()
654
+ import fla.layers as fla_layers
655
+
656
+ hidden_dim = config.hidden_dim
657
+ self.bidirectional = config.fla_bidirectional
658
+
659
+ def make_mixer() -> nn.Module:
660
+ if config.mixer_type == "fla_linear":
661
+ return fla_layers.LinearAttention(
662
+ hidden_size=hidden_dim,
663
+ num_heads=config.num_heads,
664
+ mode=config.fla_mode,
665
+ feature_map=config.fla_feature_map,
666
+ output_norm="rmsnorm",
667
+ layer_idx=layer_idx,
668
+ )
669
+ if config.mixer_type == "fla_gated_deltanet":
670
+ return fla_layers.GatedDeltaNet(
671
+ hidden_size=hidden_dim,
672
+ num_heads=config.num_heads,
673
+ head_dim=hidden_dim // config.num_heads,
674
+ expand_v=1,
675
+ mode=config.fla_mode,
676
+ use_short_conv=config.use_short_conv,
677
+ conv_size=config.conv_size,
678
+ layer_idx=layer_idx,
679
+ )
680
+ if config.mixer_type == "fla_gla":
681
+ return fla_layers.GatedLinearAttention(
682
+ hidden_size=hidden_dim,
683
+ num_heads=config.num_heads,
684
+ mode=config.fla_mode,
685
+ feature_map=config.fla_feature_map,
686
+ use_short_conv=config.use_short_conv,
687
+ conv_size=config.conv_size,
688
+ layer_idx=layer_idx,
689
+ )
690
+ raise ValueError(f"Unsupported FLA mixer_type: {config.mixer_type}")
691
+
692
+ self.mixer_fwd = make_mixer()
693
+ self.mixer_bwd = make_mixer() if self.bidirectional else None
694
+ if self.bidirectional:
695
+ self.out_proj = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
696
+ nn.init.zeros_(self.out_proj.weight)
697
+
698
+ @staticmethod
699
+ def _run_mixer(mixer: nn.Module, x: torch.Tensor) -> torch.Tensor:
700
+ y = mixer(x)
701
+ if isinstance(y, tuple):
702
+ y = y[0]
703
+ return y
704
+
705
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
706
+ y = self._run_mixer(self.mixer_fwd, x)
707
+ if not self.bidirectional:
708
+ return y
709
+ if self.mixer_bwd is None:
710
+ raise RuntimeError("bidirectional FLASelfMixer is missing the backward mixer")
711
+ y_rev = self._run_mixer(self.mixer_bwd, x.flip(1)).flip(1)
712
+ return self.out_proj(torch.cat([y, y_rev], dim=-1))
713
+
714
+
715
+ class BoomerFLABlock(nn.Module):
716
+ def __init__(self, config: BoomerFLADiTConfig, *, layer_idx: int) -> None:
717
+ super().__init__()
718
+ hidden_dim = config.hidden_dim
719
+ self.parallel_block = config.parallel_block
720
+ self.use_image_attention = (
721
+ config.image_attention_every > 0 and (layer_idx + 1) % config.image_attention_every == 0
722
+ )
723
+ self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
724
+ if config.mixer_type in {"torch", "fallback"}:
725
+ self.self_attn = TorchSelfAttention(hidden_dim, config.num_heads)
726
+ else:
727
+ self.self_attn = FLASelfMixer(config, layer_idx=layer_idx)
728
+ if self.use_image_attention:
729
+ self.image_attn_norm = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
730
+ self.image_attn_mod = nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 3))
731
+ self.image_attn = FullImageSelfAttention(
732
+ hidden_dim,
733
+ config.num_heads,
734
+ backend=config.image_attention_backend,
735
+ grid_size=config.latent_size // config.patch_size,
736
+ rope=config.image_attention_rope,
737
+ rope_theta=config.image_rope_theta,
738
+ )
739
+ self.image_attn_scale_shift_table = nn.Parameter(torch.zeros(3, hidden_dim))
740
+ cross_backend = config.cross_attention_backend
741
+ if config.cross_attention_qk_norm and cross_backend == "mha":
742
+ raise ValueError(
743
+ "cross_attention_qk_norm requires SanaMultiHeadCrossAttention "
744
+ "(cross_attention_backend sdpa/xformers/auto), not mha"
745
+ )
746
+ if cross_backend == "mha":
747
+ self.cross_attn = nn.MultiheadAttention(hidden_dim, config.num_heads, batch_first=True)
748
+ else:
749
+ self.cross_attn = SanaMultiHeadCrossAttention(
750
+ hidden_dim,
751
+ config.num_heads,
752
+ backend=cross_backend,
753
+ qk_norm=config.cross_attention_qk_norm,
754
+ )
755
+ self.mod = nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 9))
756
+ self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
757
+ self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
758
+ self.mlp = GLUMBConv(hidden_dim, config.mlp_ratio)
759
+ self.scale_shift_table = nn.Parameter(torch.zeros(9, hidden_dim))
760
+
761
+ def _cross_attention(
762
+ self,
763
+ x: torch.Tensor,
764
+ text_tokens: torch.Tensor,
765
+ text_key_padding_mask: torch.Tensor,
766
+ text_attn_bias: torch.Tensor | None,
767
+ ) -> torch.Tensor:
768
+ if isinstance(self.cross_attn, nn.MultiheadAttention):
769
+ return self.cross_attn(
770
+ x,
771
+ text_tokens,
772
+ text_tokens,
773
+ key_padding_mask=text_key_padding_mask,
774
+ need_weights=False,
775
+ )[0]
776
+ return self.cross_attn(x, text_tokens, text_key_padding_mask, text_attn_bias)
777
+
778
+ def forward(
779
+ self,
780
+ x: torch.Tensor,
781
+ text_tokens: torch.Tensor,
782
+ t_emb: torch.Tensor,
783
+ text_key_padding_mask: torch.Tensor,
784
+ text_attn_bias: torch.Tensor | None,
785
+ *,
786
+ height: int,
787
+ width: int,
788
+ ) -> torch.Tensor:
789
+ timestep_mod = self.mod(t_emb)
790
+ (
791
+ shift_msa,
792
+ scale_msa,
793
+ gate_msa,
794
+ shift_cross,
795
+ scale_cross,
796
+ gate_cross,
797
+ shift_mlp,
798
+ scale_mlp,
799
+ gate_mlp,
800
+ ) = (self.scale_shift_table[None] + timestep_mod.reshape(x.shape[0], 9, -1)).chunk(9, dim=1)
801
+ if self.parallel_block:
802
+ base = x
803
+ branches = [
804
+ gate_msa * self.self_attn(modulate(self.norm1(base), shift_msa, scale_msa)),
805
+ gate_cross
806
+ * self._cross_attention(
807
+ modulate(self.norm3(base), shift_cross, scale_cross),
808
+ text_tokens,
809
+ text_key_padding_mask,
810
+ text_attn_bias,
811
+ ),
812
+ gate_mlp * self.mlp(modulate(self.norm2(base), shift_mlp, scale_mlp), height=height, width=width),
813
+ ]
814
+ if self.use_image_attention:
815
+ image_attn_mod = self.image_attn_mod(t_emb)
816
+ shift_img, scale_img, gate_img = (
817
+ self.image_attn_scale_shift_table[None] + image_attn_mod.reshape(x.shape[0], 3, -1)
818
+ ).chunk(3, dim=1)
819
+ branches.append(
820
+ gate_img * self.image_attn(modulate(self.image_attn_norm(base), shift_img, scale_img))
821
+ )
822
+ return base + sum(branches)
823
+
824
+ x = x + gate_msa * self.self_attn(modulate(self.norm1(x), shift_msa, scale_msa))
825
+ if self.use_image_attention:
826
+ image_attn_mod = self.image_attn_mod(t_emb)
827
+ shift_img, scale_img, gate_img = (
828
+ self.image_attn_scale_shift_table[None] + image_attn_mod.reshape(x.shape[0], 3, -1)
829
+ ).chunk(3, dim=1)
830
+ x = x + gate_img * self.image_attn(modulate(self.image_attn_norm(x), shift_img, scale_img))
831
+ x = x + gate_cross * self._cross_attention(
832
+ modulate(self.norm3(x), shift_cross, scale_cross),
833
+ text_tokens,
834
+ text_key_padding_mask,
835
+ text_attn_bias,
836
+ )
837
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), height=height, width=width)
838
+ return x
839
+
840
+
841
+ class BoomerFLADualStreamBlock(nn.Module):
842
+ """FLUX-style early block with one joint text+image attention operation."""
843
+
844
+ updates_text = True
845
+
846
+ def __init__(self, config: BoomerFLADiTConfig, *, layer_idx: int) -> None:
847
+ super().__init__()
848
+ hidden_dim = config.hidden_dim
849
+ if hidden_dim % config.num_heads != 0:
850
+ raise ValueError(f"hidden_dim={hidden_dim} must be divisible by num_heads={config.num_heads}")
851
+ self.num_heads = config.num_heads
852
+ self.head_dim = hidden_dim // config.num_heads
853
+ self.hidden_dim = hidden_dim
854
+ self.parallel_block = config.parallel_block
855
+
856
+ self.image_mod = nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 6))
857
+ self.image_norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
858
+ self.image_qkv = nn.Linear(hidden_dim, hidden_dim * 3)
859
+ self.image_q_norm = AttentionRMSNorm(self.head_dim, scale_factor=1.0, eps=1e-6)
860
+ self.image_k_norm = AttentionRMSNorm(self.head_dim, scale_factor=1.0, eps=1e-6)
861
+ self.image_out_proj = nn.Linear(hidden_dim, hidden_dim)
862
+ self.image_norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
863
+ self.image_mlp = GLUMBConv(hidden_dim, config.mlp_ratio)
864
+ self.image_scale_shift_table = nn.Parameter(torch.zeros(6, hidden_dim))
865
+
866
+ self.text_mod = nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 6))
867
+ self.text_norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
868
+ self.text_qkv = nn.Linear(hidden_dim, hidden_dim * 3)
869
+ self.text_q_norm = AttentionRMSNorm(self.head_dim, scale_factor=1.0, eps=1e-6)
870
+ self.text_k_norm = AttentionRMSNorm(self.head_dim, scale_factor=1.0, eps=1e-6)
871
+ self.text_out_proj = nn.Linear(hidden_dim, hidden_dim)
872
+ self.text_norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
873
+ self.text_mlp = TokenMLP(hidden_dim, config.mlp_ratio)
874
+ self.text_scale_shift_table = nn.Parameter(torch.zeros(6, hidden_dim))
875
+
876
+ def _qkv(
877
+ self,
878
+ x: torch.Tensor,
879
+ qkv: nn.Linear,
880
+ q_norm: AttentionRMSNorm,
881
+ k_norm: AttentionRMSNorm,
882
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
883
+ batch, tokens, _ = x.shape
884
+ q, k, v = qkv(x).reshape(batch, tokens, 3, self.num_heads, self.head_dim).unbind(dim=2)
885
+ q = q_norm(q)
886
+ k = k_norm(k)
887
+ return q, k, v
888
+
889
+ def _joint_attention(
890
+ self,
891
+ image_tokens: torch.Tensor,
892
+ text_tokens: torch.Tensor,
893
+ text_key_padding_mask: torch.Tensor,
894
+ coord_rope: MultimodalCoordinateRoPE | None,
895
+ image_coord_ids: torch.Tensor | None,
896
+ text_coord_ids: torch.Tensor | None,
897
+ ) -> tuple[torch.Tensor, torch.Tensor]:
898
+ image_q, image_k, image_v = self._qkv(
899
+ image_tokens,
900
+ self.image_qkv,
901
+ self.image_q_norm,
902
+ self.image_k_norm,
903
+ )
904
+ text_q, text_k, text_v = self._qkv(
905
+ text_tokens,
906
+ self.text_qkv,
907
+ self.text_q_norm,
908
+ self.text_k_norm,
909
+ )
910
+ q = torch.cat([text_q, image_q], dim=1)
911
+ k = torch.cat([text_k, image_k], dim=1)
912
+ v = torch.cat([text_v, image_v], dim=1)
913
+ if coord_rope is not None:
914
+ if image_coord_ids is None or text_coord_ids is None:
915
+ raise ValueError("coordinate ids are required when multimodal coord RoPE is enabled")
916
+ coord_ids = torch.cat([text_coord_ids, image_coord_ids], dim=1)
917
+ q, k = coord_rope.apply(q, k, coord_ids)
918
+
919
+ image_mask = torch.zeros(
920
+ image_tokens.shape[0],
921
+ image_tokens.shape[1],
922
+ device=image_tokens.device,
923
+ dtype=text_key_padding_mask.dtype,
924
+ )
925
+ key_padding_mask = torch.cat([text_key_padding_mask, image_mask], dim=1)
926
+ attn_bias = key_padding_mask[:, None, None, :].to(dtype=q.dtype)
927
+ attn_bias = attn_bias.masked_fill(attn_bias > 0, -10000.0)
928
+ out = F.scaled_dot_product_attention(
929
+ q.transpose(1, 2),
930
+ k.transpose(1, 2),
931
+ v.transpose(1, 2),
932
+ attn_mask=attn_bias,
933
+ dropout_p=0.0,
934
+ is_causal=False,
935
+ )
936
+ out = out.transpose(1, 2).reshape(image_tokens.shape[0], text_tokens.shape[1] + image_tokens.shape[1], -1)
937
+ text_out, image_out = out.split([text_tokens.shape[1], image_tokens.shape[1]], dim=1)
938
+ return self.image_out_proj(image_out), self.text_out_proj(text_out)
939
+
940
+ def forward(
941
+ self,
942
+ x: torch.Tensor,
943
+ text_tokens: torch.Tensor,
944
+ t_emb: torch.Tensor,
945
+ text_key_padding_mask: torch.Tensor,
946
+ text_attn_bias: torch.Tensor | None,
947
+ *,
948
+ height: int,
949
+ width: int,
950
+ coord_rope: MultimodalCoordinateRoPE | None = None,
951
+ image_coord_ids: torch.Tensor | None = None,
952
+ text_coord_ids: torch.Tensor | None = None,
953
+ ) -> tuple[torch.Tensor, torch.Tensor]:
954
+ del text_attn_bias
955
+ image_timestep_mod = self.image_mod(t_emb)
956
+ text_timestep_mod = self.text_mod(t_emb)
957
+ image_shift_attn, image_scale_attn, image_gate_attn, image_shift_mlp, image_scale_mlp, image_gate_mlp = (
958
+ self.image_scale_shift_table[None] + image_timestep_mod.reshape(x.shape[0], 6, -1)
959
+ ).chunk(6, dim=1)
960
+ text_shift_attn, text_scale_attn, text_gate_attn, text_shift_mlp, text_scale_mlp, text_gate_mlp = (
961
+ self.text_scale_shift_table[None] + text_timestep_mod.reshape(text_tokens.shape[0], 6, -1)
962
+ ).chunk(6, dim=1)
963
+
964
+ image_base = x
965
+ text_base = text_tokens
966
+ image_attn_in = modulate(self.image_norm1(image_base), image_shift_attn, image_scale_attn)
967
+ text_attn_in = modulate(self.text_norm1(text_base), text_shift_attn, text_scale_attn)
968
+ image_attn, text_attn = self._joint_attention(
969
+ image_attn_in,
970
+ text_attn_in,
971
+ text_key_padding_mask,
972
+ coord_rope,
973
+ image_coord_ids,
974
+ text_coord_ids,
975
+ )
976
+ if self.parallel_block:
977
+ x = image_base + image_gate_attn * image_attn + image_gate_mlp * self.image_mlp(
978
+ modulate(self.image_norm2(image_base), image_shift_mlp, image_scale_mlp),
979
+ height=height,
980
+ width=width,
981
+ )
982
+ text_tokens = text_base + text_gate_attn * text_attn + text_gate_mlp * self.text_mlp(
983
+ modulate(self.text_norm2(text_base), text_shift_mlp, text_scale_mlp)
984
+ )
985
+ return x, text_tokens
986
+
987
+ x = image_base + image_gate_attn * image_attn
988
+ text_tokens = text_base + text_gate_attn * text_attn
989
+ x = x + image_gate_mlp * self.image_mlp(
990
+ modulate(self.image_norm2(x), image_shift_mlp, image_scale_mlp),
991
+ height=height,
992
+ width=width,
993
+ )
994
+ text_tokens = text_tokens + text_gate_mlp * self.text_mlp(
995
+ modulate(self.text_norm2(text_tokens), text_shift_mlp, text_scale_mlp)
996
+ )
997
+ return x, text_tokens
998
+
999
+
1000
+ class BoomerFLADiT(nn.Module):
1001
+ """Boomer DiT with FLA mixers, optional full image attention, and GLUMBConv FFNs."""
1002
+
1003
+ def __init__(self, config: BoomerFLADiTConfig = BoomerFLADiTConfig()) -> None:
1004
+ super().__init__()
1005
+ if config.patch_size <= 0:
1006
+ raise ValueError(f"patch_size must be positive, got {config.patch_size}")
1007
+ if config.latent_size % config.patch_size != 0:
1008
+ raise ValueError(
1009
+ f"latent_size={config.latent_size} must be divisible by patch_size={config.patch_size}"
1010
+ )
1011
+ if config.dual_stream_depth < 0:
1012
+ raise ValueError(f"dual_stream_depth must be non-negative, got {config.dual_stream_depth}")
1013
+ if config.dual_stream_depth > config.depth:
1014
+ raise ValueError(f"dual_stream_depth={config.dual_stream_depth} exceeds depth={config.depth}")
1015
+ self.config = config
1016
+ hidden_dim = config.hidden_dim
1017
+ self.patch_size = config.patch_size
1018
+ self.token_grid_size = config.latent_size // config.patch_size
1019
+ token_count = self.token_grid_size * self.token_grid_size
1020
+ self.x_embedder = (
1021
+ nn.Linear(config.latent_channels, hidden_dim)
1022
+ if config.patch_size == 1
1023
+ else nn.Conv2d(
1024
+ config.latent_channels,
1025
+ hidden_dim,
1026
+ kernel_size=config.patch_size,
1027
+ stride=config.patch_size,
1028
+ )
1029
+ )
1030
+ self.pos_embed = nn.Parameter(torch.zeros(1, token_count, hidden_dim)) if config.use_abs_pos_embed else None
1031
+ self.t_embedder = TimestepEmbedder(hidden_dim)
1032
+ self.caption_embedder = CaptionEmbedder(config.text_dim, hidden_dim, config.text_seq_len)
1033
+ self.attention_y_norm = (
1034
+ AttentionRMSNorm(hidden_dim, scale_factor=config.y_norm_scale_factor) if config.y_norm else None
1035
+ )
1036
+ self.coord_embedder = (
1037
+ MultimodalCoordinateRoPE(
1038
+ hidden_dim // config.num_heads,
1039
+ image_size=self.token_grid_size,
1040
+ text_seq_len=config.text_seq_len,
1041
+ theta=config.image_rope_theta,
1042
+ )
1043
+ if config.multimodal_coord_ids
1044
+ else None
1045
+ )
1046
+ self.blocks = nn.ModuleList(
1047
+ [
1048
+ (
1049
+ BoomerFLADualStreamBlock(config, layer_idx=i)
1050
+ if i < config.dual_stream_depth
1051
+ else BoomerFLABlock(config, layer_idx=i)
1052
+ )
1053
+ for i in range(config.depth)
1054
+ ]
1055
+ )
1056
+ self.final_norm = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
1057
+ self.final_t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 2))
1058
+ self.out_proj = nn.Linear(hidden_dim, config.latent_channels * config.patch_size * config.patch_size)
1059
+ self.initialize_weights()
1060
+
1061
+ def initialize_weights(self) -> None:
1062
+ if self.pos_embed is not None:
1063
+ nn.init.normal_(self.pos_embed, std=0.02)
1064
+
1065
+ for block in self.blocks:
1066
+ if isinstance(block, BoomerFLADualStreamBlock):
1067
+ nn.init.zeros_(block.image_mod[1].weight)
1068
+ nn.init.zeros_(block.image_mod[1].bias)
1069
+ nn.init.zeros_(block.text_mod[1].weight)
1070
+ nn.init.zeros_(block.text_mod[1].bias)
1071
+ nn.init.normal_(block.image_scale_shift_table, std=0.02)
1072
+ nn.init.normal_(block.text_scale_shift_table, std=0.02)
1073
+ else:
1074
+ nn.init.zeros_(block.mod[1].weight)
1075
+ nn.init.zeros_(block.mod[1].bias)
1076
+ nn.init.normal_(block.scale_shift_table, std=0.02)
1077
+ if block.use_image_attention:
1078
+ nn.init.zeros_(block.image_attn_mod[1].weight)
1079
+ nn.init.zeros_(block.image_attn_mod[1].bias)
1080
+ nn.init.normal_(block.image_attn_scale_shift_table, std=0.02)
1081
+
1082
+ nn.init.zeros_(self.final_t_block[1].weight)
1083
+ nn.init.zeros_(self.final_t_block[1].bias)
1084
+ nn.init.zeros_(self.out_proj.weight)
1085
+ nn.init.zeros_(self.out_proj.bias)
1086
+
1087
+ def apply_y_norm(self, caption_tokens: torch.Tensor) -> torch.Tensor:
1088
+ if self.attention_y_norm is None:
1089
+ return caption_tokens
1090
+ return self.attention_y_norm(caption_tokens)
1091
+
1092
+ def null_condition(
1093
+ self,
1094
+ batch_size: int,
1095
+ *,
1096
+ device: torch.device | str,
1097
+ dtype: torch.dtype,
1098
+ mask_dtype: torch.dtype | None = None,
1099
+ token_num: int | None = None,
1100
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1101
+ return self.caption_embedder.null_condition(
1102
+ batch_size,
1103
+ device=device,
1104
+ dtype=dtype,
1105
+ mask_dtype=mask_dtype,
1106
+ token_num=token_num,
1107
+ )
1108
+
1109
+ def apply_condition_dropout(
1110
+ self,
1111
+ text_embedding: torch.Tensor,
1112
+ attention_mask: torch.Tensor,
1113
+ probability: float,
1114
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1115
+ if probability <= 0.0:
1116
+ return text_embedding, attention_mask
1117
+ batch_size = text_embedding.shape[0]
1118
+ null_text, null_mask = self.null_condition(
1119
+ batch_size,
1120
+ device=text_embedding.device,
1121
+ dtype=text_embedding.dtype,
1122
+ mask_dtype=attention_mask.dtype,
1123
+ token_num=text_embedding.shape[-2],
1124
+ )
1125
+ # torch.where over a per-sample bool. Avoids the bool(drop.any()) CUDA
1126
+ # sync (which would defeat the training-loop sync removal) and skips
1127
+ # the full-tensor .clone() that the previous in-place path required.
1128
+ drop = torch.rand(batch_size, device=text_embedding.device) < probability
1129
+ drop_text = drop.view(batch_size, *([1] * (text_embedding.dim() - 1)))
1130
+ drop_mask = drop.view(batch_size, *([1] * (attention_mask.dim() - 1)))
1131
+ text_embedding = torch.where(drop_text, null_text, text_embedding)
1132
+ attention_mask = torch.where(drop_mask, null_mask, attention_mask)
1133
+ return text_embedding, attention_mask
1134
+
1135
+ def forward(
1136
+ self,
1137
+ noisy_latent: torch.Tensor,
1138
+ timesteps: torch.Tensor,
1139
+ text_embedding: torch.Tensor,
1140
+ attention_mask: torch.Tensor,
1141
+ ) -> torch.Tensor:
1142
+ batch, channels, height, width = noisy_latent.shape
1143
+ if channels != self.config.latent_channels:
1144
+ raise ValueError(
1145
+ f"Expected latent_channels={self.config.latent_channels}, got shape {tuple(noisy_latent.shape)}"
1146
+ )
1147
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
1148
+ raise ValueError(
1149
+ f"latent height/width must be divisible by patch_size={self.patch_size}, got {(height, width)}"
1150
+ )
1151
+ token_height = height // self.patch_size
1152
+ token_width = width // self.patch_size
1153
+ token_count = token_height * token_width
1154
+ if self.pos_embed is not None and token_count != self.pos_embed.shape[1]:
1155
+ raise ValueError(
1156
+ f"absolute pos_embed expects {self.pos_embed.shape[1]} latent tokens, got {token_count}. "
1157
+ "Disable it with --no-abs-pos-embed for variable latent sizes."
1158
+ )
1159
+ if text_embedding.shape[-1] != self.config.text_dim:
1160
+ raise ValueError(f"text_embedding last dim must be {self.config.text_dim}, got {text_embedding.shape[-1]}")
1161
+
1162
+ text_tokens = self.caption_embedder(text_embedding)
1163
+ text_tokens = self.apply_y_norm(text_tokens)
1164
+ text_key_padding_mask = attention_mask == 0
1165
+
1166
+ if self.patch_size == 1:
1167
+ x = noisy_latent.flatten(2).transpose(1, 2)
1168
+ x = self.x_embedder(x)
1169
+ else:
1170
+ x = self.x_embedder(noisy_latent).flatten(2).transpose(1, 2)
1171
+ if self.pos_embed is not None:
1172
+ x = x + self.pos_embed
1173
+ image_coord_ids = None
1174
+ text_coord_ids = None
1175
+ if self.coord_embedder is not None:
1176
+ image_coord_ids = self.coord_embedder.image_ids(
1177
+ batch,
1178
+ height=token_height,
1179
+ width=token_width,
1180
+ device=x.device,
1181
+ )
1182
+ text_coord_ids = self.coord_embedder.text_ids(batch, text_tokens.shape[1], device=text_tokens.device)
1183
+ text_attn_bias = text_key_padding_mask[:, None, None, :].to(dtype=x.dtype)
1184
+ text_attn_bias = text_attn_bias.masked_fill(text_attn_bias > 0, -10000.0)
1185
+ t_emb = self.t_embedder(timesteps)
1186
+ use_ckpt = self.config.gradient_checkpointing and self.training
1187
+ for block in self.blocks:
1188
+ if getattr(block, "updates_text", False):
1189
+ # Dual-stream block: returns (x, text_tokens).
1190
+ # Non-tensor args (height, width, coord_rope, coord IDs) captured via closure.
1191
+ _h, _w = token_height, token_width
1192
+ _cr, _ii, _ti = self.coord_embedder, image_coord_ids, text_coord_ids
1193
+ if use_ckpt:
1194
+ def _dual_fn(x, tt, te, mk, bi,
1195
+ _blk=block, h=_h, w=_w, cr=_cr, ii=_ii, ti=_ti):
1196
+ return _blk(x, tt, te, mk, bi,
1197
+ height=h, width=w, coord_rope=cr,
1198
+ image_coord_ids=ii, text_coord_ids=ti)
1199
+ x, text_tokens = _ckpt(_dual_fn, x, text_tokens, t_emb,
1200
+ text_key_padding_mask, text_attn_bias,
1201
+ use_reentrant=False,
1202
+ preserve_rng_state=False)
1203
+ else:
1204
+ x, text_tokens = block(
1205
+ x, text_tokens, t_emb, text_key_padding_mask, text_attn_bias,
1206
+ height=token_height, width=token_width,
1207
+ coord_rope=self.coord_embedder,
1208
+ image_coord_ids=image_coord_ids, text_coord_ids=text_coord_ids,
1209
+ )
1210
+ else:
1211
+ # Single-stream block: returns x only.
1212
+ _h, _w = token_height, token_width
1213
+ if use_ckpt:
1214
+ def _single_fn(x, tt, te, mk, bi,
1215
+ _blk=block, h=_h, w=_w):
1216
+ return _blk(x, tt, te, mk, bi, height=h, width=w)
1217
+ x = _ckpt(_single_fn, x, text_tokens, t_emb,
1218
+ text_key_padding_mask, text_attn_bias,
1219
+ use_reentrant=False,
1220
+ preserve_rng_state=False)
1221
+ else:
1222
+ x = block(
1223
+ x, text_tokens, t_emb, text_key_padding_mask, text_attn_bias,
1224
+ height=token_height, width=token_width,
1225
+ )
1226
+ final_mod = self.final_t_block(t_emb)
1227
+ shift, scale = final_mod.reshape(batch, 2, -1).chunk(2, dim=1)
1228
+ x = modulate(self.final_norm(x), shift, scale)
1229
+ x = self.out_proj(x)
1230
+ if self.patch_size == 1:
1231
+ return x.transpose(1, 2).reshape(batch, channels, height, width)
1232
+ patch = self.patch_size
1233
+ x = x.reshape(batch, token_height, token_width, channels, patch, patch)
1234
+ x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
1235
+ return x.reshape(batch, channels, height, width)
1236
+
1237
+ @property
1238
+ def dtype(self) -> torch.dtype:
1239
+ return next(self.parameters()).dtype
1240
+
1241
+ @property
1242
+ def device(self) -> torch.device:
1243
+ return next(self.parameters()).device
1244
+
1245
+ @classmethod
1246
+ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: str | None = None, **kwargs):
1247
+ """Load BoomerFLADiT weights from a local snapshot directory."""
1248
+ import json
1249
+ from pathlib import Path
1250
+ from safetensors.torch import load_file
1251
+
1252
+ path = Path(pretrained_model_name_or_path)
1253
+ if subfolder:
1254
+ path = path / subfolder
1255
+
1256
+ cfg_raw = json.loads((path / "config.json").read_text())
1257
+ cfg_clean = {k: v for k, v in cfg_raw.items() if not k.startswith("_")}
1258
+ model_config = BoomerFLADiTConfig(**cfg_clean)
1259
+
1260
+ model = cls(model_config)
1261
+ sd = load_file(str(path / "diffusion_pytorch_model.safetensors"))
1262
+ model.load_state_dict(sd, strict=False)
1263
+
1264
+ # Attach inference metadata (latent stats, component repos, etc.)
1265
+ # so BoomerPipeline.__init__ can read them without a separate config file.
1266
+ model._boomer_cfg = {k: v for k, v in cfg_raw.items() if k.startswith("_")}
1267
+ return model