""" TinyFlux-Lailah Gradio Demo HuggingFace Spaces with ZeroGPU support Euler discrete flow matching inference """ import gradio as gr import numpy as np import random import spaces import torch import torch.nn as nn import torch.nn.functional as F import math from dataclasses import dataclass from typing import Optional, Tuple from huggingface_hub import hf_hub_download from safetensors.torch import load_file from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL from PIL import Image # ============================================================================ # MODEL DEFINITION - Exact copy from tinyflux_deep.py # ============================================================================ @dataclass class TinyFluxDeepConfig: hidden_size: int = 512 num_attention_heads: int = 4 attention_head_dim: int = 128 in_channels: int = 16 patch_size: int = 1 joint_attention_dim: int = 768 pooled_projection_dim: int = 768 num_double_layers: int = 15 num_single_layers: int = 25 mlp_ratio: float = 4.0 axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) guidance_embeds: bool = True def __post_init__(self): assert self.num_attention_heads * self.attention_head_dim == self.hidden_size assert sum(self.axes_dims_rope) == self.attention_head_dim class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True): super().__init__() self.eps = eps self.elementwise_affine = elementwise_affine if elementwise_affine: self.weight = nn.Parameter(torch.ones(dim)) else: self.register_parameter('weight', None) def forward(self, x: torch.Tensor) -> torch.Tensor: norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() out = (x * norm).type_as(x) if self.weight is not None: out = out * self.weight return out class EmbedND(nn.Module): def __init__(self, theta: float = 10000.0, axes_dim: Tuple[int, int, int] = (16, 56, 56)): super().__init__() self.theta = theta self.axes_dim = axes_dim for i, dim in enumerate(axes_dim): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer(f'freqs_{i}', freqs, persistent=True) def forward(self, ids: torch.Tensor) -> torch.Tensor: device = ids.device n_axes = ids.shape[-1] emb_list = [] for i in range(n_axes): freqs = getattr(self, f'freqs_{i}').to(device) pos = ids[:, i].float() angles = pos.unsqueeze(-1) * freqs.unsqueeze(0) cos = angles.cos() sin = angles.sin() emb = torch.stack([cos, sin], dim=-1).flatten(-2) emb_list.append(emb) rope = torch.cat(emb_list, dim=-1) return rope.unsqueeze(1) def apply_rotary_emb_old(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: freqs = freqs_cis.squeeze(1) cos = freqs[:, 0::2].repeat_interleave(2, dim=-1) sin = freqs[:, 1::2].repeat_interleave(2, dim=-1) cos = cos[None, None, :, :].to(x.device) sin = sin[None, None, :, :].to(x.device) x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(-2) return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) class MLPEmbedder(nn.Module): def __init__(self, hidden_size: int): super().__init__() self.mlp = nn.Sequential( nn.Linear(256, hidden_size), nn.SiLU(), nn.Linear(hidden_size, hidden_size), ) def forward(self, x: torch.Tensor) -> torch.Tensor: half_dim = 128 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb) emb = x.unsqueeze(-1) * emb.unsqueeze(0) emb = torch.cat([emb.sin(), emb.cos()], dim=-1) return self.mlp(emb) class AdaLayerNormZero(nn.Module): def __init__(self, hidden_size: int): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(hidden_size, 6 * hidden_size, bias=True) self.norm = RMSNorm(hidden_size) def forward(self, x: torch.Tensor, emb: torch.Tensor): emb_out = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb_out.chunk(6, dim=-1) x = self.norm(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) return x, gate_msa, shift_mlp, scale_mlp, gate_mlp class AdaLayerNormZeroSingle(nn.Module): def __init__(self, hidden_size: int): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(hidden_size, 3 * hidden_size, bias=True) self.norm = RMSNorm(hidden_size) def forward(self, x: torch.Tensor, emb: torch.Tensor): emb_out = self.linear(self.silu(emb)) shift, scale, gate = emb_out.chunk(3, dim=-1) x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return x, gate class Attention(nn.Module): def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias) self.out_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias) def forward(self, x: torch.Tensor, rope: Optional[torch.Tensor] = None) -> torch.Tensor: B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) q, k, v = qkv.permute(2, 0, 3, 1, 4) if rope is not None: q = apply_rotary_emb_old(q, rope) k = apply_rotary_emb_old(k, rope) attn = F.scaled_dot_product_attention(q, k, v) out = attn.transpose(1, 2).reshape(B, N, -1) return self.out_proj(out) class JointAttention(nn.Module): def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.txt_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias) self.img_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias) self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias) self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias) def forward(self, txt: torch.Tensor, img: torch.Tensor, rope: Optional[torch.Tensor] = None): B, L, _ = txt.shape _, N, _ = img.shape txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim) img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim) txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4) img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4) if rope is not None: img_q = apply_rotary_emb_old(img_q, rope) img_k = apply_rotary_emb_old(img_k, rope) k = torch.cat([txt_k, img_k], dim=2) v = torch.cat([txt_v, img_v], dim=2) txt_out = F.scaled_dot_product_attention(txt_q, k, v) txt_out = txt_out.transpose(1, 2).reshape(B, L, -1) img_out = F.scaled_dot_product_attention(img_q, k, v) img_out = img_out.transpose(1, 2).reshape(B, N, -1) return self.txt_out(txt_out), self.img_out(img_out) class MLP(nn.Module): def __init__(self, hidden_size: int, mlp_ratio: float = 4.0): super().__init__() mlp_hidden = int(hidden_size * mlp_ratio) self.fc1 = nn.Linear(hidden_size, mlp_hidden, bias=True) self.act = nn.GELU(approximate='tanh') self.fc2 = nn.Linear(mlp_hidden, hidden_size, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) class DoubleStreamBlock(nn.Module): def __init__(self, config: TinyFluxDeepConfig): super().__init__() hidden = config.hidden_size heads = config.num_attention_heads head_dim = config.attention_head_dim self.img_norm1 = AdaLayerNormZero(hidden) self.txt_norm1 = AdaLayerNormZero(hidden) self.attn = JointAttention(hidden, heads, head_dim, use_bias=False) self.img_norm2 = RMSNorm(hidden) self.txt_norm2 = RMSNorm(hidden) self.img_mlp = MLP(hidden, config.mlp_ratio) self.txt_mlp = MLP(hidden, config.mlp_ratio) def forward(self, txt, img, vec, rope=None): img_normed, img_gate_msa, img_shift_mlp, img_scale_mlp, img_gate_mlp = self.img_norm1(img, vec) txt_normed, txt_gate_msa, txt_shift_mlp, txt_scale_mlp, txt_gate_mlp = self.txt_norm1(txt, vec) txt_attn_out, img_attn_out = self.attn(txt_normed, img_normed, rope) txt = txt + txt_gate_msa.unsqueeze(1) * txt_attn_out img = img + img_gate_msa.unsqueeze(1) * img_attn_out txt_mlp_in = self.txt_norm2(txt) * (1 + txt_scale_mlp.unsqueeze(1)) + txt_shift_mlp.unsqueeze(1) img_mlp_in = self.img_norm2(img) * (1 + img_scale_mlp.unsqueeze(1)) + img_shift_mlp.unsqueeze(1) txt = txt + txt_gate_mlp.unsqueeze(1) * self.txt_mlp(txt_mlp_in) img = img + img_gate_mlp.unsqueeze(1) * self.img_mlp(img_mlp_in) return txt, img class SingleStreamBlock(nn.Module): def __init__(self, config: TinyFluxDeepConfig): super().__init__() hidden = config.hidden_size heads = config.num_attention_heads head_dim = config.attention_head_dim self.norm = AdaLayerNormZeroSingle(hidden) self.attn = Attention(hidden, heads, head_dim, use_bias=False) self.mlp = MLP(hidden, config.mlp_ratio) self.norm2 = RMSNorm(hidden) def forward(self, txt, img, vec, rope=None): L = txt.shape[1] x = torch.cat([txt, img], dim=1) x_normed, gate = self.norm(x, vec) x = x + gate.unsqueeze(1) * self.attn(x_normed, rope) x = x + self.mlp(self.norm2(x)) txt, img = x.split([L, x.shape[1] - L], dim=1) return txt, img class TinyFluxDeep(nn.Module): def __init__(self, config: Optional[TinyFluxDeepConfig] = None): super().__init__() self.config = config or TinyFluxDeepConfig() cfg = self.config self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size, bias=True) self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size, bias=True) self.time_in = MLPEmbedder(cfg.hidden_size) self.vector_in = nn.Sequential( nn.SiLU(), nn.Linear(cfg.pooled_projection_dim, cfg.hidden_size, bias=True) ) if cfg.guidance_embeds: self.guidance_in = MLPEmbedder(cfg.hidden_size) self.rope = EmbedND(theta=10000.0, axes_dim=cfg.axes_dims_rope) self.double_blocks = nn.ModuleList([ DoubleStreamBlock(cfg) for _ in range(cfg.num_double_layers) ]) self.single_blocks = nn.ModuleList([ SingleStreamBlock(cfg) for _ in range(cfg.num_single_layers) ]) self.final_norm = RMSNorm(cfg.hidden_size) self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels, bias=True) def forward(self, hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids=None, guidance=None): B = hidden_states.shape[0] L = encoder_hidden_states.shape[1] N = hidden_states.shape[1] img = self.img_in(hidden_states) txt = self.txt_in(encoder_hidden_states) vec = self.time_in(timestep) vec = vec + self.vector_in(pooled_projections) if self.config.guidance_embeds and guidance is not None: vec = vec + self.guidance_in(guidance) if img_ids.ndim == 3: img_ids = img_ids[0] img_rope = self.rope(img_ids) for block in self.double_blocks: txt, img = block(txt, img, vec, img_rope) if txt_ids is None: txt_ids = torch.zeros(L, 3, device=img_ids.device, dtype=img_ids.dtype) elif txt_ids.ndim == 3: txt_ids = txt_ids[0] all_ids = torch.cat([txt_ids, img_ids], dim=0) full_rope = self.rope(all_ids) for block in self.single_blocks: txt, img = block(txt, img, vec, full_rope) img = self.final_norm(img) img = self.final_linear(img) return img @staticmethod def create_img_ids(batch_size: int, height: int, width: int, device) -> torch.Tensor: img_ids = torch.zeros(height * width, 3, device=device) for i in range(height): for j in range(width): idx = i * width + j img_ids[idx, 0] = 0 img_ids[idx, 1] = i img_ids[idx, 2] = j return img_ids @staticmethod def create_txt_ids(text_len: int, device) -> torch.Tensor: txt_ids = torch.zeros(text_len, 3, device=device) txt_ids[:, 0] = torch.arange(text_len, device=device) return txt_ids # ============================================================================ # GLOBALS # ============================================================================ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 MAX_SEED = np.iinfo(np.int32).max SHIFT = 3.0 # ============================================================================ # LOAD MODELS # ============================================================================ print("Loading TinyFlux-Lailah...") config = TinyFluxDeepConfig() model = TinyFluxDeep(config) weights_path = hf_hub_download("AbstractPhil/tiny-flux-deep", "checkpoints/step_297500_ema.safetensors") weights = load_file(weights_path) model.load_state_dict(weights, strict=False) model.eval() model.to(DTYPE) print(f"✓ Model loaded ({sum(p.numel() for p in model.parameters()):,} params)") print("Loading text encoders...") t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE) clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE) print("✓ Text encoders loaded") print("Loading VAE...") vae = AutoencoderKL.from_pretrained("./vae", torch_dtype=DTYPE) vae.eval() VAE_SCALE = vae.config.scaling_factor print(f"✓ VAE loaded (scale={VAE_SCALE})") # ============================================================================ # EULER DISCRETE FLOW MATCHING SAMPLER WITH CFG # Training uses: x_t = (1-t)*noise + t*data, v = data - noise # So t=0 is noise, t=1 is data. We sample from t=0 to t=1. # ============================================================================ def flux_shift(t, shift=SHIFT): """Flux time shift: s*t / (1 + (s-1)*t)""" return shift * t / (1 + (shift - 1) * t) @spaces.GPU(duration=90) def generate( prompt: str, negative_prompt: str, seed: int, randomize_seed: bool, width: int, height: int, guidance_embed: float, cfg_scale: float, num_inference_steps: int, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=DEVICE).manual_seed(seed) # Move to GPU model.to(DEVICE) t5_enc.to(DEVICE) clip_enc.to(DEVICE) vae.to(DEVICE) with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=DTYPE): # Encode prompts t5_in = t5_tok(prompt, max_length=128, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE) t5_cond = t5_enc(**t5_in).last_hidden_state clip_in = clip_tok(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE) clip_cond = clip_enc(**clip_in).pooler_output # Encode negative prompt for CFG do_cfg = cfg_scale > 1.0 if do_cfg: neg_prompt = negative_prompt if negative_prompt else "" t5_neg_in = t5_tok(neg_prompt, max_length=128, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE) t5_uncond = t5_enc(**t5_neg_in).last_hidden_state clip_neg_in = clip_tok(neg_prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE) clip_uncond = clip_enc(**clip_neg_in).pooler_output # Batch for efficient forward pass t5_batch = torch.cat([t5_uncond, t5_cond], dim=0) clip_batch = torch.cat([clip_uncond, clip_cond], dim=0) # Latent dimensions H_lat = height // 8 W_lat = width // 8 C = 16 # Start from noise (t=0 in this convention) x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator) # Position IDs img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE) # Timesteps: 0 -> 1 (noise to data) with Flux shift t_linear = torch.linspace(0, 1, num_inference_steps + 1, device=DEVICE) timesteps = flux_shift(t_linear, shift=SHIFT) # Guidance embedding (distilled into model during training) guidance_tensor = torch.tensor([guidance_embed], device=DEVICE, dtype=DTYPE) # Euler flow matching: x_{t+dt} = x_t + v * dt for i in range(num_inference_steps): t_curr = timesteps[i] t_next = timesteps[i + 1] dt = t_next - t_curr t_batch = t_curr.unsqueeze(0) if do_cfg: # Batched forward pass for efficiency x_batch = x.repeat(2, 1, 1) t_batch_2 = t_batch.repeat(2) guidance_batch = guidance_tensor.repeat(2) v_batch = model( hidden_states=x_batch, encoder_hidden_states=t5_batch, pooled_projections=clip_batch, timestep=t_batch_2, img_ids=img_ids, guidance=guidance_batch, ) v_uncond, v_cond = v_batch.chunk(2, dim=0) v = v_uncond + cfg_scale * (v_cond - v_uncond) else: v = model( hidden_states=x, encoder_hidden_states=t5_cond, pooled_projections=clip_cond, timestep=t_batch, img_ids=img_ids, guidance=guidance_tensor, ) x = x + v * dt # Decode latents latents = x.reshape(1, H_lat, W_lat, C).permute(0, 3, 1, 2) latents = latents / VAE_SCALE image = vae.decode(latents.to(vae.dtype)).sample image = (image / 2 + 0.5).clamp(0, 1) # To PIL image = image[0].float().permute(1, 2, 0).cpu().numpy() image = (image * 255).astype(np.uint8) image = Image.fromarray(image) return image, seed # ============================================================================ # GRADIO INTERFACE # ============================================================================ examples = [ "a photo of a cat sitting on a windowsill", "a portrait of a woman with red hair, professional photography", "a black backpack on white background, product photo", "astronaut riding a horse on mars, digital art", "a cozy coffee shop interior, warm lighting", ] css = """ #col-container { margin: 0 auto; max-width: 720px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # TinyFlux-Lailah **241M parameter** flow-matching text-to-image model. Trained on teacher latents from Flux-Schnell. [Model Card](https://huggingface.co/AbstractPhil/tiny-flux-deep) """) with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt...", container=False, ) run_button = gr.Button("Generate", scale=0, variant="primary") result = gr.Image(label="Result", show_label=False) with gr.Accordion("Settings", open=False): negative_prompt = gr.Text( label="Negative prompt (for CFG)", max_lines=1, placeholder="blurry, distorted, low quality", value="", ) seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512) height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512) with gr.Row(): guidance_embed = gr.Slider( label="Guidance Embed (distilled)", minimum=1.0, maximum=10.0, step=0.5, value=3.5, info="Passed to model (trained 1-5 range)" ) cfg_scale = gr.Slider( label="CFG Scale (two-pass)", minimum=1.0, maximum=10.0, step=0.5, value=1.0, info="1.0 = off (faster), >1 = CFG enabled" ) num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, step=1, value=25) gr.Examples(examples=examples, inputs=[prompt]) gr.Markdown(""" --- **Guidance Embed**: Distilled guidance baked into model weights. Fast (1 pass). Trained with values 1-5. **CFG Scale**: Traditional classifier-free guidance. Slower (2 passes). Set to 1.0 to disable. """) gr.on( triggers=[run_button.click, prompt.submit], fn=generate, inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_embed, cfg_scale, num_inference_steps], outputs=[result, seed], ) if __name__ == "__main__": demo.launch()