AbstractPhil commited on
Commit
25315af
·
verified ·
1 Parent(s): eac9446

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +587 -0
model.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TinyFlux: A /12 scaled Flux architecture for experimentation.
3
+ OPTIMIZED VERSION - Flash Attention, vectorized RoPE, caching
4
+
5
+ Architecture:
6
+ - hidden: 256 (3072/12)
7
+ - num_heads: 2 (24/12)
8
+ - head_dim: 128 (preserved for RoPE compatibility)
9
+ - in_channels: 16 (Flux VAE output channels)
10
+ - double_layers: 3
11
+ - single_layers: 3
12
+
13
+ Optimizations:
14
+ - Flash Attention (F.scaled_dot_product_attention)
15
+ - Vectorized RoPE with precomputed frequencies
16
+ - Vectorized img_ids creation (no Python loops)
17
+ - Caching for img_ids and RoPE embeddings
18
+ - Precomputed sinusoidal embeddings
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import math
25
+ from dataclasses import dataclass
26
+ from typing import Optional, Tuple, Dict
27
+
28
+
29
+ @dataclass
30
+ class TinyFluxConfig:
31
+ """Configuration for TinyFlux model."""
32
+ # Core dimensions
33
+ hidden_size: int = 256
34
+ num_attention_heads: int = 2
35
+ attention_head_dim: int = 128 # Preserved for RoPE
36
+
37
+ # Input/output (Flux VAE has 16 channels)
38
+ in_channels: int = 16
39
+ patch_size: int = 1
40
+
41
+ # Text encoder interfaces
42
+ joint_attention_dim: int = 768 # flan-t5-base output dim
43
+ pooled_projection_dim: int = 768 # CLIP-L pooled dim
44
+
45
+ # Layers
46
+ num_double_layers: int = 3
47
+ num_single_layers: int = 3
48
+
49
+ # MLP
50
+ mlp_ratio: float = 4.0
51
+
52
+ # RoPE (must sum to head_dim)
53
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
54
+
55
+ # Misc
56
+ guidance_embeds: bool = True
57
+
58
+ def __post_init__(self):
59
+ assert self.num_attention_heads * self.attention_head_dim == self.hidden_size, \
60
+ f"heads ({self.num_attention_heads}) * head_dim ({self.attention_head_dim}) != hidden ({self.hidden_size})"
61
+ assert sum(self.axes_dims_rope) == self.attention_head_dim, \
62
+ f"RoPE dims {self.axes_dims_rope} must sum to head_dim {self.attention_head_dim}"
63
+
64
+
65
+ class RMSNorm(nn.Module):
66
+ """Root Mean Square Layer Normalization."""
67
+
68
+ def __init__(self, dim: int, eps: float = 1e-6):
69
+ super().__init__()
70
+ self.eps = eps
71
+ self.weight = nn.Parameter(torch.ones(dim))
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
75
+ return (x * norm).type_as(x) * self.weight
76
+
77
+
78
+ class RotaryEmbedding(nn.Module):
79
+ """Rotary Position Embedding - OPTIMIZED with precomputed frequencies."""
80
+
81
+ def __init__(self, dim: int, axes_dims: Tuple[int, int, int], theta: float = 10000.0):
82
+ super().__init__()
83
+ self.dim = dim
84
+ self.axes_dims = axes_dims
85
+ self.theta = theta
86
+
87
+ # Precompute frequencies for each axis (no loop at runtime)
88
+ for i, axis_dim in enumerate(axes_dims):
89
+ freqs = 1.0 / (theta ** (torch.arange(0, axis_dim, 2).float() / axis_dim))
90
+ self.register_buffer(f'freqs_{i}', freqs)
91
+
92
+ def forward(self, ids: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor:
93
+ """
94
+ ids: (B, N, 3) - temporal, height, width indices
95
+ Returns: (B, N, dim) rotary embeddings
96
+ """
97
+ B, N, _ = ids.shape
98
+ output_dtype = dtype if dtype is not None else ids.dtype
99
+
100
+ # Extract positions for each axis
101
+ pos0 = ids[:, :, 0:1].float() # (B, N, 1)
102
+ pos1 = ids[:, :, 1:2].float()
103
+ pos2 = ids[:, :, 2:3].float()
104
+
105
+ # Compute angles (broadcasting: (B, N, 1) * (axis_dim/2,) -> (B, N, axis_dim/2))
106
+ angles0 = pos0 * self.freqs_0
107
+ angles1 = pos1 * self.freqs_1
108
+ angles2 = pos2 * self.freqs_2
109
+
110
+ # Stack sin/cos and flatten for each axis
111
+ emb0 = torch.stack([angles0.cos(), angles0.sin()], dim=-1).flatten(-2)
112
+ emb1 = torch.stack([angles1.cos(), angles1.sin()], dim=-1).flatten(-2)
113
+ emb2 = torch.stack([angles2.cos(), angles2.sin()], dim=-1).flatten(-2)
114
+
115
+ return torch.cat([emb0, emb1, emb2], dim=-1).to(output_dtype)
116
+
117
+
118
+ def apply_rope(x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
119
+ """Apply rotary embeddings to input tensor."""
120
+ # x: (B, heads, N, head_dim)
121
+ # rope: (B, N, head_dim)
122
+ B, H, N, D = x.shape
123
+
124
+ rope = rope.to(x.dtype).unsqueeze(1) # (B, 1, N, D)
125
+
126
+ # Split into pairs
127
+ x_pairs = x.reshape(B, H, N, D // 2, 2)
128
+ rope_pairs = rope.reshape(B, 1, N, D // 2, 2)
129
+
130
+ cos = rope_pairs[..., 0]
131
+ sin = rope_pairs[..., 1]
132
+
133
+ x0 = x_pairs[..., 0]
134
+ x1 = x_pairs[..., 1]
135
+
136
+ out0 = x0 * cos - x1 * sin
137
+ out1 = x1 * cos + x0 * sin
138
+
139
+ return torch.stack([out0, out1], dim=-1).flatten(-2)
140
+
141
+
142
+ class MLPEmbedder(nn.Module):
143
+ """MLP for embedding scalars - OPTIMIZED with precomputed basis."""
144
+
145
+ def __init__(self, hidden_size: int):
146
+ super().__init__()
147
+ self.mlp = nn.Sequential(
148
+ nn.Linear(256, hidden_size),
149
+ nn.SiLU(),
150
+ nn.Linear(hidden_size, hidden_size),
151
+ )
152
+ # Precompute sinusoidal basis
153
+ half_dim = 128
154
+ emb = math.log(10000) / (half_dim - 1)
155
+ emb = torch.exp(torch.arange(half_dim) * -emb)
156
+ self.register_buffer('sin_basis', emb)
157
+
158
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
159
+ # Use precomputed basis
160
+ emb = x.unsqueeze(-1) * self.sin_basis.to(x.dtype)
161
+ emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
162
+ return self.mlp(emb)
163
+
164
+
165
+ class AdaLayerNormZero(nn.Module):
166
+ """AdaLN-Zero for double-stream blocks."""
167
+
168
+ def __init__(self, hidden_size: int):
169
+ super().__init__()
170
+ self.silu = nn.SiLU()
171
+ self.linear = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
172
+ self.norm = RMSNorm(hidden_size)
173
+
174
+ def forward(
175
+ self, x: torch.Tensor, emb: torch.Tensor
176
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
177
+ emb_out = self.linear(self.silu(emb))
178
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb_out.chunk(6, dim=-1)
179
+ x = self.norm(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
180
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
181
+
182
+
183
+ class AdaLayerNormZeroSingle(nn.Module):
184
+ """AdaLN-Zero for single-stream blocks."""
185
+
186
+ def __init__(self, hidden_size: int):
187
+ super().__init__()
188
+ self.silu = nn.SiLU()
189
+ self.linear = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
190
+ self.norm = RMSNorm(hidden_size)
191
+
192
+ def forward(
193
+ self, x: torch.Tensor, emb: torch.Tensor
194
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
195
+ emb_out = self.linear(self.silu(emb))
196
+ shift, scale, gate = emb_out.chunk(3, dim=-1)
197
+ x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
198
+ return x, gate
199
+
200
+
201
+ class Attention(nn.Module):
202
+ """Multi-head attention - OPTIMIZED with Flash Attention."""
203
+
204
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int):
205
+ super().__init__()
206
+ self.num_heads = num_heads
207
+ self.head_dim = head_dim
208
+ self.scale = head_dim ** -0.5
209
+
210
+ self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=False)
211
+ self.out_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
212
+
213
+ def forward(
214
+ self,
215
+ x: torch.Tensor,
216
+ rope: Optional[torch.Tensor] = None,
217
+ mask: Optional[torch.Tensor] = None
218
+ ) -> torch.Tensor:
219
+ B, N, _ = x.shape
220
+ dtype = x.dtype
221
+
222
+ if rope is not None:
223
+ rope = rope.to(dtype)
224
+
225
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
226
+ q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3 x (B, heads, N, head_dim)
227
+
228
+ if rope is not None:
229
+ q = apply_rope(q, rope)
230
+ k = apply_rope(k, rope)
231
+
232
+ # Flash Attention - faster and memory efficient
233
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, scale=self.scale)
234
+ out = out.transpose(1, 2).reshape(B, N, -1)
235
+ return self.out_proj(out)
236
+
237
+
238
+ class JointAttention(nn.Module):
239
+ """Joint attention - OPTIMIZED with Flash Attention."""
240
+
241
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int):
242
+ super().__init__()
243
+ self.num_heads = num_heads
244
+ self.head_dim = head_dim
245
+ self.scale = head_dim ** -0.5
246
+
247
+ self.txt_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=False)
248
+ self.img_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=False)
249
+
250
+ self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
251
+ self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
252
+
253
+ def forward(
254
+ self,
255
+ txt: torch.Tensor,
256
+ img: torch.Tensor,
257
+ rope: Optional[torch.Tensor] = None,
258
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
259
+ B, L, _ = txt.shape
260
+ _, N, _ = img.shape
261
+
262
+ dtype = img.dtype
263
+ txt = txt.to(dtype)
264
+ if rope is not None:
265
+ rope = rope.to(dtype)
266
+
267
+ # Compute Q, K, V for both streams
268
+ txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim)
269
+ img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim)
270
+
271
+ txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4)
272
+ img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4)
273
+
274
+ # Apply RoPE to image only
275
+ if rope is not None:
276
+ img_q = apply_rope(img_q, rope)
277
+ img_k = apply_rope(img_k, rope)
278
+
279
+ # Concatenate for joint attention
280
+ k = torch.cat([txt_k, img_k], dim=2)
281
+ v = torch.cat([txt_v, img_v], dim=2)
282
+
283
+ # Flash Attention for both streams
284
+ txt_out = F.scaled_dot_product_attention(txt_q, k, v, scale=self.scale)
285
+ img_out = F.scaled_dot_product_attention(img_q, k, v, scale=self.scale)
286
+
287
+ txt_out = txt_out.transpose(1, 2).reshape(B, L, -1)
288
+ img_out = img_out.transpose(1, 2).reshape(B, N, -1)
289
+
290
+ return self.txt_out(txt_out), self.img_out(img_out)
291
+
292
+
293
+ class MLP(nn.Module):
294
+ """Feed-forward network."""
295
+
296
+ def __init__(self, hidden_size: int, mlp_ratio: float = 4.0):
297
+ super().__init__()
298
+ mlp_hidden = int(hidden_size * mlp_ratio)
299
+ self.fc1 = nn.Linear(hidden_size, mlp_hidden)
300
+ self.act = nn.GELU(approximate='tanh')
301
+ self.fc2 = nn.Linear(mlp_hidden, hidden_size)
302
+
303
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
304
+ return self.fc2(self.act(self.fc1(x)))
305
+
306
+
307
+ class DoubleStreamBlock(nn.Module):
308
+ """Double-stream transformer block (MMDiT style)."""
309
+
310
+ def __init__(self, config: TinyFluxConfig):
311
+ super().__init__()
312
+ hidden = config.hidden_size
313
+ heads = config.num_attention_heads
314
+ head_dim = config.attention_head_dim
315
+
316
+ self.img_norm1 = AdaLayerNormZero(hidden)
317
+ self.txt_norm1 = AdaLayerNormZero(hidden)
318
+ self.attn = JointAttention(hidden, heads, head_dim)
319
+ self.img_norm2 = RMSNorm(hidden)
320
+ self.txt_norm2 = RMSNorm(hidden)
321
+ self.img_mlp = MLP(hidden, config.mlp_ratio)
322
+ self.txt_mlp = MLP(hidden, config.mlp_ratio)
323
+
324
+ def forward(
325
+ self,
326
+ txt: torch.Tensor,
327
+ img: torch.Tensor,
328
+ vec: torch.Tensor,
329
+ rope: Optional[torch.Tensor] = None,
330
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
331
+ img_normed, img_gate_msa, img_shift_mlp, img_scale_mlp, img_gate_mlp = self.img_norm1(img, vec)
332
+ txt_normed, txt_gate_msa, txt_shift_mlp, txt_scale_mlp, txt_gate_mlp = self.txt_norm1(txt, vec)
333
+
334
+ txt_attn_out, img_attn_out = self.attn(txt_normed, img_normed, rope)
335
+
336
+ txt = txt + txt_gate_msa.unsqueeze(1) * txt_attn_out
337
+ img = img + img_gate_msa.unsqueeze(1) * img_attn_out
338
+
339
+ txt_mlp_in = self.txt_norm2(txt) * (1 + txt_scale_mlp.unsqueeze(1)) + txt_shift_mlp.unsqueeze(1)
340
+ img_mlp_in = self.img_norm2(img) * (1 + img_scale_mlp.unsqueeze(1)) + img_shift_mlp.unsqueeze(1)
341
+
342
+ txt = txt + txt_gate_mlp.unsqueeze(1) * self.txt_mlp(txt_mlp_in)
343
+ img = img + img_gate_mlp.unsqueeze(1) * self.img_mlp(img_mlp_in)
344
+
345
+ return txt, img
346
+
347
+
348
+ class SingleStreamBlock(nn.Module):
349
+ """Single-stream transformer block."""
350
+
351
+ def __init__(self, config: TinyFluxConfig):
352
+ super().__init__()
353
+ hidden = config.hidden_size
354
+ heads = config.num_attention_heads
355
+ head_dim = config.attention_head_dim
356
+
357
+ self.norm = AdaLayerNormZeroSingle(hidden)
358
+ self.attn = Attention(hidden, heads, head_dim)
359
+ self.mlp = MLP(hidden, config.mlp_ratio)
360
+ self.norm2 = RMSNorm(hidden)
361
+
362
+ def forward(
363
+ self,
364
+ txt: torch.Tensor,
365
+ img: torch.Tensor,
366
+ vec: torch.Tensor,
367
+ txt_rope: Optional[torch.Tensor] = None,
368
+ img_rope: Optional[torch.Tensor] = None,
369
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
370
+ L = txt.shape[1]
371
+
372
+ x = torch.cat([txt, img], dim=1)
373
+
374
+ if img_rope is not None:
375
+ B, N, D = img_rope.shape
376
+ txt_rope_zeros = torch.zeros(B, L, D, device=img_rope.device, dtype=img_rope.dtype)
377
+ rope = torch.cat([txt_rope_zeros, img_rope], dim=1)
378
+ else:
379
+ rope = None
380
+
381
+ x_normed, gate = self.norm(x, vec)
382
+ x = x + gate.unsqueeze(1) * self.attn(x_normed, rope)
383
+ x = x + self.mlp(self.norm2(x))
384
+
385
+ txt, img = x.split([L, x.shape[1] - L], dim=1)
386
+ return txt, img
387
+
388
+
389
+ # Global cache for img_ids (they don't change for same resolution)
390
+ _IMG_IDS_CACHE: Dict[Tuple, torch.Tensor] = {}
391
+
392
+
393
+ class TinyFlux(nn.Module):
394
+ """
395
+ TinyFlux: A scaled-down Flux diffusion transformer.
396
+ OPTIMIZED with Flash Attention, vectorized ops, and caching.
397
+ """
398
+
399
+ def __init__(self, config: Optional[TinyFluxConfig] = None):
400
+ super().__init__()
401
+ self.config = config or TinyFluxConfig()
402
+ cfg = self.config
403
+
404
+ # Input projections
405
+ self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size)
406
+ self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size)
407
+
408
+ # Conditioning projections
409
+ self.time_in = MLPEmbedder(cfg.hidden_size)
410
+ self.vector_in = nn.Sequential(
411
+ nn.SiLU(),
412
+ nn.Linear(cfg.pooled_projection_dim, cfg.hidden_size)
413
+ )
414
+ if cfg.guidance_embeds:
415
+ self.guidance_in = MLPEmbedder(cfg.hidden_size)
416
+
417
+ # RoPE
418
+ self.rope = RotaryEmbedding(cfg.attention_head_dim, cfg.axes_dims_rope)
419
+
420
+ # Transformer blocks
421
+ self.double_blocks = nn.ModuleList([
422
+ DoubleStreamBlock(cfg) for _ in range(cfg.num_double_layers)
423
+ ])
424
+ self.single_blocks = nn.ModuleList([
425
+ SingleStreamBlock(cfg) for _ in range(cfg.num_single_layers)
426
+ ])
427
+
428
+ # Output
429
+ self.final_norm = RMSNorm(cfg.hidden_size)
430
+ self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels)
431
+
432
+ # RoPE cache
433
+ self._rope_cache: Dict[Tuple, torch.Tensor] = {}
434
+
435
+ self._init_weights()
436
+
437
+ def _init_weights(self):
438
+ """Initialize weights."""
439
+ def _init(module):
440
+ if isinstance(module, nn.Linear):
441
+ nn.init.xavier_uniform_(module.weight)
442
+ if module.bias is not None:
443
+ nn.init.zeros_(module.bias)
444
+
445
+ self.apply(_init)
446
+ nn.init.zeros_(self.final_linear.weight)
447
+
448
+ def forward(
449
+ self,
450
+ hidden_states: torch.Tensor,
451
+ encoder_hidden_states: torch.Tensor,
452
+ pooled_projections: torch.Tensor,
453
+ timestep: torch.Tensor,
454
+ img_ids: torch.Tensor,
455
+ guidance: Optional[torch.Tensor] = None,
456
+ ) -> torch.Tensor:
457
+ """Forward pass."""
458
+ # Input projections
459
+ img = self.img_in(hidden_states)
460
+ txt = self.txt_in(encoder_hidden_states)
461
+
462
+ # Conditioning vector
463
+ vec = self.time_in(timestep)
464
+ vec = vec + self.vector_in(pooled_projections)
465
+ if self.config.guidance_embeds and guidance is not None:
466
+ vec = vec + self.guidance_in(guidance)
467
+
468
+ # RoPE for image positions
469
+ img_rope = self.rope(img_ids, dtype=img.dtype)
470
+
471
+ # Double-stream blocks
472
+ for block in self.double_blocks:
473
+ txt, img = block(txt, img, vec, img_rope)
474
+
475
+ # Single-stream blocks
476
+ for block in self.single_blocks:
477
+ txt, img = block(txt, img, vec, img_rope=img_rope)
478
+
479
+ # Output
480
+ img = self.final_norm(img)
481
+ img = self.final_linear(img)
482
+
483
+ return img
484
+
485
+ @staticmethod
486
+ def create_img_ids(batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
487
+ """Create image position IDs - VECTORIZED (no Python loops)."""
488
+ global _IMG_IDS_CACHE
489
+
490
+ # Check cache first
491
+ cache_key = (batch_size, height, width, device)
492
+ if cache_key in _IMG_IDS_CACHE:
493
+ return _IMG_IDS_CACHE[cache_key]
494
+
495
+ # Vectorized creation using meshgrid
496
+ h_ids = torch.arange(height, device=device, dtype=torch.float32)
497
+ w_ids = torch.arange(width, device=device, dtype=torch.float32)
498
+
499
+ grid_h, grid_w = torch.meshgrid(h_ids, w_ids, indexing='ij')
500
+
501
+ # Stack: (H*W, 3) with [temporal=0, height, width]
502
+ img_ids = torch.stack([
503
+ torch.zeros(height * width, device=device), # temporal
504
+ grid_h.flatten(),
505
+ grid_w.flatten(),
506
+ ], dim=-1)
507
+
508
+ # Expand for batch
509
+ img_ids = img_ids.unsqueeze(0).expand(batch_size, -1, -1)
510
+
511
+ # Cache it
512
+ _IMG_IDS_CACHE[cache_key] = img_ids
513
+
514
+ return img_ids
515
+
516
+ def count_parameters(self) -> dict:
517
+ """Count parameters by component."""
518
+ counts = {}
519
+ counts['img_in'] = sum(p.numel() for p in self.img_in.parameters())
520
+ counts['txt_in'] = sum(p.numel() for p in self.txt_in.parameters())
521
+ counts['time_in'] = sum(p.numel() for p in self.time_in.parameters())
522
+ counts['vector_in'] = sum(p.numel() for p in self.vector_in.parameters())
523
+ if hasattr(self, 'guidance_in'):
524
+ counts['guidance_in'] = sum(p.numel() for p in self.guidance_in.parameters())
525
+ counts['double_blocks'] = sum(p.numel() for p in self.double_blocks.parameters())
526
+ counts['single_blocks'] = sum(p.numel() for p in self.single_blocks.parameters())
527
+ counts['final'] = sum(p.numel() for p in self.final_norm.parameters()) + \
528
+ sum(p.numel() for p in self.final_linear.parameters())
529
+ counts['total'] = sum(p.numel() for p in self.parameters())
530
+ return counts
531
+
532
+
533
+ def test_tiny_flux():
534
+ """Quick test of the optimized model."""
535
+ print("=" * 60)
536
+ print("TinyFlux OPTIMIZED Model Test")
537
+ print("=" * 60)
538
+
539
+ config = TinyFluxConfig()
540
+ print(f"\nConfig:")
541
+ print(f" hidden_size: {config.hidden_size}")
542
+ print(f" num_heads: {config.num_attention_heads}")
543
+ print(f" head_dim: {config.attention_head_dim}")
544
+
545
+ model = TinyFlux(config)
546
+
547
+ counts = model.count_parameters()
548
+ print(f"\nParameters: {counts['total']:,} ({counts['total'] / 1e6:.2f}M)")
549
+
550
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
551
+ model = model.to(device)
552
+
553
+ batch_size = 4
554
+ latent_h, latent_w = 64, 64
555
+ num_patches = latent_h * latent_w
556
+ text_len = 77
557
+
558
+ hidden_states = torch.randn(batch_size, num_patches, config.in_channels, device=device)
559
+ encoder_hidden_states = torch.randn(batch_size, text_len, config.joint_attention_dim, device=device)
560
+ pooled_projections = torch.randn(batch_size, config.pooled_projection_dim, device=device)
561
+ timestep = torch.rand(batch_size, device=device)
562
+ img_ids = TinyFlux.create_img_ids(batch_size, latent_h, latent_w, device)
563
+ guidance = torch.ones(batch_size, device=device) * 3.5
564
+
565
+ # Warmup
566
+ with torch.no_grad():
567
+ for _ in range(3):
568
+ _ = model(hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, guidance)
569
+
570
+ # Benchmark
571
+ if device == 'cuda':
572
+ torch.cuda.synchronize()
573
+ import time
574
+ start = time.time()
575
+ with torch.no_grad():
576
+ for _ in range(10):
577
+ output = model(hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, guidance)
578
+ torch.cuda.synchronize()
579
+ elapsed = (time.time() - start) / 10
580
+ print(f"\nAverage forward pass: {elapsed*1000:.2f}ms")
581
+
582
+ print(f"Output shape: {output.shape}")
583
+ print("\n✓ Forward pass successful!")
584
+
585
+
586
+ #if __name__ == "__main__":
587
+ # test_tiny_flux()