tiny-flux-deep / model_v2.py
AbstractPhil's picture
Create model_v2.py
cacfc43 verified
"""
TinyFlux-Deep: Deeper variant with 15 double + 25 single blocks.
Config derived from checkpoint step_285625.safetensors:
- hidden_size: 512
- num_attention_heads: 4
- attention_head_dim: 128
- num_double_layers: 15
- num_single_layers: 25
- Uses biases in MLP
- Old RoPE format with cached freqs buffers
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Optional, Tuple, List
@dataclass
class TinyFluxDeepConfig:
"""Configuration for TinyFlux-Deep model."""
hidden_size: int = 512
num_attention_heads: int = 4
attention_head_dim: int = 128
in_channels: int = 16
patch_size: int = 1
joint_attention_dim: int = 768
pooled_projection_dim: int = 768
num_double_layers: int = 15
num_single_layers: int = 25
mlp_ratio: float = 4.0
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
guidance_embeds: bool = True
def __post_init__(self):
assert self.num_attention_heads * self.attention_head_dim == self.hidden_size
assert sum(self.axes_dims_rope) == self.attention_head_dim
# =============================================================================
# Normalization
# =============================================================================
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter('weight', None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
out = (x * norm).type_as(x)
if self.weight is not None:
out = out * self.weight
return out
# =============================================================================
# RoPE - Old format with cached frequency buffers (checkpoint compatible)
# =============================================================================
class EmbedND(nn.Module):
"""
Original TinyFlux RoPE with cached frequency buffers.
Matches checkpoint format with rope.freqs_0, rope.freqs_1, rope.freqs_2
"""
def __init__(self, theta: float = 10000.0, axes_dim: Tuple[int, int, int] = (16, 56, 56)):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
# Register frequency buffers (matches checkpoint keys rope.freqs_*)
for i, dim in enumerate(axes_dim):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer(f'freqs_{i}', freqs, persistent=True)
def forward(self, ids: torch.Tensor) -> torch.Tensor:
"""
Args:
ids: (N, 3) position indices [temporal, height, width]
Returns:
rope: (N, 1, head_dim) interleaved [cos, sin, cos, sin, ...]
"""
device = ids.device
n_axes = ids.shape[-1]
emb_list = []
for i in range(n_axes):
freqs = getattr(self, f'freqs_{i}').to(device)
pos = ids[:, i].float()
angles = pos.unsqueeze(-1) * freqs.unsqueeze(0) # (N, dim/2)
# Interleave cos and sin
cos = angles.cos()
sin = angles.sin()
emb = torch.stack([cos, sin], dim=-1).flatten(-2) # (N, dim)
emb_list.append(emb)
rope = torch.cat(emb_list, dim=-1) # (N, head_dim)
return rope.unsqueeze(1) # (N, 1, head_dim)
def apply_rotary_emb_old(
x: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings (old interleaved format).
Args:
x: (B, H, N, D) query or key tensor
freqs_cis: (N, 1, D) interleaved [cos0, sin0, cos1, sin1, ...]
Returns:
Rotated tensor of same shape
"""
# freqs_cis is (N, 1, D) with interleaved cos/sin
freqs = freqs_cis.squeeze(1) # (N, D)
# Split interleaved cos/sin
cos = freqs[:, 0::2].repeat_interleave(2, dim=-1) # (N, D)
sin = freqs[:, 1::2].repeat_interleave(2, dim=-1) # (N, D)
cos = cos[None, None, :, :].to(x.device) # (1, 1, N, D)
sin = sin[None, None, :, :].to(x.device)
# Split into real/imag pairs and rotate
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(-2)
return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
# =============================================================================
# Embeddings
# =============================================================================
class MLPEmbedder(nn.Module):
"""MLP for embedding scalars (timestep, guidance)."""
def __init__(self, hidden_size: int):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(256, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
half_dim = 128
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
return self.mlp(emb)
# =============================================================================
# AdaLayerNorm
# =============================================================================
class AdaLayerNormZero(nn.Module):
"""AdaLN-Zero for double-stream blocks (6 params)."""
def __init__(self, hidden_size: int):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
self.norm = RMSNorm(hidden_size)
def forward(self, x: torch.Tensor, emb: torch.Tensor):
emb_out = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb_out.chunk(6, dim=-1)
x = self.norm(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZeroSingle(nn.Module):
"""AdaLN-Zero for single-stream blocks (3 params)."""
def __init__(self, hidden_size: int):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.norm = RMSNorm(hidden_size)
def forward(self, x: torch.Tensor, emb: torch.Tensor):
emb_out = self.linear(self.silu(emb))
shift, scale, gate = emb_out.chunk(3, dim=-1)
x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return x, gate
# =============================================================================
# Attention (original format - no Q/K norm, matches checkpoint)
# =============================================================================
class Attention(nn.Module):
"""Multi-head attention (original TinyFlux format, no Q/K norm)."""
def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
self.out_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
def forward(
self,
x: torch.Tensor,
rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3 x (B, H, N, D)
# Apply RoPE
if rope is not None:
q = apply_rotary_emb_old(q, rope)
k = apply_rotary_emb_old(k, rope)
# Scaled dot-product attention
attn = F.scaled_dot_product_attention(q, k, v)
out = attn.transpose(1, 2).reshape(B, N, -1)
return self.out_proj(out)
class JointAttention(nn.Module):
"""Joint attention for double-stream blocks (original format)."""
def __init__(self, hidden_size: int, num_heads: int, head_dim: int, use_bias: bool = False):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.txt_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
self.img_qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=use_bias)
self.txt_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
self.img_out = nn.Linear(num_heads * head_dim, hidden_size, bias=use_bias)
def forward(
self,
txt: torch.Tensor,
img: torch.Tensor,
rope: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
B, L, _ = txt.shape
_, N, _ = img.shape
txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim)
img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim)
txt_q, txt_k, txt_v = txt_qkv.permute(2, 0, 3, 1, 4)
img_q, img_k, img_v = img_qkv.permute(2, 0, 3, 1, 4)
# Apply RoPE to image only
if rope is not None:
img_q = apply_rotary_emb_old(img_q, rope)
img_k = apply_rotary_emb_old(img_k, rope)
# Concatenate for joint attention
k = torch.cat([txt_k, img_k], dim=2)
v = torch.cat([txt_v, img_v], dim=2)
txt_out = F.scaled_dot_product_attention(txt_q, k, v)
txt_out = txt_out.transpose(1, 2).reshape(B, L, -1)
img_out = F.scaled_dot_product_attention(img_q, k, v)
img_out = img_out.transpose(1, 2).reshape(B, N, -1)
return self.txt_out(txt_out), self.img_out(img_out)
# =============================================================================
# MLP (with bias - matches checkpoint)
# =============================================================================
class MLP(nn.Module):
"""Feed-forward network with GELU activation and biases."""
def __init__(self, hidden_size: int, mlp_ratio: float = 4.0):
super().__init__()
mlp_hidden = int(hidden_size * mlp_ratio)
self.fc1 = nn.Linear(hidden_size, mlp_hidden, bias=True) # bias=True for checkpoint compat
self.act = nn.GELU(approximate='tanh')
self.fc2 = nn.Linear(mlp_hidden, hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc2(self.act(self.fc1(x)))
# =============================================================================
# Transformer Blocks
# =============================================================================
class DoubleStreamBlock(nn.Module):
"""Double-stream transformer block."""
def __init__(self, config: TinyFluxDeepConfig):
super().__init__()
hidden = config.hidden_size
heads = config.num_attention_heads
head_dim = config.attention_head_dim
self.img_norm1 = AdaLayerNormZero(hidden)
self.txt_norm1 = AdaLayerNormZero(hidden)
self.attn = JointAttention(hidden, heads, head_dim, use_bias=False)
self.img_norm2 = RMSNorm(hidden)
self.txt_norm2 = RMSNorm(hidden)
self.img_mlp = MLP(hidden, config.mlp_ratio)
self.txt_mlp = MLP(hidden, config.mlp_ratio)
def forward(
self,
txt: torch.Tensor,
img: torch.Tensor,
vec: torch.Tensor,
rope: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_normed, img_gate_msa, img_shift_mlp, img_scale_mlp, img_gate_mlp = self.img_norm1(img, vec)
txt_normed, txt_gate_msa, txt_shift_mlp, txt_scale_mlp, txt_gate_mlp = self.txt_norm1(txt, vec)
txt_attn_out, img_attn_out = self.attn(txt_normed, img_normed, rope)
txt = txt + txt_gate_msa.unsqueeze(1) * txt_attn_out
img = img + img_gate_msa.unsqueeze(1) * img_attn_out
txt_mlp_in = self.txt_norm2(txt) * (1 + txt_scale_mlp.unsqueeze(1)) + txt_shift_mlp.unsqueeze(1)
img_mlp_in = self.img_norm2(img) * (1 + img_scale_mlp.unsqueeze(1)) + img_shift_mlp.unsqueeze(1)
txt = txt + txt_gate_mlp.unsqueeze(1) * self.txt_mlp(txt_mlp_in)
img = img + img_gate_mlp.unsqueeze(1) * self.img_mlp(img_mlp_in)
return txt, img
class SingleStreamBlock(nn.Module):
"""Single-stream transformer block."""
def __init__(self, config: TinyFluxDeepConfig):
super().__init__()
hidden = config.hidden_size
heads = config.num_attention_heads
head_dim = config.attention_head_dim
self.norm = AdaLayerNormZeroSingle(hidden)
self.attn = Attention(hidden, heads, head_dim, use_bias=False)
self.mlp = MLP(hidden, config.mlp_ratio)
self.norm2 = RMSNorm(hidden)
def forward(
self,
txt: torch.Tensor,
img: torch.Tensor,
vec: torch.Tensor,
rope: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
L = txt.shape[1]
x = torch.cat([txt, img], dim=1)
x_normed, gate = self.norm(x, vec)
x = x + gate.unsqueeze(1) * self.attn(x_normed, rope)
x = x + self.mlp(self.norm2(x))
txt, img = x.split([L, x.shape[1] - L], dim=1)
return txt, img
# =============================================================================
# Main Model
# =============================================================================
class TinyFluxDeep(nn.Module):
"""TinyFlux-Deep: 15 double + 25 single blocks."""
def __init__(self, config: Optional[TinyFluxDeepConfig] = None):
super().__init__()
self.config = config or TinyFluxDeepConfig()
cfg = self.config
# Input projections (with bias to match checkpoint)
self.img_in = nn.Linear(cfg.in_channels, cfg.hidden_size, bias=True)
self.txt_in = nn.Linear(cfg.joint_attention_dim, cfg.hidden_size, bias=True)
# Conditioning
self.time_in = MLPEmbedder(cfg.hidden_size)
self.vector_in = nn.Sequential(
nn.SiLU(),
nn.Linear(cfg.pooled_projection_dim, cfg.hidden_size, bias=True)
)
if cfg.guidance_embeds:
self.guidance_in = MLPEmbedder(cfg.hidden_size)
# RoPE (old format with cached freqs)
self.rope = EmbedND(theta=10000.0, axes_dim=cfg.axes_dims_rope)
# Transformer blocks
self.double_blocks = nn.ModuleList([
DoubleStreamBlock(cfg) for _ in range(cfg.num_double_layers)
])
self.single_blocks = nn.ModuleList([
SingleStreamBlock(cfg) for _ in range(cfg.num_single_layers)
])
# Output
self.final_norm = RMSNorm(cfg.hidden_size)
self.final_linear = nn.Linear(cfg.hidden_size, cfg.in_channels, bias=True)
self._init_weights()
def _init_weights(self):
def _init(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.apply(_init)
nn.init.zeros_(self.final_linear.weight)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
pooled_projections: torch.Tensor,
timestep: torch.Tensor,
img_ids: torch.Tensor,
txt_ids: Optional[torch.Tensor] = None,
guidance: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B = hidden_states.shape[0]
L = encoder_hidden_states.shape[1]
N = hidden_states.shape[1]
# Input projections
img = self.img_in(hidden_states)
txt = self.txt_in(encoder_hidden_states)
# Conditioning
vec = self.time_in(timestep)
vec = vec + self.vector_in(pooled_projections)
if self.config.guidance_embeds and guidance is not None:
vec = vec + self.guidance_in(guidance)
# Handle img_ids shape
if img_ids.ndim == 3:
img_ids = img_ids[0] # (N, 3)
# Compute RoPE for image positions
img_rope = self.rope(img_ids) # (N, 1, head_dim)
# Double-stream blocks
for block in self.double_blocks:
txt, img = block(txt, img, vec, img_rope)
# Build full sequence RoPE for single-stream
if txt_ids is None:
txt_ids = torch.zeros(L, 3, device=img_ids.device, dtype=img_ids.dtype)
elif txt_ids.ndim == 3:
txt_ids = txt_ids[0]
all_ids = torch.cat([txt_ids, img_ids], dim=0)
full_rope = self.rope(all_ids)
# Single-stream blocks
for block in self.single_blocks:
txt, img = block(txt, img, vec, full_rope)
# Output
img = self.final_norm(img)
img = self.final_linear(img)
return img
@staticmethod
def create_img_ids(batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
"""Create image position IDs for RoPE."""
img_ids = torch.zeros(height * width, 3, device=device)
for i in range(height):
for j in range(width):
idx = i * width + j
img_ids[idx, 0] = 0
img_ids[idx, 1] = i
img_ids[idx, 2] = j
return img_ids
@staticmethod
def create_txt_ids(text_len: int, device: torch.device) -> torch.Tensor:
"""Create text position IDs."""
txt_ids = torch.zeros(text_len, 3, device=device)
txt_ids[:, 0] = torch.arange(text_len, device=device)
return txt_ids
def count_parameters(self) -> dict:
"""Count parameters by component."""
counts = {}
counts['img_in'] = sum(p.numel() for p in self.img_in.parameters())
counts['txt_in'] = sum(p.numel() for p in self.txt_in.parameters())
counts['time_in'] = sum(p.numel() for p in self.time_in.parameters())
counts['vector_in'] = sum(p.numel() for p in self.vector_in.parameters())
if hasattr(self, 'guidance_in'):
counts['guidance_in'] = sum(p.numel() for p in self.guidance_in.parameters())
counts['double_blocks'] = sum(p.numel() for p in self.double_blocks.parameters())
counts['single_blocks'] = sum(p.numel() for p in self.single_blocks.parameters())
counts['final'] = sum(p.numel() for p in self.final_norm.parameters()) + \
sum(p.numel() for p in self.final_linear.parameters())
counts['total'] = sum(p.numel() for p in self.parameters())
return counts
# =============================================================================
# Test
# =============================================================================
def test_model():
"""Test TinyFlux-Deep model."""
print("=" * 60)
print("TinyFlux-Deep Test")
print("=" * 60)
config = TinyFluxDeepConfig()
model = TinyFluxDeep(config)
counts = model.count_parameters()
print(f"\nConfig:")
print(f" hidden_size: {config.hidden_size}")
print(f" num_attention_heads: {config.num_attention_heads}")
print(f" attention_head_dim: {config.attention_head_dim}")
print(f" num_double_layers: {config.num_double_layers}")
print(f" num_single_layers: {config.num_single_layers}")
print(f"\nParameters:")
for name, count in counts.items():
print(f" {name}: {count:,}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
B, H, W = 2, 64, 64
L = 77
hidden_states = torch.randn(B, H * W, config.in_channels, device=device)
encoder_hidden_states = torch.randn(B, L, config.joint_attention_dim, device=device)
pooled_projections = torch.randn(B, config.pooled_projection_dim, device=device)
timestep = torch.rand(B, device=device)
img_ids = TinyFluxDeep.create_img_ids(B, H, W, device)
txt_ids = TinyFluxDeep.create_txt_ids(L, device)
guidance = torch.ones(B, device=device) * 3.5
with torch.no_grad():
output = model(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projections,
timestep=timestep,
img_ids=img_ids,
txt_ids=txt_ids,
guidance=guidance,
)
print(f"\nOutput shape: {output.shape}")
print(f"Output range: [{output.min():.4f}, {output.max():.4f}]")
print("\n✓ Forward pass successful!")
if __name__ == "__main__":
test_model()