AbstractPhil commited on
Commit
a859db1
·
verified ·
1 Parent(s): 1ae1073

Create colab_inference_lailah_early.py

Browse files
Files changed (1) hide show
  1. colab_inference_lailah_early.py +1057 -0
colab_inference_lailah_early.py ADDED
@@ -0,0 +1,1057 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TinyFlux-Lailah Inference
3
+ Loads the model code, the weights, and runs the inference based on the settings below.
4
+ Set up with only EULER for now.
5
+
6
+ No guarantees for any of this to work.
7
+
8
+ It's pretty bad in it's current phases, just check on it later if you're interested.
9
+ LICENSE: MIT
10
+ """
11
+
12
+
13
+ POSITIVE_PROMPT = "woman" # @param {type:"string"}
14
+ NEGATIVE_PROMPT = "" # @param {type:"string"}
15
+ STEPS = 50 # @param {type:"integer"}
16
+ CFG_GUIDANCE = 5 # @param {type: "number"}
17
+ FLUX_SHIFT = 3 # @param {type: "number"}
18
+ SEED = 420 # @param {type: "integer"}
19
+ OUTPUT_PATH = "output.png" # @param {type:"string"}
20
+ WIDTH = 512 # @param {type: "integer"}
21
+ HEIGHT = 512 # @param {type: "integer"}
22
+
23
+ # Model loading
24
+ HF_REPO = "AbstractPhil/tiny-flux-deep" # @param {type:"string"}
25
+ # "hub", "hub:step_XXXXX", "local:/path/to/weights.safetensors"
26
+ LOAD_FROM = "hub:step_293750" # @param {type:"string"}
27
+
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
30
+
31
+
32
+ #@title Preview (updates in-place)
33
+ from IPython.display import display, Image as IPyImage, update_display
34
+ from PIL import Image as PIL
35
+ import numpy as np, io
36
+
37
+ _PREVIEW_DISPLAY_ID = "tf_preview"
38
+
39
+ preview_size = min(512, max(WIDTH, HEIGHT) // 2)
40
+
41
+
42
+ def _pil_to_png_bytes(img: PIL) -> bytes:
43
+ buf = io.BytesIO()
44
+ img.save(buf, format="PNG")
45
+ return buf.getvalue()
46
+
47
+ def init_preview(square: int = 256):
48
+ """Show a black placeholder square once."""
49
+ black = PIL.fromarray(np.zeros((square, square, 3), dtype=np.uint8))
50
+ display(IPyImage(data=_pil_to_png_bytes(black)), display_id=_PREVIEW_DISPLAY_ID)
51
+
52
+ def set_preview_from_pil(img: PIL, square: int = 256):
53
+ """Update the preview in-place with a PIL image."""
54
+ im = img.convert("RGB").copy()
55
+ im.thumbnail((square, square), resample=PIL.Resampling.LANCZOS)
56
+ # pad to square (so it stays a square widget)
57
+ canvas = PIL.fromarray(np.zeros((square, square, 3), dtype=np.uint8))
58
+ x = (square - im.size[0]) // 2
59
+ y = (square - im.size[1]) // 2
60
+ canvas.paste(im, (x, y))
61
+ update_display(IPyImage(data=_pil_to_png_bytes(canvas)), display_id=_PREVIEW_DISPLAY_ID)
62
+
63
+ def set_preview_from_path(path: str, square: int = 256):
64
+ """Update preview from an image file path."""
65
+ set_preview_from_pil(PIL.open(path), square=square)
66
+
67
+ # initialize placeholder
68
+ init_preview(square=preview_size)
69
+ #set_preview_from_pil(image, square=preview_size)
70
+
71
+
72
+
73
+ """
74
+ TinyFlux-Deep: Deeper variant with 15 double + 25 single blocks.
75
+
76
+ Config derived from checkpoint step_285625.safetensors:
77
+ - hidden_size: 512
78
+ - num_attention_heads: 4
79
+ - attention_head_dim: 128
80
+ - num_double_layers: 15
81
+ - num_single_layers: 25
82
+ - Uses biases in MLP
83
+ - Old RoPE format with cached freqs buffers
84
+ """
85
+
86
+ import torch
87
+ import torch.nn as nn
88
+ import torch.nn.functional as F
89
+ import math
90
+ from dataclasses import dataclass
91
+ from typing import Optional, Tuple, List
92
+
93
+ @dataclass
94
+ class TinyFluxDeepConfig:
95
+ """Configuration for TinyFlux-Deep model."""
96
+ hidden_size: int = 512
97
+ num_attention_heads: int = 4
98
+ attention_head_dim: int = 128
99
+
100
+ in_channels: int = 16
101
+ patch_size: int = 1
102
+
103
+ joint_attention_dim: int = 768
104
+ pooled_projection_dim: int = 768
105
+
106
+ num_double_layers: int = 15
107
+ num_single_layers: int = 25
108
+
109
+ mlp_ratio: float = 4.0
110
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
111
+ guidance_embeds: bool = True
112
+
113
+ def __post_init__(self):
114
+ assert self.num_attention_heads * self.attention_head_dim == self.hidden_size
115
+ assert sum(self.axes_dims_rope) == self.attention_head_dim
116
+
117
+
118
+ # =============================================================================
119
+ # Normalization
120
+ # =============================================================================
121
+
122
+ class RMSNorm(nn.Module):
123
+ """Root Mean Square Layer Normalization."""
124
+
125
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
126
+ super().__init__()
127
+ self.eps = eps
128
+ self.elementwise_affine = elementwise_affine
129
+ if elementwise_affine:
130
+ self.weight = nn.Parameter(torch.ones(dim))
131
+ else:
132
+ self.register_parameter('weight', None)
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
136
+ out = (x * norm).type_as(x)
137
+ if self.weight is not None:
138
+ out = out * self.weight
139
+ return out
140
+
141
+
142
+ # =============================================================================
143
+ # RoPE - Old format with cached frequency buffers (checkpoint compatible)
144
+ # =============================================================================
145
+
146
+ class EmbedND(nn.Module):
147
+ """
148
+ Original TinyFlux RoPE with cached frequency buffers.
149
+ Matches checkpoint format with rope.freqs_0, rope.freqs_1, rope.freqs_2
150
+ """
151
+
152
+ def __init__(self, theta: float = 10000.0, axes_dim: Tuple[int, int, int] = (16, 56, 56)):
153
+ super().__init__()
154
+ self.theta = theta
155
+ self.axes_dim = axes_dim
156
+
157
+ # Register frequency buffers (matches checkpoint keys rope.freqs_*)
158
+ for i, dim in enumerate(axes_dim):
159
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
160
+ self.register_buffer(f'freqs_{i}', freqs, persistent=True)
161
+
162
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
163
+ """
164
+ Args:
165
+ ids: (N, 3) position indices [temporal, height, width]
166
+ Returns:
167
+ rope: (N, 1, head_dim) interleaved [cos, sin, cos, sin, ...]
168
+ """
169
+ device = ids.device
170
+ n_axes = ids.shape[-1]
171
+ emb_list = []
172
+
173
+ for i in range(n_axes):
174
+ freqs = getattr(self, f'freqs_{i}').to(device)
175
+ pos = ids[:, i].float()
176
+ angles = pos.unsqueeze(-1) * freqs.unsqueeze(0) # (N, dim/2)
177
+
178
+ # Interleave cos and sin
179
+ cos = angles.cos()
180
+ sin = angles.sin()
181
+ emb = torch.stack([cos, sin], dim=-1).flatten(-2) # (N, dim)
182
+ emb_list.append(emb)
183
+
184
+ rope = torch.cat(emb_list, dim=-1) # (N, head_dim)
185
+ return rope.unsqueeze(1) # (N, 1, head_dim)
186
+
187
+
188
+ def apply_rotary_emb_old(
189
+ x: torch.Tensor,
190
+ freqs_cis: torch.Tensor,
191
+ ) -> torch.Tensor:
192
+ """
193
+ Apply rotary embeddings (old interleaved format).
194
+
195
+ Args:
196
+ x: (B, H, N, D) query or key tensor
197
+ freqs_cis: (N, 1, D) interleaved [cos0, sin0, cos1, sin1, ...]
198
+ Returns:
199
+ Rotated tensor of same shape
200
+ """
201
+ # freqs_cis is (N, 1, D) with interleaved cos/sin
202
+ freqs = freqs_cis.squeeze(1) # (N, D)
203
+
204
+ # Split interleaved cos/sin
205
+ cos = freqs[:, 0::2].repeat_interleave(2, dim=-1) # (N, D)
206
+ sin = freqs[:, 1::2].repeat_interleave(2, dim=-1) # (N, D)
207
+
208
+ cos = cos[None, None, :, :].to(x.device) # (1, 1, N, D)
209
+ sin = sin[None, None, :, :].to(x.device)
210
+
211
+ # Split into real/imag pairs and rotate
212
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
213
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(-2)
214
+
215
+ return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
216
+
217
+
218
+ # =============================================================================
219
+ # Embeddings
220
+ # =============================================================================
221
+
222
+ class MLPEmbedder(nn.Module):
223
+ """MLP for embedding scalars (timestep, guidance)."""
224
+
225
+ def __init__(self, hidden_size: int):
226
+ super().__init__()
227
+ self.mlp = nn.Sequential(
228
+ nn.Linear(256, hidden_size),
229
+ nn.SiLU(),
230
+ nn.Linear(hidden_size, hidden_size),
231
+ )
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ half_dim = 128
235
+ emb = math.log(10000) / (half_dim - 1)
236
+ emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
237
+ emb = x.unsqueeze(-1) * emb.unsqueeze(0)
238
+ emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
239
+ return self.mlp(emb)
240
+
241
+
242
+ # =============================================================================
243
+ # AdaLayerNorm
244
+ # =============================================================================
245
+
246
+ class AdaLayerNormZero(nn.Module):
247
+ """AdaLN-Zero for double-stream blocks (6 params)."""
248
+
249
+ def __init__(self, hidden_size: int):
250
+ super().__init__()
251
+ self.silu = nn.SiLU()
252
+ self.linear = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
253
+ self.norm = RMSNorm(hidden_size)
254
+
255
+ def forward(self, x: torch.Tensor, emb: torch.Tensor):
256
+ emb_out = self.linear(self.silu(emb))
257
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb_out.chunk(6, dim=-1)
258
+ x = self.norm(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
259
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
260
+
261
+
262
+ class AdaLayerNormZeroSingle(nn.Module):
263
+ """AdaLN-Zero for single-stream blocks (3 params)."""
264
+
265
+ def __init__(self, hidden_size: int):
266
+ super().__init__()
267
+ self.silu = nn.SiLU()
268
+ self.linear = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
269
+ self.norm = RMSNorm(hidden_size)
270
+
271
+ def forward(self, x: torch.Tensor, emb: torch.Tensor):
272
+ emb_out = self.linear(self.silu(emb))
273
+ shift, scale, gate = emb_out.chunk(3, dim=-1)
274
+ x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
275
+ return x, gate
276
+
277
+
278
+ # =============================================================================
279
+ # Attention (original format - no Q/K norm, matches checkpoint)
280
+ # =============================================================================
281
+
282
+ class Attention(nn.Module):
283
+ """Multi-head attention (original TinyFlux format, no Q/K norm)."""
284
+
285
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False):
286
+ super().__init__()
287
+ self.num_heads = num_heads
288
+ self.head_dim = head_dim
289
+ self.scale = head_dim ** -0.5
290
+
291
+ self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
292
+ self.out_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
293
+
294
+ def forward(
295
+ self,
296
+ x: torch.Tensor,
297
+ rope: Optional[torch.Tensor] = None,
298
+ ) -> torch.Tensor:
299
+ B, N, _ = x.shape
300
+
301
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
302
+ q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3 x (B, H, N, D)
303
+
304
+ # Apply RoPE
305
+ if rope is not None:
306
+ q = apply_rotary_emb_old(q, rope)
307
+ k = apply_rotary_emb_old(k, rope)
308
+
309
+ # Scaled dot-product attention
310
+ attn = F.scaled_dot_product_attention(q, k, v)
311
+ out = attn.transpose(1, 2).reshape(B, N, -1)
312
+ return self.out_proj(out)
313
+
314
+
315
+ class JointAttention(nn.Module):
316
+ """Joint attention for double-stream blocks (original format)."""
317
+
318
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False):
319
+ super().__init__()
320
+ self.num_heads = num_heads
321
+ self.head_dim = head_dim
322
+ self.scale = head_dim ** -0.5
323
+
324
+ self.txt_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
325
+ self.img_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
326
+
327
+ self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
328
+ self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
329
+
330
+ def forward(
331
+ self,
332
+ txt: torch.Tensor,
333
+ img: torch.Tensor,
334
+ rope: Optional[torch.Tensor] = None,
335
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
336
+ B, L, _ = txt.shape
337
+ _, N, _ = img.shape
338
+
339
+ txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim)
340
+ img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim)
341
+
342
+ txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4)
343
+ img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4)
344
+
345
+ # Apply RoPE to image only
346
+ if rope is not None:
347
+ img_q = apply_rotary_emb_old(img_q, rope)
348
+ img_k = apply_rotary_emb_old(img_k, rope)
349
+
350
+ # Concatenate for joint attention
351
+ k = torch.cat([txt_k, img_k], dim=2)
352
+ v = torch.cat([txt_v, img_v], dim=2)
353
+
354
+ txt_out = F.scaled_dot_product_attention(txt_q, k, v)
355
+ txt_out = txt_out.transpose(1, 2).reshape(B, L, -1)
356
+
357
+ img_out = F.scaled_dot_product_attention(img_q, k, v)
358
+ img_out = img_out.transpose(1, 2).reshape(B, N, -1)
359
+
360
+ return self.txt_out(txt_out), self.img_out(img_out)
361
+
362
+
363
+ # =============================================================================
364
+ # MLP (with bias - matches checkpoint)
365
+ # =============================================================================
366
+
367
+ class MLP(nn.Module):
368
+ """Feed-forward network with GELU activation and biases."""
369
+
370
+ def __init__(self, hidden_size: int, mlp_ratio: float = 4.0):
371
+ super().__init__()
372
+ mlp_hidden = int(hidden_size * mlp_ratio)
373
+ self.fc1 = nn.Linear(hidden_size, mlp_hidden, bias=True) # bias=True for checkpoint compat
374
+ self.act = nn.GELU(approximate='tanh')
375
+ self.fc2 = nn.Linear(mlp_hidden, hidden_size, bias=True)
376
+
377
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
378
+ return self.fc2(self.act(self.fc1(x)))
379
+
380
+
381
+ # =============================================================================
382
+ # Transformer Blocks
383
+ # =============================================================================
384
+
385
+ class DoubleStreamBlock(nn.Module):
386
+ """Double-stream transformer block."""
387
+
388
+ def __init__(self, config: TinyFluxDeepConfig):
389
+ super().__init__()
390
+ hidden = config.hidden_size
391
+ heads = config.num_attention_heads
392
+ head_dim = config.attention_head_dim
393
+
394
+ self.img_norm1 = AdaLayerNormZero(hidden)
395
+ self.txt_norm1 = AdaLayerNormZero(hidden)
396
+
397
+ self.attn = JointAttention(hidden, heads, head_dim, use_bias=False)
398
+
399
+ self.img_norm2 = RMSNorm(hidden)
400
+ self.txt_norm2 = RMSNorm(hidden)
401
+
402
+ self.img_mlp = MLP(hidden, config.mlp_ratio)
403
+ self.txt_mlp = MLP(hidden, config.mlp_ratio)
404
+
405
+ def forward(
406
+ self,
407
+ txt: torch.Tensor,
408
+ img: torch.Tensor,
409
+ vec: torch.Tensor,
410
+ rope: Optional[torch.Tensor] = None,
411
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
412
+ img_normed, img_gate_msa, img_shift_mlp, img_scale_mlp, img_gate_mlp = self.img_norm1(img, vec)
413
+ txt_normed, txt_gate_msa, txt_shift_mlp, txt_scale_mlp, txt_gate_mlp = self.txt_norm1(txt, vec)
414
+
415
+ txt_attn_out, img_attn_out = self.attn(txt_normed, img_normed, rope)
416
+
417
+ txt = txt + txt_gate_msa.unsqueeze(1) * txt_attn_out
418
+ img = img + img_gate_msa.unsqueeze(1) * img_attn_out
419
+
420
+ txt_mlp_in = self.txt_norm2(txt) * (1 + txt_scale_mlp.unsqueeze(1)) + txt_shift_mlp.unsqueeze(1)
421
+ img_mlp_in = self.img_norm2(img) * (1 + img_scale_mlp.unsqueeze(1)) + img_shift_mlp.unsqueeze(1)
422
+
423
+ txt = txt + txt_gate_mlp.unsqueeze(1) * self.txt_mlp(txt_mlp_in)
424
+ img = img + img_gate_mlp.unsqueeze(1) * self.img_mlp(img_mlp_in)
425
+
426
+ return txt, img
427
+
428
+
429
+ class SingleStreamBlock(nn.Module):
430
+ """Single-stream transformer block."""
431
+
432
+ def __init__(self, config: TinyFluxDeepConfig):
433
+ super().__init__()
434
+ hidden = config.hidden_size
435
+ heads = config.num_attention_heads
436
+ head_dim = config.attention_head_dim
437
+
438
+ self.norm = AdaLayerNormZeroSingle(hidden)
439
+ self.attn = Attention(hidden, heads, head_dim, use_bias=False)
440
+ self.mlp = MLP(hidden, config.mlp_ratio)
441
+ self.norm2 = RMSNorm(hidden)
442
+
443
+ def forward(
444
+ self,
445
+ txt: torch.Tensor,
446
+ img: torch.Tensor,
447
+ vec: torch.Tensor,
448
+ rope: Optional[torch.Tensor] = None,
449
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
450
+ L = txt.shape[1]
451
+
452
+ x = torch.cat([txt, img], dim=1)
453
+
454
+ x_normed, gate = self.norm(x, vec)
455
+ x = x + gate.unsqueeze(1) * self.attn(x_normed, rope)
456
+ x = x + self.mlp(self.norm2(x))
457
+
458
+ txt, img = x.split([L, x.shape[1] - L], dim=1)
459
+ return txt, img
460
+
461
+
462
+ # =============================================================================
463
+ # Main Model
464
+ # =============================================================================
465
+
466
+ class TinyFluxDeep(nn.Module):
467
+ """TinyFlux-Deep: 15 double + 25 single blocks."""
468
+
469
+ def __init__(self, config: Optional[TinyFluxDeepConfig] = None):
470
+ super().__init__()
471
+ self.config = config or TinyFluxDeepConfig()
472
+ cfg = self.config
473
+
474
+ # Input projections (with bias to match checkpoint)
475
+ self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size, bias=True)
476
+ self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size, bias=True)
477
+
478
+ # Conditioning
479
+ self.time_in = MLPEmbedder(cfg.hidden_size)
480
+ self.vector_in = nn.Sequential(
481
+ nn.SiLU(),
482
+ nn.Linear(cfg.pooled_projection_dim, cfg.hidden_size, bias=True)
483
+ )
484
+ if cfg.guidance_embeds:
485
+ self.guidance_in = MLPEmbedder(cfg.hidden_size)
486
+
487
+ # RoPE (old format with cached freqs)
488
+ self.rope = EmbedND(theta=10000.0, axes_dim=cfg.axes_dims_rope)
489
+
490
+ # Transformer blocks
491
+ self.double_blocks = nn.ModuleList([
492
+ DoubleStreamBlock(cfg) for _ in range(cfg.num_double_layers)
493
+ ])
494
+ self.single_blocks = nn.ModuleList([
495
+ SingleStreamBlock(cfg) for _ in range(cfg.num_single_layers)
496
+ ])
497
+
498
+ # Output
499
+ self.final_norm = RMSNorm(cfg.hidden_size)
500
+ self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels, bias=True)
501
+
502
+ self._init_weights()
503
+
504
+ def _init_weights(self):
505
+ def _init(module):
506
+ if isinstance(module, nn.Linear):
507
+ nn.init.xavier_uniform_(module.weight)
508
+ if module.bias is not None:
509
+ nn.init.zeros_(module.bias)
510
+ self.apply(_init)
511
+ nn.init.zeros_(self.final_linear.weight)
512
+
513
+ def forward(
514
+ self,
515
+ hidden_states: torch.Tensor,
516
+ encoder_hidden_states: torch.Tensor,
517
+ pooled_projections: torch.Tensor,
518
+ timestep: torch.Tensor,
519
+ img_ids: torch.Tensor,
520
+ txt_ids: Optional[torch.Tensor] = None,
521
+ guidance: Optional[torch.Tensor] = None,
522
+ ) -> torch.Tensor:
523
+ B = hidden_states.shape[0]
524
+ L = encoder_hidden_states.shape[1]
525
+ N = hidden_states.shape[1]
526
+
527
+ # Input projections
528
+ img = self.img_in(hidden_states)
529
+ txt = self.txt_in(encoder_hidden_states)
530
+
531
+ # Conditioning
532
+ vec = self.time_in(timestep)
533
+ vec = vec + self.vector_in(pooled_projections)
534
+ if self.config.guidance_embeds and guidance is not None:
535
+ vec = vec + self.guidance_in(guidance)
536
+
537
+ # Handle img_ids shape
538
+ if img_ids.ndim == 3:
539
+ img_ids = img_ids[0] # (N, 3)
540
+
541
+ # Compute RoPE for image positions
542
+ img_rope = self.rope(img_ids) # (N, 1, head_dim)
543
+
544
+ # Double-stream blocks
545
+ for block in self.double_blocks:
546
+ txt, img = block(txt, img, vec, img_rope)
547
+
548
+ # Build full sequence RoPE for single-stream
549
+ if txt_ids is None:
550
+ txt_ids = torch.zeros(L, 3, device=img_ids.device, dtype=img_ids.dtype)
551
+ elif txt_ids.ndim == 3:
552
+ txt_ids = txt_ids[0]
553
+
554
+ all_ids = torch.cat([txt_ids, img_ids], dim=0)
555
+ full_rope = self.rope(all_ids)
556
+
557
+ # Single-stream blocks
558
+ for block in self.single_blocks:
559
+ txt, img = block(txt, img, vec, full_rope)
560
+
561
+ # Output
562
+ img = self.final_norm(img)
563
+ img = self.final_linear(img)
564
+
565
+ return img
566
+
567
+ @staticmethod
568
+ def create_img_ids(batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
569
+ """Create image position IDs for RoPE."""
570
+ img_ids = torch.zeros(height * width, 3, device=device)
571
+ for i in range(height):
572
+ for j in range(width):
573
+ idx = i * width + j
574
+ img_ids[idx, 0] = 0
575
+ img_ids[idx, 1] = i
576
+ img_ids[idx, 2] = j
577
+ return img_ids
578
+
579
+ @staticmethod
580
+ def create_txt_ids(text_len: int, device: torch.device) -> torch.Tensor:
581
+ """Create text position IDs."""
582
+ txt_ids = torch.zeros(text_len, 3, device=device)
583
+ txt_ids[:, 0] = torch.arange(text_len, device=device)
584
+ return txt_ids
585
+
586
+ def count_parameters(self) -> dict:
587
+ """Count parameters by component."""
588
+ counts = {}
589
+ counts['img_in'] = sum(p.numel() for p in self.img_in.parameters())
590
+ counts['txt_in'] = sum(p.numel() for p in self.txt_in.parameters())
591
+ counts['time_in'] = sum(p.numel() for p in self.time_in.parameters())
592
+ counts['vector_in'] = sum(p.numel() for p in self.vector_in.parameters())
593
+ if hasattr(self, 'guidance_in'):
594
+ counts['guidance_in'] = sum(p.numel() for p in self.guidance_in.parameters())
595
+ counts['double_blocks'] = sum(p.numel() for p in self.double_blocks.parameters())
596
+ counts['single_blocks'] = sum(p.numel() for p in self.single_blocks.parameters())
597
+ counts['final'] = sum(p.numel() for p in self.final_norm.parameters()) + \
598
+ sum(p.numel() for p in self.final_linear.parameters())
599
+ counts['total'] = sum(p.numel() for p in self.parameters())
600
+ return counts
601
+
602
+
603
+ # =============================================================================
604
+ # Test
605
+ # =============================================================================
606
+
607
+ def test_model():
608
+ """Test TinyFlux-Deep model."""
609
+ print("=" * 60)
610
+ print("TinyFlux-Deep Test")
611
+ print("=" * 60)
612
+
613
+ config = TinyFluxDeepConfig()
614
+ model = TinyFluxDeep(config)
615
+
616
+ counts = model.count_parameters()
617
+ print(f"\nConfig:")
618
+ print(f" hidden_size: {config.hidden_size}")
619
+ print(f" num_attention_heads: {config.num_attention_heads}")
620
+ print(f" attention_head_dim: {config.attention_head_dim}")
621
+ print(f" num_double_layers: {config.num_double_layers}")
622
+ print(f" num_single_layers: {config.num_single_layers}")
623
+
624
+ print(f"\nParameters:")
625
+ for name, count in counts.items():
626
+ print(f" {name}: {count:,}")
627
+
628
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
629
+ model = model.to(device)
630
+
631
+ B, H, W = 2, 64, 64
632
+ L = 77
633
+
634
+ hidden_states = torch.randn(B, H * W, config.in_channels, device=device)
635
+ encoder_hidden_states = torch.randn(B, L, config.joint_attention_dim, device=device)
636
+ pooled_projections = torch.randn(B, config.pooled_projection_dim, device=device)
637
+ timestep = torch.rand(B, device=device)
638
+ img_ids = TinyFluxDeep.create_img_ids(B, H, W, device)
639
+ txt_ids = TinyFluxDeep.create_txt_ids(L, device)
640
+ guidance = torch.ones(B, device=device) * 3.5
641
+
642
+ with torch.no_grad():
643
+ output = model(
644
+ hidden_states=hidden_states,
645
+ encoder_hidden_states=encoder_hidden_states,
646
+ pooled_projections=pooled_projections,
647
+ timestep=timestep,
648
+ img_ids=img_ids,
649
+ txt_ids=txt_ids,
650
+ guidance=guidance,
651
+ )
652
+
653
+ print(f"\nOutput shape: {output.shape}")
654
+ print(f"Output range: [{output.min():.4f}, {output.max():.4f}]")
655
+ print("\n✓ Forward pass successful!")
656
+
657
+
658
+ #if __name__ == "__main__":
659
+ # test_model()
660
+
661
+ # ============================================================================
662
+ # TinyFlux-Deep Inference Cell - Euler Discrete Flow Matching
663
+ # ============================================================================
664
+ # Run the model cell before this one (defines TinyFluxDeep, TinyFluxDeepConfig)
665
+ # Loads from: AbstractPhil/tiny-flux-deep or local checkpoint
666
+ # ============================================================================
667
+
668
+ import torch
669
+ from huggingface_hub import hf_hub_download
670
+ from safetensors.torch import load_file
671
+ from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
672
+ from diffusers import AutoencoderKL
673
+ from PIL import Image
674
+ import numpy as np
675
+ import os
676
+
677
+
678
+ # Generation settings
679
+ NUM_STEPS = STEPS
680
+ GUIDANCE_SCALE = CFG_GUIDANCE
681
+ SHIFT = FLUX_SHIFT
682
+
683
+ # ============================================================================
684
+ # LOAD TEXT ENCODERS
685
+ # ============================================================================
686
+ print("Loading text encoders...")
687
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
688
+ t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval()
689
+
690
+ clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
691
+ clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
692
+
693
+ # ============================================================================
694
+ # LOAD VAE
695
+ # ============================================================================
696
+ print("Loading Flux VAE...")
697
+ vae = AutoencoderKL.from_pretrained(
698
+ "black-forest-labs/FLUX.1-schnell",
699
+ subfolder="vae",
700
+ torch_dtype=DTYPE
701
+ ).to(DEVICE).eval()
702
+
703
+ # ============================================================================
704
+ # LOAD TINYFLUX-DEEP MODEL
705
+ # ============================================================================
706
+ print(f"Loading TinyFlux-Deep from: {LOAD_FROM}")
707
+
708
+ # Use TinyFluxDeep (512 hidden, 4 heads, 15 double, 25 single)
709
+ config = TinyFluxDeepConfig()
710
+ model = TinyFluxDeep(config).to(DEVICE).to(DTYPE)
711
+
712
+ # Deprecated keys that may exist in old checkpoints but aren't needed
713
+ DEPRECATED_KEYS = {'time_in.sin_basis', 'guidance_in.sin_basis'}
714
+
715
+
716
+ def load_weights(path):
717
+ """Load weights from .safetensors or .pt file."""
718
+ if path.endswith(".safetensors"):
719
+ state_dict = load_file(path)
720
+ elif path.endswith(".pt"):
721
+ ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
722
+ if isinstance(ckpt, dict):
723
+ if "model" in ckpt:
724
+ state_dict = ckpt["model"]
725
+ elif "state_dict" in ckpt:
726
+ state_dict = ckpt["state_dict"]
727
+ else:
728
+ state_dict = ckpt
729
+ else:
730
+ state_dict = ckpt
731
+ else:
732
+ try:
733
+ state_dict = load_file(path)
734
+ except:
735
+ state_dict = torch.load(path, map_location=DEVICE, weights_only=False)
736
+
737
+ # Strip "_orig_mod." prefix from keys (added by torch.compile)
738
+ if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
739
+ print(" Stripping torch.compile prefix...")
740
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
741
+
742
+ return state_dict
743
+
744
+
745
+ def load_model_weights(model, weights, source_name):
746
+ """Load weights with verbose reporting."""
747
+ # Filter out deprecated keys
748
+ filtered_weights = {k: v for k, v in weights.items() if k not in DEPRECATED_KEYS}
749
+ deprecated_found = [k for k in weights.keys() if k in DEPRECATED_KEYS]
750
+
751
+ if deprecated_found:
752
+ print(f" ✓ Ignored deprecated keys: {deprecated_found}")
753
+
754
+ missing, unexpected = model.load_state_dict(filtered_weights, strict=False)
755
+
756
+ if missing:
757
+ print(f" ⚠ Missing keys: {missing[:10]}{'...' if len(missing) > 10 else ''}")
758
+ if unexpected:
759
+ print(f" ⚠ Unexpected keys: {unexpected[:10]}{'...' if len(unexpected) > 10 else ''}")
760
+ if not missing and not unexpected:
761
+ print(f" ✓ All weights loaded successfully")
762
+
763
+ print(f"✓ Loaded from {source_name}")
764
+
765
+
766
+ if LOAD_FROM == "hub":
767
+ try:
768
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors")
769
+ except:
770
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.pt")
771
+ weights = load_weights(weights_path)
772
+ load_model_weights(model, weights, HF_REPO)
773
+
774
+ elif LOAD_FROM.startswith("hub:"):
775
+ ckpt_name = LOAD_FROM[4:]
776
+ for ext in [".safetensors", ".pt", ""]:
777
+ try:
778
+ if ckpt_name.endswith((".safetensors", ".pt")):
779
+ filename = ckpt_name if "/" in ckpt_name else f"checkpoints/{ckpt_name}"
780
+ else:
781
+ filename = f"checkpoints/{ckpt_name}{ext}"
782
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
783
+ weights = load_weights(weights_path)
784
+ load_model_weights(model, weights, f"{HF_REPO}/{filename}")
785
+ break
786
+ except Exception as e:
787
+ continue
788
+ else:
789
+ raise ValueError(f"Could not find checkpoint: {ckpt_name}")
790
+
791
+ elif LOAD_FROM.startswith("local:"):
792
+ weights_path = LOAD_FROM[6:]
793
+ weights = load_weights(weights_path)
794
+ load_model_weights(model, weights, weights_path)
795
+
796
+ else:
797
+ raise ValueError(f"Unknown LOAD_FROM: {LOAD_FROM}")
798
+
799
+ model.eval()
800
+ print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")
801
+
802
+ # ============================================================================
803
+ # ENCODING FUNCTIONS
804
+ # ============================================================================
805
+ @torch.inference_mode()
806
+ def encode_prompt(prompt: str, max_length: int = 128):
807
+ """Encode prompt with flan-t5-base and CLIP-L."""
808
+ t5_in = t5_tok(
809
+ prompt,
810
+ max_length=max_length,
811
+ padding="max_length",
812
+ truncation=True,
813
+ return_tensors="pt"
814
+ ).to(DEVICE)
815
+ t5_out = t5_enc(
816
+ input_ids=t5_in.input_ids,
817
+ attention_mask=t5_in.attention_mask
818
+ ).last_hidden_state
819
+
820
+ clip_in = clip_tok(
821
+ prompt,
822
+ max_length=77,
823
+ padding="max_length",
824
+ truncation=True,
825
+ return_tensors="pt"
826
+ ).to(DEVICE)
827
+ clip_out = clip_enc(
828
+ input_ids=clip_in.input_ids,
829
+ attention_mask=clip_in.attention_mask
830
+ )
831
+ clip_pooled = clip_out.pooler_output
832
+
833
+ return t5_out.to(DTYPE), clip_pooled.to(DTYPE)
834
+
835
+
836
+ # ============================================================================
837
+ # FLOW MATCHING HELPERS
838
+ # ============================================================================
839
+ def flux_shift(t, s=SHIFT):
840
+ """Flux timestep shift - biases towards higher t (closer to data)."""
841
+ return s * t / (1 + (s - 1) * t)
842
+
843
+
844
+ # ============================================================================
845
+ # EULER DISCRETE FLOW MATCHING SAMPLER
846
+ # ============================================================================
847
+ @torch.inference_mode()
848
+ def euler_sample(
849
+ model,
850
+ prompt: str,
851
+ negative_prompt: str = "",
852
+ num_steps: int = 28,
853
+ guidance_scale: float = 3.5,
854
+ height: int = 512,
855
+ width: int = 512,
856
+ seed: int = None,
857
+ ):
858
+ """
859
+ Euler discrete sampler for rectified flow matching.
860
+
861
+ Flow Matching formulation:
862
+ x_t = (1 - t) * noise + t * data
863
+ At t=0: noise, At t=1: data
864
+ Velocity v = data - noise (constant)
865
+
866
+ Sampling: Integrate from t=0 (noise) to t=1 (data)
867
+ """
868
+ if seed is not None:
869
+ torch.manual_seed(seed)
870
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
871
+ else:
872
+ generator = None
873
+
874
+ H_lat = height // 8
875
+ W_lat = width // 8
876
+ C_lat = 16
877
+
878
+ # Encode prompts
879
+ t5_cond, clip_cond = encode_prompt(prompt)
880
+ if guidance_scale > 1.0 and negative_prompt is not None:
881
+ t5_uncond, clip_uncond = encode_prompt(negative_prompt)
882
+ else:
883
+ t5_uncond, clip_uncond = None, None
884
+
885
+ # Start from pure noise (t=0)
886
+ x = torch.randn(1, H_lat * W_lat, C_lat, device=DEVICE, dtype=DTYPE, generator=generator)
887
+
888
+ # Create image position IDs
889
+ img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
890
+
891
+ # Timesteps: 0 → 1 with flux shift
892
+ t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
893
+ timesteps = flux_shift(t_linear, s=SHIFT)
894
+
895
+ print(f"Sampling with {num_steps} Euler steps (t: 0→1, shifted)...")
896
+
897
+ for i in range(num_steps):
898
+ t_curr = timesteps[i]
899
+ t_next = timesteps[i + 1]
900
+ dt = t_next - t_curr
901
+
902
+ t_batch = t_curr.unsqueeze(0)
903
+ guidance_embed = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
904
+
905
+ # Predict velocity
906
+ v_cond = model(
907
+ hidden_states=x,
908
+ encoder_hidden_states=t5_cond,
909
+ pooled_projections=clip_cond,
910
+ timestep=t_batch,
911
+ img_ids=img_ids,
912
+ guidance=guidance_embed,
913
+ )
914
+
915
+ # Classifier-free guidance
916
+ if guidance_scale > 1.0 and t5_uncond is not None:
917
+ v_uncond = model(
918
+ hidden_states=x,
919
+ encoder_hidden_states=t5_uncond,
920
+ pooled_projections=clip_uncond,
921
+ timestep=t_batch,
922
+ img_ids=img_ids,
923
+ guidance=guidance_embed,
924
+ )
925
+ v = v_uncond + guidance_scale * (v_cond - v_uncond)
926
+ else:
927
+ v = v_cond
928
+
929
+ # Euler step: x_{t+dt} = x_t + v * dt
930
+ x = x + v * dt
931
+
932
+ if (i + 1) % max(1, num_steps // 5) == 0 or i == num_steps - 1:
933
+ print(f" Step {i+1}/{num_steps}, t={t_next.item():.3f}")
934
+
935
+ # Reshape: (1, H*W, C) -> (1, C, H, W)
936
+ latents = x.reshape(1, H_lat, W_lat, C_lat).permute(0, 3, 1, 2)
937
+
938
+ return latents
939
+
940
+
941
+ # ============================================================================
942
+ # DECODE LATENTS TO IMAGE
943
+ # ============================================================================
944
+ @torch.inference_mode()
945
+ def decode_latents(latents):
946
+ """Decode VAE latents to PIL Image."""
947
+ latents = latents / vae.config.scaling_factor
948
+ image = vae.decode(latents.to(vae.dtype)).sample
949
+ image = (image / 2 + 0.5).clamp(0, 1)
950
+ image = image[0].float().permute(1, 2, 0).cpu().numpy()
951
+ image = (image * 255).astype(np.uint8)
952
+ return Image.fromarray(image)
953
+
954
+
955
+ # ============================================================================
956
+ # MAIN GENERATION FUNCTION
957
+ # ============================================================================
958
+ def generate(
959
+ prompt: str = POSITIVE_PROMPT,
960
+ negative_prompt: str = NEGATIVE_PROMPT,
961
+ num_steps: int = NUM_STEPS,
962
+ guidance_scale: float = GUIDANCE_SCALE,
963
+ height: int = HEIGHT,
964
+ width: int = WIDTH,
965
+ seed: int = SEED,
966
+ save_path: str = OUTPUT_PATH,
967
+ ):
968
+ """
969
+ Generate an image from a text prompt.
970
+
971
+ Args:
972
+ prompt: Text description of desired image
973
+ negative_prompt: What to avoid (empty string for none)
974
+ num_steps: Number of Euler steps (20-50 recommended)
975
+ guidance_scale: CFG scale (1.0=none, 3-7 typical)
976
+ height: Output height in pixels (divisible by 8)
977
+ width: Output width in pixels (divisible by 8)
978
+ seed: Random seed (None for random)
979
+ save_path: Path to save image (None to skip)
980
+
981
+ Returns:
982
+ PIL.Image
983
+ """
984
+ #print(f"\nGenerating: '{prompt}'")
985
+ #print(f"Settings: {num_steps} steps, cfg={guidance_scale}, {width}x{height}, seed={seed}")
986
+
987
+ latents = euler_sample(
988
+ model=model,
989
+ prompt=prompt,
990
+ negative_prompt=negative_prompt,
991
+ num_steps=num_steps,
992
+ guidance_scale=guidance_scale,
993
+ height=height,
994
+ width=width,
995
+ seed=seed,
996
+ )
997
+
998
+ #print("Decoding latents...")
999
+ image = decode_latents(latents)
1000
+
1001
+ if save_path:
1002
+ image.save(save_path)
1003
+ #print(f"✓ Saved to {save_path}")
1004
+
1005
+ set_preview_from_pil(image, square=512)
1006
+
1007
+ print("✓ Done!")
1008
+ return image
1009
+
1010
+
1011
+ # ============================================================================
1012
+ # BATCH GENERATION
1013
+ # ============================================================================
1014
+ def generate_batch(
1015
+ prompts: list,
1016
+ negative_prompt: str = "",
1017
+ num_steps: int = NUM_STEPS,
1018
+ guidance_scale: float = GUIDANCE_SCALE,
1019
+ height: int = HEIGHT,
1020
+ width: int = WIDTH,
1021
+ seed: int = SEED,
1022
+ output_dir: str = "./outputs",
1023
+ ):
1024
+ """Generate multiple images."""
1025
+ os.makedirs(output_dir, exist_ok=True)
1026
+ images = []
1027
+
1028
+ for i, prompt in enumerate(prompts):
1029
+ img_seed = seed + i if seed is not None else None
1030
+ image = generate(
1031
+ prompt=prompt,
1032
+ negative_prompt=negative_prompt,
1033
+ num_steps=num_steps,
1034
+ guidance_scale=guidance_scale,
1035
+ height=height,
1036
+ width=width,
1037
+ seed=img_seed,
1038
+ save_path=os.path.join(output_dir, f"{i:03d}.png"),
1039
+ )
1040
+ images.append(image)
1041
+
1042
+ return images
1043
+
1044
+
1045
+ # ============================================================================
1046
+ # QUICK TEST
1047
+ # ============================================================================
1048
+ #print("\n" + "="*60)
1049
+ #print("TinyFlux-Deep Inference Ready!")
1050
+ #print("="*60)
1051
+ #print(f"Config: {config.hidden_size} hidden, {config.num_attention_heads} heads")
1052
+ #print(f" {config.num_double_layers} double, {config.num_single_layers} single layers")
1053
+ #print(f"Total: {sum(p.numel() for p in model.parameters()):,} parameters")
1054
+
1055
+ # Example usage:
1056
+ image = generate()
1057
+ #image