AbstractPhil commited on
Commit
e4ba6b1
·
verified ·
1 Parent(s): 120062a

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +623 -0
model.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TinyFlux: A /12 scaled Flux architecture for experimentation.
3
+
4
+ Architecture:
5
+ - hidden: 256 (3072/12)
6
+ - num_heads: 2 (24/12)
7
+ - head_dim: 128 (preserved for RoPE compatibility)
8
+ - in_channels: 16 (Flux VAE output channels)
9
+ - double_layers: 3
10
+ - single_layers: 3
11
+
12
+ Text Encoders (runtime):
13
+ - flan-t5-base (768 dim) → txt_in projects to hidden
14
+ - CLIP-L (768 dim pooled) → vector_in projects to hidden
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import math
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple
23
+
24
+
25
+ @dataclass
26
+ class TinyFluxConfig:
27
+ """Configuration for TinyFlux model."""
28
+ # Core dimensions
29
+ hidden_size: int = 256
30
+ num_attention_heads: int = 2
31
+ attention_head_dim: int = 128 # Preserved for RoPE
32
+
33
+ # Input/output (Flux VAE has 16 channels)
34
+ in_channels: int = 16 # Flux VAE output channels
35
+ patch_size: int = 1 # No 2x2 patchification, raw latent tokens
36
+
37
+ # Text encoder interfaces (runtime encoding)
38
+ joint_attention_dim: int = 768 # flan-t5-base output dim
39
+ pooled_projection_dim: int = 768 # CLIP-L pooled dim
40
+
41
+ # Layers
42
+ num_double_layers: int = 3
43
+ num_single_layers: int = 3
44
+
45
+ # MLP
46
+ mlp_ratio: float = 4.0
47
+
48
+ # RoPE (must sum to head_dim)
49
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
50
+
51
+ # Misc
52
+ guidance_embeds: bool = True
53
+
54
+ def __post_init__(self):
55
+ assert self.num_attention_heads * self.attention_head_dim == self.hidden_size, \
56
+ f"heads ({self.num_attention_heads}) * head_dim ({self.attention_head_dim}) != hidden ({self.hidden_size})"
57
+ assert sum(self.axes_dims_rope) == self.attention_head_dim, \
58
+ f"RoPE dims {self.axes_dims_rope} must sum to head_dim {self.attention_head_dim}"
59
+
60
+
61
+ class RMSNorm(nn.Module):
62
+ """Root Mean Square Layer Normalization."""
63
+ def __init__(self, dim: int, eps: float = 1e-6):
64
+ super().__init__()
65
+ self.eps = eps
66
+ self.weight = nn.Parameter(torch.ones(dim))
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
70
+ return (x * norm).type_as(x) * self.weight
71
+
72
+
73
+ class RotaryEmbedding(nn.Module):
74
+ """Rotary Position Embedding for 2D + temporal."""
75
+ def __init__(self, dim: int, axes_dims: Tuple[int, int, int], theta: float = 10000.0):
76
+ super().__init__()
77
+ self.dim = dim
78
+ self.axes_dims = axes_dims # (temporal, height, width)
79
+ self.theta = theta
80
+
81
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
82
+ """
83
+ ids: (B, N, 3) - temporal, height, width indices
84
+ Returns: (B, N, dim) rotary embeddings
85
+ """
86
+ B, N, _ = ids.shape
87
+ device = ids.device
88
+ dtype = torch.float32
89
+
90
+ embeddings = []
91
+ dim_offset = 0
92
+
93
+ for axis_idx, axis_dim in enumerate(self.axes_dims):
94
+ # Compute frequencies for this axis
95
+ freqs = 1.0 / (self.theta ** (torch.arange(0, axis_dim, 2, device=device, dtype=dtype) / axis_dim))
96
+ # Get positions for this axis
97
+ positions = ids[:, :, axis_idx].float() # (B, N)
98
+ # Outer product: (B, N) x (axis_dim/2) -> (B, N, axis_dim/2)
99
+ angles = positions.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0)
100
+ # Interleave sin/cos
101
+ emb = torch.stack([angles.cos(), angles.sin()], dim=-1) # (B, N, axis_dim/2, 2)
102
+ emb = emb.flatten(-2) # (B, N, axis_dim)
103
+ embeddings.append(emb)
104
+ dim_offset += axis_dim
105
+
106
+ return torch.cat(embeddings, dim=-1) # (B, N, dim)
107
+
108
+
109
+ def apply_rope(x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
110
+ """Apply rotary embeddings to input tensor."""
111
+ # x: (B, heads, N, head_dim)
112
+ # rope: (B, N, head_dim)
113
+ B, H, N, D = x.shape
114
+ rope = rope.unsqueeze(1) # (B, 1, N, D)
115
+
116
+ # Split into pairs
117
+ x_pairs = x.reshape(B, H, N, D // 2, 2)
118
+ rope_pairs = rope.reshape(B, 1, N, D // 2, 2)
119
+
120
+ cos = rope_pairs[..., 0]
121
+ sin = rope_pairs[..., 1]
122
+
123
+ x0 = x_pairs[..., 0]
124
+ x1 = x_pairs[..., 1]
125
+
126
+ out0 = x0 * cos - x1 * sin
127
+ out1 = x1 * cos + x0 * sin
128
+
129
+ return torch.stack([out0, out1], dim=-1).flatten(-2)
130
+
131
+
132
+ class MLPEmbedder(nn.Module):
133
+ """MLP for embedding scalars (timestep, guidance)."""
134
+ def __init__(self, hidden_size: int):
135
+ super().__init__()
136
+ self.mlp = nn.Sequential(
137
+ nn.Linear(256, hidden_size),
138
+ nn.SiLU(),
139
+ nn.Linear(hidden_size, hidden_size),
140
+ )
141
+
142
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
143
+ # Sinusoidal embedding first
144
+ half_dim = 128
145
+ emb = math.log(10000) / (half_dim - 1)
146
+ emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
147
+ emb = x.unsqueeze(-1) * emb.unsqueeze(0)
148
+ emb = torch.cat([emb.sin(), emb.cos()], dim=-1) # (B, 256)
149
+ return self.mlp(emb)
150
+
151
+
152
+ class AdaLayerNormZero(nn.Module):
153
+ """
154
+ AdaLN-Zero for double-stream blocks.
155
+ Outputs 6 modulation params: (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp)
156
+ """
157
+ def __init__(self, hidden_size: int):
158
+ super().__init__()
159
+ self.silu = nn.SiLU()
160
+ self.linear = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
161
+ self.norm = RMSNorm(hidden_size)
162
+
163
+ def forward(
164
+ self, x: torch.Tensor, emb: torch.Tensor
165
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
166
+ """
167
+ Args:
168
+ x: hidden states (B, N, D)
169
+ emb: conditioning embedding (B, D)
170
+ Returns:
171
+ (normed_x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
172
+ """
173
+ emb_out = self.linear(self.silu(emb))
174
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb_out.chunk(6, dim=-1)
175
+ x = self.norm(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
176
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
177
+
178
+
179
+ class AdaLayerNormZeroSingle(nn.Module):
180
+ """
181
+ AdaLN-Zero for single-stream blocks.
182
+ Outputs 3 modulation params: (shift, scale, gate)
183
+ """
184
+ def __init__(self, hidden_size: int):
185
+ super().__init__()
186
+ self.silu = nn.SiLU()
187
+ self.linear = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
188
+ self.norm = RMSNorm(hidden_size)
189
+
190
+ def forward(
191
+ self, x: torch.Tensor, emb: torch.Tensor
192
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
193
+ """
194
+ Args:
195
+ x: hidden states (B, N, D)
196
+ emb: conditioning embedding (B, D)
197
+ Returns:
198
+ (normed_x, gate)
199
+ """
200
+ emb_out = self.linear(self.silu(emb))
201
+ shift, scale, gate = emb_out.chunk(3, dim=-1)
202
+ x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
203
+ return x, gate
204
+
205
+
206
+ class Attention(nn.Module):
207
+ """Multi-head attention with optional RoPE."""
208
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int):
209
+ super().__init__()
210
+ self.num_heads = num_heads
211
+ self.head_dim = head_dim
212
+ self.scale = head_dim ** -0.5
213
+
214
+ self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=False)
215
+ self.out_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
216
+
217
+ def forward(
218
+ self,
219
+ x: torch.Tensor,
220
+ rope: Optional[torch.Tensor] = None,
221
+ mask: Optional[torch.Tensor] = None
222
+ ) -> torch.Tensor:
223
+ B, N, _ = x.shape
224
+
225
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
226
+ q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3 x (B, heads, N, head_dim)
227
+
228
+ if rope is not None:
229
+ q = apply_rope(q, rope)
230
+ k = apply_rope(k, rope)
231
+
232
+ # Scaled dot-product attention
233
+ attn = (q @ k.transpose(-2, -1)) * self.scale
234
+ if mask is not None:
235
+ attn = attn.masked_fill(mask == 0, float('-inf'))
236
+ attn = attn.softmax(dim=-1)
237
+
238
+ out = (attn @ v).transpose(1, 2).reshape(B, N, -1)
239
+ return self.out_proj(out)
240
+
241
+
242
+ class JointAttention(nn.Module):
243
+ """Joint attention for double-stream blocks (separate Q,K,V for txt and img)."""
244
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int):
245
+ super().__init__()
246
+ self.num_heads = num_heads
247
+ self.head_dim = head_dim
248
+ self.scale = head_dim ** -0.5
249
+
250
+ # Separate projections for text and image
251
+ self.txt_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=False)
252
+ self.img_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=False)
253
+
254
+ self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
255
+ self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
256
+
257
+ def forward(
258
+ self,
259
+ txt: torch.Tensor,
260
+ img: torch.Tensor,
261
+ rope: Optional[torch.Tensor] = None,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
263
+ B, L, _ = txt.shape
264
+ _, N, _ = img.shape
265
+
266
+ # Compute Q, K, V for both streams
267
+ txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim)
268
+ img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim)
269
+
270
+ txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4)
271
+ img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4)
272
+
273
+ # Apply RoPE to image queries/keys only (text doesn't have positions)
274
+ if rope is not None:
275
+ img_q = apply_rope(img_q, rope)
276
+ img_k = apply_rope(img_k, rope)
277
+
278
+ # Concatenate keys and values for joint attention
279
+ k = torch.cat([txt_k, img_k], dim=2) # (B, heads, L+N, head_dim)
280
+ v = torch.cat([txt_v, img_v], dim=2)
281
+
282
+ # Text attends to all
283
+ txt_attn = (txt_q @ k.transpose(-2, -1)) * self.scale
284
+ txt_attn = txt_attn.softmax(dim=-1)
285
+ txt_out = (txt_attn @ v).transpose(1, 2).reshape(B, L, -1)
286
+
287
+ # Image attends to all
288
+ img_attn = (img_q @ k.transpose(-2, -1)) * self.scale
289
+ img_attn = img_attn.softmax(dim=-1)
290
+ img_out = (img_attn @ v).transpose(1, 2).reshape(B, N, -1)
291
+
292
+ return self.txt_out(txt_out), self.img_out(img_out)
293
+
294
+
295
+ class MLP(nn.Module):
296
+ """Feed-forward network."""
297
+ def __init__(self, hidden_size: int, mlp_ratio: float = 4.0):
298
+ super().__init__()
299
+ mlp_hidden = int(hidden_size * mlp_ratio)
300
+ self.fc1 = nn.Linear(hidden_size, mlp_hidden)
301
+ self.act = nn.GELU(approximate='tanh')
302
+ self.fc2 = nn.Linear(mlp_hidden, hidden_size)
303
+
304
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
305
+ return self.fc2(self.act(self.fc1(x)))
306
+
307
+
308
+ class DoubleStreamBlock(nn.Module):
309
+ """
310
+ Double-stream transformer block (MMDiT style).
311
+ Text and image have separate weights but attend to each other.
312
+ Uses AdaLN-Zero with 6 modulation params per stream.
313
+ """
314
+ def __init__(self, config: TinyFluxConfig):
315
+ super().__init__()
316
+ hidden = config.hidden_size
317
+ heads = config.num_attention_heads
318
+ head_dim = config.attention_head_dim
319
+ mlp_hidden = int(hidden * config.mlp_ratio)
320
+
321
+ # AdaLN-Zero for each stream (outputs 6 params each)
322
+ self.img_norm1 = AdaLayerNormZero(hidden)
323
+ self.txt_norm1 = AdaLayerNormZero(hidden)
324
+
325
+ # Joint attention (separate QKV projections)
326
+ self.attn = JointAttention(hidden, heads, head_dim)
327
+
328
+ # Second norm for MLP (not adaptive, uses params from norm1)
329
+ self.img_norm2 = RMSNorm(hidden)
330
+ self.txt_norm2 = RMSNorm(hidden)
331
+
332
+ # MLPs
333
+ self.img_mlp = MLP(hidden, config.mlp_ratio)
334
+ self.txt_mlp = MLP(hidden, config.mlp_ratio)
335
+
336
+ def forward(
337
+ self,
338
+ txt: torch.Tensor,
339
+ img: torch.Tensor,
340
+ vec: torch.Tensor,
341
+ rope: Optional[torch.Tensor] = None,
342
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
343
+ # Image stream: norm + modulation, get MLP params for later
344
+ img_normed, img_gate_msa, img_shift_mlp, img_scale_mlp, img_gate_mlp = self.img_norm1(img, vec)
345
+
346
+ # Text stream: norm + modulation, get MLP params for later
347
+ txt_normed, txt_gate_msa, txt_shift_mlp, txt_scale_mlp, txt_gate_mlp = self.txt_norm1(txt, vec)
348
+
349
+ # Joint attention
350
+ txt_attn_out, img_attn_out = self.attn(txt_normed, img_normed, rope)
351
+
352
+ # Residual with gate
353
+ txt = txt + txt_gate_msa.unsqueeze(1) * txt_attn_out
354
+ img = img + img_gate_msa.unsqueeze(1) * img_attn_out
355
+
356
+ # MLP with modulation (using params from norm1)
357
+ txt_mlp_in = self.txt_norm2(txt) * (1 + txt_scale_mlp.unsqueeze(1)) + txt_shift_mlp.unsqueeze(1)
358
+ img_mlp_in = self.img_norm2(img) * (1 + img_scale_mlp.unsqueeze(1)) + img_shift_mlp.unsqueeze(1)
359
+
360
+ txt = txt + txt_gate_mlp.unsqueeze(1) * self.txt_mlp(txt_mlp_in)
361
+ img = img + img_gate_mlp.unsqueeze(1) * self.img_mlp(img_mlp_in)
362
+
363
+ return txt, img
364
+
365
+
366
+ class SingleStreamBlock(nn.Module):
367
+ """
368
+ Single-stream transformer block.
369
+ Text and image are concatenated and share weights.
370
+ Uses AdaLN-Zero with 3 modulation params (no separate MLP modulation).
371
+ """
372
+ def __init__(self, config: TinyFluxConfig):
373
+ super().__init__()
374
+ hidden = config.hidden_size
375
+ heads = config.num_attention_heads
376
+ head_dim = config.attention_head_dim
377
+ mlp_hidden = int(hidden * config.mlp_ratio)
378
+
379
+ # AdaLN-Zero (outputs 3 params: shift, scale, gate)
380
+ self.norm = AdaLayerNormZeroSingle(hidden)
381
+
382
+ # Combined QKV + MLP projection (Flux fuses these)
383
+ # Linear attention: QKV projection
384
+ self.attn = Attention(hidden, heads, head_dim)
385
+
386
+ # MLP
387
+ self.mlp = MLP(hidden, config.mlp_ratio)
388
+
389
+ # Pre-MLP norm (not modulated in single-stream)
390
+ self.norm2 = RMSNorm(hidden)
391
+
392
+ def forward(
393
+ self,
394
+ txt: torch.Tensor,
395
+ img: torch.Tensor,
396
+ vec: torch.Tensor,
397
+ txt_rope: Optional[torch.Tensor] = None,
398
+ img_rope: Optional[torch.Tensor] = None,
399
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
400
+ L = txt.shape[1]
401
+
402
+ # Concatenate txt and img
403
+ x = torch.cat([txt, img], dim=1)
404
+
405
+ # Concatenate RoPE (zeros for text positions)
406
+ if img_rope is not None:
407
+ B, N, D = img_rope.shape
408
+ txt_rope_zeros = torch.zeros(B, L, D, device=img_rope.device, dtype=img_rope.dtype)
409
+ rope = torch.cat([txt_rope_zeros, img_rope], dim=1)
410
+ else:
411
+ rope = None
412
+
413
+ # Norm + modulation (only 3 params for single stream)
414
+ x_normed, gate = self.norm(x, vec)
415
+
416
+ # Attention with gated residual
417
+ x = x + gate.unsqueeze(1) * self.attn(x_normed, rope)
418
+
419
+ # MLP (no separate modulation in single-stream Flux)
420
+ x = x + self.mlp(self.norm2(x))
421
+
422
+ # Split back
423
+ txt, img = x.split([L, x.shape[1] - L], dim=1)
424
+ return txt, img
425
+
426
+
427
+ class TinyFlux(nn.Module):
428
+ """
429
+ TinyFlux: A scaled-down Flux diffusion transformer.
430
+
431
+ Scaling: /12 from original Flux
432
+ - hidden: 3072 → 256
433
+ - heads: 24 → 2
434
+ - head_dim: 128 (preserved)
435
+ - in_channels: 16 (Flux VAE)
436
+ """
437
+ def __init__(self, config: Optional[TinyFluxConfig] = None):
438
+ super().__init__()
439
+ self.config = config or TinyFluxConfig()
440
+ cfg = self.config
441
+
442
+ # Input projections
443
+ self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size)
444
+ self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size)
445
+
446
+ # Conditioning projections
447
+ self.time_in = MLPEmbedder(cfg.hidden_size)
448
+ self.vector_in = nn.Sequential(
449
+ nn.SiLU(),
450
+ nn.Linear(cfg.pooled_projection_dim, cfg.hidden_size)
451
+ )
452
+ if cfg.guidance_embeds:
453
+ self.guidance_in = MLPEmbedder(cfg.hidden_size)
454
+
455
+ # RoPE
456
+ self.rope = RotaryEmbedding(cfg.attention_head_dim, cfg.axes_dims_rope)
457
+
458
+ # Transformer blocks
459
+ self.double_blocks = nn.ModuleList([
460
+ DoubleStreamBlock(cfg) for _ in range(cfg.num_double_layers)
461
+ ])
462
+ self.single_blocks = nn.ModuleList([
463
+ SingleStreamBlock(cfg) for _ in range(cfg.num_single_layers)
464
+ ])
465
+
466
+ # Output
467
+ self.final_norm = RMSNorm(cfg.hidden_size)
468
+ self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels)
469
+
470
+ self._init_weights()
471
+
472
+ def _init_weights(self):
473
+ """Initialize weights."""
474
+ def _init(module):
475
+ if isinstance(module, nn.Linear):
476
+ nn.init.xavier_uniform_(module.weight)
477
+ if module.bias is not None:
478
+ nn.init.zeros_(module.bias)
479
+ self.apply(_init)
480
+
481
+ # Zero-init output projection for residual
482
+ nn.init.zeros_(self.final_linear.weight)
483
+
484
+ def forward(
485
+ self,
486
+ hidden_states: torch.Tensor, # (B, N, in_channels) - image patches
487
+ encoder_hidden_states: torch.Tensor, # (B, L, joint_attention_dim) - T5 tokens
488
+ pooled_projections: torch.Tensor, # (B, pooled_projection_dim) - CLIP pooled
489
+ timestep: torch.Tensor, # (B,) - diffusion timestep
490
+ img_ids: torch.Tensor, # (B, N, 3) - image position ids
491
+ guidance: Optional[torch.Tensor] = None, # (B,) - guidance scale
492
+ ) -> torch.Tensor:
493
+ """
494
+ Forward pass.
495
+
496
+ Returns:
497
+ Predicted noise/velocity of shape (B, N, in_channels)
498
+ """
499
+ # Input projections
500
+ img = self.img_in(hidden_states) # (B, N, hidden)
501
+ txt = self.txt_in(encoder_hidden_states) # (B, L, hidden)
502
+
503
+ # Conditioning vector
504
+ vec = self.time_in(timestep)
505
+ vec = vec + self.vector_in(pooled_projections)
506
+ if self.config.guidance_embeds and guidance is not None:
507
+ vec = vec + self.guidance_in(guidance)
508
+
509
+ # RoPE for image positions
510
+ img_rope = self.rope(img_ids)
511
+
512
+ # Double-stream blocks
513
+ for block in self.double_blocks:
514
+ txt, img = block(txt, img, vec, img_rope)
515
+
516
+ # Single-stream blocks
517
+ for block in self.single_blocks:
518
+ txt, img = block(txt, img, vec, img_rope=img_rope)
519
+
520
+ # Output (image only)
521
+ img = self.final_norm(img)
522
+ img = self.final_linear(img)
523
+
524
+ return img
525
+
526
+ @staticmethod
527
+ def create_img_ids(batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
528
+ """Create image position IDs for RoPE."""
529
+ # height, width are in latent space (image_size / 8)
530
+ img_ids = torch.zeros(batch_size, height * width, 3, device=device)
531
+
532
+ for i in range(height):
533
+ for j in range(width):
534
+ idx = i * width + j
535
+ img_ids[:, idx, 0] = 0 # temporal (always 0 for images)
536
+ img_ids[:, idx, 1] = i # height
537
+ img_ids[:, idx, 2] = j # width
538
+
539
+ return img_ids
540
+
541
+ def count_parameters(self) -> dict:
542
+ """Count parameters by component."""
543
+ counts = {}
544
+ counts['img_in'] = sum(p.numel() for p in self.img_in.parameters())
545
+ counts['txt_in'] = sum(p.numel() for p in self.txt_in.parameters())
546
+ counts['time_in'] = sum(p.numel() for p in self.time_in.parameters())
547
+ counts['vector_in'] = sum(p.numel() for p in self.vector_in.parameters())
548
+ if hasattr(self, 'guidance_in'):
549
+ counts['guidance_in'] = sum(p.numel() for p in self.guidance_in.parameters())
550
+ counts['double_blocks'] = sum(p.numel() for p in self.double_blocks.parameters())
551
+ counts['single_blocks'] = sum(p.numel() for p in self.single_blocks.parameters())
552
+ counts['final'] = sum(p.numel() for p in self.final_norm.parameters()) + \
553
+ sum(p.numel() for p in self.final_linear.parameters())
554
+ counts['total'] = sum(p.numel() for p in self.parameters())
555
+ return counts
556
+
557
+
558
+ def test_tiny_flux():
559
+ """Quick test of the model."""
560
+ print("=" * 60)
561
+ print("TinyFlux Model Test")
562
+ print("=" * 60)
563
+
564
+ config = TinyFluxConfig()
565
+ print(f"\nConfig:")
566
+ print(f" hidden_size: {config.hidden_size}")
567
+ print(f" num_heads: {config.num_attention_heads}")
568
+ print(f" head_dim: {config.attention_head_dim}")
569
+ print(f" in_channels: {config.in_channels}")
570
+ print(f" double_layers: {config.num_double_layers}")
571
+ print(f" single_layers: {config.num_single_layers}")
572
+ print(f" joint_attention_dim: {config.joint_attention_dim}")
573
+ print(f" pooled_projection_dim: {config.pooled_projection_dim}")
574
+
575
+ model = TinyFlux(config)
576
+
577
+ # Count parameters
578
+ counts = model.count_parameters()
579
+ print(f"\nParameters:")
580
+ for name, count in counts.items():
581
+ print(f" {name}: {count:,} ({count/1e6:.2f}M)")
582
+
583
+ # Test forward pass
584
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
585
+ model = model.to(device)
586
+
587
+ batch_size = 2
588
+ latent_h, latent_w = 64, 64 # 512x512 image / 8
589
+ num_patches = latent_h * latent_w
590
+ text_len = 77
591
+
592
+ # Create dummy inputs
593
+ hidden_states = torch.randn(batch_size, num_patches, config.in_channels, device=device)
594
+ encoder_hidden_states = torch.randn(batch_size, text_len, config.joint_attention_dim, device=device)
595
+ pooled_projections = torch.randn(batch_size, config.pooled_projection_dim, device=device)
596
+ timestep = torch.rand(batch_size, device=device)
597
+ img_ids = TinyFlux.create_img_ids(batch_size, latent_h, latent_w, device)
598
+ guidance = torch.ones(batch_size, device=device) * 3.5
599
+
600
+ print(f"\nInput shapes:")
601
+ print(f" hidden_states: {hidden_states.shape}")
602
+ print(f" encoder_hidden_states: {encoder_hidden_states.shape}")
603
+ print(f" pooled_projections: {pooled_projections.shape}")
604
+ print(f" img_ids: {img_ids.shape}")
605
+
606
+ # Forward pass
607
+ with torch.no_grad():
608
+ output = model(
609
+ hidden_states=hidden_states,
610
+ encoder_hidden_states=encoder_hidden_states,
611
+ pooled_projections=pooled_projections,
612
+ timestep=timestep,
613
+ img_ids=img_ids,
614
+ guidance=guidance,
615
+ )
616
+
617
+ print(f"\nOutput shape: {output.shape}")
618
+ print(f"Output range: [{output.min():.4f}, {output.max():.4f}]")
619
+ print("\n✓ Forward pass successful!")
620
+
621
+
622
+ if __name__ == "__main__":
623
+ test_tiny_flux()