tiny-flux-deep / port_tiny_to_deep.py
AbstractPhil's picture
Update port_tiny_to_deep.py
457e2ff verified
raw
history blame
35.6 kB
# ============================================================================
# TinyFlux β†’ TinyFlux-Deep Porting Script
# ============================================================================
# Expands: 3 single + 3 double β†’ 25 single + 15 double
# Heads: 2 β†’ 4 (doubles heads, hidden 256 β†’ 512)
# Freezes ported layers, trains new ones
# ============================================================================
import torch
import torch.nn as nn
from safetensors.torch import load_file, save_file
from huggingface_hub import hf_hub_download, HfApi
from dataclasses import dataclass
from copy import deepcopy
from typing import Tuple
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16
# ============================================================================
# CONFIGS
# ============================================================================
@dataclass
class TinyFluxConfig:
"""Original small config - matches TinyFlux model on hub (hidden=768, 6 heads)"""
# Core dimensions (detected from hub: 768 hidden, 6 heads)
hidden_size: int = 768
num_attention_heads: int = 6
attention_head_dim: int = 128 # 6 * 128 = 768
# Input/output
in_channels: int = 16
patch_size: int = 1
# Text encoder interfaces
joint_attention_dim: int = 768
pooled_projection_dim: int = 768
# Layers
num_double_layers: int = 3
num_single_layers: int = 3
# MLP
mlp_ratio: float = 4.0
# RoPE
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
# Misc
guidance_embeds: bool = True
@dataclass
class TinyFluxDeepConfig:
"""
Expanded deep config - matches TinyFlux model attribute names exactly.
Original TinyFlux: hidden_size=256, 2 heads (256/128=2)
Deep variant: hidden_size=512, 4 heads (4*128=512) - double heads
"""
# Core dimensions
hidden_size: int = 512 # 4 heads * 128 head_dim
num_attention_heads: int = 4 # 2 β†’ 4 (double the heads)
attention_head_dim: int = 128 # Same (required for RoPE)
# Input/output
in_channels: int = 16
patch_size: int = 1
# Text encoder interfaces
joint_attention_dim: int = 768 # T5 embed dim
pooled_projection_dim: int = 768 # CLIP embed dim
# Layers (uses _layers not _blocks)
num_double_layers: int = 15 # 3 β†’ 15
num_single_layers: int = 25 # 3 β†’ 25 (more singles like original Flux)
# MLP
mlp_ratio: float = 4.0
# RoPE (must sum to head_dim=128)
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
# Misc
guidance_embeds: bool = True
def __post_init__(self):
assert self.num_attention_heads * self.attention_head_dim == self.hidden_size, \
f"heads ({self.num_attention_heads}) * head_dim ({self.attention_head_dim}) != hidden ({self.hidden_size})"
# ============================================================================
# LAYER MAPPING
# ============================================================================
# Single blocks: 3 β†’ 25
# - Layer 0 β†’ position 0 (frozen)
# - Layer 1 β†’ positions 8, 12, 16 (center, spaced, frozen)
# - Layer 2 β†’ position 24 (frozen)
# - Rest β†’ new (trainable)
SINGLE_MAPPING = {
0: [0], # Old layer 0 β†’ new position 0
1: [8, 12, 16], # Old layer 1 β†’ new positions 8, 12, 16
2: [24], # Old layer 2 β†’ new position 24
}
SINGLE_FROZEN = {0, 8, 12, 16, 24} # These positions are frozen
# Double blocks: 3 β†’ 15
# - Layer 0 β†’ position 0 (frozen)
# - Layer 1 β†’ positions 4, 7, 10 (3 copies, spaced, frozen)
# - Layer 2 β†’ position 14 (frozen)
# - Rest β†’ new (trainable)
DOUBLE_MAPPING = {
0: [0], # Old layer 0 β†’ new position 0
1: [4, 7, 10], # Old layer 1 β†’ 3 positions
2: [14], # Old layer 2 β†’ new position 14
}
DOUBLE_FROZEN = {0, 4, 7, 10, 14} # These positions are frozen
# ============================================================================
# WEIGHT EXPANSION UTILITIES
# ============================================================================
def expand_qkv_weights(old_weight, old_hidden=768, new_hidden=1536, head_dim=128):
"""
Expand QKV projection weights when increasing hidden size / head count.
QKV weight shape: (3 * num_heads * head_dim, hidden_size) = (3 * hidden_size, hidden_size)
Strategy: Copy old weights to corresponding positions, random init new heads.
Old heads are spread evenly across new head positions.
"""
old_qkv_dim = old_weight.shape[0] # 3 * old_hidden
new_qkv_dim = 3 * new_hidden
old_heads = old_hidden // head_dim
new_heads = new_hidden // head_dim
# Initialize new weights
new_weight = torch.zeros(new_qkv_dim, new_hidden, dtype=old_weight.dtype, device=old_weight.device)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02 # Scale down random init
# For each of Q, K, V: copy old heads to first N positions
for qkv_idx in range(3):
old_start = qkv_idx * old_hidden
new_start = qkv_idx * new_hidden
# Copy all old heads to first old_heads positions of new
for h in range(old_heads):
old_h_start = old_start + h * head_dim
old_h_end = old_h_start + head_dim
new_h_start = new_start + h * head_dim
new_h_end = new_h_start + head_dim
# Copy weights, input dim goes to first old_hidden columns
new_weight[new_h_start:new_h_end, :old_hidden] = old_weight[old_h_start:old_h_end, :]
return new_weight
def expand_qkv_bias(old_bias, old_hidden=768, new_hidden=1536, head_dim=128):
"""Expand QKV bias from old_hidden to new_hidden."""
new_qkv_dim = 3 * new_hidden
new_bias = torch.zeros(new_qkv_dim, dtype=old_bias.dtype, device=old_bias.device)
old_heads = old_hidden // head_dim
# Copy old biases to first old_heads positions for each of Q, K, V
for qkv_idx in range(3):
old_start = qkv_idx * old_hidden
new_start = qkv_idx * new_hidden
new_bias[new_start:new_start + old_hidden] = old_bias[old_start:old_start + old_hidden]
return new_bias
def expand_out_proj_weights(old_weight, old_hidden=768, new_hidden=1536, head_dim=128):
"""
Expand output projection weights.
Out proj weight shape: (hidden_size, num_heads * head_dim) = (hidden_size, hidden_size)
"""
# Initialize new weights
new_weight = torch.zeros(new_hidden, new_hidden, dtype=old_weight.dtype, device=old_weight.device)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
# Copy old weights to top-left corner
new_weight[:old_hidden, :old_hidden] = old_weight
return new_weight
def expand_out_proj_bias(old_bias, old_hidden=768, new_hidden=1536):
"""Expand output projection bias."""
new_bias = torch.zeros(new_hidden, dtype=old_bias.dtype, device=old_bias.device)
new_bias[:old_hidden] = old_bias
return new_bias
def expand_linear_hidden(old_weight, old_hidden=768, new_hidden=1536, expand_in=True, expand_out=True):
"""
Expand a linear layer weight from old_hidden to new_hidden.
"""
old_out, old_in = old_weight.shape
new_out = new_hidden if expand_out else old_out
new_in = new_hidden if expand_in else old_in
new_weight = torch.zeros(new_out, new_in, dtype=old_weight.dtype, device=old_weight.device)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
# Copy old weights to top-left corner
copy_out = old_hidden if expand_out else old_out
copy_in = old_hidden if expand_in else old_in
new_weight[:copy_out, :copy_in] = old_weight[:copy_out, :copy_in]
return new_weight
def expand_bias(old_bias, old_hidden=768, new_hidden=1536):
"""Expand bias from old_hidden to new_hidden."""
new_bias = torch.zeros(new_hidden, dtype=old_bias.dtype, device=old_bias.device)
new_bias[:old_hidden] = old_bias
return new_bias
def expand_norm(old_weight, old_hidden=768, new_hidden=1536):
"""Expand RMSNorm weight from old_hidden to new_hidden."""
new_weight = torch.ones(new_hidden, dtype=old_weight.dtype, device=old_weight.device)
new_weight[:old_hidden] = old_weight
return new_weight
def port_single_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=256, new_hidden=1024):
"""Port weights from old single block to new single block with dimension expansion."""
old_prefix = f"single_blocks.{old_idx}"
new_prefix = f"single_blocks.{new_idx}"
for old_key in list(old_state.keys()):
if not old_key.startswith(old_prefix):
continue
new_key = old_key.replace(old_prefix, new_prefix)
old_weight = old_state[old_key]
# Attention QKV
if "attn.qkv.weight" in old_key:
new_state[new_key] = expand_qkv_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
print(f" Expanded QKV weight: {old_key}")
elif "attn.qkv.bias" in old_key:
new_state[new_key] = expand_qkv_bias(old_weight)
print(f" Expanded QKV bias: {old_key}")
# Attention output projection
elif "attn.out_proj.weight" in old_key:
new_state[new_key] = expand_out_proj_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
print(f" Expanded out_proj weight: {old_key}")
elif "attn.out_proj.bias" in old_key:
new_state[new_key] = expand_out_proj_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
print(f" Expanded out_proj bias: {old_key}")
# MLP layers (hidden β†’ 4*hidden β†’ hidden)
elif "mlp.fc1.weight" in old_key:
# fc1: hidden β†’ 4*hidden
old_mlp_hidden = old_hidden * 4
new_mlp_hidden = new_hidden * 4
new_weight = torch.zeros(new_mlp_hidden, new_hidden, dtype=old_weight.dtype, device=old_weight.device)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
new_weight[:old_mlp_hidden, :old_hidden] = old_weight
new_state[new_key] = new_weight
print(f" Expanded MLP fc1 weight: {old_key}")
elif "mlp.fc1.bias" in old_key:
old_mlp_hidden = old_hidden * 4
new_mlp_hidden = new_hidden * 4
new_bias = torch.zeros(new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device)
new_bias[:old_mlp_hidden] = old_weight
new_state[new_key] = new_bias
print(f" Expanded MLP fc1 bias: {old_key}")
elif "mlp.fc2.weight" in old_key:
# fc2: 4*hidden β†’ hidden
old_mlp_hidden = old_hidden * 4
new_mlp_hidden = new_hidden * 4
new_weight = torch.zeros(new_hidden, new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
new_weight[:old_hidden, :old_mlp_hidden] = old_weight
new_state[new_key] = new_weight
print(f" Expanded MLP fc2 weight: {old_key}")
elif "mlp.fc2.bias" in old_key:
new_state[new_key] = expand_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
print(f" Expanded MLP fc2 bias: {old_key}")
# AdaLayerNorm modulation linear (norm.linear) - outputs 3*hidden for single blocks
elif "norm.linear.weight" in old_key:
# Shape: (3*old_hidden, old_hidden) β†’ (3*new_hidden, new_hidden)
old_out = old_hidden * 3
new_out = new_hidden * 3
new_weight = torch.zeros(new_out, new_hidden, dtype=old_weight.dtype, device=old_weight.device)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
new_weight[:old_out, :old_hidden] = old_weight
new_state[new_key] = new_weight
print(f" Expanded AdaLN linear weight: {old_key} ({old_out},{old_hidden})β†’({new_out},{new_hidden})")
elif "norm.linear.bias" in old_key:
old_out = old_hidden * 3
new_out = new_hidden * 3
new_bias = torch.zeros(new_out, dtype=old_weight.dtype, device=old_weight.device)
new_bias[:old_out] = old_weight
new_state[new_key] = new_bias
print(f" Expanded AdaLN linear bias: {old_key} ({old_out})β†’({new_out})")
# RMSNorm inside AdaLN (norm.norm.weight) or standalone norm
elif "norm.norm.weight" in old_key or "norm2.weight" in old_key:
new_state[new_key] = expand_norm(old_weight, old_hidden, new_hidden)
print(f" Expanded RMSNorm weight: {old_key}")
# Generic normalization layers - check actual sizes
elif "norm" in old_key and "weight" in old_key:
old_size = old_weight.shape[0]
new_key_shape = new_state.get(new_key, torch.empty(0)).shape
if len(new_key_shape) > 0:
new_size = new_key_shape[0]
if old_size == new_size:
new_state[new_key] = old_weight.clone()
print(f" Direct copy norm weight: {old_key} ({old_size})")
else:
new_weight = torch.ones(new_size, dtype=old_weight.dtype, device=old_weight.device)
copy_size = min(old_size, new_size)
new_weight[:copy_size] = old_weight[:copy_size]
new_state[new_key] = new_weight
print(f" Padded norm weight: {old_key} ({old_size}β†’{new_size})")
elif "norm" in old_key and "bias" in old_key:
old_size = old_weight.shape[0]
new_key_shape = new_state.get(new_key, torch.empty(0)).shape
if len(new_key_shape) > 0:
new_size = new_key_shape[0]
if old_size == new_size:
new_state[new_key] = old_weight.clone()
print(f" Direct copy norm bias: {old_key} ({old_size})")
else:
new_bias = torch.zeros(new_size, dtype=old_weight.dtype, device=old_weight.device)
copy_size = min(old_size, new_size)
new_bias[:copy_size] = old_weight[:copy_size]
new_state[new_key] = new_bias
print(f" Padded norm bias: {old_key} ({old_size}β†’{new_size})")
# Direct copy for anything else (shouldn't be much)
else:
if old_weight.shape == new_state.get(new_key, torch.empty(0)).shape:
new_state[new_key] = old_weight.clone()
print(f" Direct copy: {old_key}")
else:
print(f" SKIP (shape mismatch): {old_key}")
def port_double_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=256, new_hidden=1024):
"""Port weights from old double block to new double block with dimension expansion."""
old_prefix = f"double_blocks.{old_idx}"
new_prefix = f"double_blocks.{new_idx}"
for old_key in list(old_state.keys()):
if not old_key.startswith(old_prefix):
continue
new_key = old_key.replace(old_prefix, new_prefix)
old_weight = old_state[old_key]
# Joint attention QKV (img and txt)
if any(x in old_key for x in ["img_qkv.weight", "txt_qkv.weight"]):
new_state[new_key] = expand_qkv_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
print(f" Expanded QKV weight: {old_key}")
elif any(x in old_key for x in ["img_qkv.bias", "txt_qkv.bias"]):
new_state[new_key] = expand_qkv_bias(old_weight)
print(f" Expanded QKV bias: {old_key}")
# Joint attention output projections
elif any(x in old_key for x in ["img_out.weight", "txt_out.weight"]):
new_state[new_key] = expand_out_proj_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
print(f" Expanded out_proj weight: {old_key}")
elif any(x in old_key for x in ["img_out.bias", "txt_out.bias"]):
new_state[new_key] = expand_out_proj_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
print(f" Expanded out_proj bias: {old_key}")
# MLP layers
elif "mlp" in old_key and "fc1.weight" in old_key:
old_mlp_hidden = old_hidden * 4
new_mlp_hidden = new_hidden * 4
new_weight = torch.zeros(new_mlp_hidden, new_hidden, dtype=old_weight.dtype, device=old_weight.device)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
new_weight[:old_mlp_hidden, :old_hidden] = old_weight
new_state[new_key] = new_weight
print(f" Expanded MLP fc1 weight: {old_key}")
elif "mlp" in old_key and "fc1.bias" in old_key:
old_mlp_hidden = old_hidden * 4
new_mlp_hidden = new_hidden * 4
new_bias = torch.zeros(new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device)
new_bias[:old_mlp_hidden] = old_weight
new_state[new_key] = new_bias
print(f" Expanded MLP fc1 bias: {old_key}")
elif "mlp" in old_key and "fc2.weight" in old_key:
old_mlp_hidden = old_hidden * 4
new_mlp_hidden = new_hidden * 4
new_weight = torch.zeros(new_hidden, new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
new_weight[:old_hidden, :old_mlp_hidden] = old_weight
new_state[new_key] = new_weight
print(f" Expanded MLP fc2 weight: {old_key}")
elif "mlp" in old_key and "fc2.bias" in old_key:
new_state[new_key] = expand_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
print(f" Expanded MLP fc2 bias: {old_key}")
# AdaLayerNormZero modulation linear - outputs 6*hidden (img_norm1, txt_norm1)
elif ("img_norm1.linear" in old_key or "txt_norm1.linear" in old_key) and "weight" in old_key:
old_out = old_hidden * 6
new_out = new_hidden * 6
new_weight = torch.zeros(new_out, new_hidden, dtype=old_weight.dtype, device=old_weight.device)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
new_weight[:old_out, :old_hidden] = old_weight
new_state[new_key] = new_weight
print(f" Expanded AdaLN linear weight: {old_key}")
elif ("img_norm1.linear" in old_key or "txt_norm1.linear" in old_key) and "bias" in old_key:
old_out = old_hidden * 6
new_out = new_hidden * 6
new_bias = torch.zeros(new_out, dtype=old_weight.dtype, device=old_weight.device)
new_bias[:old_out] = old_weight
new_state[new_key] = new_bias
print(f" Expanded AdaLN linear bias: {old_key}")
# RMSNorm inside AdaLN (img_norm1.norm, txt_norm1.norm) or standalone (img_norm2, txt_norm2)
elif any(x in old_key for x in ["_norm1.norm.weight", "_norm2.weight"]):
new_state[new_key] = expand_norm(old_weight, old_hidden, new_hidden)
print(f" Expanded RMSNorm weight: {old_key}")
# Generic normalization layers - check actual sizes
elif "norm" in old_key and "weight" in old_key:
old_size = old_weight.shape[0]
new_key_shape = new_state.get(new_key, torch.empty(0)).shape
if len(new_key_shape) > 0:
new_size = new_key_shape[0]
if old_size == new_size:
new_state[new_key] = old_weight.clone()
print(f" Direct copy norm weight: {old_key} ({old_size})")
else:
new_weight = torch.ones(new_size, dtype=old_weight.dtype, device=old_weight.device)
copy_size = min(old_size, new_size)
new_weight[:copy_size] = old_weight[:copy_size]
new_state[new_key] = new_weight
print(f" Padded norm weight: {old_key} ({old_size}β†’{new_size})")
elif "norm" in old_key and "bias" in old_key:
old_size = old_weight.shape[0]
new_key_shape = new_state.get(new_key, torch.empty(0)).shape
if len(new_key_shape) > 0:
new_size = new_key_shape[0]
if old_size == new_size:
new_state[new_key] = old_weight.clone()
print(f" Direct copy norm bias: {old_key} ({old_size})")
else:
new_bias = torch.zeros(new_size, dtype=old_weight.dtype, device=old_weight.device)
copy_size = min(old_size, new_size)
new_bias[:copy_size] = old_weight[:copy_size]
new_state[new_key] = new_bias
print(f" Padded norm bias: {old_key} ({old_size}β†’{new_size})")
# Direct copy for matching shapes
else:
if old_weight.shape == new_state.get(new_key, torch.empty(0)).shape:
new_state[new_key] = old_weight.clone()
print(f" Direct copy: {old_key}")
else:
print(f" SKIP (shape mismatch): {old_key}")
def port_non_block_weights(old_state, new_state, old_hidden=256, new_hidden=1024):
"""Port weights that aren't in single/double blocks with dimension expansion."""
for old_key, old_weight in old_state.items():
# Skip block weights (handled separately)
if "single_blocks" in old_key or "double_blocks" in old_key:
continue
# Skip buffers that will be recomputed
if any(x in old_key for x in ["sin_basis", "freqs_"]):
print(f" Skip buffer: {old_key}")
continue
# img_in: in_channels β†’ hidden
if "img_in.weight" in old_key:
new_weight = torch.zeros(new_hidden, old_weight.shape[1], dtype=old_weight.dtype)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
new_weight[:old_hidden, :] = old_weight
new_state[old_key] = new_weight
print(f" Expanded: {old_key}")
elif "img_in.bias" in old_key:
new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden)
print(f" Expanded: {old_key}")
# txt_in: joint_attention_dim β†’ hidden
elif "txt_in.weight" in old_key:
new_weight = torch.zeros(new_hidden, old_weight.shape[1], dtype=old_weight.dtype)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
new_weight[:old_hidden, :] = old_weight
new_state[old_key] = new_weight
print(f" Expanded: {old_key}")
elif "txt_in.bias" in old_key:
new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden)
print(f" Expanded: {old_key}")
# time_in, guidance_in: MLPEmbedder (hidden β†’ hidden)
elif any(x in old_key for x in ["time_in", "guidance_in"]):
if "fc1.weight" in old_key:
new_weight = torch.zeros(new_hidden, new_hidden, dtype=old_weight.dtype)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
new_weight[:old_hidden, :old_hidden] = old_weight
new_state[old_key] = new_weight
print(f" Expanded: {old_key}")
elif "fc1.bias" in old_key:
new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden)
print(f" Expanded: {old_key}")
elif "fc2.weight" in old_key:
new_weight = torch.zeros(new_hidden, new_hidden, dtype=old_weight.dtype)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
new_weight[:old_hidden, :old_hidden] = old_weight
new_state[old_key] = new_weight
print(f" Expanded: {old_key}")
elif "fc2.bias" in old_key:
new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden)
print(f" Expanded: {old_key}")
# vector_in: pooled_projection_dim β†’ hidden
elif "vector_in" in old_key:
if "weight" in old_key:
new_weight = torch.zeros(new_hidden, old_weight.shape[1], dtype=old_weight.dtype)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
new_weight[:old_hidden, :] = old_weight
new_state[old_key] = new_weight
print(f" Expanded: {old_key}")
elif "bias" in old_key:
new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden)
print(f" Expanded: {old_key}")
# final_norm: RMSNorm(hidden)
elif "final_norm" in old_key:
if "weight" in old_key:
new_state[old_key] = expand_norm(old_weight, old_hidden, new_hidden)
print(f" Expanded: {old_key}")
# final_linear: hidden β†’ in_channels
elif "final_linear.weight" in old_key:
new_weight = torch.zeros(old_weight.shape[0], new_hidden, dtype=old_weight.dtype)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.02
new_weight[:, :old_hidden] = old_weight
new_state[old_key] = new_weight
print(f" Expanded: {old_key}")
elif "final_linear.bias" in old_key:
new_state[old_key] = old_weight.clone() # output dim unchanged
print(f" Direct copy: {old_key}")
# RoPE - skip, will be recomputed
elif "rope" in old_key:
print(f" Skip RoPE: {old_key}")
else:
print(f" Unknown non-block key: {old_key}")
# ============================================================================
# MAIN PORTING FUNCTION
# ============================================================================
def port_tinyflux_to_deep(old_weights_path, new_model):
"""
Port TinyFlux weights to TinyFlux-Deep.
Returns:
new_state_dict: Ported weights
frozen_params: Set of parameter names to freeze
"""
print("Loading old weights...")
if old_weights_path.endswith(".safetensors"):
old_state = load_file(old_weights_path)
else:
old_state = torch.load(old_weights_path, map_location="cpu")
if "model" in old_state:
old_state = old_state["model"]
# Strip _orig_mod prefix if present
if any(k.startswith("_orig_mod.") for k in old_state.keys()):
print("Stripping _orig_mod prefix...")
old_state = {k.replace("_orig_mod.", ""): v for k, v in old_state.items()}
# Get new model's state dict as template FIRST
new_state = new_model.state_dict()
frozen_params = set()
# Auto-detect old hidden size from weights
if "final_norm.weight" in old_state:
old_hidden = old_state["final_norm.weight"].shape[0]
elif "img_in.weight" in old_state:
old_hidden = old_state["img_in.weight"].shape[0]
else:
old_hidden = 256 # Default for TinyFlux
# Get new hidden size from new model's state dict
if "final_norm.weight" in new_state:
new_hidden = new_state["final_norm.weight"].shape[0]
else:
new_hidden = 512 # Default for TinyFlux-Deep
print(f"Detected old hidden size: {old_hidden}")
print(f"New hidden size: {new_hidden}")
print("\n" + "="*60)
print("Porting non-block weights...")
print("="*60)
port_non_block_weights(old_state, new_state, old_hidden=old_hidden, new_hidden=new_hidden)
print("\n" + "="*60)
print("Porting single blocks (3 β†’ 25)...")
print("="*60)
for old_idx, new_positions in SINGLE_MAPPING.items():
for new_idx in new_positions:
print(f"\nSingle block {old_idx} β†’ {new_idx}:")
port_single_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=old_hidden, new_hidden=new_hidden)
# Mark as frozen
for key in new_state.keys():
if f"single_blocks.{new_idx}." in key:
frozen_params.add(key)
print("\n" + "="*60)
print("Porting double blocks (3 β†’ 15)...")
print("="*60)
for old_idx, new_positions in DOUBLE_MAPPING.items():
for new_idx in new_positions:
print(f"\nDouble block {old_idx} β†’ {new_idx}:")
port_double_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=old_hidden, new_hidden=new_hidden)
# Mark as frozen
for key in new_state.keys():
if f"double_blocks.{new_idx}." in key:
frozen_params.add(key)
print("\n" + "="*60)
print("Summary")
print("="*60)
print(f"Total parameters in new model: {len(new_state)}")
print(f"Frozen parameters: {len(frozen_params)}")
print(f"Trainable parameters: {len(new_state) - len(frozen_params)}")
print(f"\nFrozen single block positions: {sorted(SINGLE_FROZEN)}")
print(f"Frozen double block positions: {sorted(DOUBLE_FROZEN)}")
return new_state, frozen_params
# ============================================================================
# FREEZE HELPER
# ============================================================================
def freeze_ported_layers(model, frozen_params):
"""Freeze the ported layers, keep new layers trainable."""
frozen_count = 0
trainable_count = 0
for name, param in model.named_parameters():
if name in frozen_params:
param.requires_grad = False
frozen_count += param.numel()
else:
param.requires_grad = True
trainable_count += param.numel()
print(f"\nFrozen params: {frozen_count:,}")
print(f"Trainable params: {trainable_count:,}")
print(f"Total params: {frozen_count + trainable_count:,}")
print(f"Trainable ratio: {trainable_count / (frozen_count + trainable_count) * 100:.1f}%")
return model
# ============================================================================
# MAIN SCRIPT
# ============================================================================
if __name__ == "__main__":
print("="*60)
print("TinyFlux β†’ TinyFlux-Deep Porting")
print("="*60)
# Load old weights from hub FIRST to detect dimensions
print("\nDownloading TinyFlux weights from hub...")
old_weights_path = hf_hub_download(
repo_id="AbstractPhil/tiny-flux",
filename="model.safetensors"
)
# Load and detect old dimensions
print("Detecting old model dimensions...")
old_state = load_file(old_weights_path)
if any(k.startswith("_orig_mod.") for k in old_state.keys()):
old_state = {k.replace("_orig_mod.", ""): v for k, v in old_state.items()}
# Detect old hidden size
old_hidden = old_state["final_norm.weight"].shape[0]
head_dim = 128 # Fixed for RoPE
old_heads = old_hidden // head_dim
print(f" Old hidden size: {old_hidden}")
print(f" Old attention heads: {old_heads}")
print(f" Head dim: {head_dim}")
# Calculate new dimensions (double the heads)
new_heads = old_heads * 2 # 6 β†’ 12
new_hidden = new_heads * head_dim # 12 * 128 = 1536
print(f"\nNew dimensions:")
print(f" New hidden size: {new_hidden}")
print(f" New attention heads: {new_heads}")
# Create deep config with detected dimensions
deep_config = TinyFluxDeepConfig()
deep_config.hidden_size = new_hidden
deep_config.num_attention_heads = new_heads
print("\nCreating TinyFlux-Deep model...")
# You need to define TinyFlux class first (run model cell)
deep_model = TinyFlux(deep_config).to(DTYPE)
print(f"\nDeep model config:")
print(f" Hidden size: {deep_config.hidden_size}")
print(f" Attention heads: {deep_config.num_attention_heads}")
print(f" Single layers: {deep_config.num_single_layers}")
print(f" Double layers: {deep_config.num_double_layers}")
# Port weights
new_state, frozen_params = port_tinyflux_to_deep(old_weights_path, deep_model)
# Load ported weights
print("\nLoading ported weights into model...")
missing, unexpected = deep_model.load_state_dict(new_state, strict=False)
if missing:
print(f" Missing keys: {missing[:5]}..." if len(missing) > 5 else f" Missing keys: {missing}")
if unexpected:
print(f" Unexpected keys: {unexpected}")
# Freeze ported layers
print("\nFreezing ported layers...")
deep_model = freeze_ported_layers(deep_model, frozen_params)
# Save
print("\nSaving ported model...")
save_path = "tinyflux_deep_ported.safetensors"
# Strip any _orig_mod prefix before saving
state_to_save = deep_model.state_dict()
if any(k.startswith("_orig_mod.") for k in state_to_save.keys()):
state_to_save = {k.replace("_orig_mod.", ""): v for k, v in state_to_save.items()}
save_file(state_to_save, save_path)
print(f"βœ“ Saved to {save_path}")
# Save frozen params list
import json
with open("frozen_params.json", "w") as f:
json.dump(list(frozen_params), f)
print("βœ“ Saved frozen_params.json")
# Save config
config_dict = {
"hidden_size": deep_config.hidden_size,
"num_attention_heads": deep_config.num_attention_heads,
"attention_head_dim": deep_config.attention_head_dim,
"num_single_layers": deep_config.num_single_layers,
"num_double_layers": deep_config.num_double_layers,
"mlp_ratio": deep_config.mlp_ratio,
"joint_attention_dim": deep_config.joint_attention_dim,
"pooled_projection_dim": deep_config.pooled_projection_dim,
"in_channels": deep_config.in_channels,
"axes_dims_rope": list(deep_config.axes_dims_rope),
"guidance_embeds": deep_config.guidance_embeds,
}
with open("config_deep.json", "w") as f:
json.dump(config_dict, f, indent=2)
print("βœ“ Saved config_deep.json")
# Upload to hub
print("\nUploading to AbstractPhil/tiny-flux-deep...")
api = HfApi()
try:
api.create_repo(repo_id="AbstractPhil/tiny-flux-deep", exist_ok=True, repo_type="model")
api.upload_file(path_or_fileobj=save_path, path_in_repo="model.safetensors", repo_id="AbstractPhil/tiny-flux-deep")
api.upload_file(path_or_fileobj="config_deep.json", path_in_repo="config.json", repo_id="AbstractPhil/tiny-flux-deep")
api.upload_file(path_or_fileobj="frozen_params.json", path_in_repo="frozen_params.json", repo_id="AbstractPhil/tiny-flux-deep")
print("βœ“ Uploaded to hub!")
except Exception as e:
print(f"⚠ Upload failed: {e}")
print("\n" + "="*60)
print("Porting complete!")
print("="*60)
print("\nNext steps:")
print("1. Update TinyFlux model definition to accept TinyFluxDeepConfig")
print("2. Use the frozen_params.json to freeze layers during training")
print("3. Train on AbstractPhil/tiny-flux-deep repo")