Spaces:
Sleeping
Sleeping
| """ | |
| TinyFlux-Lailah Gradio Demo | |
| HuggingFace Spaces with ZeroGPU support | |
| Euler discrete flow matching inference | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import spaces | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer | |
| from diffusers import AutoencoderKL | |
| from PIL import Image | |
| # ============================================================================ | |
| # MODEL DEFINITION - Exact copy from tinyflux_deep.py | |
| # ============================================================================ | |
| class TinyFluxDeepConfig: | |
| hidden_size: int = 512 | |
| num_attention_heads: int = 4 | |
| attention_head_dim: int = 128 | |
| in_channels: int = 16 | |
| patch_size: int = 1 | |
| joint_attention_dim: int = 768 | |
| pooled_projection_dim: int = 768 | |
| num_double_layers: int = 15 | |
| num_single_layers: int = 25 | |
| mlp_ratio: float = 4.0 | |
| axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) | |
| guidance_embeds: bool = True | |
| def __post_init__(self): | |
| assert self.num_attention_heads * self.attention_head_dim == self.hidden_size | |
| assert sum(self.axes_dims_rope) == self.attention_head_dim | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True): | |
| super().__init__() | |
| self.eps = eps | |
| self.elementwise_affine = elementwise_affine | |
| if elementwise_affine: | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| else: | |
| self.register_parameter('weight', None) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() | |
| out = (x * norm).type_as(x) | |
| if self.weight is not None: | |
| out = out * self.weight | |
| return out | |
| class EmbedND(nn.Module): | |
| def __init__(self, theta: float = 10000.0, axes_dim: Tuple[int, int, int] = (16, 56, 56)): | |
| super().__init__() | |
| self.theta = theta | |
| self.axes_dim = axes_dim | |
| for i, dim in enumerate(axes_dim): | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer(f'freqs_{i}', freqs, persistent=True) | |
| def forward(self, ids: torch.Tensor) -> torch.Tensor: | |
| device = ids.device | |
| n_axes = ids.shape[-1] | |
| emb_list = [] | |
| for i in range(n_axes): | |
| freqs = getattr(self, f'freqs_{i}').to(device) | |
| pos = ids[:, i].float() | |
| angles = pos.unsqueeze(-1) * freqs.unsqueeze(0) | |
| cos = angles.cos() | |
| sin = angles.sin() | |
| emb = torch.stack([cos, sin], dim=-1).flatten(-2) | |
| emb_list.append(emb) | |
| rope = torch.cat(emb_list, dim=-1) | |
| return rope.unsqueeze(1) | |
| def apply_rotary_emb_old(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: | |
| freqs = freqs_cis.squeeze(1) | |
| cos = freqs[:, 0::2].repeat_interleave(2, dim=-1) | |
| sin = freqs[:, 1::2].repeat_interleave(2, dim=-1) | |
| cos = cos[None, None, :, :].to(x.device) | |
| sin = sin[None, None, :, :].to(x.device) | |
| x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) | |
| x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(-2) | |
| return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) | |
| class MLPEmbedder(nn.Module): | |
| def __init__(self, hidden_size: int): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(256, hidden_size), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, hidden_size), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| half_dim = 128 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb) | |
| emb = x.unsqueeze(-1) * emb.unsqueeze(0) | |
| emb = torch.cat([emb.sin(), emb.cos()], dim=-1) | |
| return self.mlp(emb) | |
| class AdaLayerNormZero(nn.Module): | |
| def __init__(self, hidden_size: int): | |
| super().__init__() | |
| self.silu = nn.SiLU() | |
| self.linear = nn.Linear(hidden_size, 6 * hidden_size, bias=True) | |
| self.norm = RMSNorm(hidden_size) | |
| def forward(self, x: torch.Tensor, emb: torch.Tensor): | |
| emb_out = self.linear(self.silu(emb)) | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb_out.chunk(6, dim=-1) | |
| x = self.norm(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) | |
| return x, gate_msa, shift_mlp, scale_mlp, gate_mlp | |
| class AdaLayerNormZeroSingle(nn.Module): | |
| def __init__(self, hidden_size: int): | |
| super().__init__() | |
| self.silu = nn.SiLU() | |
| self.linear = nn.Linear(hidden_size, 3 * hidden_size, bias=True) | |
| self.norm = RMSNorm(hidden_size) | |
| def forward(self, x: torch.Tensor, emb: torch.Tensor): | |
| emb_out = self.linear(self.silu(emb)) | |
| shift, scale, gate = emb_out.chunk(3, dim=-1) | |
| x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) | |
| return x, gate | |
| class Attention(nn.Module): | |
| def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = head_dim | |
| self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias) | |
| self.out_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias) | |
| def forward(self, x: torch.Tensor, rope: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| B, N, _ = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) | |
| q, k, v = qkv.permute(2, 0, 3, 1, 4) | |
| if rope is not None: | |
| q = apply_rotary_emb_old(q, rope) | |
| k = apply_rotary_emb_old(k, rope) | |
| attn = F.scaled_dot_product_attention(q, k, v) | |
| out = attn.transpose(1, 2).reshape(B, N, -1) | |
| return self.out_proj(out) | |
| class JointAttention(nn.Module): | |
| def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = head_dim | |
| self.txt_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias) | |
| self.img_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias) | |
| self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias) | |
| self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias) | |
| def forward(self, txt: torch.Tensor, img: torch.Tensor, rope: Optional[torch.Tensor] = None): | |
| B, L, _ = txt.shape | |
| _, N, _ = img.shape | |
| txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim) | |
| img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim) | |
| txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4) | |
| img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4) | |
| if rope is not None: | |
| img_q = apply_rotary_emb_old(img_q, rope) | |
| img_k = apply_rotary_emb_old(img_k, rope) | |
| k = torch.cat([txt_k, img_k], dim=2) | |
| v = torch.cat([txt_v, img_v], dim=2) | |
| txt_out = F.scaled_dot_product_attention(txt_q, k, v) | |
| txt_out = txt_out.transpose(1, 2).reshape(B, L, -1) | |
| img_out = F.scaled_dot_product_attention(img_q, k, v) | |
| img_out = img_out.transpose(1, 2).reshape(B, N, -1) | |
| return self.txt_out(txt_out), self.img_out(img_out) | |
| class MLP(nn.Module): | |
| def __init__(self, hidden_size: int, mlp_ratio: float = 4.0): | |
| super().__init__() | |
| mlp_hidden = int(hidden_size * mlp_ratio) | |
| self.fc1 = nn.Linear(hidden_size, mlp_hidden, bias=True) | |
| self.act = nn.GELU(approximate='tanh') | |
| self.fc2 = nn.Linear(mlp_hidden, hidden_size, bias=True) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.fc2(self.act(self.fc1(x))) | |
| class DoubleStreamBlock(nn.Module): | |
| def __init__(self, config: TinyFluxDeepConfig): | |
| super().__init__() | |
| hidden = config.hidden_size | |
| heads = config.num_attention_heads | |
| head_dim = config.attention_head_dim | |
| self.img_norm1 = AdaLayerNormZero(hidden) | |
| self.txt_norm1 = AdaLayerNormZero(hidden) | |
| self.attn = JointAttention(hidden, heads, head_dim, use_bias=False) | |
| self.img_norm2 = RMSNorm(hidden) | |
| self.txt_norm2 = RMSNorm(hidden) | |
| self.img_mlp = MLP(hidden, config.mlp_ratio) | |
| self.txt_mlp = MLP(hidden, config.mlp_ratio) | |
| def forward(self, txt, img, vec, rope=None): | |
| img_normed, img_gate_msa, img_shift_mlp, img_scale_mlp, img_gate_mlp = self.img_norm1(img, vec) | |
| txt_normed, txt_gate_msa, txt_shift_mlp, txt_scale_mlp, txt_gate_mlp = self.txt_norm1(txt, vec) | |
| txt_attn_out, img_attn_out = self.attn(txt_normed, img_normed, rope) | |
| txt = txt + txt_gate_msa.unsqueeze(1) * txt_attn_out | |
| img = img + img_gate_msa.unsqueeze(1) * img_attn_out | |
| txt_mlp_in = self.txt_norm2(txt) * (1 + txt_scale_mlp.unsqueeze(1)) + txt_shift_mlp.unsqueeze(1) | |
| img_mlp_in = self.img_norm2(img) * (1 + img_scale_mlp.unsqueeze(1)) + img_shift_mlp.unsqueeze(1) | |
| txt = txt + txt_gate_mlp.unsqueeze(1) * self.txt_mlp(txt_mlp_in) | |
| img = img + img_gate_mlp.unsqueeze(1) * self.img_mlp(img_mlp_in) | |
| return txt, img | |
| class SingleStreamBlock(nn.Module): | |
| def __init__(self, config: TinyFluxDeepConfig): | |
| super().__init__() | |
| hidden = config.hidden_size | |
| heads = config.num_attention_heads | |
| head_dim = config.attention_head_dim | |
| self.norm = AdaLayerNormZeroSingle(hidden) | |
| self.attn = Attention(hidden, heads, head_dim, use_bias=False) | |
| self.mlp = MLP(hidden, config.mlp_ratio) | |
| self.norm2 = RMSNorm(hidden) | |
| def forward(self, txt, img, vec, rope=None): | |
| L = txt.shape[1] | |
| x = torch.cat([txt, img], dim=1) | |
| x_normed, gate = self.norm(x, vec) | |
| x = x + gate.unsqueeze(1) * self.attn(x_normed, rope) | |
| x = x + self.mlp(self.norm2(x)) | |
| txt, img = x.split([L, x.shape[1] - L], dim=1) | |
| return txt, img | |
| class TinyFluxDeep(nn.Module): | |
| def __init__(self, config: Optional[TinyFluxDeepConfig] = None): | |
| super().__init__() | |
| self.config = config or TinyFluxDeepConfig() | |
| cfg = self.config | |
| self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size, bias=True) | |
| self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size, bias=True) | |
| self.time_in = MLPEmbedder(cfg.hidden_size) | |
| self.vector_in = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(cfg.pooled_projection_dim, cfg.hidden_size, bias=True) | |
| ) | |
| if cfg.guidance_embeds: | |
| self.guidance_in = MLPEmbedder(cfg.hidden_size) | |
| self.rope = EmbedND(theta=10000.0, axes_dim=cfg.axes_dims_rope) | |
| self.double_blocks = nn.ModuleList([ | |
| DoubleStreamBlock(cfg) for _ in range(cfg.num_double_layers) | |
| ]) | |
| self.single_blocks = nn.ModuleList([ | |
| SingleStreamBlock(cfg) for _ in range(cfg.num_single_layers) | |
| ]) | |
| self.final_norm = RMSNorm(cfg.hidden_size) | |
| self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels, bias=True) | |
| def forward(self, hidden_states, encoder_hidden_states, pooled_projections, timestep, | |
| img_ids, txt_ids=None, guidance=None): | |
| B = hidden_states.shape[0] | |
| L = encoder_hidden_states.shape[1] | |
| N = hidden_states.shape[1] | |
| img = self.img_in(hidden_states) | |
| txt = self.txt_in(encoder_hidden_states) | |
| vec = self.time_in(timestep) | |
| vec = vec + self.vector_in(pooled_projections) | |
| if self.config.guidance_embeds and guidance is not None: | |
| vec = vec + self.guidance_in(guidance) | |
| if img_ids.ndim == 3: | |
| img_ids = img_ids[0] | |
| img_rope = self.rope(img_ids) | |
| for block in self.double_blocks: | |
| txt, img = block(txt, img, vec, img_rope) | |
| if txt_ids is None: | |
| txt_ids = torch.zeros(L, 3, device=img_ids.device, dtype=img_ids.dtype) | |
| elif txt_ids.ndim == 3: | |
| txt_ids = txt_ids[0] | |
| all_ids = torch.cat([txt_ids, img_ids], dim=0) | |
| full_rope = self.rope(all_ids) | |
| for block in self.single_blocks: | |
| txt, img = block(txt, img, vec, full_rope) | |
| img = self.final_norm(img) | |
| img = self.final_linear(img) | |
| return img | |
| def create_img_ids(batch_size: int, height: int, width: int, device) -> torch.Tensor: | |
| img_ids = torch.zeros(height * width, 3, device=device) | |
| for i in range(height): | |
| for j in range(width): | |
| idx = i * width + j | |
| img_ids[idx, 0] = 0 | |
| img_ids[idx, 1] = i | |
| img_ids[idx, 2] = j | |
| return img_ids | |
| def create_txt_ids(text_len: int, device) -> torch.Tensor: | |
| txt_ids = torch.zeros(text_len, 3, device=device) | |
| txt_ids[:, 0] = torch.arange(text_len, device=device) | |
| return txt_ids | |
| # ============================================================================ | |
| # GLOBALS | |
| # ============================================================================ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| MAX_SEED = np.iinfo(np.int32).max | |
| SHIFT = 3.0 | |
| # ============================================================================ | |
| # LOAD MODELS | |
| # ============================================================================ | |
| print("Loading TinyFlux-Lailah...") | |
| config = TinyFluxDeepConfig() | |
| model = TinyFluxDeep(config) | |
| weights_path = hf_hub_download("AbstractPhil/tiny-flux-deep", "checkpoints/step_297500_ema.safetensors") | |
| weights = load_file(weights_path) | |
| model.load_state_dict(weights, strict=False) | |
| model.eval() | |
| model.to(DTYPE) | |
| print(f"✓ Model loaded ({sum(p.numel() for p in model.parameters()):,} params)") | |
| print("Loading text encoders...") | |
| t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
| t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE) | |
| clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
| clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE) | |
| print("✓ Text encoders loaded") | |
| print("Loading VAE...") | |
| vae = AutoencoderKL.from_pretrained("./vae", torch_dtype=DTYPE) | |
| vae.eval() | |
| VAE_SCALE = vae.config.scaling_factor | |
| print(f"✓ VAE loaded (scale={VAE_SCALE})") | |
| # ============================================================================ | |
| # EULER DISCRETE FLOW MATCHING SAMPLER WITH CFG | |
| # Training uses: x_t = (1-t)*noise + t*data, v = data - noise | |
| # So t=0 is noise, t=1 is data. We sample from t=0 to t=1. | |
| # ============================================================================ | |
| def flux_shift(t, shift=SHIFT): | |
| """Flux time shift: s*t / (1 + (s-1)*t)""" | |
| return shift * t / (1 + (shift - 1) * t) | |
| def generate( | |
| prompt: str, | |
| negative_prompt: str, | |
| seed: int, | |
| randomize_seed: bool, | |
| width: int, | |
| height: int, | |
| guidance_embed: float, | |
| cfg_scale: float, | |
| num_inference_steps: int, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
| # Move to GPU | |
| model.to(DEVICE) | |
| t5_enc.to(DEVICE) | |
| clip_enc.to(DEVICE) | |
| vae.to(DEVICE) | |
| with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=DTYPE): | |
| # Encode prompts | |
| t5_in = t5_tok(prompt, max_length=128, padding="max_length", | |
| truncation=True, return_tensors="pt").to(DEVICE) | |
| t5_cond = t5_enc(**t5_in).last_hidden_state | |
| clip_in = clip_tok(prompt, max_length=77, padding="max_length", | |
| truncation=True, return_tensors="pt").to(DEVICE) | |
| clip_cond = clip_enc(**clip_in).pooler_output | |
| # Encode negative prompt for CFG | |
| do_cfg = cfg_scale > 1.0 | |
| if do_cfg: | |
| neg_prompt = negative_prompt if negative_prompt else "" | |
| t5_neg_in = t5_tok(neg_prompt, max_length=128, padding="max_length", | |
| truncation=True, return_tensors="pt").to(DEVICE) | |
| t5_uncond = t5_enc(**t5_neg_in).last_hidden_state | |
| clip_neg_in = clip_tok(neg_prompt, max_length=77, padding="max_length", | |
| truncation=True, return_tensors="pt").to(DEVICE) | |
| clip_uncond = clip_enc(**clip_neg_in).pooler_output | |
| # Batch for efficient forward pass | |
| t5_batch = torch.cat([t5_uncond, t5_cond], dim=0) | |
| clip_batch = torch.cat([clip_uncond, clip_cond], dim=0) | |
| # Latent dimensions | |
| H_lat = height // 8 | |
| W_lat = width // 8 | |
| C = 16 | |
| # Start from noise (t=0 in this convention) | |
| x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator) | |
| # Position IDs | |
| img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE) | |
| # Timesteps: 0 -> 1 (noise to data) with Flux shift | |
| t_linear = torch.linspace(0, 1, num_inference_steps + 1, device=DEVICE) | |
| timesteps = flux_shift(t_linear, shift=SHIFT) | |
| # Guidance embedding (distilled into model during training) | |
| guidance_tensor = torch.tensor([guidance_embed], device=DEVICE, dtype=DTYPE) | |
| # Euler flow matching: x_{t+dt} = x_t + v * dt | |
| for i in range(num_inference_steps): | |
| t_curr = timesteps[i] | |
| t_next = timesteps[i + 1] | |
| dt = t_next - t_curr | |
| t_batch = t_curr.unsqueeze(0) | |
| if do_cfg: | |
| # Batched forward pass for efficiency | |
| x_batch = x.repeat(2, 1, 1) | |
| t_batch_2 = t_batch.repeat(2) | |
| guidance_batch = guidance_tensor.repeat(2) | |
| v_batch = model( | |
| hidden_states=x_batch, | |
| encoder_hidden_states=t5_batch, | |
| pooled_projections=clip_batch, | |
| timestep=t_batch_2, | |
| img_ids=img_ids, | |
| guidance=guidance_batch, | |
| ) | |
| v_uncond, v_cond = v_batch.chunk(2, dim=0) | |
| v = v_uncond + cfg_scale * (v_cond - v_uncond) | |
| else: | |
| v = model( | |
| hidden_states=x, | |
| encoder_hidden_states=t5_cond, | |
| pooled_projections=clip_cond, | |
| timestep=t_batch, | |
| img_ids=img_ids, | |
| guidance=guidance_tensor, | |
| ) | |
| x = x + v * dt | |
| # Decode latents | |
| latents = x.reshape(1, H_lat, W_lat, C).permute(0, 3, 1, 2) | |
| latents = latents / VAE_SCALE | |
| image = vae.decode(latents.to(vae.dtype)).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| # To PIL | |
| image = image[0].float().permute(1, 2, 0).cpu().numpy() | |
| image = (image * 255).astype(np.uint8) | |
| image = Image.fromarray(image) | |
| return image, seed | |
| # ============================================================================ | |
| # GRADIO INTERFACE | |
| # ============================================================================ | |
| examples = [ | |
| "a photo of a cat sitting on a windowsill", | |
| "a portrait of a woman with red hair, professional photography", | |
| "a black backpack on white background, product photo", | |
| "astronaut riding a horse on mars, digital art", | |
| "a cozy coffee shop interior, warm lighting", | |
| ] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 720px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown(""" | |
| # TinyFlux-Lailah | |
| **241M parameter** flow-matching text-to-image model. | |
| Trained on teacher latents from Flux-Schnell. | |
| [Model Card](https://huggingface.co/AbstractPhil/tiny-flux-deep) | |
| """) | |
| with gr.Row(): | |
| prompt = gr.Text( | |
| label="Prompt", | |
| show_label=False, | |
| max_lines=2, | |
| placeholder="Enter your prompt...", | |
| container=False, | |
| ) | |
| run_button = gr.Button("Generate", scale=0, variant="primary") | |
| result = gr.Image(label="Result", show_label=False) | |
| with gr.Accordion("Settings", open=False): | |
| negative_prompt = gr.Text( | |
| label="Negative prompt (for CFG)", | |
| max_lines=1, | |
| placeholder="blurry, distorted, low quality", | |
| value="", | |
| ) | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512) | |
| height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512) | |
| with gr.Row(): | |
| guidance_embed = gr.Slider( | |
| label="Guidance Embed (distilled)", | |
| minimum=1.0, maximum=10.0, step=0.5, value=3.5, | |
| info="Passed to model (trained 1-5 range)" | |
| ) | |
| cfg_scale = gr.Slider( | |
| label="CFG Scale (two-pass)", | |
| minimum=1.0, maximum=10.0, step=0.5, value=1.0, | |
| info="1.0 = off (faster), >1 = CFG enabled" | |
| ) | |
| num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, step=1, value=25) | |
| gr.Examples(examples=examples, inputs=[prompt]) | |
| gr.Markdown(""" | |
| --- | |
| **Guidance Embed**: Distilled guidance baked into model weights. Fast (1 pass). Trained with values 1-5. | |
| **CFG Scale**: Traditional classifier-free guidance. Slower (2 passes). Set to 1.0 to disable. | |
| """) | |
| gr.on( | |
| triggers=[run_button.click, prompt.submit], | |
| fn=generate, | |
| inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_embed, cfg_scale, num_inference_steps], | |
| outputs=[result, seed], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |