AbstractPhil commited on
Commit
cacfc43
·
verified ·
1 Parent(s): 502b985

Create model_v2.py

Browse files
Files changed (1) hide show
  1. model_v2.py +588 -0
model_v2.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TinyFlux-Deep: Deeper variant with 15 double + 25 single blocks.
3
+
4
+ Config derived from checkpoint step_285625.safetensors:
5
+ - hidden_size: 512
6
+ - num_attention_heads: 4
7
+ - attention_head_dim: 128
8
+ - num_double_layers: 15
9
+ - num_single_layers: 25
10
+ - Uses biases in MLP
11
+ - Old RoPE format with cached freqs buffers
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, List
20
+
21
+
22
+ @dataclass
23
+ class TinyFluxDeepConfig:
24
+ """Configuration for TinyFlux-Deep model."""
25
+ hidden_size: int = 512
26
+ num_attention_heads: int = 4
27
+ attention_head_dim: int = 128
28
+
29
+ in_channels: int = 16
30
+ patch_size: int = 1
31
+
32
+ joint_attention_dim: int = 768
33
+ pooled_projection_dim: int = 768
34
+
35
+ num_double_layers: int = 15
36
+ num_single_layers: int = 25
37
+
38
+ mlp_ratio: float = 4.0
39
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
40
+ guidance_embeds: bool = True
41
+
42
+ def __post_init__(self):
43
+ assert self.num_attention_heads * self.attention_head_dim == self.hidden_size
44
+ assert sum(self.axes_dims_rope) == self.attention_head_dim
45
+
46
+
47
+ # =============================================================================
48
+ # Normalization
49
+ # =============================================================================
50
+
51
+ class RMSNorm(nn.Module):
52
+ """Root Mean Square Layer Normalization."""
53
+
54
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
55
+ super().__init__()
56
+ self.eps = eps
57
+ self.elementwise_affine = elementwise_affine
58
+ if elementwise_affine:
59
+ self.weight = nn.Parameter(torch.ones(dim))
60
+ else:
61
+ self.register_parameter('weight', None)
62
+
63
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
64
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
65
+ out = (x * norm).type_as(x)
66
+ if self.weight is not None:
67
+ out = out * self.weight
68
+ return out
69
+
70
+
71
+ # =============================================================================
72
+ # RoPE - Old format with cached frequency buffers (checkpoint compatible)
73
+ # =============================================================================
74
+
75
+ class EmbedND(nn.Module):
76
+ """
77
+ Original TinyFlux RoPE with cached frequency buffers.
78
+ Matches checkpoint format with rope.freqs_0, rope.freqs_1, rope.freqs_2
79
+ """
80
+
81
+ def __init__(self, theta: float = 10000.0, axes_dim: Tuple[int, int, int] = (16, 56, 56)):
82
+ super().__init__()
83
+ self.theta = theta
84
+ self.axes_dim = axes_dim
85
+
86
+ # Register frequency buffers (matches checkpoint keys rope.freqs_*)
87
+ for i, dim in enumerate(axes_dim):
88
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
89
+ self.register_buffer(f'freqs_{i}', freqs, persistent=True)
90
+
91
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
92
+ """
93
+ Args:
94
+ ids: (N, 3) position indices [temporal, height, width]
95
+ Returns:
96
+ rope: (N, 1, head_dim) interleaved [cos, sin, cos, sin, ...]
97
+ """
98
+ device = ids.device
99
+ n_axes = ids.shape[-1]
100
+ emb_list = []
101
+
102
+ for i in range(n_axes):
103
+ freqs = getattr(self, f'freqs_{i}').to(device)
104
+ pos = ids[:, i].float()
105
+ angles = pos.unsqueeze(-1) * freqs.unsqueeze(0) # (N, dim/2)
106
+
107
+ # Interleave cos and sin
108
+ cos = angles.cos()
109
+ sin = angles.sin()
110
+ emb = torch.stack([cos, sin], dim=-1).flatten(-2) # (N, dim)
111
+ emb_list.append(emb)
112
+
113
+ rope = torch.cat(emb_list, dim=-1) # (N, head_dim)
114
+ return rope.unsqueeze(1) # (N, 1, head_dim)
115
+
116
+
117
+ def apply_rotary_emb_old(
118
+ x: torch.Tensor,
119
+ freqs_cis: torch.Tensor,
120
+ ) -> torch.Tensor:
121
+ """
122
+ Apply rotary embeddings (old interleaved format).
123
+
124
+ Args:
125
+ x: (B, H, N, D) query or key tensor
126
+ freqs_cis: (N, 1, D) interleaved [cos0, sin0, cos1, sin1, ...]
127
+ Returns:
128
+ Rotated tensor of same shape
129
+ """
130
+ # freqs_cis is (N, 1, D) with interleaved cos/sin
131
+ freqs = freqs_cis.squeeze(1) # (N, D)
132
+
133
+ # Split interleaved cos/sin
134
+ cos = freqs[:, 0::2].repeat_interleave(2, dim=-1) # (N, D)
135
+ sin = freqs[:, 1::2].repeat_interleave(2, dim=-1) # (N, D)
136
+
137
+ cos = cos[None, None, :, :].to(x.device) # (1, 1, N, D)
138
+ sin = sin[None, None, :, :].to(x.device)
139
+
140
+ # Split into real/imag pairs and rotate
141
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
142
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(-2)
143
+
144
+ return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
145
+
146
+
147
+ # =============================================================================
148
+ # Embeddings
149
+ # =============================================================================
150
+
151
+ class MLPEmbedder(nn.Module):
152
+ """MLP for embedding scalars (timestep, guidance)."""
153
+
154
+ def __init__(self, hidden_size: int):
155
+ super().__init__()
156
+ self.mlp = nn.Sequential(
157
+ nn.Linear(256, hidden_size),
158
+ nn.SiLU(),
159
+ nn.Linear(hidden_size, hidden_size),
160
+ )
161
+
162
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
163
+ half_dim = 128
164
+ emb = math.log(10000) / (half_dim - 1)
165
+ emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
166
+ emb = x.unsqueeze(-1) * emb.unsqueeze(0)
167
+ emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
168
+ return self.mlp(emb)
169
+
170
+
171
+ # =============================================================================
172
+ # AdaLayerNorm
173
+ # =============================================================================
174
+
175
+ class AdaLayerNormZero(nn.Module):
176
+ """AdaLN-Zero for double-stream blocks (6 params)."""
177
+
178
+ def __init__(self, hidden_size: int):
179
+ super().__init__()
180
+ self.silu = nn.SiLU()
181
+ self.linear = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
182
+ self.norm = RMSNorm(hidden_size)
183
+
184
+ def forward(self, x: torch.Tensor, emb: torch.Tensor):
185
+ emb_out = self.linear(self.silu(emb))
186
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb_out.chunk(6, dim=-1)
187
+ x = self.norm(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
188
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
189
+
190
+
191
+ class AdaLayerNormZeroSingle(nn.Module):
192
+ """AdaLN-Zero for single-stream blocks (3 params)."""
193
+
194
+ def __init__(self, hidden_size: int):
195
+ super().__init__()
196
+ self.silu = nn.SiLU()
197
+ self.linear = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
198
+ self.norm = RMSNorm(hidden_size)
199
+
200
+ def forward(self, x: torch.Tensor, emb: torch.Tensor):
201
+ emb_out = self.linear(self.silu(emb))
202
+ shift, scale, gate = emb_out.chunk(3, dim=-1)
203
+ x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
204
+ return x, gate
205
+
206
+
207
+ # =============================================================================
208
+ # Attention (original format - no Q/K norm, matches checkpoint)
209
+ # =============================================================================
210
+
211
+ class Attention(nn.Module):
212
+ """Multi-head attention (original TinyFlux format, no Q/K norm)."""
213
+
214
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False):
215
+ super().__init__()
216
+ self.num_heads = num_heads
217
+ self.head_dim = head_dim
218
+ self.scale = head_dim ** -0.5
219
+
220
+ self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
221
+ self.out_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
222
+
223
+ def forward(
224
+ self,
225
+ x: torch.Tensor,
226
+ rope: Optional[torch.Tensor] = None,
227
+ ) -> torch.Tensor:
228
+ B, N, _ = x.shape
229
+
230
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
231
+ q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3 x (B, H, N, D)
232
+
233
+ # Apply RoPE
234
+ if rope is not None:
235
+ q = apply_rotary_emb_old(q, rope)
236
+ k = apply_rotary_emb_old(k, rope)
237
+
238
+ # Scaled dot-product attention
239
+ attn = F.scaled_dot_product_attention(q, k, v)
240
+ out = attn.transpose(1, 2).reshape(B, N, -1)
241
+ return self.out_proj(out)
242
+
243
+
244
+ class JointAttention(nn.Module):
245
+ """Joint attention for double-stream blocks (original format)."""
246
+
247
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False):
248
+ super().__init__()
249
+ self.num_heads = num_heads
250
+ self.head_dim = head_dim
251
+ self.scale = head_dim ** -0.5
252
+
253
+ self.txt_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
254
+ self.img_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
255
+
256
+ self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
257
+ self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
258
+
259
+ def forward(
260
+ self,
261
+ txt: torch.Tensor,
262
+ img: torch.Tensor,
263
+ rope: Optional[torch.Tensor] = None,
264
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
265
+ B, L, _ = txt.shape
266
+ _, N, _ = img.shape
267
+
268
+ txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim)
269
+ img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim)
270
+
271
+ txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4)
272
+ img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4)
273
+
274
+ # Apply RoPE to image only
275
+ if rope is not None:
276
+ img_q = apply_rotary_emb_old(img_q, rope)
277
+ img_k = apply_rotary_emb_old(img_k, rope)
278
+
279
+ # Concatenate for joint attention
280
+ k = torch.cat([txt_k, img_k], dim=2)
281
+ v = torch.cat([txt_v, img_v], dim=2)
282
+
283
+ txt_out = F.scaled_dot_product_attention(txt_q, k, v)
284
+ txt_out = txt_out.transpose(1, 2).reshape(B, L, -1)
285
+
286
+ img_out = F.scaled_dot_product_attention(img_q, k, v)
287
+ img_out = img_out.transpose(1, 2).reshape(B, N, -1)
288
+
289
+ return self.txt_out(txt_out), self.img_out(img_out)
290
+
291
+
292
+ # =============================================================================
293
+ # MLP (with bias - matches checkpoint)
294
+ # =============================================================================
295
+
296
+ class MLP(nn.Module):
297
+ """Feed-forward network with GELU activation and biases."""
298
+
299
+ def __init__(self, hidden_size: int, mlp_ratio: float = 4.0):
300
+ super().__init__()
301
+ mlp_hidden = int(hidden_size * mlp_ratio)
302
+ self.fc1 = nn.Linear(hidden_size, mlp_hidden, bias=True) # bias=True for checkpoint compat
303
+ self.act = nn.GELU(approximate='tanh')
304
+ self.fc2 = nn.Linear(mlp_hidden, hidden_size, bias=True)
305
+
306
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
307
+ return self.fc2(self.act(self.fc1(x)))
308
+
309
+
310
+ # =============================================================================
311
+ # Transformer Blocks
312
+ # =============================================================================
313
+
314
+ class DoubleStreamBlock(nn.Module):
315
+ """Double-stream transformer block."""
316
+
317
+ def __init__(self, config: TinyFluxDeepConfig):
318
+ super().__init__()
319
+ hidden = config.hidden_size
320
+ heads = config.num_attention_heads
321
+ head_dim = config.attention_head_dim
322
+
323
+ self.img_norm1 = AdaLayerNormZero(hidden)
324
+ self.txt_norm1 = AdaLayerNormZero(hidden)
325
+
326
+ self.attn = JointAttention(hidden, heads, head_dim, use_bias=False)
327
+
328
+ self.img_norm2 = RMSNorm(hidden)
329
+ self.txt_norm2 = RMSNorm(hidden)
330
+
331
+ self.img_mlp = MLP(hidden, config.mlp_ratio)
332
+ self.txt_mlp = MLP(hidden, config.mlp_ratio)
333
+
334
+ def forward(
335
+ self,
336
+ txt: torch.Tensor,
337
+ img: torch.Tensor,
338
+ vec: torch.Tensor,
339
+ rope: Optional[torch.Tensor] = None,
340
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
341
+ img_normed, img_gate_msa, img_shift_mlp, img_scale_mlp, img_gate_mlp = self.img_norm1(img, vec)
342
+ txt_normed, txt_gate_msa, txt_shift_mlp, txt_scale_mlp, txt_gate_mlp = self.txt_norm1(txt, vec)
343
+
344
+ txt_attn_out, img_attn_out = self.attn(txt_normed, img_normed, rope)
345
+
346
+ txt = txt + txt_gate_msa.unsqueeze(1) * txt_attn_out
347
+ img = img + img_gate_msa.unsqueeze(1) * img_attn_out
348
+
349
+ txt_mlp_in = self.txt_norm2(txt) * (1 + txt_scale_mlp.unsqueeze(1)) + txt_shift_mlp.unsqueeze(1)
350
+ img_mlp_in = self.img_norm2(img) * (1 + img_scale_mlp.unsqueeze(1)) + img_shift_mlp.unsqueeze(1)
351
+
352
+ txt = txt + txt_gate_mlp.unsqueeze(1) * self.txt_mlp(txt_mlp_in)
353
+ img = img + img_gate_mlp.unsqueeze(1) * self.img_mlp(img_mlp_in)
354
+
355
+ return txt, img
356
+
357
+
358
+ class SingleStreamBlock(nn.Module):
359
+ """Single-stream transformer block."""
360
+
361
+ def __init__(self, config: TinyFluxDeepConfig):
362
+ super().__init__()
363
+ hidden = config.hidden_size
364
+ heads = config.num_attention_heads
365
+ head_dim = config.attention_head_dim
366
+
367
+ self.norm = AdaLayerNormZeroSingle(hidden)
368
+ self.attn = Attention(hidden, heads, head_dim, use_bias=False)
369
+ self.mlp = MLP(hidden, config.mlp_ratio)
370
+ self.norm2 = RMSNorm(hidden)
371
+
372
+ def forward(
373
+ self,
374
+ txt: torch.Tensor,
375
+ img: torch.Tensor,
376
+ vec: torch.Tensor,
377
+ rope: Optional[torch.Tensor] = None,
378
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
379
+ L = txt.shape[1]
380
+
381
+ x = torch.cat([txt, img], dim=1)
382
+
383
+ x_normed, gate = self.norm(x, vec)
384
+ x = x + gate.unsqueeze(1) * self.attn(x_normed, rope)
385
+ x = x + self.mlp(self.norm2(x))
386
+
387
+ txt, img = x.split([L, x.shape[1] - L], dim=1)
388
+ return txt, img
389
+
390
+
391
+ # =============================================================================
392
+ # Main Model
393
+ # =============================================================================
394
+
395
+ class TinyFluxDeep(nn.Module):
396
+ """TinyFlux-Deep: 15 double + 25 single blocks."""
397
+
398
+ def __init__(self, config: Optional[TinyFluxDeepConfig] = None):
399
+ super().__init__()
400
+ self.config = config or TinyFluxDeepConfig()
401
+ cfg = self.config
402
+
403
+ # Input projections (with bias to match checkpoint)
404
+ self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size, bias=True)
405
+ self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size, bias=True)
406
+
407
+ # Conditioning
408
+ self.time_in = MLPEmbedder(cfg.hidden_size)
409
+ self.vector_in = nn.Sequential(
410
+ nn.SiLU(),
411
+ nn.Linear(cfg.pooled_projection_dim, cfg.hidden_size, bias=True)
412
+ )
413
+ if cfg.guidance_embeds:
414
+ self.guidance_in = MLPEmbedder(cfg.hidden_size)
415
+
416
+ # RoPE (old format with cached freqs)
417
+ self.rope = EmbedND(theta=10000.0, axes_dim=cfg.axes_dims_rope)
418
+
419
+ # Transformer blocks
420
+ self.double_blocks = nn.ModuleList([
421
+ DoubleStreamBlock(cfg) for _ in range(cfg.num_double_layers)
422
+ ])
423
+ self.single_blocks = nn.ModuleList([
424
+ SingleStreamBlock(cfg) for _ in range(cfg.num_single_layers)
425
+ ])
426
+
427
+ # Output
428
+ self.final_norm = RMSNorm(cfg.hidden_size)
429
+ self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels, bias=True)
430
+
431
+ self._init_weights()
432
+
433
+ def _init_weights(self):
434
+ def _init(module):
435
+ if isinstance(module, nn.Linear):
436
+ nn.init.xavier_uniform_(module.weight)
437
+ if module.bias is not None:
438
+ nn.init.zeros_(module.bias)
439
+ self.apply(_init)
440
+ nn.init.zeros_(self.final_linear.weight)
441
+
442
+ def forward(
443
+ self,
444
+ hidden_states: torch.Tensor,
445
+ encoder_hidden_states: torch.Tensor,
446
+ pooled_projections: torch.Tensor,
447
+ timestep: torch.Tensor,
448
+ img_ids: torch.Tensor,
449
+ txt_ids: Optional[torch.Tensor] = None,
450
+ guidance: Optional[torch.Tensor] = None,
451
+ ) -> torch.Tensor:
452
+ B = hidden_states.shape[0]
453
+ L = encoder_hidden_states.shape[1]
454
+ N = hidden_states.shape[1]
455
+
456
+ # Input projections
457
+ img = self.img_in(hidden_states)
458
+ txt = self.txt_in(encoder_hidden_states)
459
+
460
+ # Conditioning
461
+ vec = self.time_in(timestep)
462
+ vec = vec + self.vector_in(pooled_projections)
463
+ if self.config.guidance_embeds and guidance is not None:
464
+ vec = vec + self.guidance_in(guidance)
465
+
466
+ # Handle img_ids shape
467
+ if img_ids.ndim == 3:
468
+ img_ids = img_ids[0] # (N, 3)
469
+
470
+ # Compute RoPE for image positions
471
+ img_rope = self.rope(img_ids) # (N, 1, head_dim)
472
+
473
+ # Double-stream blocks
474
+ for block in self.double_blocks:
475
+ txt, img = block(txt, img, vec, img_rope)
476
+
477
+ # Build full sequence RoPE for single-stream
478
+ if txt_ids is None:
479
+ txt_ids = torch.zeros(L, 3, device=img_ids.device, dtype=img_ids.dtype)
480
+ elif txt_ids.ndim == 3:
481
+ txt_ids = txt_ids[0]
482
+
483
+ all_ids = torch.cat([txt_ids, img_ids], dim=0)
484
+ full_rope = self.rope(all_ids)
485
+
486
+ # Single-stream blocks
487
+ for block in self.single_blocks:
488
+ txt, img = block(txt, img, vec, full_rope)
489
+
490
+ # Output
491
+ img = self.final_norm(img)
492
+ img = self.final_linear(img)
493
+
494
+ return img
495
+
496
+ @staticmethod
497
+ def create_img_ids(batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
498
+ """Create image position IDs for RoPE."""
499
+ img_ids = torch.zeros(height * width, 3, device=device)
500
+ for i in range(height):
501
+ for j in range(width):
502
+ idx = i * width + j
503
+ img_ids[idx, 0] = 0
504
+ img_ids[idx, 1] = i
505
+ img_ids[idx, 2] = j
506
+ return img_ids
507
+
508
+ @staticmethod
509
+ def create_txt_ids(text_len: int, device: torch.device) -> torch.Tensor:
510
+ """Create text position IDs."""
511
+ txt_ids = torch.zeros(text_len, 3, device=device)
512
+ txt_ids[:, 0] = torch.arange(text_len, device=device)
513
+ return txt_ids
514
+
515
+ def count_parameters(self) -> dict:
516
+ """Count parameters by component."""
517
+ counts = {}
518
+ counts['img_in'] = sum(p.numel() for p in self.img_in.parameters())
519
+ counts['txt_in'] = sum(p.numel() for p in self.txt_in.parameters())
520
+ counts['time_in'] = sum(p.numel() for p in self.time_in.parameters())
521
+ counts['vector_in'] = sum(p.numel() for p in self.vector_in.parameters())
522
+ if hasattr(self, 'guidance_in'):
523
+ counts['guidance_in'] = sum(p.numel() for p in self.guidance_in.parameters())
524
+ counts['double_blocks'] = sum(p.numel() for p in self.double_blocks.parameters())
525
+ counts['single_blocks'] = sum(p.numel() for p in self.single_blocks.parameters())
526
+ counts['final'] = sum(p.numel() for p in self.final_norm.parameters()) + \
527
+ sum(p.numel() for p in self.final_linear.parameters())
528
+ counts['total'] = sum(p.numel() for p in self.parameters())
529
+ return counts
530
+
531
+
532
+ # =============================================================================
533
+ # Test
534
+ # =============================================================================
535
+
536
+ def test_model():
537
+ """Test TinyFlux-Deep model."""
538
+ print("=" * 60)
539
+ print("TinyFlux-Deep Test")
540
+ print("=" * 60)
541
+
542
+ config = TinyFluxDeepConfig()
543
+ model = TinyFluxDeep(config)
544
+
545
+ counts = model.count_parameters()
546
+ print(f"\nConfig:")
547
+ print(f" hidden_size: {config.hidden_size}")
548
+ print(f" num_attention_heads: {config.num_attention_heads}")
549
+ print(f" attention_head_dim: {config.attention_head_dim}")
550
+ print(f" num_double_layers: {config.num_double_layers}")
551
+ print(f" num_single_layers: {config.num_single_layers}")
552
+
553
+ print(f"\nParameters:")
554
+ for name, count in counts.items():
555
+ print(f" {name}: {count:,}")
556
+
557
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
558
+ model = model.to(device)
559
+
560
+ B, H, W = 2, 64, 64
561
+ L = 77
562
+
563
+ hidden_states = torch.randn(B, H * W, config.in_channels, device=device)
564
+ encoder_hidden_states = torch.randn(B, L, config.joint_attention_dim, device=device)
565
+ pooled_projections = torch.randn(B, config.pooled_projection_dim, device=device)
566
+ timestep = torch.rand(B, device=device)
567
+ img_ids = TinyFluxDeep.create_img_ids(B, H, W, device)
568
+ txt_ids = TinyFluxDeep.create_txt_ids(L, device)
569
+ guidance = torch.ones(B, device=device) * 3.5
570
+
571
+ with torch.no_grad():
572
+ output = model(
573
+ hidden_states=hidden_states,
574
+ encoder_hidden_states=encoder_hidden_states,
575
+ pooled_projections=pooled_projections,
576
+ timestep=timestep,
577
+ img_ids=img_ids,
578
+ txt_ids=txt_ids,
579
+ guidance=guidance,
580
+ )
581
+
582
+ print(f"\nOutput shape: {output.shape}")
583
+ print(f"Output range: [{output.min():.4f}, {output.max():.4f}]")
584
+ print("\n✓ Forward pass successful!")
585
+
586
+
587
+ if __name__ == "__main__":
588
+ test_model()