AbstractPhil commited on
Commit
6a487eb
·
verified ·
1 Parent(s): 0832ec9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +539 -0
app.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TinyFlux-Lailah Gradio Demo
3
+ HuggingFace Spaces with ZeroGPU support
4
+ """
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import random
9
+ import spaces
10
+ import torch
11
+ from huggingface_hub import hf_hub_download
12
+ from safetensors.torch import load_file
13
+ from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
14
+ from diffusers import AutoencoderKL
15
+ from PIL import Image
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Tuple
21
+
22
+ # ============================================================================
23
+ # MODEL DEFINITION (TinyFluxDeep / Lailah)
24
+ # ============================================================================
25
+
26
+ @dataclass
27
+ class TinyFluxDeepConfig:
28
+ hidden_size: int = 512
29
+ num_attention_heads: int = 4
30
+ attention_head_dim: int = 128
31
+ in_channels: int = 16
32
+ patch_size: int = 1
33
+ joint_attention_dim: int = 768
34
+ pooled_projection_dim: int = 768
35
+ num_double_layers: int = 15
36
+ num_single_layers: int = 25
37
+ mlp_ratio: float = 4.0
38
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
39
+ guidance_embeds: bool = True
40
+
41
+ def __post_init__(self):
42
+ assert self.num_attention_heads * self.attention_head_dim == self.hidden_size
43
+
44
+
45
+ class RMSNorm(nn.Module):
46
+ def __init__(self, dim, eps=1e-6):
47
+ super().__init__()
48
+ self.weight = nn.Parameter(torch.ones(dim))
49
+ self.eps = eps
50
+
51
+ def forward(self, x):
52
+ dtype = x.dtype
53
+ x = x.float()
54
+ norm = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
55
+ return (x * norm).to(dtype) * self.weight
56
+
57
+
58
+ class EmbedND(nn.Module):
59
+ def __init__(self, theta=10000.0, axes_dim=(16, 56, 56)):
60
+ super().__init__()
61
+ self.theta = theta
62
+ self.axes_dim = axes_dim
63
+ for i, dim in enumerate(axes_dim):
64
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
65
+ self.register_buffer(f'freqs_{i}', freqs, persistent=True)
66
+
67
+ def forward(self, ids):
68
+ rope_components = []
69
+ for i, dim in enumerate(self.axes_dim):
70
+ freqs = getattr(self, f'freqs_{i}').to(ids.device)
71
+ axis_ids = ids[..., i:i+1]
72
+ angles = axis_ids * freqs
73
+ cos = torch.cos(angles)
74
+ sin = torch.sin(angles)
75
+ interleaved = torch.stack([cos, sin], dim=-1).flatten(-2)
76
+ rope_components.append(interleaved)
77
+ return torch.cat(rope_components, dim=-1)
78
+
79
+
80
+ def apply_rope(x, rope):
81
+ B, H, N, D = x.shape
82
+ rope = rope[:, :N, :D]
83
+ rope = rope.unsqueeze(1)
84
+ x_pairs = x.reshape(B, H, N, D // 2, 2)
85
+ cos = rope[..., 0::2]
86
+ sin = rope[..., 1::2]
87
+ x_rot = torch.stack([
88
+ x_pairs[..., 0] * cos - x_pairs[..., 1] * sin,
89
+ x_pairs[..., 1] * cos + x_pairs[..., 0] * sin,
90
+ ], dim=-1)
91
+ return x_rot.flatten(-2)
92
+
93
+
94
+ class MLPEmbedder(nn.Module):
95
+ def __init__(self, in_dim, hidden_dim):
96
+ super().__init__()
97
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
98
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
99
+
100
+ def forward(self, x):
101
+ return self.fc2(F.silu(self.fc1(x)))
102
+
103
+
104
+ class QKNorm(nn.Module):
105
+ def __init__(self, dim):
106
+ super().__init__()
107
+ self.query_norm = RMSNorm(dim)
108
+ self.key_norm = RMSNorm(dim)
109
+
110
+ def forward(self, q, k):
111
+ return self.query_norm(q), self.key_norm(k)
112
+
113
+
114
+ class DoubleAttention(nn.Module):
115
+ def __init__(self, hidden_size, num_heads, head_dim):
116
+ super().__init__()
117
+ self.num_heads = num_heads
118
+ self.head_dim = head_dim
119
+ qkv_dim = num_heads * head_dim * 3
120
+ self.img_qkv = nn.Linear(hidden_size, qkv_dim, bias=True)
121
+ self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=True)
122
+ self.txt_qkv = nn.Linear(hidden_size, qkv_dim, bias=True)
123
+ self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=True)
124
+ self.img_norm = QKNorm(head_dim)
125
+ self.txt_norm = QKNorm(head_dim)
126
+
127
+ def forward(self, img, txt, rope):
128
+ B, N_img, _ = img.shape
129
+ N_txt = txt.shape[1]
130
+ img_qkv = self.img_qkv(img).reshape(B, N_img, 3, self.num_heads, self.head_dim)
131
+ img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4).unbind(0)
132
+ img_q, img_k = self.img_norm(img_q, img_k)
133
+ img_q = apply_rope(img_q, rope)
134
+ img_k = apply_rope(img_k, rope)
135
+ txt_qkv = self.txt_qkv(txt).reshape(B, N_txt, 3, self.num_heads, self.head_dim)
136
+ txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4).unbind(0)
137
+ txt_q, txt_k = self.txt_norm(txt_q, txt_k)
138
+ q = torch.cat([txt_q, img_q], dim=2)
139
+ k = torch.cat([txt_k, img_k], dim=2)
140
+ v = torch.cat([txt_v, img_v], dim=2)
141
+ attn_out = F.scaled_dot_product_attention(q, k, v)
142
+ txt_out, img_out = attn_out.split([N_txt, N_img], dim=2)
143
+ img_out = img_out.transpose(1, 2).reshape(B, N_img, -1)
144
+ txt_out = txt_out.transpose(1, 2).reshape(B, N_txt, -1)
145
+ return self.img_out(img_out), self.txt_out(txt_out)
146
+
147
+
148
+ class DoubleBlock(nn.Module):
149
+ def __init__(self, hidden_size, num_heads, head_dim, mlp_ratio=4.0):
150
+ super().__init__()
151
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
152
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
153
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
154
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
155
+ self.img_mod = nn.Linear(hidden_size, hidden_size * 6, bias=True)
156
+ self.txt_mod = nn.Linear(hidden_size, hidden_size * 6, bias=True)
157
+ self.attn = DoubleAttention(hidden_size, num_heads, head_dim)
158
+ mlp_hidden = int(hidden_size * mlp_ratio)
159
+ self.img_mlp = nn.Sequential(
160
+ nn.Linear(hidden_size, mlp_hidden, bias=True),
161
+ nn.GELU(approximate="tanh"),
162
+ nn.Linear(mlp_hidden, hidden_size, bias=True),
163
+ )
164
+ self.txt_mlp = nn.Sequential(
165
+ nn.Linear(hidden_size, mlp_hidden, bias=True),
166
+ nn.GELU(approximate="tanh"),
167
+ nn.Linear(mlp_hidden, hidden_size, bias=True),
168
+ )
169
+
170
+ def forward(self, img, txt, cond, rope):
171
+ img_mod = self.img_mod(cond)
172
+ img_scale1, img_shift1, img_gate1, img_scale2, img_shift2, img_gate2 = img_mod.chunk(6, dim=-1)
173
+ txt_mod = self.txt_mod(cond)
174
+ txt_scale1, txt_shift1, txt_gate1, txt_scale2, txt_shift2, txt_gate2 = txt_mod.chunk(6, dim=-1)
175
+ img_normed = self.img_norm1(img) * (1 + img_scale1.unsqueeze(1)) + img_shift1.unsqueeze(1)
176
+ txt_normed = self.txt_norm1(txt) * (1 + txt_scale1.unsqueeze(1)) + txt_shift1.unsqueeze(1)
177
+ img_attn, txt_attn = self.attn(img_normed, txt_normed, rope)
178
+ img = img + img_gate1.unsqueeze(1) * img_attn
179
+ txt = txt + txt_gate1.unsqueeze(1) * txt_attn
180
+ img_normed2 = self.img_norm2(img) * (1 + img_scale2.unsqueeze(1)) + img_shift2.unsqueeze(1)
181
+ txt_normed2 = self.txt_norm2(txt) * (1 + txt_scale2.unsqueeze(1)) + txt_shift2.unsqueeze(1)
182
+ img = img + img_gate2.unsqueeze(1) * self.img_mlp(img_normed2)
183
+ txt = txt + txt_gate2.unsqueeze(1) * self.txt_mlp(txt_normed2)
184
+ return img, txt
185
+
186
+
187
+ class SingleAttention(nn.Module):
188
+ def __init__(self, hidden_size, num_heads, head_dim):
189
+ super().__init__()
190
+ self.num_heads = num_heads
191
+ self.head_dim = head_dim
192
+ self.qkv = nn.Linear(hidden_size, num_heads * head_dim * 3, bias=True)
193
+ self.norm = QKNorm(head_dim)
194
+
195
+ def forward(self, x, rope):
196
+ B, N, _ = x.shape
197
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
198
+ q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
199
+ q, k = self.norm(q, k)
200
+ q = apply_rope(q, rope)
201
+ k = apply_rope(k, rope)
202
+ out = F.scaled_dot_product_attention(q, k, v)
203
+ return out.transpose(1, 2).reshape(B, N, -1)
204
+
205
+
206
+ class SingleBlock(nn.Module):
207
+ def __init__(self, hidden_size, num_heads, head_dim, mlp_ratio=4.0):
208
+ super().__init__()
209
+ self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
210
+ self.mod = nn.Linear(hidden_size, hidden_size * 3, bias=True)
211
+ self.attn = SingleAttention(hidden_size, num_heads, head_dim)
212
+ self.proj = nn.Linear(num_heads * head_dim, hidden_size, bias=True)
213
+ mlp_hidden = int(hidden_size * mlp_ratio)
214
+ self.mlp = nn.Sequential(
215
+ nn.Linear(hidden_size, mlp_hidden, bias=True),
216
+ nn.GELU(approximate="tanh"),
217
+ nn.Linear(mlp_hidden, hidden_size, bias=True),
218
+ )
219
+
220
+ def forward(self, x, cond, rope):
221
+ mod = self.mod(cond)
222
+ scale, shift, gate = mod.chunk(3, dim=-1)
223
+ normed = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
224
+ attn_out = self.proj(self.attn(normed, rope))
225
+ mlp_out = self.mlp(normed)
226
+ return x + gate.unsqueeze(1) * (attn_out + mlp_out)
227
+
228
+
229
+ class TinyFluxDeep(nn.Module):
230
+ def __init__(self, cfg: TinyFluxDeepConfig):
231
+ super().__init__()
232
+ self.cfg = cfg
233
+ self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size, bias=True)
234
+ self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size, bias=True)
235
+ self.time_in = MLPEmbedder(256, cfg.hidden_size)
236
+ self.guidance_in = MLPEmbedder(256, cfg.hidden_size)
237
+ self.vector_in = MLPEmbedder(cfg.pooled_projection_dim, cfg.hidden_size)
238
+ self.rope = EmbedND(axes_dim=cfg.axes_dims_rope)
239
+ self.double_blocks = nn.ModuleList([
240
+ DoubleBlock(cfg.hidden_size, cfg.num_attention_heads, cfg.attention_head_dim, cfg.mlp_ratio)
241
+ for _ in range(cfg.num_double_layers)
242
+ ])
243
+ self.single_blocks = nn.ModuleList([
244
+ SingleBlock(cfg.hidden_size, cfg.num_attention_heads, cfg.attention_head_dim, cfg.mlp_ratio)
245
+ for _ in range(cfg.num_single_layers)
246
+ ])
247
+ self.final_norm = nn.LayerNorm(cfg.hidden_size, elementwise_affine=False, eps=1e-6)
248
+ self.final_mod = nn.Linear(cfg.hidden_size, cfg.hidden_size * 2, bias=True)
249
+ self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels, bias=True)
250
+
251
+ def time_embed(self, t):
252
+ half_dim = 128
253
+ freqs = torch.exp(-math.log(10000) * torch.arange(half_dim, device=t.device) / half_dim)
254
+ args = t.unsqueeze(-1) * freqs * 1000
255
+ return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
256
+
257
+ @staticmethod
258
+ def create_img_ids(batch_size, h, w, device):
259
+ img_ids = torch.zeros(h, w, 3, device=device)
260
+ img_ids[..., 1] = torch.arange(h, device=device).unsqueeze(1)
261
+ img_ids[..., 2] = torch.arange(w, device=device).unsqueeze(0)
262
+ return img_ids.reshape(1, h * w, 3).expand(batch_size, -1, -1)
263
+
264
+ def forward(self, hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, guidance=None):
265
+ img = self.img_in(hidden_states)
266
+ txt = self.txt_in(encoder_hidden_states)
267
+ t_emb = self.time_embed(timestep)
268
+ cond = self.time_in(t_emb)
269
+ if guidance is not None and self.cfg.guidance_embeds:
270
+ g_emb = self.time_embed(guidance)
271
+ cond = cond + self.guidance_in(g_emb)
272
+ cond = cond + self.vector_in(pooled_projections)
273
+ rope = self.rope(img_ids)
274
+ for block in self.double_blocks:
275
+ img, txt = block(img, txt, cond, rope)
276
+ x = torch.cat([txt, img], dim=1)
277
+ for block in self.single_blocks:
278
+ x = block(x, cond, rope)
279
+ img = x[:, txt.shape[1]:, :]
280
+ mod = self.final_mod(cond)
281
+ scale, shift = mod.chunk(2, dim=-1)
282
+ img = self.final_norm(img) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
283
+ return self.final_linear(img)
284
+
285
+
286
+ # ============================================================================
287
+ # GLOBALS
288
+ # ============================================================================
289
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
290
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
291
+ MAX_SEED = np.iinfo(np.int32).max
292
+ SHIFT = 3.0
293
+
294
+ # ============================================================================
295
+ # LOAD MODELS (outside GPU function for ZeroGPU compatibility)
296
+ # ============================================================================
297
+ print("Loading TinyFlux-Lailah...")
298
+
299
+ # Model
300
+ config = TinyFluxDeepConfig()
301
+ model = TinyFluxDeep(config)
302
+
303
+ # Load EMA weights (best quality)
304
+ weights_path = hf_hub_download("AbstractPhil/tiny-flux-deep", "checkpoints/step_286250_ema.safetensors")
305
+ weights = load_file(weights_path)
306
+ model.load_state_dict(weights, strict=False)
307
+ model.eval()
308
+ model.to(DTYPE)
309
+ print(f"✓ Model loaded ({sum(p.numel() for p in model.parameters()):,} params)")
310
+
311
+ # Text encoders
312
+ print("Loading text encoders...")
313
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
314
+ t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE)
315
+ clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
316
+ clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE)
317
+ print("✓ Text encoders loaded")
318
+
319
+ # VAE
320
+ print("Loading VAE...")
321
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae", torch_dtype=DTYPE)
322
+ VAE_SCALE = vae.config.scaling_factor
323
+ print("✓ VAE loaded")
324
+
325
+
326
+ # ============================================================================
327
+ # INFERENCE FUNCTIONS
328
+ # ============================================================================
329
+ def flux_shift(t, s=SHIFT):
330
+ return s * t / (1 + (s - 1) * t)
331
+
332
+
333
+ @spaces.GPU(duration=90)
334
+ def generate(
335
+ prompt: str,
336
+ negative_prompt: str,
337
+ seed: int,
338
+ randomize_seed: bool,
339
+ width: int,
340
+ height: int,
341
+ guidance_scale: float,
342
+ num_inference_steps: int,
343
+ progress=gr.Progress(track_tqdm=True),
344
+ ):
345
+ """Generate image with TinyFlux-Lailah."""
346
+ if randomize_seed:
347
+ seed = random.randint(0, MAX_SEED)
348
+
349
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
350
+
351
+ # Move models to GPU
352
+ model.to(DEVICE)
353
+ t5_enc.to(DEVICE)
354
+ clip_enc.to(DEVICE)
355
+ vae.to(DEVICE)
356
+
357
+ with torch.inference_mode():
358
+ # Encode prompt
359
+ t5_in = t5_tok(
360
+ prompt, max_length=128, padding="max_length",
361
+ truncation=True, return_tensors="pt"
362
+ ).to(DEVICE)
363
+ t5_out = t5_enc(**t5_in).last_hidden_state.to(DTYPE)
364
+
365
+ clip_in = clip_tok(
366
+ prompt, max_length=77, padding="max_length",
367
+ truncation=True, return_tensors="pt"
368
+ ).to(DEVICE)
369
+ clip_out = clip_enc(**clip_in).pooler_output.to(DTYPE)
370
+
371
+ # Latent dimensions
372
+ H_lat = height // 8
373
+ W_lat = width // 8
374
+ C = 16
375
+
376
+ # Start from noise
377
+ x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator)
378
+ img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
379
+
380
+ # Timesteps with Flux shift
381
+ t_linear = torch.linspace(0, 1, num_inference_steps + 1, device=DEVICE, dtype=DTYPE)
382
+ timesteps = flux_shift(t_linear, s=SHIFT)
383
+
384
+ # Euler sampling
385
+ for i in range(num_inference_steps):
386
+ t_curr = timesteps[i]
387
+ t_next = timesteps[i + 1]
388
+ dt = t_next - t_curr
389
+
390
+ t_batch = t_curr.unsqueeze(0)
391
+ guidance = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
392
+
393
+ v = model(
394
+ hidden_states=x,
395
+ encoder_hidden_states=t5_out,
396
+ pooled_projections=clip_out,
397
+ timestep=t_batch,
398
+ img_ids=img_ids,
399
+ guidance=guidance,
400
+ )
401
+ x = x + v * dt
402
+
403
+ # Decode
404
+ latents = x.reshape(1, H_lat, W_lat, C).permute(0, 3, 1, 2)
405
+ latents = latents / VAE_SCALE
406
+ image = vae.decode(latents.to(vae.dtype)).sample
407
+ image = (image / 2 + 0.5).clamp(0, 1)
408
+
409
+ # To PIL
410
+ image = image[0].float().permute(1, 2, 0).cpu().numpy()
411
+ image = (image * 255).astype(np.uint8)
412
+ image = Image.fromarray(image)
413
+
414
+ return image, seed
415
+
416
+
417
+ # ============================================================================
418
+ # GRADIO INTERFACE
419
+ # ============================================================================
420
+ examples = [
421
+ "a photo of a cat sitting on a windowsill",
422
+ "a portrait of a woman with red hair, professional photography",
423
+ "a black backpack on white background, product photo",
424
+ "astronaut riding a horse on mars, digital art",
425
+ "a cozy coffee shop interior, warm lighting",
426
+ ]
427
+
428
+ css = """
429
+ #col-container {
430
+ margin: 0 auto;
431
+ max-width: 720px;
432
+ }
433
+ """
434
+
435
+ with gr.Blocks(css=css) as demo:
436
+ with gr.Column(elem_id="col-container"):
437
+ gr.Markdown("""
438
+ # TinyFlux-Lailah
439
+
440
+ **241M parameter** flow-matching text-to-image model.
441
+ Trained on teacher latents from Flux-Schnell.
442
+
443
+ [Model Card](https://huggingface.co/AbstractPhil/tiny-flux-deep) |
444
+ [GitHub](https://github.com/AbstractPhil)
445
+ """)
446
+
447
+ with gr.Row():
448
+ prompt = gr.Text(
449
+ label="Prompt",
450
+ show_label=False,
451
+ max_lines=2,
452
+ placeholder="Enter your prompt...",
453
+ container=False,
454
+ )
455
+ run_button = gr.Button("Generate", scale=0, variant="primary")
456
+
457
+ result = gr.Image(label="Result", show_label=False)
458
+
459
+ with gr.Accordion("Settings", open=False):
460
+ negative_prompt = gr.Text(
461
+ label="Negative prompt",
462
+ max_lines=1,
463
+ placeholder="(not used in this model)",
464
+ visible=False,
465
+ )
466
+
467
+ seed = gr.Slider(
468
+ label="Seed",
469
+ minimum=0,
470
+ maximum=MAX_SEED,
471
+ step=1,
472
+ value=42,
473
+ )
474
+
475
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
476
+
477
+ with gr.Row():
478
+ width = gr.Slider(
479
+ label="Width",
480
+ minimum=256,
481
+ maximum=768,
482
+ step=64,
483
+ value=512,
484
+ )
485
+
486
+ height = gr.Slider(
487
+ label="Height",
488
+ minimum=256,
489
+ maximum=768,
490
+ step=64,
491
+ value=512,
492
+ )
493
+
494
+ with gr.Row():
495
+ guidance_scale = gr.Slider(
496
+ label="Guidance scale",
497
+ minimum=1.0,
498
+ maximum=10.0,
499
+ step=0.5,
500
+ value=3.5,
501
+ )
502
+
503
+ num_inference_steps = gr.Slider(
504
+ label="Steps",
505
+ minimum=10,
506
+ maximum=50,
507
+ step=1,
508
+ value=28,
509
+ )
510
+
511
+ gr.Examples(examples=examples, inputs=[prompt])
512
+
513
+ gr.Markdown("""
514
+ ---
515
+ **Notes:**
516
+ - Trained on 512×512 resolution
517
+ - Best results at guidance 3.0-5.0
518
+ - 20-30 steps recommended
519
+ - Early checkpoint - quality improving with training
520
+ """)
521
+
522
+ gr.on(
523
+ triggers=[run_button.click, prompt.submit],
524
+ fn=generate,
525
+ inputs=[
526
+ prompt,
527
+ negative_prompt,
528
+ seed,
529
+ randomize_seed,
530
+ width,
531
+ height,
532
+ guidance_scale,
533
+ num_inference_steps,
534
+ ],
535
+ outputs=[result, seed],
536
+ )
537
+
538
+ if __name__ == "__main__":
539
+ demo.launch()