tiny-flux-deep / scripts /inference_v3_colab.py
AbstractPhil's picture
Rename colab_inference_lailah_early.py to scripts/inference_v3_colab.py
dbd6197 verified
raw
history blame
37 kB
"""
TinyFlux-Lailah Inference
Loads the model code, the weights, and runs the inference based on the settings below.
Set up with only EULER for now.
No guarantees for any of this to work.
It's pretty bad in it's current phases, just check on it later if you're interested.
LICENSE: MIT
"""
POSITIVE_PROMPT = "woman" # @param {type:"string"}
NEGATIVE_PROMPT = "" # @param {type:"string"}
STEPS = 50 # @param {type:"integer"}
CFG_GUIDANCE = 5 # @param {type: "number"}
FLUX_SHIFT = 3 # @param {type: "number"}
SEED = 420 # @param {type: "integer"}
OUTPUT_PATH = "output.png" # @param {type:"string"}
WIDTH = 512 # @param {type: "integer"}
HEIGHT = 512 # @param {type: "integer"}
# Model loading
HF_REPO = "AbstractPhil/tiny-flux-deep" # @param {type:"string"}
# "hub", "hub:step_XXXXX", "local:/path/to/weights.safetensors"
LOAD_FROM = "hub:step_293750" # @param {type:"string"}
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
#@title Preview (updates in-place)
from IPython.display import display, Image as IPyImage, update_display
from PIL import Image as PIL
import numpy as np, io
_PREVIEW_DISPLAY_ID = "tf_preview"
preview_size = min(512, max(WIDTH, HEIGHT) // 2)
def _pil_to_png_bytes(img: PIL) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def init_preview(square: int = 256):
"""Show a black placeholder square once."""
black = PIL.fromarray(np.zeros((square, square, 3), dtype=np.uint8))
display(IPyImage(data=_pil_to_png_bytes(black)), display_id=_PREVIEW_DISPLAY_ID)
def set_preview_from_pil(img: PIL, square: int = 256):
"""Update the preview in-place with a PIL image."""
im = img.convert("RGB").copy()
im.thumbnail((square, square), resample=PIL.Resampling.LANCZOS)
# pad to square (so it stays a square widget)
canvas = PIL.fromarray(np.zeros((square, square, 3), dtype=np.uint8))
x = (square - im.size[0]) // 2
y = (square - im.size[1]) // 2
canvas.paste(im, (x, y))
update_display(IPyImage(data=_pil_to_png_bytes(canvas)), display_id=_PREVIEW_DISPLAY_ID)
def set_preview_from_path(path: str, square: int = 256):
"""Update preview from an image file path."""
set_preview_from_pil(PIL.open(path), square=square)
# initialize placeholder
init_preview(square=preview_size)
#set_preview_from_pil(image, square=preview_size)
"""
TinyFlux-Deep: Deeper variant with 15 double + 25 single blocks.
Config derived from checkpoint step_285625.safetensors:
- hidden_size: 512
- num_attention_heads: 4
- attention_head_dim: 128
- num_double_layers: 15
- num_single_layers: 25
- Uses biases in MLP
- Old RoPE format with cached freqs buffers
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Optional, Tuple, List
@dataclass
class TinyFluxDeepConfig:
"""Configuration for TinyFlux-Deep model."""
hidden_size: int = 512
num_attention_heads: int = 4
attention_head_dim: int = 128
in_channels: int = 16
patch_size: int = 1
joint_attention_dim: int = 768
pooled_projection_dim: int = 768
num_double_layers: int = 15
num_single_layers: int = 25
mlp_ratio: float = 4.0
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
guidance_embeds: bool = True
def __post_init__(self):
assert self.num_attention_heads * self.attention_head_dim == self.hidden_size
assert sum(self.axes_dims_rope) == self.attention_head_dim
# =============================================================================
# Normalization
# =============================================================================
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter('weight', None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
out = (x * norm).type_as(x)
if self.weight is not None:
out = out * self.weight
return out
# =============================================================================
# RoPE - Old format with cached frequency buffers (checkpoint compatible)
# =============================================================================
class EmbedND(nn.Module):
"""
Original TinyFlux RoPE with cached frequency buffers.
Matches checkpoint format with rope.freqs_0, rope.freqs_1, rope.freqs_2
"""
def __init__(self, theta: float = 10000.0, axes_dim: Tuple[int, int, int] = (16, 56, 56)):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
# Register frequency buffers (matches checkpoint keys rope.freqs_*)
for i, dim in enumerate(axes_dim):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer(f'freqs_{i}', freqs, persistent=True)
def forward(self, ids: torch.Tensor) -> torch.Tensor:
"""
Args:
ids: (N, 3) position indices [temporal, height, width]
Returns:
rope: (N, 1, head_dim) interleaved [cos, sin, cos, sin, ...]
"""
device = ids.device
n_axes = ids.shape[-1]
emb_list = []
for i in range(n_axes):
freqs = getattr(self, f'freqs_{i}').to(device)
pos = ids[:, i].float()
angles = pos.unsqueeze(-1) * freqs.unsqueeze(0) # (N, dim/2)
# Interleave cos and sin
cos = angles.cos()
sin = angles.sin()
emb = torch.stack([cos, sin], dim=-1).flatten(-2) # (N, dim)
emb_list.append(emb)
rope = torch.cat(emb_list, dim=-1) # (N, head_dim)
return rope.unsqueeze(1) # (N, 1, head_dim)
def apply_rotary_emb_old(
x: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings (old interleaved format).
Args:
x: (B, H, N, D) query or key tensor
freqs_cis: (N, 1, D) interleaved [cos0, sin0, cos1, sin1, ...]
Returns:
Rotated tensor of same shape
"""
# freqs_cis is (N, 1, D) with interleaved cos/sin
freqs = freqs_cis.squeeze(1) # (N, D)
# Split interleaved cos/sin
cos = freqs[:, 0::2].repeat_interleave(2, dim=-1) # (N, D)
sin = freqs[:, 1::2].repeat_interleave(2, dim=-1) # (N, D)
cos = cos[None, None, :, :].to(x.device) # (1, 1, N, D)
sin = sin[None, None, :, :].to(x.device)
# Split into real/imag pairs and rotate
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(-2)
return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
# =============================================================================
# Embeddings
# =============================================================================
class MLPEmbedder(nn.Module):
"""MLP for embedding scalars (timestep, guidance)."""
def __init__(self, hidden_size: int):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(256, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
half_dim = 128
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
return self.mlp(emb)
# =============================================================================
# AdaLayerNorm
# =============================================================================
class AdaLayerNormZero(nn.Module):
"""AdaLN-Zero for double-stream blocks (6 params)."""
def __init__(self, hidden_size: int):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
self.norm = RMSNorm(hidden_size)
def forward(self, x: torch.Tensor, emb: torch.Tensor):
emb_out = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb_out.chunk(6, dim=-1)
x = self.norm(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZeroSingle(nn.Module):
"""AdaLN-Zero for single-stream blocks (3 params)."""
def __init__(self, hidden_size: int):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.norm = RMSNorm(hidden_size)
def forward(self, x: torch.Tensor, emb: torch.Tensor):
emb_out = self.linear(self.silu(emb))
shift, scale, gate = emb_out.chunk(3, dim=-1)
x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return x, gate
# =============================================================================
# Attention (original format - no Q/K norm, matches checkpoint)
# =============================================================================
class Attention(nn.Module):
"""Multi-head attention (original TinyFlux format, no Q/K norm)."""
def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
self.out_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
def forward(
self,
x: torch.Tensor,
rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3 x (B, H, N, D)
# Apply RoPE
if rope is not None:
q = apply_rotary_emb_old(q, rope)
k = apply_rotary_emb_old(k, rope)
# Scaled dot-product attention
attn = F.scaled_dot_product_attention(q, k, v)
out = attn.transpose(1, 2).reshape(B, N, -1)
return self.out_proj(out)
class JointAttention(nn.Module):
"""Joint attention for double-stream blocks (original format)."""
def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.txt_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
self.img_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
def forward(
self,
txt: torch.Tensor,
img: torch.Tensor,
rope: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
B, L, _ = txt.shape
_, N, _ = img.shape
txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim)
img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim)
txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4)
img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4)
# Apply RoPE to image only
if rope is not None:
img_q = apply_rotary_emb_old(img_q, rope)
img_k = apply_rotary_emb_old(img_k, rope)
# Concatenate for joint attention
k = torch.cat([txt_k, img_k], dim=2)
v = torch.cat([txt_v, img_v], dim=2)
txt_out = F.scaled_dot_product_attention(txt_q, k, v)
txt_out = txt_out.transpose(1, 2).reshape(B, L, -1)
img_out = F.scaled_dot_product_attention(img_q, k, v)
img_out = img_out.transpose(1, 2).reshape(B, N, -1)
return self.txt_out(txt_out), self.img_out(img_out)
# =============================================================================
# MLP (with bias - matches checkpoint)
# =============================================================================
class MLP(nn.Module):
"""Feed-forward network with GELU activation and biases."""
def __init__(self, hidden_size: int, mlp_ratio: float = 4.0):
super().__init__()
mlp_hidden = int(hidden_size * mlp_ratio)
self.fc1 = nn.Linear(hidden_size, mlp_hidden, bias=True) # bias=True for checkpoint compat
self.act = nn.GELU(approximate='tanh')
self.fc2 = nn.Linear(mlp_hidden, hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc2(self.act(self.fc1(x)))
# =============================================================================
# Transformer Blocks
# =============================================================================
class DoubleStreamBlock(nn.Module):
"""Double-stream transformer block."""
def __init__(self, config: TinyFluxDeepConfig):
super().__init__()
hidden = config.hidden_size
heads = config.num_attention_heads
head_dim = config.attention_head_dim
self.img_norm1 = AdaLayerNormZero(hidden)
self.txt_norm1 = AdaLayerNormZero(hidden)
self.attn = JointAttention(hidden, heads, head_dim, use_bias=False)
self.img_norm2 = RMSNorm(hidden)
self.txt_norm2 = RMSNorm(hidden)
self.img_mlp = MLP(hidden, config.mlp_ratio)
self.txt_mlp = MLP(hidden, config.mlp_ratio)
def forward(
self,
txt: torch.Tensor,
img: torch.Tensor,
vec: torch.Tensor,
rope: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_normed, img_gate_msa, img_shift_mlp, img_scale_mlp, img_gate_mlp = self.img_norm1(img, vec)
txt_normed, txt_gate_msa, txt_shift_mlp, txt_scale_mlp, txt_gate_mlp = self.txt_norm1(txt, vec)
txt_attn_out, img_attn_out = self.attn(txt_normed, img_normed, rope)
txt = txt + txt_gate_msa.unsqueeze(1) * txt_attn_out
img = img + img_gate_msa.unsqueeze(1) * img_attn_out
txt_mlp_in = self.txt_norm2(txt) * (1 + txt_scale_mlp.unsqueeze(1)) + txt_shift_mlp.unsqueeze(1)
img_mlp_in = self.img_norm2(img) * (1 + img_scale_mlp.unsqueeze(1)) + img_shift_mlp.unsqueeze(1)
txt = txt + txt_gate_mlp.unsqueeze(1) * self.txt_mlp(txt_mlp_in)
img = img + img_gate_mlp.unsqueeze(1) * self.img_mlp(img_mlp_in)
return txt, img
class SingleStreamBlock(nn.Module):
"""Single-stream transformer block."""
def __init__(self, config: TinyFluxDeepConfig):
super().__init__()
hidden = config.hidden_size
heads = config.num_attention_heads
head_dim = config.attention_head_dim
self.norm = AdaLayerNormZeroSingle(hidden)
self.attn = Attention(hidden, heads, head_dim, use_bias=False)
self.mlp = MLP(hidden, config.mlp_ratio)
self.norm2 = RMSNorm(hidden)
def forward(
self,
txt: torch.Tensor,
img: torch.Tensor,
vec: torch.Tensor,
rope: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
L = txt.shape[1]
x = torch.cat([txt, img], dim=1)
x_normed, gate = self.norm(x, vec)
x = x + gate.unsqueeze(1) * self.attn(x_normed, rope)
x = x + self.mlp(self.norm2(x))
txt, img = x.split([L, x.shape[1] - L], dim=1)
return txt, img
# =============================================================================
# Main Model
# =============================================================================
class TinyFluxDeep(nn.Module):
"""TinyFlux-Deep: 15 double + 25 single blocks."""
def __init__(self, config: Optional[TinyFluxDeepConfig] = None):
super().__init__()
self.config = config or TinyFluxDeepConfig()
cfg = self.config
# Input projections (with bias to match checkpoint)
self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size, bias=True)
self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size, bias=True)
# Conditioning
self.time_in = MLPEmbedder(cfg.hidden_size)
self.vector_in = nn.Sequential(
nn.SiLU(),
nn.Linear(cfg.pooled_projection_dim, cfg.hidden_size, bias=True)
)
if cfg.guidance_embeds:
self.guidance_in = MLPEmbedder(cfg.hidden_size)
# RoPE (old format with cached freqs)
self.rope = EmbedND(theta=10000.0, axes_dim=cfg.axes_dims_rope)
# Transformer blocks
self.double_blocks = nn.ModuleList([
DoubleStreamBlock(cfg) for _ in range(cfg.num_double_layers)
])
self.single_blocks = nn.ModuleList([
SingleStreamBlock(cfg) for _ in range(cfg.num_single_layers)
])
# Output
self.final_norm = RMSNorm(cfg.hidden_size)
self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels, bias=True)
self._init_weights()
def _init_weights(self):
def _init(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.apply(_init)
nn.init.zeros_(self.final_linear.weight)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
pooled_projections: torch.Tensor,
timestep: torch.Tensor,
img_ids: torch.Tensor,
txt_ids: Optional[torch.Tensor] = None,
guidance: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B = hidden_states.shape[0]
L = encoder_hidden_states.shape[1]
N = hidden_states.shape[1]
# Input projections
img = self.img_in(hidden_states)
txt = self.txt_in(encoder_hidden_states)
# Conditioning
vec = self.time_in(timestep)
vec = vec + self.vector_in(pooled_projections)
if self.config.guidance_embeds and guidance is not None:
vec = vec + self.guidance_in(guidance)
# Handle img_ids shape
if img_ids.ndim == 3:
img_ids = img_ids[0] # (N, 3)
# Compute RoPE for image positions
img_rope = self.rope(img_ids) # (N, 1, head_dim)
# Double-stream blocks
for block in self.double_blocks:
txt, img = block(txt, img, vec, img_rope)
# Build full sequence RoPE for single-stream
if txt_ids is None:
txt_ids = torch.zeros(L, 3, device=img_ids.device, dtype=img_ids.dtype)
elif txt_ids.ndim == 3:
txt_ids = txt_ids[0]
all_ids = torch.cat([txt_ids, img_ids], dim=0)
full_rope = self.rope(all_ids)
# Single-stream blocks
for block in self.single_blocks:
txt, img = block(txt, img, vec, full_rope)
# Output
img = self.final_norm(img)
img = self.final_linear(img)
return img
@staticmethod
def create_img_ids(batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
"""Create image position IDs for RoPE."""
img_ids = torch.zeros(height * width, 3, device=device)
for i in range(height):
for j in range(width):
idx = i * width + j
img_ids[idx, 0] = 0
img_ids[idx, 1] = i
img_ids[idx, 2] = j
return img_ids
@staticmethod
def create_txt_ids(text_len: int, device: torch.device) -> torch.Tensor:
"""Create text position IDs."""
txt_ids = torch.zeros(text_len, 3, device=device)
txt_ids[:, 0] = torch.arange(text_len, device=device)
return txt_ids
def count_parameters(self) -> dict:
"""Count parameters by component."""
counts = {}
counts['img_in'] = sum(p.numel() for p in self.img_in.parameters())
counts['txt_in'] = sum(p.numel() for p in self.txt_in.parameters())
counts['time_in'] = sum(p.numel() for p in self.time_in.parameters())
counts['vector_in'] = sum(p.numel() for p in self.vector_in.parameters())
if hasattr(self, 'guidance_in'):
counts['guidance_in'] = sum(p.numel() for p in self.guidance_in.parameters())
counts['double_blocks'] = sum(p.numel() for p in self.double_blocks.parameters())
counts['single_blocks'] = sum(p.numel() for p in self.single_blocks.parameters())
counts['final'] = sum(p.numel() for p in self.final_norm.parameters()) + \
sum(p.numel() for p in self.final_linear.parameters())
counts['total'] = sum(p.numel() for p in self.parameters())
return counts
# =============================================================================
# Test
# =============================================================================
def test_model():
"""Test TinyFlux-Deep model."""
print("=" * 60)
print("TinyFlux-Deep Test")
print("=" * 60)
config = TinyFluxDeepConfig()
model = TinyFluxDeep(config)
counts = model.count_parameters()
print(f"\nConfig:")
print(f" hidden_size: {config.hidden_size}")
print(f" num_attention_heads: {config.num_attention_heads}")
print(f" attention_head_dim: {config.attention_head_dim}")
print(f" num_double_layers: {config.num_double_layers}")
print(f" num_single_layers: {config.num_single_layers}")
print(f"\nParameters:")
for name, count in counts.items():
print(f" {name}: {count:,}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
B, H, W = 2, 64, 64
L = 77
hidden_states = torch.randn(B, H * W, config.in_channels, device=device)
encoder_hidden_states = torch.randn(B, L, config.joint_attention_dim, device=device)
pooled_projections = torch.randn(B, config.pooled_projection_dim, device=device)
timestep = torch.rand(B, device=device)
img_ids = TinyFluxDeep.create_img_ids(B, H, W, device)
txt_ids = TinyFluxDeep.create_txt_ids(L, device)
guidance = torch.ones(B, device=device) * 3.5
with torch.no_grad():
output = model(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projections,
timestep=timestep,
img_ids=img_ids,
txt_ids=txt_ids,
guidance=guidance,
)
print(f"\nOutput shape: {output.shape}")
print(f"Output range: [{output.min():.4f}, {output.max():.4f}]")
print("\n✓ Forward pass successful!")
#if __name__ == "__main__":
# test_model()
# ============================================================================
# TinyFlux-Deep Inference Cell - Euler Discrete Flow Matching
# ============================================================================
# Run the model cell before this one (defines TinyFluxDeep, TinyFluxDeepConfig)
# Loads from: AbstractPhil/tiny-flux-deep or local checkpoint
# ============================================================================
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
from PIL import Image
import numpy as np
import os
# Generation settings
NUM_STEPS = STEPS
GUIDANCE_SCALE = CFG_GUIDANCE
SHIFT = FLUX_SHIFT
# ============================================================================
# LOAD TEXT ENCODERS
# ============================================================================
print("Loading text encoders...")
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval()
clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
# ============================================================================
# LOAD VAE
# ============================================================================
print("Loading Flux VAE...")
vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
subfolder="vae",
torch_dtype=DTYPE
).to(DEVICE).eval()
# ============================================================================
# LOAD TINYFLUX-DEEP MODEL
# ============================================================================
print(f"Loading TinyFlux-Deep from: {LOAD_FROM}")
# Use TinyFluxDeep (512 hidden, 4 heads, 15 double, 25 single)
config = TinyFluxDeepConfig()
model = TinyFluxDeep(config).to(DEVICE).to(DTYPE)
# Deprecated keys that may exist in old checkpoints but aren't needed
DEPRECATED_KEYS = {'time_in.sin_basis', 'guidance_in.sin_basis'}
def load_weights(path):
"""Load weights from .safetensors or .pt file."""
if path.endswith(".safetensors"):
state_dict = load_file(path)
elif path.endswith(".pt"):
ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
if isinstance(ckpt, dict):
if "model" in ckpt:
state_dict = ckpt["model"]
elif "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
else:
state_dict = ckpt
else:
state_dict = ckpt
else:
try:
state_dict = load_file(path)
except:
state_dict = torch.load(path, map_location=DEVICE, weights_only=False)
# Strip "_orig_mod." prefix from keys (added by torch.compile)
if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
print(" Stripping torch.compile prefix...")
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
return state_dict
def load_model_weights(model, weights, source_name):
"""Load weights with verbose reporting."""
# Filter out deprecated keys
filtered_weights = {k: v for k, v in weights.items() if k not in DEPRECATED_KEYS}
deprecated_found = [k for k in weights.keys() if k in DEPRECATED_KEYS]
if deprecated_found:
print(f" ✓ Ignored deprecated keys: {deprecated_found}")
missing, unexpected = model.load_state_dict(filtered_weights, strict=False)
if missing:
print(f" ⚠ Missing keys: {missing[:10]}{'...' if len(missing) > 10 else ''}")
if unexpected:
print(f" ⚠ Unexpected keys: {unexpected[:10]}{'...' if len(unexpected) > 10 else ''}")
if not missing and not unexpected:
print(f" ✓ All weights loaded successfully")
print(f"✓ Loaded from {source_name}")
if LOAD_FROM == "hub":
try:
weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors")
except:
weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.pt")
weights = load_weights(weights_path)
load_model_weights(model, weights, HF_REPO)
elif LOAD_FROM.startswith("hub:"):
ckpt_name = LOAD_FROM[4:]
for ext in [".safetensors", ".pt", ""]:
try:
if ckpt_name.endswith((".safetensors", ".pt")):
filename = ckpt_name if "/" in ckpt_name else f"checkpoints/{ckpt_name}"
else:
filename = f"checkpoints/{ckpt_name}{ext}"
weights_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
weights = load_weights(weights_path)
load_model_weights(model, weights, f"{HF_REPO}/{filename}")
break
except Exception as e:
continue
else:
raise ValueError(f"Could not find checkpoint: {ckpt_name}")
elif LOAD_FROM.startswith("local:"):
weights_path = LOAD_FROM[6:]
weights = load_weights(weights_path)
load_model_weights(model, weights, weights_path)
else:
raise ValueError(f"Unknown LOAD_FROM: {LOAD_FROM}")
model.eval()
print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")
# ============================================================================
# ENCODING FUNCTIONS
# ============================================================================
@torch.inference_mode()
def encode_prompt(prompt: str, max_length: int = 128):
"""Encode prompt with flan-t5-base and CLIP-L."""
t5_in = t5_tok(
prompt,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).to(DEVICE)
t5_out = t5_enc(
input_ids=t5_in.input_ids,
attention_mask=t5_in.attention_mask
).last_hidden_state
clip_in = clip_tok(
prompt,
max_length=77,
padding="max_length",
truncation=True,
return_tensors="pt"
).to(DEVICE)
clip_out = clip_enc(
input_ids=clip_in.input_ids,
attention_mask=clip_in.attention_mask
)
clip_pooled = clip_out.pooler_output
return t5_out.to(DTYPE), clip_pooled.to(DTYPE)
# ============================================================================
# FLOW MATCHING HELPERS
# ============================================================================
def flux_shift(t, s=SHIFT):
"""Flux timestep shift - biases towards higher t (closer to data)."""
return s * t / (1 + (s - 1) * t)
# ============================================================================
# EULER DISCRETE FLOW MATCHING SAMPLER
# ============================================================================
@torch.inference_mode()
def euler_sample(
model,
prompt: str,
negative_prompt: str = "",
num_steps: int = 28,
guidance_scale: float = 3.5,
height: int = 512,
width: int = 512,
seed: int = None,
):
"""
Euler discrete sampler for rectified flow matching.
Flow Matching formulation:
x_t = (1 - t) * noise + t * data
At t=0: noise, At t=1: data
Velocity v = data - noise (constant)
Sampling: Integrate from t=0 (noise) to t=1 (data)
"""
if seed is not None:
torch.manual_seed(seed)
generator = torch.Generator(device=DEVICE).manual_seed(seed)
else:
generator = None
H_lat = height // 8
W_lat = width // 8
C_lat = 16
# Encode prompts
t5_cond, clip_cond = encode_prompt(prompt)
if guidance_scale > 1.0 and negative_prompt is not None:
t5_uncond, clip_uncond = encode_prompt(negative_prompt)
else:
t5_uncond, clip_uncond = None, None
# Start from pure noise (t=0)
x = torch.randn(1, H_lat * W_lat, C_lat, device=DEVICE, dtype=DTYPE, generator=generator)
# Create image position IDs
img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
# Timesteps: 0 → 1 with flux shift
t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
timesteps = flux_shift(t_linear, s=SHIFT)
print(f"Sampling with {num_steps} Euler steps (t: 0→1, shifted)...")
for i in range(num_steps):
t_curr = timesteps[i]
t_next = timesteps[i + 1]
dt = t_next - t_curr
t_batch = t_curr.unsqueeze(0)
guidance_embed = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
# Predict velocity
v_cond = model(
hidden_states=x,
encoder_hidden_states=t5_cond,
pooled_projections=clip_cond,
timestep=t_batch,
img_ids=img_ids,
guidance=guidance_embed,
)
# Classifier-free guidance
if guidance_scale > 1.0 and t5_uncond is not None:
v_uncond = model(
hidden_states=x,
encoder_hidden_states=t5_uncond,
pooled_projections=clip_uncond,
timestep=t_batch,
img_ids=img_ids,
guidance=guidance_embed,
)
v = v_uncond + guidance_scale * (v_cond - v_uncond)
else:
v = v_cond
# Euler step: x_{t+dt} = x_t + v * dt
x = x + v * dt
if (i + 1) % max(1, num_steps // 5) == 0 or i == num_steps - 1:
print(f" Step {i+1}/{num_steps}, t={t_next.item():.3f}")
# Reshape: (1, H*W, C) -> (1, C, H, W)
latents = x.reshape(1, H_lat, W_lat, C_lat).permute(0, 3, 1, 2)
return latents
# ============================================================================
# DECODE LATENTS TO IMAGE
# ============================================================================
@torch.inference_mode()
def decode_latents(latents):
"""Decode VAE latents to PIL Image."""
latents = latents / vae.config.scaling_factor
image = vae.decode(latents.to(vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image[0].float().permute(1, 2, 0).cpu().numpy()
image = (image * 255).astype(np.uint8)
return Image.fromarray(image)
# ============================================================================
# MAIN GENERATION FUNCTION
# ============================================================================
def generate(
prompt: str = POSITIVE_PROMPT,
negative_prompt: str = NEGATIVE_PROMPT,
num_steps: int = NUM_STEPS,
guidance_scale: float = GUIDANCE_SCALE,
height: int = HEIGHT,
width: int = WIDTH,
seed: int = SEED,
save_path: str = OUTPUT_PATH,
):
"""
Generate an image from a text prompt.
Args:
prompt: Text description of desired image
negative_prompt: What to avoid (empty string for none)
num_steps: Number of Euler steps (20-50 recommended)
guidance_scale: CFG scale (1.0=none, 3-7 typical)
height: Output height in pixels (divisible by 8)
width: Output width in pixels (divisible by 8)
seed: Random seed (None for random)
save_path: Path to save image (None to skip)
Returns:
PIL.Image
"""
#print(f"\nGenerating: '{prompt}'")
#print(f"Settings: {num_steps} steps, cfg={guidance_scale}, {width}x{height}, seed={seed}")
latents = euler_sample(
model=model,
prompt=prompt,
negative_prompt=negative_prompt,
num_steps=num_steps,
guidance_scale=guidance_scale,
height=height,
width=width,
seed=seed,
)
#print("Decoding latents...")
image = decode_latents(latents)
if save_path:
image.save(save_path)
#print(f"✓ Saved to {save_path}")
set_preview_from_pil(image, square=512)
print("✓ Done!")
return image
# ============================================================================
# BATCH GENERATION
# ============================================================================
def generate_batch(
prompts: list,
negative_prompt: str = "",
num_steps: int = NUM_STEPS,
guidance_scale: float = GUIDANCE_SCALE,
height: int = HEIGHT,
width: int = WIDTH,
seed: int = SEED,
output_dir: str = "./outputs",
):
"""Generate multiple images."""
os.makedirs(output_dir, exist_ok=True)
images = []
for i, prompt in enumerate(prompts):
img_seed = seed + i if seed is not None else None
image = generate(
prompt=prompt,
negative_prompt=negative_prompt,
num_steps=num_steps,
guidance_scale=guidance_scale,
height=height,
width=width,
seed=img_seed,
save_path=os.path.join(output_dir, f"{i:03d}.png"),
)
images.append(image)
return images
# ============================================================================
# QUICK TEST
# ============================================================================
#print("\n" + "="*60)
#print("TinyFlux-Deep Inference Ready!")
#print("="*60)
#print(f"Config: {config.hidden_size} hidden, {config.num_attention_heads} heads")
#print(f" {config.num_double_layers} double, {config.num_single_layers} single layers")
#print(f"Total: {sum(p.numel() for p in model.parameters()):,} parameters")
# Example usage:
image = generate()
#image