AbstractPhil commited on
Commit
048b8bb
·
verified ·
1 Parent(s): 47c87d8

Create model_v3.py

Browse files
Files changed (1) hide show
  1. model_v3.py +829 -0
model_v3.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TinyFlux-Deep with Expert Predictor
3
+
4
+ Integrates a distillation pathway for SD1.5-flow timestep expertise.
5
+ During training: learns to predict expert features from (timestep, CLIP).
6
+ During inference: runs standalone, no expert needed.
7
+
8
+ Based on TinyFlux-Deep: 15 double + 25 single blocks.
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Dict
17
+
18
+
19
+ @dataclass
20
+ class TinyFluxDeepConfig:
21
+ """Configuration for TinyFlux-Deep model."""
22
+ hidden_size: int = 512
23
+ num_attention_heads: int = 4
24
+ attention_head_dim: int = 128
25
+
26
+ in_channels: int = 16
27
+ patch_size: int = 1
28
+
29
+ joint_attention_dim: int = 768
30
+ pooled_projection_dim: int = 768
31
+
32
+ num_double_layers: int = 15
33
+ num_single_layers: int = 25
34
+
35
+ mlp_ratio: float = 4.0
36
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
37
+
38
+ # Expert predictor config
39
+ use_expert_predictor: bool = True
40
+ expert_dim: int = 1280 # SD1.5 mid-block dimension
41
+ expert_hidden_dim: int = 512
42
+ expert_dropout: float = 0.1 # Dropout during training for robustness
43
+
44
+ # Legacy guidance (disabled when using expert)
45
+ guidance_embeds: bool = False
46
+
47
+ def __post_init__(self):
48
+ assert self.num_attention_heads * self.attention_head_dim == self.hidden_size
49
+ assert sum(self.axes_dims_rope) == self.attention_head_dim
50
+
51
+
52
+ # =============================================================================
53
+ # Normalization
54
+ # =============================================================================
55
+
56
+ class RMSNorm(nn.Module):
57
+ """Root Mean Square Layer Normalization."""
58
+
59
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
60
+ super().__init__()
61
+ self.eps = eps
62
+ self.elementwise_affine = elementwise_affine
63
+ if elementwise_affine:
64
+ self.weight = nn.Parameter(torch.ones(dim))
65
+ else:
66
+ self.register_parameter('weight', None)
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
70
+ out = (x * norm).type_as(x)
71
+ if self.weight is not None:
72
+ out = out * self.weight
73
+ return out
74
+
75
+
76
+ # =============================================================================
77
+ # RoPE - Old format with cached frequency buffers
78
+ # =============================================================================
79
+
80
+ class EmbedND(nn.Module):
81
+ """Original TinyFlux RoPE with cached frequency buffers."""
82
+
83
+ def __init__(self, theta: float = 10000.0, axes_dim: Tuple[int, int, int] = (16, 56, 56)):
84
+ super().__init__()
85
+ self.theta = theta
86
+ self.axes_dim = axes_dim
87
+
88
+ for i, dim in enumerate(axes_dim):
89
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
90
+ self.register_buffer(f'freqs_{i}', freqs, persistent=True)
91
+
92
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
93
+ device = ids.device
94
+ n_axes = ids.shape[-1]
95
+ emb_list = []
96
+
97
+ for i in range(n_axes):
98
+ freqs = getattr(self, f'freqs_{i}').to(device)
99
+ pos = ids[:, i].float()
100
+ angles = pos.unsqueeze(-1) * freqs.unsqueeze(0)
101
+ cos = angles.cos()
102
+ sin = angles.sin()
103
+ emb = torch.stack([cos, sin], dim=-1).flatten(-2)
104
+ emb_list.append(emb)
105
+
106
+ rope = torch.cat(emb_list, dim=-1)
107
+ return rope.unsqueeze(1)
108
+
109
+
110
+ def apply_rotary_emb_old(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
111
+ """Apply rotary embeddings (old interleaved format)."""
112
+ freqs = freqs_cis.squeeze(1)
113
+ cos = freqs[:, 0::2].repeat_interleave(2, dim=-1)
114
+ sin = freqs[:, 1::2].repeat_interleave(2, dim=-1)
115
+ cos = cos[None, None, :, :].to(x.device)
116
+ sin = sin[None, None, :, :].to(x.device)
117
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
118
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(-2)
119
+ return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
120
+
121
+
122
+ # =============================================================================
123
+ # Embeddings
124
+ # =============================================================================
125
+
126
+ class MLPEmbedder(nn.Module):
127
+ """MLP for embedding scalars (timestep)."""
128
+
129
+ def __init__(self, hidden_size: int):
130
+ super().__init__()
131
+ self.mlp = nn.Sequential(
132
+ nn.Linear(256, hidden_size),
133
+ nn.SiLU(),
134
+ nn.Linear(hidden_size, hidden_size),
135
+ )
136
+
137
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
138
+ half_dim = 128
139
+ emb = math.log(10000) / (half_dim - 1)
140
+ emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
141
+ emb = x.unsqueeze(-1) * emb.unsqueeze(0)
142
+ emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
143
+ return self.mlp(emb)
144
+
145
+
146
+ # =============================================================================
147
+ # Expert Predictor
148
+ # =============================================================================
149
+
150
+ class ExpertPredictor(nn.Module):
151
+ """
152
+ Predicts SD1.5-flow expert features from (timestep_emb, CLIP_pooled).
153
+
154
+ Training: learns to match real expert features via distillation loss.
155
+ Inference: runs standalone, no expert model needed.
156
+
157
+ The predictor learns:
158
+ - What the expert "sees" at each timestep
159
+ - How text conditioning modulates that view
160
+ - Trajectory shape priors from the expert's knowledge
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ time_dim: int = 512,
166
+ clip_dim: int = 768,
167
+ expert_dim: int = 1280,
168
+ hidden_dim: int = 512,
169
+ output_dim: int = 512,
170
+ dropout: float = 0.1,
171
+ ):
172
+ super().__init__()
173
+
174
+ self.expert_dim = expert_dim
175
+ self.dropout = dropout
176
+
177
+ # Input fusion
178
+ self.input_proj = nn.Linear(time_dim + clip_dim, hidden_dim)
179
+
180
+ # Predictor core - learns expert behavior
181
+ self.predictor = nn.Sequential(
182
+ nn.SiLU(),
183
+ nn.Linear(hidden_dim, hidden_dim),
184
+ nn.SiLU(),
185
+ nn.Dropout(dropout),
186
+ nn.Linear(hidden_dim, hidden_dim),
187
+ nn.SiLU(),
188
+ nn.Linear(hidden_dim, expert_dim),
189
+ )
190
+
191
+ # Project predicted expert features to vec dimension
192
+ self.output_proj = nn.Sequential(
193
+ nn.LayerNorm(expert_dim),
194
+ nn.Linear(expert_dim, output_dim),
195
+ )
196
+
197
+ # Learnable gate for expert influence
198
+ self.expert_gate = nn.Parameter(torch.ones(1) * 0.5)
199
+
200
+ self._init_weights()
201
+
202
+ def _init_weights(self):
203
+ for m in self.modules():
204
+ if isinstance(m, nn.Linear):
205
+ nn.init.xavier_uniform_(m.weight, gain=0.5)
206
+ if m.bias is not None:
207
+ nn.init.zeros_(m.bias)
208
+
209
+ def forward(
210
+ self,
211
+ time_emb: torch.Tensor,
212
+ clip_pooled: torch.Tensor,
213
+ real_expert_features: Optional[torch.Tensor] = None,
214
+ force_predictor: bool = False,
215
+ ) -> Dict[str, torch.Tensor]:
216
+ """
217
+ Forward pass.
218
+
219
+ Args:
220
+ time_emb: [B, time_dim] - timestep embedding from time_in
221
+ clip_pooled: [B, clip_dim] - pooled CLIP features
222
+ real_expert_features: [B, expert_dim] - real expert output (training only)
223
+ force_predictor: if True, use predictor even when real features available
224
+
225
+ Returns:
226
+ dict with:
227
+ - 'expert_signal': [B, output_dim] - signal to add to vec
228
+ - 'expert_pred': [B, expert_dim] - predicted expert features (for loss)
229
+ - 'expert_used': str - 'real' or 'predicted'
230
+ """
231
+ B = time_emb.shape[0]
232
+ device = time_emb.device
233
+
234
+ # Fuse inputs
235
+ combined = torch.cat([time_emb, clip_pooled], dim=-1)
236
+ hidden = self.input_proj(combined)
237
+
238
+ # Predict expert features
239
+ expert_pred = self.predictor(hidden)
240
+
241
+ # Decide which features to use
242
+ use_real = (
243
+ real_expert_features is not None
244
+ and self.training
245
+ and not force_predictor
246
+ and torch.rand(1).item() > self.dropout # Sometimes use predictor even in training
247
+ )
248
+
249
+ if use_real:
250
+ expert_features = real_expert_features
251
+ expert_used = 'real'
252
+ else:
253
+ expert_features = expert_pred
254
+ expert_used = 'predicted'
255
+
256
+ # Project to output dimension with gating
257
+ gate = torch.sigmoid(self.expert_gate)
258
+ expert_signal = gate * self.output_proj(expert_features)
259
+
260
+ return {
261
+ 'expert_signal': expert_signal,
262
+ 'expert_pred': expert_pred,
263
+ 'expert_used': expert_used,
264
+ }
265
+
266
+ def compute_distillation_loss(
267
+ self,
268
+ expert_pred: torch.Tensor,
269
+ real_expert_features: torch.Tensor,
270
+ ) -> torch.Tensor:
271
+ """MSE loss between predicted and real expert features."""
272
+ return F.mse_loss(expert_pred, real_expert_features)
273
+
274
+
275
+ # =============================================================================
276
+ # AdaLayerNorm
277
+ # =============================================================================
278
+
279
+ class AdaLayerNormZero(nn.Module):
280
+ """AdaLN-Zero for double-stream blocks (6 params)."""
281
+
282
+ def __init__(self, hidden_size: int):
283
+ super().__init__()
284
+ self.silu = nn.SiLU()
285
+ self.linear = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
286
+ self.norm = RMSNorm(hidden_size)
287
+
288
+ def forward(self, x: torch.Tensor, emb: torch.Tensor):
289
+ emb_out = self.linear(self.silu(emb))
290
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb_out.chunk(6, dim=-1)
291
+ x = self.norm(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
292
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
293
+
294
+
295
+ class AdaLayerNormZeroSingle(nn.Module):
296
+ """AdaLN-Zero for single-stream blocks (3 params)."""
297
+
298
+ def __init__(self, hidden_size: int):
299
+ super().__init__()
300
+ self.silu = nn.SiLU()
301
+ self.linear = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
302
+ self.norm = RMSNorm(hidden_size)
303
+
304
+ def forward(self, x: torch.Tensor, emb: torch.Tensor):
305
+ emb_out = self.linear(self.silu(emb))
306
+ shift, scale, gate = emb_out.chunk(3, dim=-1)
307
+ x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
308
+ return x, gate
309
+
310
+
311
+ # =============================================================================
312
+ # Attention
313
+ # =============================================================================
314
+
315
+ class Attention(nn.Module):
316
+ """Multi-head attention."""
317
+
318
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False):
319
+ super().__init__()
320
+ self.num_heads = num_heads
321
+ self.head_dim = head_dim
322
+ self.scale = head_dim ** -0.5
323
+
324
+ self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
325
+ self.out_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
326
+
327
+ def forward(self, x: torch.Tensor, rope: Optional[torch.Tensor] = None) -> torch.Tensor:
328
+ B, N, _ = x.shape
329
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
330
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
331
+
332
+ if rope is not None:
333
+ q = apply_rotary_emb_old(q, rope)
334
+ k = apply_rotary_emb_old(k, rope)
335
+
336
+ attn = F.scaled_dot_product_attention(q, k, v)
337
+ out = attn.transpose(1, 2).reshape(B, N, -1)
338
+ return self.out_proj(out)
339
+
340
+
341
+ class JointAttention(nn.Module):
342
+ """Joint attention for double-stream blocks."""
343
+
344
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False):
345
+ super().__init__()
346
+ self.num_heads = num_heads
347
+ self.head_dim = head_dim
348
+ self.scale = head_dim ** -0.5
349
+
350
+ self.txt_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
351
+ self.img_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
352
+
353
+ self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
354
+ self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
355
+
356
+ def forward(
357
+ self,
358
+ txt: torch.Tensor,
359
+ img: torch.Tensor,
360
+ rope: Optional[torch.Tensor] = None,
361
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
362
+ B, L, _ = txt.shape
363
+ _, N, _ = img.shape
364
+
365
+ txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim)
366
+ img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim)
367
+
368
+ txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4)
369
+ img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4)
370
+
371
+ if rope is not None:
372
+ img_q = apply_rotary_emb_old(img_q, rope)
373
+ img_k = apply_rotary_emb_old(img_k, rope)
374
+
375
+ k = torch.cat([txt_k, img_k], dim=2)
376
+ v = torch.cat([txt_v, img_v], dim=2)
377
+
378
+ txt_out = F.scaled_dot_product_attention(txt_q, k, v)
379
+ txt_out = txt_out.transpose(1, 2).reshape(B, L, -1)
380
+
381
+ img_out = F.scaled_dot_product_attention(img_q, k, v)
382
+ img_out = img_out.transpose(1, 2).reshape(B, N, -1)
383
+
384
+ return self.txt_out(txt_out), self.img_out(img_out)
385
+
386
+
387
+ # =============================================================================
388
+ # MLP
389
+ # =============================================================================
390
+
391
+ class MLP(nn.Module):
392
+ """Feed-forward network with GELU activation."""
393
+
394
+ def __init__(self, hidden_size: int, mlp_ratio: float = 4.0):
395
+ super().__init__()
396
+ mlp_hidden = int(hidden_size * mlp_ratio)
397
+ self.fc1 = nn.Linear(hidden_size, mlp_hidden, bias=True)
398
+ self.act = nn.GELU(approximate='tanh')
399
+ self.fc2 = nn.Linear(mlp_hidden, hidden_size, bias=True)
400
+
401
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
402
+ return self.fc2(self.act(self.fc1(x)))
403
+
404
+
405
+ # =============================================================================
406
+ # Transformer Blocks
407
+ # =============================================================================
408
+
409
+ class DoubleStreamBlock(nn.Module):
410
+ """Double-stream transformer block."""
411
+
412
+ def __init__(self, config: TinyFluxDeepConfig):
413
+ super().__init__()
414
+ hidden = config.hidden_size
415
+ heads = config.num_attention_heads
416
+ head_dim = config.attention_head_dim
417
+
418
+ self.img_norm1 = AdaLayerNormZero(hidden)
419
+ self.txt_norm1 = AdaLayerNormZero(hidden)
420
+ self.attn = JointAttention(hidden, heads, head_dim, use_bias=False)
421
+ self.img_norm2 = RMSNorm(hidden)
422
+ self.txt_norm2 = RMSNorm(hidden)
423
+ self.img_mlp = MLP(hidden, config.mlp_ratio)
424
+ self.txt_mlp = MLP(hidden, config.mlp_ratio)
425
+
426
+ def forward(
427
+ self,
428
+ txt: torch.Tensor,
429
+ img: torch.Tensor,
430
+ vec: torch.Tensor,
431
+ rope: Optional[torch.Tensor] = None,
432
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
433
+ img_normed, img_gate_msa, img_shift_mlp, img_scale_mlp, img_gate_mlp = self.img_norm1(img, vec)
434
+ txt_normed, txt_gate_msa, txt_shift_mlp, txt_scale_mlp, txt_gate_mlp = self.txt_norm1(txt, vec)
435
+
436
+ txt_attn_out, img_attn_out = self.attn(txt_normed, img_normed, rope)
437
+
438
+ txt = txt + txt_gate_msa.unsqueeze(1) * txt_attn_out
439
+ img = img + img_gate_msa.unsqueeze(1) * img_attn_out
440
+
441
+ txt_mlp_in = self.txt_norm2(txt) * (1 + txt_scale_mlp.unsqueeze(1)) + txt_shift_mlp.unsqueeze(1)
442
+ img_mlp_in = self.img_norm2(img) * (1 + img_scale_mlp.unsqueeze(1)) + img_shift_mlp.unsqueeze(1)
443
+
444
+ txt = txt + txt_gate_mlp.unsqueeze(1) * self.txt_mlp(txt_mlp_in)
445
+ img = img + img_gate_mlp.unsqueeze(1) * self.img_mlp(img_mlp_in)
446
+
447
+ return txt, img
448
+
449
+
450
+ class SingleStreamBlock(nn.Module):
451
+ """Single-stream transformer block."""
452
+
453
+ def __init__(self, config: TinyFluxDeepConfig):
454
+ super().__init__()
455
+ hidden = config.hidden_size
456
+ heads = config.num_attention_heads
457
+ head_dim = config.attention_head_dim
458
+
459
+ self.norm = AdaLayerNormZeroSingle(hidden)
460
+ self.attn = Attention(hidden, heads, head_dim, use_bias=False)
461
+ self.mlp = MLP(hidden, config.mlp_ratio)
462
+ self.norm2 = RMSNorm(hidden)
463
+
464
+ def forward(
465
+ self,
466
+ txt: torch.Tensor,
467
+ img: torch.Tensor,
468
+ vec: torch.Tensor,
469
+ rope: Optional[torch.Tensor] = None,
470
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
471
+ L = txt.shape[1]
472
+ x = torch.cat([txt, img], dim=1)
473
+ x_normed, gate = self.norm(x, vec)
474
+ x = x + gate.unsqueeze(1) * self.attn(x_normed, rope)
475
+ x = x + self.mlp(self.norm2(x))
476
+ txt, img = x.split([L, x.shape[1] - L], dim=1)
477
+ return txt, img
478
+
479
+
480
+ # =============================================================================
481
+ # Main Model
482
+ # =============================================================================
483
+
484
+ class TinyFluxDeep(nn.Module):
485
+ """
486
+ TinyFlux-Deep with Expert Predictor.
487
+
488
+ The expert predictor learns to emulate SD1.5-flow's timestep expertise,
489
+ allowing the model to benefit from trajectory priors without requiring
490
+ the expert model at inference time.
491
+ """
492
+
493
+ def __init__(self, config: Optional[TinyFluxDeepConfig] = None):
494
+ super().__init__()
495
+ self.config = config or TinyFluxDeepConfig()
496
+ cfg = self.config
497
+
498
+ # Input projections
499
+ self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size, bias=True)
500
+ self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size, bias=True)
501
+
502
+ # Conditioning
503
+ self.time_in = MLPEmbedder(cfg.hidden_size)
504
+ self.vector_in = nn.Sequential(
505
+ nn.SiLU(),
506
+ nn.Linear(cfg.pooled_projection_dim, cfg.hidden_size, bias=True)
507
+ )
508
+
509
+ # Expert predictor (replaces guidance_in)
510
+ if cfg.use_expert_predictor:
511
+ self.expert_predictor = ExpertPredictor(
512
+ time_dim=cfg.hidden_size,
513
+ clip_dim=cfg.pooled_projection_dim,
514
+ expert_dim=cfg.expert_dim,
515
+ hidden_dim=cfg.expert_hidden_dim,
516
+ output_dim=cfg.hidden_size,
517
+ dropout=cfg.expert_dropout,
518
+ )
519
+ else:
520
+ self.expert_predictor = None
521
+
522
+ # Legacy guidance (for backward compat / comparison)
523
+ if cfg.guidance_embeds:
524
+ self.guidance_in = MLPEmbedder(cfg.hidden_size)
525
+ else:
526
+ self.guidance_in = None
527
+
528
+ # RoPE
529
+ self.rope = EmbedND(theta=10000.0, axes_dim=cfg.axes_dims_rope)
530
+
531
+ # Transformer blocks
532
+ self.double_blocks = nn.ModuleList([
533
+ DoubleStreamBlock(cfg) for _ in range(cfg.num_double_layers)
534
+ ])
535
+ self.single_blocks = nn.ModuleList([
536
+ SingleStreamBlock(cfg) for _ in range(cfg.num_single_layers)
537
+ ])
538
+
539
+ # Output
540
+ self.final_norm = RMSNorm(cfg.hidden_size)
541
+ self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels, bias=True)
542
+
543
+ self._init_weights()
544
+
545
+ def _init_weights(self):
546
+ def _init(module):
547
+ if isinstance(module, nn.Linear):
548
+ nn.init.xavier_uniform_(module.weight)
549
+ if module.bias is not None:
550
+ nn.init.zeros_(module.bias)
551
+ self.apply(_init)
552
+ nn.init.zeros_(self.final_linear.weight)
553
+
554
+ def forward(
555
+ self,
556
+ hidden_states: torch.Tensor,
557
+ encoder_hidden_states: torch.Tensor,
558
+ pooled_projections: torch.Tensor,
559
+ timestep: torch.Tensor,
560
+ img_ids: torch.Tensor,
561
+ txt_ids: Optional[torch.Tensor] = None,
562
+ guidance: Optional[torch.Tensor] = None,
563
+ expert_features: Optional[torch.Tensor] = None,
564
+ return_expert_pred: bool = False,
565
+ ) -> torch.Tensor:
566
+ """
567
+ Forward pass.
568
+
569
+ Args:
570
+ hidden_states: [B, N, C] - image latents
571
+ encoder_hidden_states: [B, L, D] - T5 text embeddings
572
+ pooled_projections: [B, D] - CLIP pooled features
573
+ timestep: [B] - diffusion timestep
574
+ img_ids: [N, 3] or [B, N, 3] - image position IDs
575
+ txt_ids: [L, 3] or [B, L, 3] - text position IDs (optional)
576
+ guidance: [B] - legacy guidance scale (if guidance_embeds=True)
577
+ expert_features: [B, 1280] - real expert features (training only)
578
+ return_expert_pred: if True, return (output, expert_info) tuple
579
+
580
+ Returns:
581
+ output: [B, N, C] - predicted velocity
582
+ expert_info: dict (if return_expert_pred=True)
583
+ """
584
+ B = hidden_states.shape[0]
585
+ L = encoder_hidden_states.shape[1]
586
+ N = hidden_states.shape[1]
587
+
588
+ # Input projections
589
+ img = self.img_in(hidden_states)
590
+ txt = self.txt_in(encoder_hidden_states)
591
+
592
+ # Conditioning: time + pooled text
593
+ time_emb = self.time_in(timestep)
594
+ vec = time_emb + self.vector_in(pooled_projections)
595
+
596
+ # Expert predictor (third stream)
597
+ expert_info = None
598
+ if self.expert_predictor is not None:
599
+ expert_out = self.expert_predictor(
600
+ time_emb=time_emb,
601
+ clip_pooled=pooled_projections,
602
+ real_expert_features=expert_features,
603
+ )
604
+ vec = vec + expert_out['expert_signal']
605
+ expert_info = expert_out
606
+
607
+ # Legacy guidance (fallback)
608
+ elif self.guidance_in is not None and guidance is not None:
609
+ vec = vec + self.guidance_in(guidance)
610
+
611
+ # Handle img_ids shape
612
+ if img_ids.ndim == 3:
613
+ img_ids = img_ids[0]
614
+ img_rope = self.rope(img_ids)
615
+
616
+ # Double-stream blocks
617
+ for block in self.double_blocks:
618
+ txt, img = block(txt, img, vec, img_rope)
619
+
620
+ # Build full sequence RoPE for single-stream
621
+ if txt_ids is None:
622
+ txt_ids = torch.zeros(L, 3, device=img_ids.device, dtype=img_ids.dtype)
623
+ elif txt_ids.ndim == 3:
624
+ txt_ids = txt_ids[0]
625
+
626
+ all_ids = torch.cat([txt_ids, img_ids], dim=0)
627
+ full_rope = self.rope(all_ids)
628
+
629
+ # Single-stream blocks
630
+ for block in self.single_blocks:
631
+ txt, img = block(txt, img, vec, full_rope)
632
+
633
+ # Output
634
+ img = self.final_norm(img)
635
+ output = self.final_linear(img)
636
+
637
+ if return_expert_pred:
638
+ return output, expert_info
639
+ return output
640
+
641
+ def compute_loss(
642
+ self,
643
+ output: torch.Tensor,
644
+ target: torch.Tensor,
645
+ expert_pred: Optional[torch.Tensor] = None,
646
+ real_expert_features: Optional[torch.Tensor] = None,
647
+ distill_weight: float = 0.1,
648
+ ) -> Dict[str, torch.Tensor]:
649
+ """
650
+ Compute combined loss.
651
+
652
+ Args:
653
+ output: model prediction
654
+ target: flow matching target (data - noise)
655
+ expert_pred: predicted expert features
656
+ real_expert_features: real expert features
657
+ distill_weight: weight for distillation loss
658
+
659
+ Returns:
660
+ dict with 'total', 'main', 'distill' losses
661
+ """
662
+ # Main flow matching loss
663
+ main_loss = F.mse_loss(output, target)
664
+
665
+ losses = {
666
+ 'main': main_loss,
667
+ 'distill': torch.tensor(0.0, device=output.device),
668
+ 'total': main_loss,
669
+ }
670
+
671
+ # Distillation loss
672
+ if expert_pred is not None and real_expert_features is not None:
673
+ distill_loss = self.expert_predictor.compute_distillation_loss(
674
+ expert_pred, real_expert_features
675
+ )
676
+ losses['distill'] = distill_loss
677
+ losses['total'] = main_loss + distill_weight * distill_loss
678
+
679
+ return losses
680
+
681
+ @staticmethod
682
+ def create_img_ids(batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
683
+ """Create image position IDs for RoPE."""
684
+ img_ids = torch.zeros(height * width, 3, device=device)
685
+ for i in range(height):
686
+ for j in range(width):
687
+ idx = i * width + j
688
+ img_ids[idx, 0] = 0
689
+ img_ids[idx, 1] = i
690
+ img_ids[idx, 2] = j
691
+ return img_ids
692
+
693
+ @staticmethod
694
+ def create_txt_ids(text_len: int, device: torch.device) -> torch.Tensor:
695
+ """Create text position IDs."""
696
+ txt_ids = torch.zeros(text_len, 3, device=device)
697
+ txt_ids[:, 0] = torch.arange(text_len, device=device)
698
+ return txt_ids
699
+
700
+ def count_parameters(self) -> Dict[str, int]:
701
+ """Count parameters by component."""
702
+ counts = {}
703
+ counts['img_in'] = sum(p.numel() for p in self.img_in.parameters())
704
+ counts['txt_in'] = sum(p.numel() for p in self.txt_in.parameters())
705
+ counts['time_in'] = sum(p.numel() for p in self.time_in.parameters())
706
+ counts['vector_in'] = sum(p.numel() for p in self.vector_in.parameters())
707
+
708
+ if self.expert_predictor is not None:
709
+ counts['expert_predictor'] = sum(p.numel() for p in self.expert_predictor.parameters())
710
+ if self.guidance_in is not None:
711
+ counts['guidance_in'] = sum(p.numel() for p in self.guidance_in.parameters())
712
+
713
+ counts['double_blocks'] = sum(p.numel() for p in self.double_blocks.parameters())
714
+ counts['single_blocks'] = sum(p.numel() for p in self.single_blocks.parameters())
715
+ counts['final'] = sum(p.numel() for p in self.final_norm.parameters()) + \
716
+ sum(p.numel() for p in self.final_linear.parameters())
717
+ counts['total'] = sum(p.numel() for p in self.parameters())
718
+ return counts
719
+
720
+
721
+ # =============================================================================
722
+ # Test
723
+ # =============================================================================
724
+
725
+ def test_model():
726
+ """Test TinyFlux-Deep with Expert Predictor."""
727
+ print("=" * 60)
728
+ print("TinyFlux-Deep + Expert Predictor Test")
729
+ print("=" * 60)
730
+
731
+ config = TinyFluxDeepConfig(
732
+ use_expert_predictor=True,
733
+ expert_dim=1280,
734
+ expert_hidden_dim=512,
735
+ guidance_embeds=False,
736
+ )
737
+ model = TinyFluxDeep(config)
738
+
739
+ counts = model.count_parameters()
740
+ print(f"\nConfig:")
741
+ print(f" hidden_size: {config.hidden_size}")
742
+ print(f" num_double_layers: {config.num_double_layers}")
743
+ print(f" num_single_layers: {config.num_single_layers}")
744
+ print(f" expert_dim: {config.expert_dim}")
745
+ print(f" use_expert_predictor: {config.use_expert_predictor}")
746
+
747
+ print(f"\nParameters:")
748
+ for name, count in counts.items():
749
+ print(f" {name}: {count:,}")
750
+
751
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
752
+ model = model.to(device)
753
+
754
+ B, H, W = 2, 64, 64
755
+ L = 77
756
+
757
+ hidden_states = torch.randn(B, H * W, config.in_channels, device=device)
758
+ encoder_hidden_states = torch.randn(B, L, config.joint_attention_dim, device=device)
759
+ pooled_projections = torch.randn(B, config.pooled_projection_dim, device=device)
760
+ timestep = torch.rand(B, device=device)
761
+ img_ids = TinyFluxDeep.create_img_ids(B, H, W, device)
762
+ txt_ids = TinyFluxDeep.create_txt_ids(L, device)
763
+
764
+ # Simulated expert features
765
+ expert_features = torch.randn(B, config.expert_dim, device=device)
766
+
767
+ print("\n[Test 1: Training mode with expert features]")
768
+ model.train()
769
+ with torch.no_grad():
770
+ output, expert_info = model(
771
+ hidden_states=hidden_states,
772
+ encoder_hidden_states=encoder_hidden_states,
773
+ pooled_projections=pooled_projections,
774
+ timestep=timestep,
775
+ img_ids=img_ids,
776
+ txt_ids=txt_ids,
777
+ expert_features=expert_features,
778
+ return_expert_pred=True,
779
+ )
780
+ print(f" Output shape: {output.shape}")
781
+ print(f" Expert used: {expert_info['expert_used']}")
782
+ print(f" Expert pred shape: {expert_info['expert_pred'].shape}")
783
+
784
+ print("\n[Test 2: Inference mode (no expert)]")
785
+ model.eval()
786
+ with torch.no_grad():
787
+ output = model(
788
+ hidden_states=hidden_states,
789
+ encoder_hidden_states=encoder_hidden_states,
790
+ pooled_projections=pooled_projections,
791
+ timestep=timestep,
792
+ img_ids=img_ids,
793
+ txt_ids=txt_ids,
794
+ expert_features=None, # No expert at inference
795
+ )
796
+ print(f" Output shape: {output.shape}")
797
+ print(f" Output range: [{output.min():.4f}, {output.max():.4f}]")
798
+
799
+ print("\n[Test 3: Loss computation]")
800
+ target = torch.randn_like(output)
801
+ model.train()
802
+ output, expert_info = model(
803
+ hidden_states=hidden_states,
804
+ encoder_hidden_states=encoder_hidden_states,
805
+ pooled_projections=pooled_projections,
806
+ timestep=timestep,
807
+ img_ids=img_ids,
808
+ txt_ids=txt_ids,
809
+ expert_features=expert_features,
810
+ return_expert_pred=True,
811
+ )
812
+ losses = model.compute_loss(
813
+ output=output,
814
+ target=target,
815
+ expert_pred=expert_info['expert_pred'],
816
+ real_expert_features=expert_features,
817
+ distill_weight=0.1,
818
+ )
819
+ print(f" Main loss: {losses['main']:.4f}")
820
+ print(f" Distill loss: {losses['distill']:.4f}")
821
+ print(f" Total loss: {losses['total']:.4f}")
822
+
823
+ print("\n" + "=" * 60)
824
+ print("✓ All tests passed!")
825
+ print("=" * 60)
826
+
827
+
828
+ if __name__ == "__main__":
829
+ test_model()