""" TinyFlux-Lailah Inference Loads the model code, the weights, and runs the inference based on the settings below. Set up with only EULER for now. No guarantees for any of this to work. It's pretty bad in it's current phases, just check on it later if you're interested. LICENSE: MIT """ POSITIVE_PROMPT = "woman" # @param {type:"string"} NEGATIVE_PROMPT = "" # @param {type:"string"} STEPS = 50 # @param {type:"integer"} CFG_GUIDANCE = 5 # @param {type: "number"} FLUX_SHIFT = 3 # @param {type: "number"} SEED = 420 # @param {type: "integer"} OUTPUT_PATH = "output.png" # @param {type:"string"} WIDTH = 512 # @param {type: "integer"} HEIGHT = 512 # @param {type: "integer"} # Model loading HF_REPO = "AbstractPhil/tiny-flux-deep" # @param {type:"string"} # "hub", "hub:step_XXXXX", "local:/path/to/weights.safetensors" LOAD_FROM = "hub:step_293750" # @param {type:"string"} DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 #@title Preview (updates in-place) from IPython.display import display, Image as IPyImage, update_display from PIL import Image as PIL import numpy as np, io _PREVIEW_DISPLAY_ID = "tf_preview" preview_size = min(512, max(WIDTH, HEIGHT) // 2) def _pil_to_png_bytes(img: PIL) -> bytes: buf = io.BytesIO() img.save(buf, format="PNG") return buf.getvalue() def init_preview(square: int = 256): """Show a black placeholder square once.""" black = PIL.fromarray(np.zeros((square, square, 3), dtype=np.uint8)) display(IPyImage(data=_pil_to_png_bytes(black)), display_id=_PREVIEW_DISPLAY_ID) def set_preview_from_pil(img: PIL, square: int = 256): """Update the preview in-place with a PIL image.""" im = img.convert("RGB").copy() im.thumbnail((square, square), resample=PIL.Resampling.LANCZOS) # pad to square (so it stays a square widget) canvas = PIL.fromarray(np.zeros((square, square, 3), dtype=np.uint8)) x = (square - im.size[0]) // 2 y = (square - im.size[1]) // 2 canvas.paste(im, (x, y)) update_display(IPyImage(data=_pil_to_png_bytes(canvas)), display_id=_PREVIEW_DISPLAY_ID) def set_preview_from_path(path: str, square: int = 256): """Update preview from an image file path.""" set_preview_from_pil(PIL.open(path), square=square) # initialize placeholder init_preview(square=preview_size) #set_preview_from_pil(image, square=preview_size) """ TinyFlux-Deep: Deeper variant with 15 double + 25 single blocks. Config derived from checkpoint step_285625.safetensors: - hidden_size: 512 - num_attention_heads: 4 - attention_head_dim: 128 - num_double_layers: 15 - num_single_layers: 25 - Uses biases in MLP - Old RoPE format with cached freqs buffers """ import torch import torch.nn as nn import torch.nn.functional as F import math from dataclasses import dataclass from typing import Optional, Tuple, List @dataclass class TinyFluxDeepConfig: """Configuration for TinyFlux-Deep model.""" hidden_size: int = 512 num_attention_heads: int = 4 attention_head_dim: int = 128 in_channels: int = 16 patch_size: int = 1 joint_attention_dim: int = 768 pooled_projection_dim: int = 768 num_double_layers: int = 15 num_single_layers: int = 25 mlp_ratio: float = 4.0 axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) guidance_embeds: bool = True def __post_init__(self): assert self.num_attention_heads * self.attention_head_dim == self.hidden_size assert sum(self.axes_dims_rope) == self.attention_head_dim # ============================================================================= # Normalization # ============================================================================= class RMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True): super().__init__() self.eps = eps self.elementwise_affine = elementwise_affine if elementwise_affine: self.weight = nn.Parameter(torch.ones(dim)) else: self.register_parameter('weight', None) def forward(self, x: torch.Tensor) -> torch.Tensor: norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() out = (x * norm).type_as(x) if self.weight is not None: out = out * self.weight return out # ============================================================================= # RoPE - Old format with cached frequency buffers (checkpoint compatible) # ============================================================================= class EmbedND(nn.Module): """ Original TinyFlux RoPE with cached frequency buffers. Matches checkpoint format with rope.freqs_0, rope.freqs_1, rope.freqs_2 """ def __init__(self, theta: float = 10000.0, axes_dim: Tuple[int, int, int] = (16, 56, 56)): super().__init__() self.theta = theta self.axes_dim = axes_dim # Register frequency buffers (matches checkpoint keys rope.freqs_*) for i, dim in enumerate(axes_dim): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer(f'freqs_{i}', freqs, persistent=True) def forward(self, ids: torch.Tensor) -> torch.Tensor: """ Args: ids: (N, 3) position indices [temporal, height, width] Returns: rope: (N, 1, head_dim) interleaved [cos, sin, cos, sin, ...] """ device = ids.device n_axes = ids.shape[-1] emb_list = [] for i in range(n_axes): freqs = getattr(self, f'freqs_{i}').to(device) pos = ids[:, i].float() angles = pos.unsqueeze(-1) * freqs.unsqueeze(0) # (N, dim/2) # Interleave cos and sin cos = angles.cos() sin = angles.sin() emb = torch.stack([cos, sin], dim=-1).flatten(-2) # (N, dim) emb_list.append(emb) rope = torch.cat(emb_list, dim=-1) # (N, head_dim) return rope.unsqueeze(1) # (N, 1, head_dim) def apply_rotary_emb_old( x: torch.Tensor, freqs_cis: torch.Tensor, ) -> torch.Tensor: """ Apply rotary embeddings (old interleaved format). Args: x: (B, H, N, D) query or key tensor freqs_cis: (N, 1, D) interleaved [cos0, sin0, cos1, sin1, ...] Returns: Rotated tensor of same shape """ # freqs_cis is (N, 1, D) with interleaved cos/sin freqs = freqs_cis.squeeze(1) # (N, D) # Split interleaved cos/sin cos = freqs[:, 0::2].repeat_interleave(2, dim=-1) # (N, D) sin = freqs[:, 1::2].repeat_interleave(2, dim=-1) # (N, D) cos = cos[None, None, :, :].to(x.device) # (1, 1, N, D) sin = sin[None, None, :, :].to(x.device) # Split into real/imag pairs and rotate x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(-2) return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) # ============================================================================= # Embeddings # ============================================================================= class MLPEmbedder(nn.Module): """MLP for embedding scalars (timestep, guidance).""" def __init__(self, hidden_size: int): super().__init__() self.mlp = nn.Sequential( nn.Linear(256, hidden_size), nn.SiLU(), nn.Linear(hidden_size, hidden_size), ) def forward(self, x: torch.Tensor) -> torch.Tensor: half_dim = 128 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb) emb = x.unsqueeze(-1) * emb.unsqueeze(0) emb = torch.cat([emb.sin(), emb.cos()], dim=-1) return self.mlp(emb) # ============================================================================= # AdaLayerNorm # ============================================================================= class AdaLayerNormZero(nn.Module): """AdaLN-Zero for double-stream blocks (6 params).""" def __init__(self, hidden_size: int): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(hidden_size, 6 * hidden_size, bias=True) self.norm = RMSNorm(hidden_size) def forward(self, x: torch.Tensor, emb: torch.Tensor): emb_out = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb_out.chunk(6, dim=-1) x = self.norm(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) return x, gate_msa, shift_mlp, scale_mlp, gate_mlp class AdaLayerNormZeroSingle(nn.Module): """AdaLN-Zero for single-stream blocks (3 params).""" def __init__(self, hidden_size: int): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(hidden_size, 3 * hidden_size, bias=True) self.norm = RMSNorm(hidden_size) def forward(self, x: torch.Tensor, emb: torch.Tensor): emb_out = self.linear(self.silu(emb)) shift, scale, gate = emb_out.chunk(3, dim=-1) x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return x, gate # ============================================================================= # Attention (original format - no Q/K norm, matches checkpoint) # ============================================================================= class Attention(nn.Module): """Multi-head attention (original TinyFlux format, no Q/K norm).""" def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.scale = head_dim ** -0.5 self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias) self.out_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias) def forward( self, x: torch.Tensor, rope: Optional[torch.Tensor] = None, ) -> torch.Tensor: B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3 x (B, H, N, D) # Apply RoPE if rope is not None: q = apply_rotary_emb_old(q, rope) k = apply_rotary_emb_old(k, rope) # Scaled dot-product attention attn = F.scaled_dot_product_attention(q, k, v) out = attn.transpose(1, 2).reshape(B, N, -1) return self.out_proj(out) class JointAttention(nn.Module): """Joint attention for double-stream blocks (original format).""" def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.scale = head_dim ** -0.5 self.txt_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias) self.img_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias) self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias) self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias) def forward( self, txt: torch.Tensor, img: torch.Tensor, rope: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: B, L, _ = txt.shape _, N, _ = img.shape txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim) img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim) txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4) img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4) # Apply RoPE to image only if rope is not None: img_q = apply_rotary_emb_old(img_q, rope) img_k = apply_rotary_emb_old(img_k, rope) # Concatenate for joint attention k = torch.cat([txt_k, img_k], dim=2) v = torch.cat([txt_v, img_v], dim=2) txt_out = F.scaled_dot_product_attention(txt_q, k, v) txt_out = txt_out.transpose(1, 2).reshape(B, L, -1) img_out = F.scaled_dot_product_attention(img_q, k, v) img_out = img_out.transpose(1, 2).reshape(B, N, -1) return self.txt_out(txt_out), self.img_out(img_out) # ============================================================================= # MLP (with bias - matches checkpoint) # ============================================================================= class MLP(nn.Module): """Feed-forward network with GELU activation and biases.""" def __init__(self, hidden_size: int, mlp_ratio: float = 4.0): super().__init__() mlp_hidden = int(hidden_size * mlp_ratio) self.fc1 = nn.Linear(hidden_size, mlp_hidden, bias=True) # bias=True for checkpoint compat self.act = nn.GELU(approximate='tanh') self.fc2 = nn.Linear(mlp_hidden, hidden_size, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) # ============================================================================= # Transformer Blocks # ============================================================================= class DoubleStreamBlock(nn.Module): """Double-stream transformer block.""" def __init__(self, config: TinyFluxDeepConfig): super().__init__() hidden = config.hidden_size heads = config.num_attention_heads head_dim = config.attention_head_dim self.img_norm1 = AdaLayerNormZero(hidden) self.txt_norm1 = AdaLayerNormZero(hidden) self.attn = JointAttention(hidden, heads, head_dim, use_bias=False) self.img_norm2 = RMSNorm(hidden) self.txt_norm2 = RMSNorm(hidden) self.img_mlp = MLP(hidden, config.mlp_ratio) self.txt_mlp = MLP(hidden, config.mlp_ratio) def forward( self, txt: torch.Tensor, img: torch.Tensor, vec: torch.Tensor, rope: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: img_normed, img_gate_msa, img_shift_mlp, img_scale_mlp, img_gate_mlp = self.img_norm1(img, vec) txt_normed, txt_gate_msa, txt_shift_mlp, txt_scale_mlp, txt_gate_mlp = self.txt_norm1(txt, vec) txt_attn_out, img_attn_out = self.attn(txt_normed, img_normed, rope) txt = txt + txt_gate_msa.unsqueeze(1) * txt_attn_out img = img + img_gate_msa.unsqueeze(1) * img_attn_out txt_mlp_in = self.txt_norm2(txt) * (1 + txt_scale_mlp.unsqueeze(1)) + txt_shift_mlp.unsqueeze(1) img_mlp_in = self.img_norm2(img) * (1 + img_scale_mlp.unsqueeze(1)) + img_shift_mlp.unsqueeze(1) txt = txt + txt_gate_mlp.unsqueeze(1) * self.txt_mlp(txt_mlp_in) img = img + img_gate_mlp.unsqueeze(1) * self.img_mlp(img_mlp_in) return txt, img class SingleStreamBlock(nn.Module): """Single-stream transformer block.""" def __init__(self, config: TinyFluxDeepConfig): super().__init__() hidden = config.hidden_size heads = config.num_attention_heads head_dim = config.attention_head_dim self.norm = AdaLayerNormZeroSingle(hidden) self.attn = Attention(hidden, heads, head_dim, use_bias=False) self.mlp = MLP(hidden, config.mlp_ratio) self.norm2 = RMSNorm(hidden) def forward( self, txt: torch.Tensor, img: torch.Tensor, vec: torch.Tensor, rope: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: L = txt.shape[1] x = torch.cat([txt, img], dim=1) x_normed, gate = self.norm(x, vec) x = x + gate.unsqueeze(1) * self.attn(x_normed, rope) x = x + self.mlp(self.norm2(x)) txt, img = x.split([L, x.shape[1] - L], dim=1) return txt, img # ============================================================================= # Main Model # ============================================================================= class TinyFluxDeep(nn.Module): """TinyFlux-Deep: 15 double + 25 single blocks.""" def __init__(self, config: Optional[TinyFluxDeepConfig] = None): super().__init__() self.config = config or TinyFluxDeepConfig() cfg = self.config # Input projections (with bias to match checkpoint) self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size, bias=True) self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size, bias=True) # Conditioning self.time_in = MLPEmbedder(cfg.hidden_size) self.vector_in = nn.Sequential( nn.SiLU(), nn.Linear(cfg.pooled_projection_dim, cfg.hidden_size, bias=True) ) if cfg.guidance_embeds: self.guidance_in = MLPEmbedder(cfg.hidden_size) # RoPE (old format with cached freqs) self.rope = EmbedND(theta=10000.0, axes_dim=cfg.axes_dims_rope) # Transformer blocks self.double_blocks = nn.ModuleList([ DoubleStreamBlock(cfg) for _ in range(cfg.num_double_layers) ]) self.single_blocks = nn.ModuleList([ SingleStreamBlock(cfg) for _ in range(cfg.num_single_layers) ]) # Output self.final_norm = RMSNorm(cfg.hidden_size) self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels, bias=True) self._init_weights() def _init_weights(self): def _init(module): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) self.apply(_init) nn.init.zeros_(self.final_linear.weight) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, pooled_projections: torch.Tensor, timestep: torch.Tensor, img_ids: torch.Tensor, txt_ids: Optional[torch.Tensor] = None, guidance: Optional[torch.Tensor] = None, ) -> torch.Tensor: B = hidden_states.shape[0] L = encoder_hidden_states.shape[1] N = hidden_states.shape[1] # Input projections img = self.img_in(hidden_states) txt = self.txt_in(encoder_hidden_states) # Conditioning vec = self.time_in(timestep) vec = vec + self.vector_in(pooled_projections) if self.config.guidance_embeds and guidance is not None: vec = vec + self.guidance_in(guidance) # Handle img_ids shape if img_ids.ndim == 3: img_ids = img_ids[0] # (N, 3) # Compute RoPE for image positions img_rope = self.rope(img_ids) # (N, 1, head_dim) # Double-stream blocks for block in self.double_blocks: txt, img = block(txt, img, vec, img_rope) # Build full sequence RoPE for single-stream if txt_ids is None: txt_ids = torch.zeros(L, 3, device=img_ids.device, dtype=img_ids.dtype) elif txt_ids.ndim == 3: txt_ids = txt_ids[0] all_ids = torch.cat([txt_ids, img_ids], dim=0) full_rope = self.rope(all_ids) # Single-stream blocks for block in self.single_blocks: txt, img = block(txt, img, vec, full_rope) # Output img = self.final_norm(img) img = self.final_linear(img) return img @staticmethod def create_img_ids(batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: """Create image position IDs for RoPE.""" img_ids = torch.zeros(height * width, 3, device=device) for i in range(height): for j in range(width): idx = i * width + j img_ids[idx, 0] = 0 img_ids[idx, 1] = i img_ids[idx, 2] = j return img_ids @staticmethod def create_txt_ids(text_len: int, device: torch.device) -> torch.Tensor: """Create text position IDs.""" txt_ids = torch.zeros(text_len, 3, device=device) txt_ids[:, 0] = torch.arange(text_len, device=device) return txt_ids def count_parameters(self) -> dict: """Count parameters by component.""" counts = {} counts['img_in'] = sum(p.numel() for p in self.img_in.parameters()) counts['txt_in'] = sum(p.numel() for p in self.txt_in.parameters()) counts['time_in'] = sum(p.numel() for p in self.time_in.parameters()) counts['vector_in'] = sum(p.numel() for p in self.vector_in.parameters()) if hasattr(self, 'guidance_in'): counts['guidance_in'] = sum(p.numel() for p in self.guidance_in.parameters()) counts['double_blocks'] = sum(p.numel() for p in self.double_blocks.parameters()) counts['single_blocks'] = sum(p.numel() for p in self.single_blocks.parameters()) counts['final'] = sum(p.numel() for p in self.final_norm.parameters()) + \ sum(p.numel() for p in self.final_linear.parameters()) counts['total'] = sum(p.numel() for p in self.parameters()) return counts # ============================================================================= # Test # ============================================================================= def test_model(): """Test TinyFlux-Deep model.""" print("=" * 60) print("TinyFlux-Deep Test") print("=" * 60) config = TinyFluxDeepConfig() model = TinyFluxDeep(config) counts = model.count_parameters() print(f"\nConfig:") print(f" hidden_size: {config.hidden_size}") print(f" num_attention_heads: {config.num_attention_heads}") print(f" attention_head_dim: {config.attention_head_dim}") print(f" num_double_layers: {config.num_double_layers}") print(f" num_single_layers: {config.num_single_layers}") print(f"\nParameters:") for name, count in counts.items(): print(f" {name}: {count:,}") device = 'cuda' if torch.cuda.is_available() else 'cpu' model = model.to(device) B, H, W = 2, 64, 64 L = 77 hidden_states = torch.randn(B, H * W, config.in_channels, device=device) encoder_hidden_states = torch.randn(B, L, config.joint_attention_dim, device=device) pooled_projections = torch.randn(B, config.pooled_projection_dim, device=device) timestep = torch.rand(B, device=device) img_ids = TinyFluxDeep.create_img_ids(B, H, W, device) txt_ids = TinyFluxDeep.create_txt_ids(L, device) guidance = torch.ones(B, device=device) * 3.5 with torch.no_grad(): output = model( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, pooled_projections=pooled_projections, timestep=timestep, img_ids=img_ids, txt_ids=txt_ids, guidance=guidance, ) print(f"\nOutput shape: {output.shape}") print(f"Output range: [{output.min():.4f}, {output.max():.4f}]") print("\n✓ Forward pass successful!") #if __name__ == "__main__": # test_model() # ============================================================================ # TinyFlux-Deep Inference Cell - Euler Discrete Flow Matching # ============================================================================ # Run the model cell before this one (defines TinyFluxDeep, TinyFluxDeepConfig) # Loads from: AbstractPhil/tiny-flux-deep or local checkpoint # ============================================================================ import torch from huggingface_hub import hf_hub_download from safetensors.torch import load_file from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL from PIL import Image import numpy as np import os # Generation settings NUM_STEPS = STEPS GUIDANCE_SCALE = CFG_GUIDANCE SHIFT = FLUX_SHIFT # ============================================================================ # LOAD TEXT ENCODERS # ============================================================================ print("Loading text encoders...") t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval() clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() # ============================================================================ # LOAD VAE # ============================================================================ print("Loading Flux VAE...") vae = AutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-schnell", subfolder="vae", torch_dtype=DTYPE ).to(DEVICE).eval() # ============================================================================ # LOAD TINYFLUX-DEEP MODEL # ============================================================================ print(f"Loading TinyFlux-Deep from: {LOAD_FROM}") # Use TinyFluxDeep (512 hidden, 4 heads, 15 double, 25 single) config = TinyFluxDeepConfig() model = TinyFluxDeep(config).to(DEVICE).to(DTYPE) # Deprecated keys that may exist in old checkpoints but aren't needed DEPRECATED_KEYS = {'time_in.sin_basis', 'guidance_in.sin_basis'} def load_weights(path): """Load weights from .safetensors or .pt file.""" if path.endswith(".safetensors"): state_dict = load_file(path) elif path.endswith(".pt"): ckpt = torch.load(path, map_location=DEVICE, weights_only=False) if isinstance(ckpt, dict): if "model" in ckpt: state_dict = ckpt["model"] elif "state_dict" in ckpt: state_dict = ckpt["state_dict"] else: state_dict = ckpt else: state_dict = ckpt else: try: state_dict = load_file(path) except: state_dict = torch.load(path, map_location=DEVICE, weights_only=False) # Strip "_orig_mod." prefix from keys (added by torch.compile) if any(k.startswith("_orig_mod.") for k in state_dict.keys()): print(" Stripping torch.compile prefix...") state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} return state_dict def load_model_weights(model, weights, source_name): """Load weights with verbose reporting.""" # Filter out deprecated keys filtered_weights = {k: v for k, v in weights.items() if k not in DEPRECATED_KEYS} deprecated_found = [k for k in weights.keys() if k in DEPRECATED_KEYS] if deprecated_found: print(f" ✓ Ignored deprecated keys: {deprecated_found}") missing, unexpected = model.load_state_dict(filtered_weights, strict=False) if missing: print(f" ⚠ Missing keys: {missing[:10]}{'...' if len(missing) > 10 else ''}") if unexpected: print(f" ⚠ Unexpected keys: {unexpected[:10]}{'...' if len(unexpected) > 10 else ''}") if not missing and not unexpected: print(f" ✓ All weights loaded successfully") print(f"✓ Loaded from {source_name}") if LOAD_FROM == "hub": try: weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors") except: weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.pt") weights = load_weights(weights_path) load_model_weights(model, weights, HF_REPO) elif LOAD_FROM.startswith("hub:"): ckpt_name = LOAD_FROM[4:] for ext in [".safetensors", ".pt", ""]: try: if ckpt_name.endswith((".safetensors", ".pt")): filename = ckpt_name if "/" in ckpt_name else f"checkpoints/{ckpt_name}" else: filename = f"checkpoints/{ckpt_name}{ext}" weights_path = hf_hub_download(repo_id=HF_REPO, filename=filename) weights = load_weights(weights_path) load_model_weights(model, weights, f"{HF_REPO}/{filename}") break except Exception as e: continue else: raise ValueError(f"Could not find checkpoint: {ckpt_name}") elif LOAD_FROM.startswith("local:"): weights_path = LOAD_FROM[6:] weights = load_weights(weights_path) load_model_weights(model, weights, weights_path) else: raise ValueError(f"Unknown LOAD_FROM: {LOAD_FROM}") model.eval() print(f"Model params: {sum(p.numel() for p in model.parameters()):,}") # ============================================================================ # ENCODING FUNCTIONS # ============================================================================ @torch.inference_mode() def encode_prompt(prompt: str, max_length: int = 128): """Encode prompt with flan-t5-base and CLIP-L.""" t5_in = t5_tok( prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ).to(DEVICE) t5_out = t5_enc( input_ids=t5_in.input_ids, attention_mask=t5_in.attention_mask ).last_hidden_state clip_in = clip_tok( prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt" ).to(DEVICE) clip_out = clip_enc( input_ids=clip_in.input_ids, attention_mask=clip_in.attention_mask ) clip_pooled = clip_out.pooler_output return t5_out.to(DTYPE), clip_pooled.to(DTYPE) # ============================================================================ # FLOW MATCHING HELPERS # ============================================================================ def flux_shift(t, s=SHIFT): """Flux timestep shift - biases towards higher t (closer to data).""" return s * t / (1 + (s - 1) * t) # ============================================================================ # EULER DISCRETE FLOW MATCHING SAMPLER # ============================================================================ @torch.inference_mode() def euler_sample( model, prompt: str, negative_prompt: str = "", num_steps: int = 28, guidance_scale: float = 3.5, height: int = 512, width: int = 512, seed: int = None, ): """ Euler discrete sampler for rectified flow matching. Flow Matching formulation: x_t = (1 - t) * noise + t * data At t=0: noise, At t=1: data Velocity v = data - noise (constant) Sampling: Integrate from t=0 (noise) to t=1 (data) """ if seed is not None: torch.manual_seed(seed) generator = torch.Generator(device=DEVICE).manual_seed(seed) else: generator = None H_lat = height // 8 W_lat = width // 8 C_lat = 16 # Encode prompts t5_cond, clip_cond = encode_prompt(prompt) if guidance_scale > 1.0 and negative_prompt is not None: t5_uncond, clip_uncond = encode_prompt(negative_prompt) else: t5_uncond, clip_uncond = None, None # Start from pure noise (t=0) x = torch.randn(1, H_lat * W_lat, C_lat, device=DEVICE, dtype=DTYPE, generator=generator) # Create image position IDs img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE) # Timesteps: 0 → 1 with flux shift t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE) timesteps = flux_shift(t_linear, s=SHIFT) print(f"Sampling with {num_steps} Euler steps (t: 0→1, shifted)...") for i in range(num_steps): t_curr = timesteps[i] t_next = timesteps[i + 1] dt = t_next - t_curr t_batch = t_curr.unsqueeze(0) guidance_embed = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE) # Predict velocity v_cond = model( hidden_states=x, encoder_hidden_states=t5_cond, pooled_projections=clip_cond, timestep=t_batch, img_ids=img_ids, guidance=guidance_embed, ) # Classifier-free guidance if guidance_scale > 1.0 and t5_uncond is not None: v_uncond = model( hidden_states=x, encoder_hidden_states=t5_uncond, pooled_projections=clip_uncond, timestep=t_batch, img_ids=img_ids, guidance=guidance_embed, ) v = v_uncond + guidance_scale * (v_cond - v_uncond) else: v = v_cond # Euler step: x_{t+dt} = x_t + v * dt x = x + v * dt if (i + 1) % max(1, num_steps // 5) == 0 or i == num_steps - 1: print(f" Step {i+1}/{num_steps}, t={t_next.item():.3f}") # Reshape: (1, H*W, C) -> (1, C, H, W) latents = x.reshape(1, H_lat, W_lat, C_lat).permute(0, 3, 1, 2) return latents # ============================================================================ # DECODE LATENTS TO IMAGE # ============================================================================ @torch.inference_mode() def decode_latents(latents): """Decode VAE latents to PIL Image.""" latents = latents / vae.config.scaling_factor image = vae.decode(latents.to(vae.dtype)).sample image = (image / 2 + 0.5).clamp(0, 1) image = image[0].float().permute(1, 2, 0).cpu().numpy() image = (image * 255).astype(np.uint8) return Image.fromarray(image) # ============================================================================ # MAIN GENERATION FUNCTION # ============================================================================ def generate( prompt: str = POSITIVE_PROMPT, negative_prompt: str = NEGATIVE_PROMPT, num_steps: int = NUM_STEPS, guidance_scale: float = GUIDANCE_SCALE, height: int = HEIGHT, width: int = WIDTH, seed: int = SEED, save_path: str = OUTPUT_PATH, ): """ Generate an image from a text prompt. Args: prompt: Text description of desired image negative_prompt: What to avoid (empty string for none) num_steps: Number of Euler steps (20-50 recommended) guidance_scale: CFG scale (1.0=none, 3-7 typical) height: Output height in pixels (divisible by 8) width: Output width in pixels (divisible by 8) seed: Random seed (None for random) save_path: Path to save image (None to skip) Returns: PIL.Image """ #print(f"\nGenerating: '{prompt}'") #print(f"Settings: {num_steps} steps, cfg={guidance_scale}, {width}x{height}, seed={seed}") latents = euler_sample( model=model, prompt=prompt, negative_prompt=negative_prompt, num_steps=num_steps, guidance_scale=guidance_scale, height=height, width=width, seed=seed, ) #print("Decoding latents...") image = decode_latents(latents) if save_path: image.save(save_path) #print(f"✓ Saved to {save_path}") set_preview_from_pil(image, square=512) print("✓ Done!") return image # ============================================================================ # BATCH GENERATION # ============================================================================ def generate_batch( prompts: list, negative_prompt: str = "", num_steps: int = NUM_STEPS, guidance_scale: float = GUIDANCE_SCALE, height: int = HEIGHT, width: int = WIDTH, seed: int = SEED, output_dir: str = "./outputs", ): """Generate multiple images.""" os.makedirs(output_dir, exist_ok=True) images = [] for i, prompt in enumerate(prompts): img_seed = seed + i if seed is not None else None image = generate( prompt=prompt, negative_prompt=negative_prompt, num_steps=num_steps, guidance_scale=guidance_scale, height=height, width=width, seed=img_seed, save_path=os.path.join(output_dir, f"{i:03d}.png"), ) images.append(image) return images # ============================================================================ # QUICK TEST # ============================================================================ #print("\n" + "="*60) #print("TinyFlux-Deep Inference Ready!") #print("="*60) #print(f"Config: {config.hidden_size} hidden, {config.num_attention_heads} heads") #print(f" {config.num_double_layers} double, {config.num_single_layers} single layers") #print(f"Total: {sum(p.numel() for p in model.parameters()):,} parameters") # Example usage: image = generate() #image