AbstractPhil commited on
Commit
a29d3c5
·
verified ·
1 Parent(s): 5086269

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -305
app.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  TinyFlux-Lailah Gradio Demo
3
  HuggingFace Spaces with ZeroGPU support
 
4
  """
5
 
6
  import gradio as gr
@@ -8,19 +9,20 @@ import numpy as np
8
  import random
9
  import spaces
10
  import torch
 
 
 
 
 
11
  from huggingface_hub import hf_hub_download
12
  from safetensors.torch import load_file
13
  from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
14
  from diffusers import AutoencoderKL
15
  from PIL import Image
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- import math
19
- from dataclasses import dataclass
20
- from typing import Tuple
21
 
22
  # ============================================================================
23
- # MODEL DEFINITION (TinyFluxDeep / Lailah)
24
  # ============================================================================
25
 
26
  @dataclass
@@ -40,23 +42,29 @@ class TinyFluxDeepConfig:
40
 
41
  def __post_init__(self):
42
  assert self.num_attention_heads * self.attention_head_dim == self.hidden_size
 
43
 
44
 
45
  class RMSNorm(nn.Module):
46
- def __init__(self, dim, eps=1e-6):
47
  super().__init__()
48
- self.weight = nn.Parameter(torch.ones(dim))
49
  self.eps = eps
 
 
 
 
 
50
 
51
- def forward(self, x):
52
- dtype = x.dtype
53
- x = x.float()
54
- norm = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
55
- return (x * norm).to(dtype) * self.weight
 
56
 
57
 
58
  class EmbedND(nn.Module):
59
- def __init__(self, theta=10000.0, axes_dim=(16, 56, 56)):
60
  super().__init__()
61
  self.theta = theta
62
  self.axes_dim = axes_dim
@@ -64,231 +72,263 @@ class EmbedND(nn.Module):
64
  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
65
  self.register_buffer(f'freqs_{i}', freqs, persistent=True)
66
 
67
- def forward(self, ids):
68
- rope_components = []
69
- for i, dim in enumerate(self.axes_dim):
70
- freqs = getattr(self, f'freqs_{i}').to(ids.device)
71
- axis_ids = ids[..., i:i+1]
72
- angles = axis_ids * freqs
73
- cos = torch.cos(angles)
74
- sin = torch.sin(angles)
75
- interleaved = torch.stack([cos, sin], dim=-1).flatten(-2)
76
- rope_components.append(interleaved)
77
- return torch.cat(rope_components, dim=-1)
78
-
79
-
80
- def apply_rope(x, rope):
81
- B, H, N, D = x.shape
82
- rope = rope[:, :N, :D]
83
- rope = rope.unsqueeze(1)
84
- x_pairs = x.reshape(B, H, N, D // 2, 2)
85
- cos = rope[..., 0::2]
86
- sin = rope[..., 1::2]
87
- x_rot = torch.stack([
88
- x_pairs[..., 0] * cos - x_pairs[..., 1] * sin,
89
- x_pairs[..., 1] * cos + x_pairs[..., 0] * sin,
90
- ], dim=-1)
91
- return x_rot.flatten(-2)
92
 
93
 
94
  class MLPEmbedder(nn.Module):
95
- def __init__(self, in_dim, hidden_dim):
96
  super().__init__()
97
- self.fc1 = nn.Linear(in_dim, hidden_dim)
98
- self.fc2 = nn.Linear(hidden_dim, hidden_dim)
 
 
 
99
 
100
- def forward(self, x):
101
- return self.fc2(F.silu(self.fc1(x)))
 
 
 
 
 
102
 
103
 
104
- class QKNorm(nn.Module):
105
- def __init__(self, dim):
106
  super().__init__()
107
- self.query_norm = RMSNorm(dim)
108
- self.key_norm = RMSNorm(dim)
 
109
 
110
- def forward(self, q, k):
111
- return self.query_norm(q), self.key_norm(k)
 
 
 
112
 
113
 
114
- class DoubleAttention(nn.Module):
115
- def __init__(self, hidden_size, num_heads, head_dim):
116
  super().__init__()
117
- self.num_heads = num_heads
118
- self.head_dim = head_dim
119
- qkv_dim = num_heads * head_dim * 3
120
- self.img_qkv = nn.Linear(hidden_size, qkv_dim, bias=True)
121
- self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=True)
122
- self.txt_qkv = nn.Linear(hidden_size, qkv_dim, bias=True)
123
- self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=True)
124
- self.img_norm = QKNorm(head_dim)
125
- self.txt_norm = QKNorm(head_dim)
126
-
127
- def forward(self, img, txt, rope):
128
- B, N_img, _ = img.shape
129
- N_txt = txt.shape[1]
130
- img_qkv = self.img_qkv(img).reshape(B, N_img, 3, self.num_heads, self.head_dim)
131
- img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4).unbind(0)
132
- img_q, img_k = self.img_norm(img_q, img_k)
133
- img_q = apply_rope(img_q, rope)
134
- img_k = apply_rope(img_k, rope)
135
- txt_qkv = self.txt_qkv(txt).reshape(B, N_txt, 3, self.num_heads, self.head_dim)
136
- txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4).unbind(0)
137
- txt_q, txt_k = self.txt_norm(txt_q, txt_k)
138
- q = torch.cat([txt_q, img_q], dim=2)
139
- k = torch.cat([txt_k, img_k], dim=2)
140
- v = torch.cat([txt_v, img_v], dim=2)
141
- attn_out = F.scaled_dot_product_attention(q, k, v)
142
- txt_out, img_out = attn_out.split([N_txt, N_img], dim=2)
143
- img_out = img_out.transpose(1, 2).reshape(B, N_img, -1)
144
- txt_out = txt_out.transpose(1, 2).reshape(B, N_txt, -1)
145
- return self.img_out(img_out), self.txt_out(txt_out)
146
 
 
 
 
 
 
147
 
148
- class DoubleBlock(nn.Module):
149
- def __init__(self, hidden_size, num_heads, head_dim, mlp_ratio=4.0):
150
- super().__init__()
151
- self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
152
- self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
153
- self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
154
- self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
155
- self.img_mod = nn.Linear(hidden_size, hidden_size * 6, bias=True)
156
- self.txt_mod = nn.Linear(hidden_size, hidden_size * 6, bias=True)
157
- self.attn = DoubleAttention(hidden_size, num_heads, head_dim)
158
- mlp_hidden = int(hidden_size * mlp_ratio)
159
- self.img_mlp = nn.Sequential(
160
- nn.Linear(hidden_size, mlp_hidden, bias=True),
161
- nn.GELU(approximate="tanh"),
162
- nn.Linear(mlp_hidden, hidden_size, bias=True),
163
- )
164
- self.txt_mlp = nn.Sequential(
165
- nn.Linear(hidden_size, mlp_hidden, bias=True),
166
- nn.GELU(approximate="tanh"),
167
- nn.Linear(mlp_hidden, hidden_size, bias=True),
168
- )
169
 
170
- def forward(self, img, txt, cond, rope):
171
- img_mod = self.img_mod(cond)
172
- img_scale1, img_shift1, img_gate1, img_scale2, img_shift2, img_gate2 = img_mod.chunk(6, dim=-1)
173
- txt_mod = self.txt_mod(cond)
174
- txt_scale1, txt_shift1, txt_gate1, txt_scale2, txt_shift2, txt_gate2 = txt_mod.chunk(6, dim=-1)
175
- img_normed = self.img_norm1(img) * (1 + img_scale1.unsqueeze(1)) + img_shift1.unsqueeze(1)
176
- txt_normed = self.txt_norm1(txt) * (1 + txt_scale1.unsqueeze(1)) + txt_shift1.unsqueeze(1)
177
- img_attn, txt_attn = self.attn(img_normed, txt_normed, rope)
178
- img = img + img_gate1.unsqueeze(1) * img_attn
179
- txt = txt + txt_gate1.unsqueeze(1) * txt_attn
180
- img_normed2 = self.img_norm2(img) * (1 + img_scale2.unsqueeze(1)) + img_shift2.unsqueeze(1)
181
- txt_normed2 = self.txt_norm2(txt) * (1 + txt_scale2.unsqueeze(1)) + txt_shift2.unsqueeze(1)
182
- img = img + img_gate2.unsqueeze(1) * self.img_mlp(img_normed2)
183
- txt = txt + txt_gate2.unsqueeze(1) * self.txt_mlp(txt_normed2)
184
- return img, txt
185
-
186
-
187
- class SingleAttention(nn.Module):
188
- def __init__(self, hidden_size, num_heads, head_dim):
189
  super().__init__()
190
  self.num_heads = num_heads
191
  self.head_dim = head_dim
192
- self.qkv = nn.Linear(hidden_size, num_heads * head_dim * 3, bias=True)
193
- self.norm = QKNorm(head_dim)
194
 
195
- def forward(self, x, rope):
196
  B, N, _ = x.shape
197
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
198
- q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
199
- q, k = self.norm(q, k)
200
- q = apply_rope(q, rope)
201
- k = apply_rope(k, rope)
202
- out = F.scaled_dot_product_attention(q, k, v)
203
- return out.transpose(1, 2).reshape(B, N, -1)
 
204
 
205
 
206
- class SingleBlock(nn.Module):
207
- def __init__(self, hidden_size, num_heads, head_dim, mlp_ratio=4.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  super().__init__()
209
- self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
210
- self.mod = nn.Linear(hidden_size, hidden_size * 3, bias=True)
211
- self.attn = SingleAttention(hidden_size, num_heads, head_dim)
212
- self.proj = nn.Linear(num_heads * head_dim, hidden_size, bias=True)
213
  mlp_hidden = int(hidden_size * mlp_ratio)
214
- self.mlp = nn.Sequential(
215
- nn.Linear(hidden_size, mlp_hidden, bias=True),
216
- nn.GELU(approximate="tanh"),
217
- nn.Linear(mlp_hidden, hidden_size, bias=True),
218
- )
 
219
 
220
- def forward(self, x, cond, rope):
221
- mod = self.mod(cond)
222
- scale, shift, gate = mod.chunk(3, dim=-1)
223
- normed = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
224
- attn_out = self.proj(self.attn(normed, rope))
225
- mlp_out = self.mlp(normed)
226
- return x + gate.unsqueeze(1) * (attn_out + mlp_out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
 
229
  class TinyFluxDeep(nn.Module):
230
- def __init__(self, cfg: TinyFluxDeepConfig):
231
  super().__init__()
232
- self.cfg = cfg
 
233
  self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size, bias=True)
234
  self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size, bias=True)
235
- self.time_in = MLPEmbedder(256, cfg.hidden_size)
236
- self.guidance_in = MLPEmbedder(256, cfg.hidden_size)
237
- self.vector_in = MLPEmbedder(cfg.pooled_projection_dim, cfg.hidden_size)
238
- self.rope = EmbedND(axes_dim=cfg.axes_dims_rope)
 
 
 
 
239
  self.double_blocks = nn.ModuleList([
240
- DoubleBlock(cfg.hidden_size, cfg.num_attention_heads, cfg.attention_head_dim, cfg.mlp_ratio)
241
- for _ in range(cfg.num_double_layers)
242
  ])
243
  self.single_blocks = nn.ModuleList([
244
- SingleBlock(cfg.hidden_size, cfg.num_attention_heads, cfg.attention_head_dim, cfg.mlp_ratio)
245
- for _ in range(cfg.num_single_layers)
246
  ])
247
- self.final_norm = nn.LayerNorm(cfg.hidden_size, elementwise_affine=False, eps=1e-6)
248
- self.final_mod = nn.Linear(cfg.hidden_size, cfg.hidden_size * 2, bias=True)
249
  self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels, bias=True)
250
 
251
- def time_embed(self, t):
252
- half_dim = 128
253
- freqs = torch.exp(-math.log(10000) * torch.arange(half_dim, device=t.device) / half_dim)
254
- args = t.unsqueeze(-1) * freqs * 1000
255
- return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
256
-
257
- @staticmethod
258
- def create_img_ids(batch_size, h, w, device):
259
- img_ids = torch.zeros(h, w, 3, device=device)
260
- img_ids[..., 1] = torch.arange(h, device=device).unsqueeze(1)
261
- img_ids[..., 2] = torch.arange(w, device=device).unsqueeze(0)
262
- return img_ids.reshape(1, h * w, 3).expand(batch_size, -1, -1)
263
 
264
- def forward(self, hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, guidance=None):
265
  img = self.img_in(hidden_states)
266
  txt = self.txt_in(encoder_hidden_states)
267
- t_emb = self.time_embed(timestep).to(hidden_states.dtype)
268
- cond = self.time_in(t_emb)
269
- if guidance is not None and self.cfg.guidance_embeds:
270
- g_emb = self.time_embed(guidance).to(hidden_states.dtype)
271
- cond = cond + self.guidance_in(g_emb)
272
- cond = cond + self.vector_in(pooled_projections)
273
- rope = self.rope(img_ids)
 
 
 
274
  for block in self.double_blocks:
275
- img, txt = block(img, txt, cond, rope)
276
- x = torch.cat([txt, img], dim=1)
277
-
278
- # Pad rope with identity for text positions (text has no positional encoding)
279
- # RoPE format is interleaved [cos0, sin0, cos1, sin1, ...], identity = cos=1, sin=0
280
- txt_len = txt.shape[1]
281
- identity_rope = torch.zeros(rope.shape[0], txt_len, rope.shape[-1], device=rope.device, dtype=rope.dtype)
282
- identity_rope[..., 0::2] = 1.0 # cos positions = 1, sin positions stay 0
283
- full_rope = torch.cat([identity_rope, rope], dim=1)
284
-
285
  for block in self.single_blocks:
286
- x = block(x, cond, full_rope)
287
- img = x[:, txt.shape[1]:, :]
288
- mod = self.final_mod(cond)
289
- scale, shift = mod.chunk(2, dim=-1)
290
- img = self.final_norm(img) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
291
- return self.final_linear(img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
 
294
  # ============================================================================
@@ -299,16 +339,15 @@ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
299
  MAX_SEED = np.iinfo(np.int32).max
300
  SHIFT = 3.0
301
 
 
302
  # ============================================================================
303
- # LOAD MODELS (outside GPU function for ZeroGPU compatibility)
304
  # ============================================================================
305
  print("Loading TinyFlux-Lailah...")
306
 
307
- # Model
308
  config = TinyFluxDeepConfig()
309
  model = TinyFluxDeep(config)
310
 
311
- # Load EMA weights (best quality)
312
  weights_path = hf_hub_download("AbstractPhil/tiny-flux-deep", "checkpoints/step_297500_ema.safetensors")
313
  weights = load_file(weights_path)
314
  model.load_state_dict(weights, strict=False)
@@ -316,7 +355,6 @@ model.eval()
316
  model.to(DTYPE)
317
  print(f"✓ Model loaded ({sum(p.numel() for p in model.parameters()):,} params)")
318
 
319
- # Text encoders
320
  print("Loading text encoders...")
321
  t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
322
  t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE)
@@ -324,19 +362,19 @@ clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
324
  clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE)
325
  print("✓ Text encoders loaded")
326
 
327
- # VAE (local diffusers format)
328
  print("Loading VAE...")
329
  vae = AutoencoderKL.from_pretrained("./vae", torch_dtype=DTYPE)
330
  vae.eval()
331
  VAE_SCALE = vae.config.scaling_factor
332
- print("✓ VAE loaded")
333
 
334
 
335
  # ============================================================================
336
- # INFERENCE FUNCTIONS
337
  # ============================================================================
338
- def flux_shift(t, s=SHIFT):
339
- return s * t / (1 + (s - 1) * t)
 
340
 
341
 
342
  @spaces.GPU(duration=90)
@@ -351,76 +389,75 @@ def generate(
351
  num_inference_steps: int,
352
  progress=gr.Progress(track_tqdm=True),
353
  ):
354
- """Generate image with TinyFlux-Lailah."""
355
  if randomize_seed:
356
  seed = random.randint(0, MAX_SEED)
357
-
358
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
359
-
360
- # Move models to GPU
361
  model.to(DEVICE)
362
  t5_enc.to(DEVICE)
363
  clip_enc.to(DEVICE)
364
  vae.to(DEVICE)
365
-
366
- with torch.inference_mode():
367
  # Encode prompt
368
- t5_in = t5_tok(
369
- prompt, max_length=128, padding="max_length",
370
- truncation=True, return_tensors="pt"
371
- ).to(DEVICE)
372
- t5_out = t5_enc(**t5_in).last_hidden_state.to(DTYPE)
373
-
374
- clip_in = clip_tok(
375
- prompt, max_length=77, padding="max_length",
376
- truncation=True, return_tensors="pt"
377
- ).to(DEVICE)
378
- clip_out = clip_enc(**clip_in).pooler_output.to(DTYPE)
379
-
380
  # Latent dimensions
381
  H_lat = height // 8
382
  W_lat = width // 8
383
  C = 16
384
-
385
- # Start from noise
 
386
  x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator)
387
- img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
388
-
389
- # Timesteps with Flux shift
390
- t_linear = torch.linspace(0, 1, num_inference_steps + 1, device=DEVICE, dtype=DTYPE)
391
- timesteps = flux_shift(t_linear, s=SHIFT)
392
 
393
- # Euler sampling
 
 
 
 
 
 
 
 
394
  for i in range(num_inference_steps):
395
  t_curr = timesteps[i]
396
  t_next = timesteps[i + 1]
397
- dt = t_next - t_curr
398
-
399
  t_batch = t_curr.unsqueeze(0)
400
  guidance = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
401
-
402
- with torch.autocast(device_type="cuda", dtype=DTYPE):
403
- v = model(
404
- hidden_states=x,
405
- encoder_hidden_states=t5_out,
406
- pooled_projections=clip_out,
407
- timestep=t_batch,
408
- img_ids=img_ids,
409
- guidance=guidance,
410
- )
411
  x = x + v * dt
412
-
413
- # Decode
414
  latents = x.reshape(1, H_lat, W_lat, C).permute(0, 3, 1, 2)
415
  latents = latents / VAE_SCALE
416
  image = vae.decode(latents.to(vae.dtype)).sample
417
  image = (image / 2 + 0.5).clamp(0, 1)
418
-
419
  # To PIL
420
  image = image[0].float().permute(1, 2, 0).cpu().numpy()
421
  image = (image * 255).astype(np.uint8)
422
  image = Image.fromarray(image)
423
-
424
  return image, seed
425
 
426
 
@@ -450,13 +487,13 @@ with gr.Blocks(css=css) as demo:
450
  **241M parameter** flow-matching text-to-image model.
451
  Trained on teacher latents from Flux-Schnell.
452
 
453
- [Model Card](https://huggingface.co/AbstractPhil/tiny-flux-deep) |
454
- [GitHub](https://github.com/AbstractPhil)
455
  """)
456
 
457
  with gr.Row():
458
  prompt = gr.Text(
459
  label="Prompt",
 
460
  show_label=False,
461
  max_lines=2,
462
  placeholder="Enter your prompt...",
@@ -470,78 +507,31 @@ with gr.Blocks(css=css) as demo:
470
  negative_prompt = gr.Text(
471
  label="Negative prompt",
472
  max_lines=1,
473
- placeholder="(not used in this model)",
474
  visible=False,
475
  )
476
-
477
- seed = gr.Slider(
478
- label="Seed",
479
- minimum=0,
480
- maximum=MAX_SEED,
481
- step=1,
482
- value=42,
483
- )
484
-
485
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
486
 
487
  with gr.Row():
488
- width = gr.Slider(
489
- label="Width",
490
- minimum=256,
491
- maximum=768,
492
- step=64,
493
- value=512,
494
- )
495
-
496
- height = gr.Slider(
497
- label="Height",
498
- minimum=256,
499
- maximum=768,
500
- step=64,
501
- value=512,
502
- )
503
 
504
  with gr.Row():
505
- guidance_scale = gr.Slider(
506
- label="Guidance scale",
507
- minimum=1.0,
508
- maximum=10.0,
509
- step=0.5,
510
- value=3.5,
511
- )
512
-
513
- num_inference_steps = gr.Slider(
514
- label="Steps",
515
- minimum=10,
516
- maximum=50,
517
- step=1,
518
- value=28,
519
- )
520
 
521
  gr.Examples(examples=examples, inputs=[prompt])
522
-
523
  gr.Markdown("""
524
  ---
525
- **Notes:**
526
- - Trained on 512×512 resolution
527
- - Best results at guidance 3.0-5.0
528
- - 20-30 steps recommended
529
- - Early checkpoint - quality improving with training
530
  """)
531
 
532
  gr.on(
533
  triggers=[run_button.click, prompt.submit],
534
  fn=generate,
535
- inputs=[
536
- prompt,
537
- negative_prompt,
538
- seed,
539
- randomize_seed,
540
- width,
541
- height,
542
- guidance_scale,
543
- num_inference_steps,
544
- ],
545
  outputs=[result, seed],
546
  )
547
 
 
1
  """
2
  TinyFlux-Lailah Gradio Demo
3
  HuggingFace Spaces with ZeroGPU support
4
+ Euler discrete flow matching inference
5
  """
6
 
7
  import gradio as gr
 
9
  import random
10
  import spaces
11
  import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple
17
  from huggingface_hub import hf_hub_download
18
  from safetensors.torch import load_file
19
  from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
20
  from diffusers import AutoencoderKL
21
  from PIL import Image
22
+
 
 
 
 
23
 
24
  # ============================================================================
25
+ # MODEL DEFINITION - Exact copy from tinyflux_deep.py
26
  # ============================================================================
27
 
28
  @dataclass
 
42
 
43
  def __post_init__(self):
44
  assert self.num_attention_heads * self.attention_head_dim == self.hidden_size
45
+ assert sum(self.axes_dims_rope) == self.attention_head_dim
46
 
47
 
48
  class RMSNorm(nn.Module):
49
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
50
  super().__init__()
 
51
  self.eps = eps
52
+ self.elementwise_affine = elementwise_affine
53
+ if elementwise_affine:
54
+ self.weight = nn.Parameter(torch.ones(dim))
55
+ else:
56
+ self.register_parameter('weight', None)
57
 
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
60
+ out = (x * norm).type_as(x)
61
+ if self.weight is not None:
62
+ out = out * self.weight
63
+ return out
64
 
65
 
66
  class EmbedND(nn.Module):
67
+ def __init__(self, theta: float = 10000.0, axes_dim: Tuple[int, int, int] = (16, 56, 56)):
68
  super().__init__()
69
  self.theta = theta
70
  self.axes_dim = axes_dim
 
72
  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
73
  self.register_buffer(f'freqs_{i}', freqs, persistent=True)
74
 
75
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
76
+ device = ids.device
77
+ n_axes = ids.shape[-1]
78
+ emb_list = []
79
+ for i in range(n_axes):
80
+ freqs = getattr(self, f'freqs_{i}').to(device)
81
+ pos = ids[:, i].float()
82
+ angles = pos.unsqueeze(-1) * freqs.unsqueeze(0)
83
+ cos = angles.cos()
84
+ sin = angles.sin()
85
+ emb = torch.stack([cos, sin], dim=-1).flatten(-2)
86
+ emb_list.append(emb)
87
+ rope = torch.cat(emb_list, dim=-1)
88
+ return rope.unsqueeze(1)
89
+
90
+
91
+ def apply_rotary_emb_old(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
92
+ freqs = freqs_cis.squeeze(1)
93
+ cos = freqs[:, 0::2].repeat_interleave(2, dim=-1)
94
+ sin = freqs[:, 1::2].repeat_interleave(2, dim=-1)
95
+ cos = cos[None, None, :, :].to(x.device)
96
+ sin = sin[None, None, :, :].to(x.device)
97
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
98
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(-2)
99
+ return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
100
 
101
 
102
  class MLPEmbedder(nn.Module):
103
+ def __init__(self, hidden_size: int):
104
  super().__init__()
105
+ self.mlp = nn.Sequential(
106
+ nn.Linear(256, hidden_size),
107
+ nn.SiLU(),
108
+ nn.Linear(hidden_size, hidden_size),
109
+ )
110
 
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ half_dim = 128
113
+ emb = math.log(10000) / (half_dim - 1)
114
+ emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
115
+ emb = x.unsqueeze(-1) * emb.unsqueeze(0)
116
+ emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
117
+ return self.mlp(emb)
118
 
119
 
120
+ class AdaLayerNormZero(nn.Module):
121
+ def __init__(self, hidden_size: int):
122
  super().__init__()
123
+ self.silu = nn.SiLU()
124
+ self.linear = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
125
+ self.norm = RMSNorm(hidden_size)
126
 
127
+ def forward(self, x: torch.Tensor, emb: torch.Tensor):
128
+ emb_out = self.linear(self.silu(emb))
129
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb_out.chunk(6, dim=-1)
130
+ x = self.norm(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
131
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
132
 
133
 
134
+ class AdaLayerNormZeroSingle(nn.Module):
135
+ def __init__(self, hidden_size: int):
136
  super().__init__()
137
+ self.silu = nn.SiLU()
138
+ self.linear = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
139
+ self.norm = RMSNorm(hidden_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ def forward(self, x: torch.Tensor, emb: torch.Tensor):
142
+ emb_out = self.linear(self.silu(emb))
143
+ shift, scale, gate = emb_out.chunk(3, dim=-1)
144
+ x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
145
+ return x, gate
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ class Attention(nn.Module):
149
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  super().__init__()
151
  self.num_heads = num_heads
152
  self.head_dim = head_dim
153
+ self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
154
+ self.out_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
155
 
156
+ def forward(self, x: torch.Tensor, rope: Optional[torch.Tensor] = None) -> torch.Tensor:
157
  B, N, _ = x.shape
158
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
159
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
160
+ if rope is not None:
161
+ q = apply_rotary_emb_old(q, rope)
162
+ k = apply_rotary_emb_old(k, rope)
163
+ attn = F.scaled_dot_product_attention(q, k, v)
164
+ out = attn.transpose(1, 2).reshape(B, N, -1)
165
+ return self.out_proj(out)
166
 
167
 
168
+ class JointAttention(nn.Module):
169
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False):
170
+ super().__init__()
171
+ self.num_heads = num_heads
172
+ self.head_dim = head_dim
173
+ self.txt_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
174
+ self.img_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
175
+ self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
176
+ self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
177
+
178
+ def forward(self, txt: torch.Tensor, img: torch.Tensor, rope: Optional[torch.Tensor] = None):
179
+ B, L, _ = txt.shape
180
+ _, N, _ = img.shape
181
+ txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim)
182
+ img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim)
183
+ txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4)
184
+ img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4)
185
+ if rope is not None:
186
+ img_q = apply_rotary_emb_old(img_q, rope)
187
+ img_k = apply_rotary_emb_old(img_k, rope)
188
+ k = torch.cat([txt_k, img_k], dim=2)
189
+ v = torch.cat([txt_v, img_v], dim=2)
190
+ txt_out = F.scaled_dot_product_attention(txt_q, k, v)
191
+ txt_out = txt_out.transpose(1, 2).reshape(B, L, -1)
192
+ img_out = F.scaled_dot_product_attention(img_q, k, v)
193
+ img_out = img_out.transpose(1, 2).reshape(B, N, -1)
194
+ return self.txt_out(txt_out), self.img_out(img_out)
195
+
196
+
197
+ class MLP(nn.Module):
198
+ def __init__(self, hidden_size: int, mlp_ratio: float = 4.0):
199
  super().__init__()
 
 
 
 
200
  mlp_hidden = int(hidden_size * mlp_ratio)
201
+ self.fc1 = nn.Linear(hidden_size, mlp_hidden, bias=True)
202
+ self.act = nn.GELU(approximate='tanh')
203
+ self.fc2 = nn.Linear(mlp_hidden, hidden_size, bias=True)
204
+
205
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
206
+ return self.fc2(self.act(self.fc1(x)))
207
 
208
+
209
+ class DoubleStreamBlock(nn.Module):
210
+ def __init__(self, config: TinyFluxDeepConfig):
211
+ super().__init__()
212
+ hidden = config.hidden_size
213
+ heads = config.num_attention_heads
214
+ head_dim = config.attention_head_dim
215
+ self.img_norm1 = AdaLayerNormZero(hidden)
216
+ self.txt_norm1 = AdaLayerNormZero(hidden)
217
+ self.attn = JointAttention(hidden, heads, head_dim, use_bias=False)
218
+ self.img_norm2 = RMSNorm(hidden)
219
+ self.txt_norm2 = RMSNorm(hidden)
220
+ self.img_mlp = MLP(hidden, config.mlp_ratio)
221
+ self.txt_mlp = MLP(hidden, config.mlp_ratio)
222
+
223
+ def forward(self, txt, img, vec, rope=None):
224
+ img_normed, img_gate_msa, img_shift_mlp, img_scale_mlp, img_gate_mlp = self.img_norm1(img, vec)
225
+ txt_normed, txt_gate_msa, txt_shift_mlp, txt_scale_mlp, txt_gate_mlp = self.txt_norm1(txt, vec)
226
+ txt_attn_out, img_attn_out = self.attn(txt_normed, img_normed, rope)
227
+ txt = txt + txt_gate_msa.unsqueeze(1) * txt_attn_out
228
+ img = img + img_gate_msa.unsqueeze(1) * img_attn_out
229
+ txt_mlp_in = self.txt_norm2(txt) * (1 + txt_scale_mlp.unsqueeze(1)) + txt_shift_mlp.unsqueeze(1)
230
+ img_mlp_in = self.img_norm2(img) * (1 + img_scale_mlp.unsqueeze(1)) + img_shift_mlp.unsqueeze(1)
231
+ txt = txt + txt_gate_mlp.unsqueeze(1) * self.txt_mlp(txt_mlp_in)
232
+ img = img + img_gate_mlp.unsqueeze(1) * self.img_mlp(img_mlp_in)
233
+ return txt, img
234
+
235
+
236
+ class SingleStreamBlock(nn.Module):
237
+ def __init__(self, config: TinyFluxDeepConfig):
238
+ super().__init__()
239
+ hidden = config.hidden_size
240
+ heads = config.num_attention_heads
241
+ head_dim = config.attention_head_dim
242
+ self.norm = AdaLayerNormZeroSingle(hidden)
243
+ self.attn = Attention(hidden, heads, head_dim, use_bias=False)
244
+ self.mlp = MLP(hidden, config.mlp_ratio)
245
+ self.norm2 = RMSNorm(hidden)
246
+
247
+ def forward(self, txt, img, vec, rope=None):
248
+ L = txt.shape[1]
249
+ x = torch.cat([txt, img], dim=1)
250
+ x_normed, gate = self.norm(x, vec)
251
+ x = x + gate.unsqueeze(1) * self.attn(x_normed, rope)
252
+ x = x + self.mlp(self.norm2(x))
253
+ txt, img = x.split([L, x.shape[1] - L], dim=1)
254
+ return txt, img
255
 
256
 
257
  class TinyFluxDeep(nn.Module):
258
+ def __init__(self, config: Optional[TinyFluxDeepConfig] = None):
259
  super().__init__()
260
+ self.config = config or TinyFluxDeepConfig()
261
+ cfg = self.config
262
  self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size, bias=True)
263
  self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size, bias=True)
264
+ self.time_in = MLPEmbedder(cfg.hidden_size)
265
+ self.vector_in = nn.Sequential(
266
+ nn.SiLU(),
267
+ nn.Linear(cfg.pooled_projection_dim, cfg.hidden_size, bias=True)
268
+ )
269
+ if cfg.guidance_embeds:
270
+ self.guidance_in = MLPEmbedder(cfg.hidden_size)
271
+ self.rope = EmbedND(theta=10000.0, axes_dim=cfg.axes_dims_rope)
272
  self.double_blocks = nn.ModuleList([
273
+ DoubleStreamBlock(cfg) for _ in range(cfg.num_double_layers)
 
274
  ])
275
  self.single_blocks = nn.ModuleList([
276
+ SingleStreamBlock(cfg) for _ in range(cfg.num_single_layers)
 
277
  ])
278
+ self.final_norm = RMSNorm(cfg.hidden_size)
 
279
  self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels, bias=True)
280
 
281
+ def forward(self, hidden_states, encoder_hidden_states, pooled_projections, timestep,
282
+ img_ids, txt_ids=None, guidance=None):
283
+ B = hidden_states.shape[0]
284
+ L = encoder_hidden_states.shape[1]
285
+ N = hidden_states.shape[1]
 
 
 
 
 
 
 
286
 
 
287
  img = self.img_in(hidden_states)
288
  txt = self.txt_in(encoder_hidden_states)
289
+
290
+ vec = self.time_in(timestep)
291
+ vec = vec + self.vector_in(pooled_projections)
292
+ if self.config.guidance_embeds and guidance is not None:
293
+ vec = vec + self.guidance_in(guidance)
294
+
295
+ if img_ids.ndim == 3:
296
+ img_ids = img_ids[0]
297
+ img_rope = self.rope(img_ids)
298
+
299
  for block in self.double_blocks:
300
+ txt, img = block(txt, img, vec, img_rope)
301
+
302
+ if txt_ids is None:
303
+ txt_ids = torch.zeros(L, 3, device=img_ids.device, dtype=img_ids.dtype)
304
+ elif txt_ids.ndim == 3:
305
+ txt_ids = txt_ids[0]
306
+ all_ids = torch.cat([txt_ids, img_ids], dim=0)
307
+ full_rope = self.rope(all_ids)
308
+
 
309
  for block in self.single_blocks:
310
+ txt, img = block(txt, img, vec, full_rope)
311
+
312
+ img = self.final_norm(img)
313
+ img = self.final_linear(img)
314
+ return img
315
+
316
+ @staticmethod
317
+ def create_img_ids(batch_size: int, height: int, width: int, device) -> torch.Tensor:
318
+ img_ids = torch.zeros(height * width, 3, device=device)
319
+ for i in range(height):
320
+ for j in range(width):
321
+ idx = i * width + j
322
+ img_ids[idx, 0] = 0
323
+ img_ids[idx, 1] = i
324
+ img_ids[idx, 2] = j
325
+ return img_ids
326
+
327
+ @staticmethod
328
+ def create_txt_ids(text_len: int, device) -> torch.Tensor:
329
+ txt_ids = torch.zeros(text_len, 3, device=device)
330
+ txt_ids[:, 0] = torch.arange(text_len, device=device)
331
+ return txt_ids
332
 
333
 
334
  # ============================================================================
 
339
  MAX_SEED = np.iinfo(np.int32).max
340
  SHIFT = 3.0
341
 
342
+
343
  # ============================================================================
344
+ # LOAD MODELS
345
  # ============================================================================
346
  print("Loading TinyFlux-Lailah...")
347
 
 
348
  config = TinyFluxDeepConfig()
349
  model = TinyFluxDeep(config)
350
 
 
351
  weights_path = hf_hub_download("AbstractPhil/tiny-flux-deep", "checkpoints/step_297500_ema.safetensors")
352
  weights = load_file(weights_path)
353
  model.load_state_dict(weights, strict=False)
 
355
  model.to(DTYPE)
356
  print(f"✓ Model loaded ({sum(p.numel() for p in model.parameters()):,} params)")
357
 
 
358
  print("Loading text encoders...")
359
  t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
360
  t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE)
 
362
  clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE)
363
  print("✓ Text encoders loaded")
364
 
 
365
  print("Loading VAE...")
366
  vae = AutoencoderKL.from_pretrained("./vae", torch_dtype=DTYPE)
367
  vae.eval()
368
  VAE_SCALE = vae.config.scaling_factor
369
+ print(f"✓ VAE loaded (scale={VAE_SCALE})")
370
 
371
 
372
  # ============================================================================
373
+ # EULER DISCRETE FLOW MATCHING SAMPLER
374
  # ============================================================================
375
+ def flux_shift(t, shift=SHIFT):
376
+ """Flux time shift: s*t / (1 + (s-1)*t)"""
377
+ return shift * t / (1 + (shift - 1) * t)
378
 
379
 
380
  @spaces.GPU(duration=90)
 
389
  num_inference_steps: int,
390
  progress=gr.Progress(track_tqdm=True),
391
  ):
 
392
  if randomize_seed:
393
  seed = random.randint(0, MAX_SEED)
394
+
395
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
396
+
397
+ # Move to GPU
398
  model.to(DEVICE)
399
  t5_enc.to(DEVICE)
400
  clip_enc.to(DEVICE)
401
  vae.to(DEVICE)
402
+
403
+ with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=DTYPE):
404
  # Encode prompt
405
+ t5_in = t5_tok(prompt, max_length=128, padding="max_length",
406
+ truncation=True, return_tensors="pt").to(DEVICE)
407
+ t5_out = t5_enc(**t5_in).last_hidden_state
408
+
409
+ clip_in = clip_tok(prompt, max_length=77, padding="max_length",
410
+ truncation=True, return_tensors="pt").to(DEVICE)
411
+ clip_out = clip_enc(**clip_in).pooler_output
412
+
 
 
 
 
413
  # Latent dimensions
414
  H_lat = height // 8
415
  W_lat = width // 8
416
  C = 16
417
+ L = 128 # T5 sequence length
418
+
419
+ # Start from noise (t=1 in flow matching)
420
  x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator)
 
 
 
 
 
421
 
422
+ # Position IDs
423
+ img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
424
+ txt_ids = TinyFluxDeep.create_txt_ids(L, DEVICE)
425
+
426
+ # Timesteps: 1 -> 0 with Flux shift
427
+ t_linear = torch.linspace(1, 0, num_inference_steps + 1, device=DEVICE)
428
+ timesteps = flux_shift(t_linear, shift=SHIFT)
429
+
430
+ # Euler discrete flow matching: x_{t-dt} = x_t + v * dt
431
  for i in range(num_inference_steps):
432
  t_curr = timesteps[i]
433
  t_next = timesteps[i + 1]
434
+ dt = t_next - t_curr # Negative since going 1->0
435
+
436
  t_batch = t_curr.unsqueeze(0)
437
  guidance = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
438
+
439
+ v = model(
440
+ hidden_states=x,
441
+ encoder_hidden_states=t5_out,
442
+ pooled_projections=clip_out,
443
+ timestep=t_batch,
444
+ img_ids=img_ids,
445
+ txt_ids=txt_ids,
446
+ guidance=guidance,
447
+ )
448
  x = x + v * dt
449
+
450
+ # Decode latents
451
  latents = x.reshape(1, H_lat, W_lat, C).permute(0, 3, 1, 2)
452
  latents = latents / VAE_SCALE
453
  image = vae.decode(latents.to(vae.dtype)).sample
454
  image = (image / 2 + 0.5).clamp(0, 1)
455
+
456
  # To PIL
457
  image = image[0].float().permute(1, 2, 0).cpu().numpy()
458
  image = (image * 255).astype(np.uint8)
459
  image = Image.fromarray(image)
460
+
461
  return image, seed
462
 
463
 
 
487
  **241M parameter** flow-matching text-to-image model.
488
  Trained on teacher latents from Flux-Schnell.
489
 
490
+ [Model Card](https://huggingface.co/AbstractPhil/tiny-flux-deep)
 
491
  """)
492
 
493
  with gr.Row():
494
  prompt = gr.Text(
495
  label="Prompt",
496
+ value="cat",
497
  show_label=False,
498
  max_lines=2,
499
  placeholder="Enter your prompt...",
 
507
  negative_prompt = gr.Text(
508
  label="Negative prompt",
509
  max_lines=1,
510
+ placeholder="(not used)",
511
  visible=False,
512
  )
513
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
 
 
 
 
 
 
 
 
514
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
515
 
516
  with gr.Row():
517
+ width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
518
+ height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
519
 
520
  with gr.Row():
521
+ guidance_scale = gr.Slider(label="Guidance", minimum=1.0, maximum=10.0, step=0.5, value=3.5)
522
+ num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, step=1, value=28)
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
  gr.Examples(examples=examples, inputs=[prompt])
525
+
526
  gr.Markdown("""
527
  ---
528
+ **Notes:** Trained at 512×512. Best results at guidance 3.0-5.0, 20-30 steps.
 
 
 
 
529
  """)
530
 
531
  gr.on(
532
  triggers=[run_button.click, prompt.submit],
533
  fn=generate,
534
+ inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
 
 
 
 
 
 
 
 
 
535
  outputs=[result, seed],
536
  )
537