# ============================================================================ # TinyFlux → TinyFlux-Deep Porting Script # ============================================================================ # Expands: 3 single + 3 double → 25 single + 15 double # Heads: 2 → 4 (doubles heads, hidden 256 → 512) # Freezes ported layers, trains new ones # ============================================================================ import torch import torch.nn as nn from safetensors.torch import load_file, save_file from huggingface_hub import hf_hub_download, HfApi from dataclasses import dataclass from copy import deepcopy from typing import Tuple DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 # ============================================================================ # CONFIGS # ============================================================================ @dataclass class TinyFluxConfig: """Original small config - matches TinyFlux model on hub (hidden=768, 6 heads)""" # Core dimensions (detected from hub: 768 hidden, 6 heads) hidden_size: int = 768 num_attention_heads: int = 6 attention_head_dim: int = 128 # 6 * 128 = 768 # Input/output in_channels: int = 16 patch_size: int = 1 # Text encoder interfaces joint_attention_dim: int = 768 pooled_projection_dim: int = 768 # Layers num_double_layers: int = 3 num_single_layers: int = 3 # MLP mlp_ratio: float = 4.0 # RoPE axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) # Misc guidance_embeds: bool = True @dataclass class TinyFluxDeepConfig: """ Expanded deep config - matches TinyFlux model attribute names exactly. Original TinyFlux: hidden_size=256, 2 heads (256/128=2) Deep variant: hidden_size=512, 4 heads (4*128=512) - double heads """ # Core dimensions hidden_size: int = 512 # 4 heads * 128 head_dim num_attention_heads: int = 4 # 2 → 4 (double the heads) attention_head_dim: int = 128 # Same (required for RoPE) # Input/output in_channels: int = 16 patch_size: int = 1 # Text encoder interfaces joint_attention_dim: int = 768 # T5 embed dim pooled_projection_dim: int = 768 # CLIP embed dim # Layers (uses _layers not _blocks) num_double_layers: int = 15 # 3 → 15 num_single_layers: int = 25 # 3 → 25 (more singles like original Flux) # MLP mlp_ratio: float = 4.0 # RoPE (must sum to head_dim=128) axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) # Misc guidance_embeds: bool = True def __post_init__(self): assert self.num_attention_heads * self.attention_head_dim == self.hidden_size, \ f"heads ({self.num_attention_heads}) * head_dim ({self.attention_head_dim}) != hidden ({self.hidden_size})" # ============================================================================ # LAYER MAPPING # ============================================================================ # Single blocks: 3 → 25 # - Layer 0 → position 0 (frozen) # - Layer 1 → positions 8, 12, 16 (center, spaced, frozen) # - Layer 2 → position 24 (frozen) # - Rest → new (trainable) SINGLE_MAPPING = { 0: [0], # Old layer 0 → new position 0 1: [8, 12, 16], # Old layer 1 → new positions 8, 12, 16 2: [24], # Old layer 2 → new position 24 } SINGLE_FROZEN = {0, 8, 12, 16, 24} # These positions are frozen # Double blocks: 3 → 15 # - Layer 0 → position 0 (frozen) # - Layer 1 → positions 4, 7, 10 (3 copies, spaced, frozen) # - Layer 2 → position 14 (frozen) # - Rest → new (trainable) DOUBLE_MAPPING = { 0: [0], # Old layer 0 → new position 0 1: [4, 7, 10], # Old layer 1 → 3 positions 2: [14], # Old layer 2 → new position 14 } DOUBLE_FROZEN = {0, 4, 7, 10, 14} # These positions are frozen # ============================================================================ # WEIGHT EXPANSION UTILITIES # ============================================================================ def expand_qkv_weights(old_weight, old_hidden=768, new_hidden=1536, head_dim=128): """ Expand QKV projection weights when increasing hidden size / head count. QKV weight shape: (3 * num_heads * head_dim, hidden_size) = (3 * hidden_size, hidden_size) Strategy: Copy old weights to corresponding positions, random init new heads. Old heads are spread evenly across new head positions. """ old_qkv_dim = old_weight.shape[0] # 3 * old_hidden new_qkv_dim = 3 * new_hidden old_heads = old_hidden // head_dim new_heads = new_hidden // head_dim # Initialize new weights new_weight = torch.zeros(new_qkv_dim, new_hidden, dtype=old_weight.dtype, device=old_weight.device) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 # Scale down random init # For each of Q, K, V: copy old heads to first N positions for qkv_idx in range(3): old_start = qkv_idx * old_hidden new_start = qkv_idx * new_hidden # Copy all old heads to first old_heads positions of new for h in range(old_heads): old_h_start = old_start + h * head_dim old_h_end = old_h_start + head_dim new_h_start = new_start + h * head_dim new_h_end = new_h_start + head_dim # Copy weights, input dim goes to first old_hidden columns new_weight[new_h_start:new_h_end, :old_hidden] = old_weight[old_h_start:old_h_end, :] return new_weight def expand_qkv_bias(old_bias, old_hidden=768, new_hidden=1536, head_dim=128): """Expand QKV bias from old_hidden to new_hidden.""" new_qkv_dim = 3 * new_hidden new_bias = torch.zeros(new_qkv_dim, dtype=old_bias.dtype, device=old_bias.device) old_heads = old_hidden // head_dim # Copy old biases to first old_heads positions for each of Q, K, V for qkv_idx in range(3): old_start = qkv_idx * old_hidden new_start = qkv_idx * new_hidden new_bias[new_start:new_start + old_hidden] = old_bias[old_start:old_start + old_hidden] return new_bias def expand_out_proj_weights(old_weight, old_hidden=768, new_hidden=1536, head_dim=128): """ Expand output projection weights. Out proj weight shape: (hidden_size, num_heads * head_dim) = (hidden_size, hidden_size) """ # Initialize new weights new_weight = torch.zeros(new_hidden, new_hidden, dtype=old_weight.dtype, device=old_weight.device) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 # Copy old weights to top-left corner new_weight[:old_hidden, :old_hidden] = old_weight return new_weight def expand_out_proj_bias(old_bias, old_hidden=768, new_hidden=1536): """Expand output projection bias.""" new_bias = torch.zeros(new_hidden, dtype=old_bias.dtype, device=old_bias.device) new_bias[:old_hidden] = old_bias return new_bias def expand_linear_hidden(old_weight, old_hidden=768, new_hidden=1536, expand_in=True, expand_out=True): """ Expand a linear layer weight from old_hidden to new_hidden. """ old_out, old_in = old_weight.shape new_out = new_hidden if expand_out else old_out new_in = new_hidden if expand_in else old_in new_weight = torch.zeros(new_out, new_in, dtype=old_weight.dtype, device=old_weight.device) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 # Copy old weights to top-left corner copy_out = old_hidden if expand_out else old_out copy_in = old_hidden if expand_in else old_in new_weight[:copy_out, :copy_in] = old_weight[:copy_out, :copy_in] return new_weight def expand_bias(old_bias, old_hidden=768, new_hidden=1536): """Expand bias from old_hidden to new_hidden.""" new_bias = torch.zeros(new_hidden, dtype=old_bias.dtype, device=old_bias.device) new_bias[:old_hidden] = old_bias return new_bias def expand_norm(old_weight, old_hidden=768, new_hidden=1536): """Expand RMSNorm weight from old_hidden to new_hidden.""" new_weight = torch.ones(new_hidden, dtype=old_weight.dtype, device=old_weight.device) new_weight[:old_hidden] = old_weight return new_weight def port_single_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=256, new_hidden=1024): """Port weights from old single block to new single block with dimension expansion.""" old_prefix = f"single_blocks.{old_idx}" new_prefix = f"single_blocks.{new_idx}" for old_key in list(old_state.keys()): if not old_key.startswith(old_prefix): continue new_key = old_key.replace(old_prefix, new_prefix) old_weight = old_state[old_key] # Attention QKV if "attn.qkv.weight" in old_key: new_state[new_key] = expand_qkv_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) print(f" Expanded QKV weight: {old_key}") elif "attn.qkv.bias" in old_key: new_state[new_key] = expand_qkv_bias(old_weight) print(f" Expanded QKV bias: {old_key}") # Attention output projection elif "attn.out_proj.weight" in old_key: new_state[new_key] = expand_out_proj_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) print(f" Expanded out_proj weight: {old_key}") elif "attn.out_proj.bias" in old_key: new_state[new_key] = expand_out_proj_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) print(f" Expanded out_proj bias: {old_key}") # MLP layers (hidden → 4*hidden → hidden) elif "mlp.fc1.weight" in old_key: # fc1: hidden → 4*hidden old_mlp_hidden = old_hidden * 4 new_mlp_hidden = new_hidden * 4 new_weight = torch.zeros(new_mlp_hidden, new_hidden, dtype=old_weight.dtype, device=old_weight.device) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 new_weight[:old_mlp_hidden, :old_hidden] = old_weight new_state[new_key] = new_weight print(f" Expanded MLP fc1 weight: {old_key}") elif "mlp.fc1.bias" in old_key: old_mlp_hidden = old_hidden * 4 new_mlp_hidden = new_hidden * 4 new_bias = torch.zeros(new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device) new_bias[:old_mlp_hidden] = old_weight new_state[new_key] = new_bias print(f" Expanded MLP fc1 bias: {old_key}") elif "mlp.fc2.weight" in old_key: # fc2: 4*hidden → hidden old_mlp_hidden = old_hidden * 4 new_mlp_hidden = new_hidden * 4 new_weight = torch.zeros(new_hidden, new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 new_weight[:old_hidden, :old_mlp_hidden] = old_weight new_state[new_key] = new_weight print(f" Expanded MLP fc2 weight: {old_key}") elif "mlp.fc2.bias" in old_key: new_state[new_key] = expand_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) print(f" Expanded MLP fc2 bias: {old_key}") # AdaLayerNorm modulation linear (norm.linear) - outputs 3*hidden for single blocks elif "norm.linear.weight" in old_key: # Shape: (3*old_hidden, old_hidden) → (3*new_hidden, new_hidden) old_out = old_hidden * 3 new_out = new_hidden * 3 new_weight = torch.zeros(new_out, new_hidden, dtype=old_weight.dtype, device=old_weight.device) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 new_weight[:old_out, :old_hidden] = old_weight new_state[new_key] = new_weight print(f" Expanded AdaLN linear weight: {old_key} ({old_out},{old_hidden})→({new_out},{new_hidden})") elif "norm.linear.bias" in old_key: old_out = old_hidden * 3 new_out = new_hidden * 3 new_bias = torch.zeros(new_out, dtype=old_weight.dtype, device=old_weight.device) new_bias[:old_out] = old_weight new_state[new_key] = new_bias print(f" Expanded AdaLN linear bias: {old_key} ({old_out})→({new_out})") # RMSNorm inside AdaLN (norm.norm.weight) or standalone norm elif "norm.norm.weight" in old_key or "norm2.weight" in old_key: new_state[new_key] = expand_norm(old_weight, old_hidden, new_hidden) print(f" Expanded RMSNorm weight: {old_key}") # Generic normalization layers - check actual sizes elif "norm" in old_key and "weight" in old_key: old_size = old_weight.shape[0] new_key_shape = new_state.get(new_key, torch.empty(0)).shape if len(new_key_shape) > 0: new_size = new_key_shape[0] if old_size == new_size: new_state[new_key] = old_weight.clone() print(f" Direct copy norm weight: {old_key} ({old_size})") else: new_weight = torch.ones(new_size, dtype=old_weight.dtype, device=old_weight.device) copy_size = min(old_size, new_size) new_weight[:copy_size] = old_weight[:copy_size] new_state[new_key] = new_weight print(f" Padded norm weight: {old_key} ({old_size}→{new_size})") elif "norm" in old_key and "bias" in old_key: old_size = old_weight.shape[0] new_key_shape = new_state.get(new_key, torch.empty(0)).shape if len(new_key_shape) > 0: new_size = new_key_shape[0] if old_size == new_size: new_state[new_key] = old_weight.clone() print(f" Direct copy norm bias: {old_key} ({old_size})") else: new_bias = torch.zeros(new_size, dtype=old_weight.dtype, device=old_weight.device) copy_size = min(old_size, new_size) new_bias[:copy_size] = old_weight[:copy_size] new_state[new_key] = new_bias print(f" Padded norm bias: {old_key} ({old_size}→{new_size})") # Direct copy for anything else (shouldn't be much) else: if old_weight.shape == new_state.get(new_key, torch.empty(0)).shape: new_state[new_key] = old_weight.clone() print(f" Direct copy: {old_key}") else: print(f" SKIP (shape mismatch): {old_key}") def port_double_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=256, new_hidden=1024): """Port weights from old double block to new double block with dimension expansion.""" old_prefix = f"double_blocks.{old_idx}" new_prefix = f"double_blocks.{new_idx}" for old_key in list(old_state.keys()): if not old_key.startswith(old_prefix): continue new_key = old_key.replace(old_prefix, new_prefix) old_weight = old_state[old_key] # Joint attention QKV (img and txt) if any(x in old_key for x in ["img_qkv.weight", "txt_qkv.weight"]): new_state[new_key] = expand_qkv_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) print(f" Expanded QKV weight: {old_key}") elif any(x in old_key for x in ["img_qkv.bias", "txt_qkv.bias"]): new_state[new_key] = expand_qkv_bias(old_weight) print(f" Expanded QKV bias: {old_key}") # Joint attention output projections elif any(x in old_key for x in ["img_out.weight", "txt_out.weight"]): new_state[new_key] = expand_out_proj_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) print(f" Expanded out_proj weight: {old_key}") elif any(x in old_key for x in ["img_out.bias", "txt_out.bias"]): new_state[new_key] = expand_out_proj_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) print(f" Expanded out_proj bias: {old_key}") # MLP layers elif "mlp" in old_key and "fc1.weight" in old_key: old_mlp_hidden = old_hidden * 4 new_mlp_hidden = new_hidden * 4 new_weight = torch.zeros(new_mlp_hidden, new_hidden, dtype=old_weight.dtype, device=old_weight.device) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 new_weight[:old_mlp_hidden, :old_hidden] = old_weight new_state[new_key] = new_weight print(f" Expanded MLP fc1 weight: {old_key}") elif "mlp" in old_key and "fc1.bias" in old_key: old_mlp_hidden = old_hidden * 4 new_mlp_hidden = new_hidden * 4 new_bias = torch.zeros(new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device) new_bias[:old_mlp_hidden] = old_weight new_state[new_key] = new_bias print(f" Expanded MLP fc1 bias: {old_key}") elif "mlp" in old_key and "fc2.weight" in old_key: old_mlp_hidden = old_hidden * 4 new_mlp_hidden = new_hidden * 4 new_weight = torch.zeros(new_hidden, new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 new_weight[:old_hidden, :old_mlp_hidden] = old_weight new_state[new_key] = new_weight print(f" Expanded MLP fc2 weight: {old_key}") elif "mlp" in old_key and "fc2.bias" in old_key: new_state[new_key] = expand_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) print(f" Expanded MLP fc2 bias: {old_key}") # AdaLayerNormZero modulation linear - outputs 6*hidden (img_norm1, txt_norm1) elif ("img_norm1.linear" in old_key or "txt_norm1.linear" in old_key) and "weight" in old_key: old_out = old_hidden * 6 new_out = new_hidden * 6 new_weight = torch.zeros(new_out, new_hidden, dtype=old_weight.dtype, device=old_weight.device) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 new_weight[:old_out, :old_hidden] = old_weight new_state[new_key] = new_weight print(f" Expanded AdaLN linear weight: {old_key}") elif ("img_norm1.linear" in old_key or "txt_norm1.linear" in old_key) and "bias" in old_key: old_out = old_hidden * 6 new_out = new_hidden * 6 new_bias = torch.zeros(new_out, dtype=old_weight.dtype, device=old_weight.device) new_bias[:old_out] = old_weight new_state[new_key] = new_bias print(f" Expanded AdaLN linear bias: {old_key}") # RMSNorm inside AdaLN (img_norm1.norm, txt_norm1.norm) or standalone (img_norm2, txt_norm2) elif any(x in old_key for x in ["_norm1.norm.weight", "_norm2.weight"]): new_state[new_key] = expand_norm(old_weight, old_hidden, new_hidden) print(f" Expanded RMSNorm weight: {old_key}") # Generic normalization layers - check actual sizes elif "norm" in old_key and "weight" in old_key: old_size = old_weight.shape[0] new_key_shape = new_state.get(new_key, torch.empty(0)).shape if len(new_key_shape) > 0: new_size = new_key_shape[0] if old_size == new_size: new_state[new_key] = old_weight.clone() print(f" Direct copy norm weight: {old_key} ({old_size})") else: new_weight = torch.ones(new_size, dtype=old_weight.dtype, device=old_weight.device) copy_size = min(old_size, new_size) new_weight[:copy_size] = old_weight[:copy_size] new_state[new_key] = new_weight print(f" Padded norm weight: {old_key} ({old_size}→{new_size})") elif "norm" in old_key and "bias" in old_key: old_size = old_weight.shape[0] new_key_shape = new_state.get(new_key, torch.empty(0)).shape if len(new_key_shape) > 0: new_size = new_key_shape[0] if old_size == new_size: new_state[new_key] = old_weight.clone() print(f" Direct copy norm bias: {old_key} ({old_size})") else: new_bias = torch.zeros(new_size, dtype=old_weight.dtype, device=old_weight.device) copy_size = min(old_size, new_size) new_bias[:copy_size] = old_weight[:copy_size] new_state[new_key] = new_bias print(f" Padded norm bias: {old_key} ({old_size}→{new_size})") # Direct copy for matching shapes else: if old_weight.shape == new_state.get(new_key, torch.empty(0)).shape: new_state[new_key] = old_weight.clone() print(f" Direct copy: {old_key}") else: print(f" SKIP (shape mismatch): {old_key}") def port_non_block_weights(old_state, new_state, old_hidden=256, new_hidden=1024): """Port weights that aren't in single/double blocks with dimension expansion.""" for old_key, old_weight in old_state.items(): # Skip block weights (handled separately) if "single_blocks" in old_key or "double_blocks" in old_key: continue # Skip buffers that will be recomputed if any(x in old_key for x in ["sin_basis", "freqs_"]): print(f" Skip buffer: {old_key}") continue # img_in: in_channels → hidden if "img_in.weight" in old_key: new_weight = torch.zeros(new_hidden, old_weight.shape[1], dtype=old_weight.dtype) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 new_weight[:old_hidden, :] = old_weight new_state[old_key] = new_weight print(f" Expanded: {old_key}") elif "img_in.bias" in old_key: new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden) print(f" Expanded: {old_key}") # txt_in: joint_attention_dim → hidden elif "txt_in.weight" in old_key: new_weight = torch.zeros(new_hidden, old_weight.shape[1], dtype=old_weight.dtype) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 new_weight[:old_hidden, :] = old_weight new_state[old_key] = new_weight print(f" Expanded: {old_key}") elif "txt_in.bias" in old_key: new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden) print(f" Expanded: {old_key}") # time_in, guidance_in: MLPEmbedder (hidden → hidden) elif any(x in old_key for x in ["time_in", "guidance_in"]): if "fc1.weight" in old_key: new_weight = torch.zeros(new_hidden, new_hidden, dtype=old_weight.dtype) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 new_weight[:old_hidden, :old_hidden] = old_weight new_state[old_key] = new_weight print(f" Expanded: {old_key}") elif "fc1.bias" in old_key: new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden) print(f" Expanded: {old_key}") elif "fc2.weight" in old_key: new_weight = torch.zeros(new_hidden, new_hidden, dtype=old_weight.dtype) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 new_weight[:old_hidden, :old_hidden] = old_weight new_state[old_key] = new_weight print(f" Expanded: {old_key}") elif "fc2.bias" in old_key: new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden) print(f" Expanded: {old_key}") # vector_in: pooled_projection_dim → hidden elif "vector_in" in old_key: if "weight" in old_key: new_weight = torch.zeros(new_hidden, old_weight.shape[1], dtype=old_weight.dtype) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 new_weight[:old_hidden, :] = old_weight new_state[old_key] = new_weight print(f" Expanded: {old_key}") elif "bias" in old_key: new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden) print(f" Expanded: {old_key}") # final_norm: RMSNorm(hidden) elif "final_norm" in old_key: if "weight" in old_key: new_state[old_key] = expand_norm(old_weight, old_hidden, new_hidden) print(f" Expanded: {old_key}") # final_linear: hidden → in_channels elif "final_linear.weight" in old_key: new_weight = torch.zeros(old_weight.shape[0], new_hidden, dtype=old_weight.dtype) nn.init.xavier_uniform_(new_weight) new_weight *= 0.02 new_weight[:, :old_hidden] = old_weight new_state[old_key] = new_weight print(f" Expanded: {old_key}") elif "final_linear.bias" in old_key: new_state[old_key] = old_weight.clone() # output dim unchanged print(f" Direct copy: {old_key}") # RoPE - skip, will be recomputed elif "rope" in old_key: print(f" Skip RoPE: {old_key}") else: print(f" Unknown non-block key: {old_key}") # ============================================================================ # MAIN PORTING FUNCTION # ============================================================================ def port_tinyflux_to_deep(old_weights_path, new_model): """ Port TinyFlux weights to TinyFlux-Deep. Returns: new_state_dict: Ported weights frozen_params: Set of parameter names to freeze """ print("Loading old weights...") if old_weights_path.endswith(".safetensors"): old_state = load_file(old_weights_path) else: old_state = torch.load(old_weights_path, map_location="cpu") if "model" in old_state: old_state = old_state["model"] # Strip _orig_mod prefix if present if any(k.startswith("_orig_mod.") for k in old_state.keys()): print("Stripping _orig_mod prefix...") old_state = {k.replace("_orig_mod.", ""): v for k, v in old_state.items()} # Get new model's state dict as template FIRST new_state = new_model.state_dict() frozen_params = set() # Auto-detect old hidden size from weights if "final_norm.weight" in old_state: old_hidden = old_state["final_norm.weight"].shape[0] elif "img_in.weight" in old_state: old_hidden = old_state["img_in.weight"].shape[0] else: old_hidden = 256 # Default for TinyFlux # Get new hidden size from new model's state dict if "final_norm.weight" in new_state: new_hidden = new_state["final_norm.weight"].shape[0] else: new_hidden = 512 # Default for TinyFlux-Deep print(f"Detected old hidden size: {old_hidden}") print(f"New hidden size: {new_hidden}") print("\n" + "="*60) print("Porting non-block weights...") print("="*60) port_non_block_weights(old_state, new_state, old_hidden=old_hidden, new_hidden=new_hidden) print("\n" + "="*60) print("Porting single blocks (3 → 25)...") print("="*60) for old_idx, new_positions in SINGLE_MAPPING.items(): for new_idx in new_positions: print(f"\nSingle block {old_idx} → {new_idx}:") port_single_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=old_hidden, new_hidden=new_hidden) # Mark as frozen for key in new_state.keys(): if f"single_blocks.{new_idx}." in key: frozen_params.add(key) print("\n" + "="*60) print("Porting double blocks (3 → 15)...") print("="*60) for old_idx, new_positions in DOUBLE_MAPPING.items(): for new_idx in new_positions: print(f"\nDouble block {old_idx} → {new_idx}:") port_double_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=old_hidden, new_hidden=new_hidden) # Mark as frozen for key in new_state.keys(): if f"double_blocks.{new_idx}." in key: frozen_params.add(key) print("\n" + "="*60) print("Summary") print("="*60) print(f"Total parameters in new model: {len(new_state)}") print(f"Frozen parameters: {len(frozen_params)}") print(f"Trainable parameters: {len(new_state) - len(frozen_params)}") print(f"\nFrozen single block positions: {sorted(SINGLE_FROZEN)}") print(f"Frozen double block positions: {sorted(DOUBLE_FROZEN)}") return new_state, frozen_params # ============================================================================ # FREEZE HELPER # ============================================================================ def freeze_ported_layers(model, frozen_params): """Freeze the ported layers, keep new layers trainable.""" frozen_count = 0 trainable_count = 0 for name, param in model.named_parameters(): if name in frozen_params: param.requires_grad = False frozen_count += param.numel() else: param.requires_grad = True trainable_count += param.numel() print(f"\nFrozen params: {frozen_count:,}") print(f"Trainable params: {trainable_count:,}") print(f"Total params: {frozen_count + trainable_count:,}") print(f"Trainable ratio: {trainable_count / (frozen_count + trainable_count) * 100:.1f}%") return model # ============================================================================ # MAIN SCRIPT # ============================================================================ if __name__ == "__main__": print("="*60) print("TinyFlux → TinyFlux-Deep Porting") print("="*60) # Load old weights from hub FIRST to detect dimensions print("\nDownloading TinyFlux weights from hub...") old_weights_path = hf_hub_download( repo_id="AbstractPhil/tiny-flux", filename="model.safetensors" ) # Load and detect old dimensions print("Detecting old model dimensions...") old_state = load_file(old_weights_path) if any(k.startswith("_orig_mod.") for k in old_state.keys()): old_state = {k.replace("_orig_mod.", ""): v for k, v in old_state.items()} # Detect old hidden size old_hidden = old_state["final_norm.weight"].shape[0] head_dim = 128 # Fixed for RoPE old_heads = old_hidden // head_dim print(f" Old hidden size: {old_hidden}") print(f" Old attention heads: {old_heads}") print(f" Head dim: {head_dim}") # Calculate new dimensions (double the heads) new_heads = old_heads * 2 # 6 → 12 new_hidden = new_heads * head_dim # 12 * 128 = 1536 print(f"\nNew dimensions:") print(f" New hidden size: {new_hidden}") print(f" New attention heads: {new_heads}") # Create deep config with detected dimensions deep_config = TinyFluxDeepConfig() deep_config.hidden_size = new_hidden deep_config.num_attention_heads = new_heads print("\nCreating TinyFlux-Deep model...") # You need to define TinyFlux class first (run model cell) deep_model = TinyFlux(deep_config).to(DTYPE) print(f"\nDeep model config:") print(f" Hidden size: {deep_config.hidden_size}") print(f" Attention heads: {deep_config.num_attention_heads}") print(f" Single layers: {deep_config.num_single_layers}") print(f" Double layers: {deep_config.num_double_layers}") # Port weights new_state, frozen_params = port_tinyflux_to_deep(old_weights_path, deep_model) # Load ported weights print("\nLoading ported weights into model...") missing, unexpected = deep_model.load_state_dict(new_state, strict=False) if missing: print(f" Missing keys: {missing[:5]}..." if len(missing) > 5 else f" Missing keys: {missing}") if unexpected: print(f" Unexpected keys: {unexpected}") # Freeze ported layers print("\nFreezing ported layers...") deep_model = freeze_ported_layers(deep_model, frozen_params) # Save print("\nSaving ported model...") save_path = "tinyflux_deep_ported.safetensors" # Strip any _orig_mod prefix before saving state_to_save = deep_model.state_dict() if any(k.startswith("_orig_mod.") for k in state_to_save.keys()): state_to_save = {k.replace("_orig_mod.", ""): v for k, v in state_to_save.items()} save_file(state_to_save, save_path) print(f"✓ Saved to {save_path}") # Save frozen params list import json with open("frozen_params.json", "w") as f: json.dump(list(frozen_params), f) print("✓ Saved frozen_params.json") # Save config config_dict = { "hidden_size": deep_config.hidden_size, "num_attention_heads": deep_config.num_attention_heads, "attention_head_dim": deep_config.attention_head_dim, "num_single_layers": deep_config.num_single_layers, "num_double_layers": deep_config.num_double_layers, "mlp_ratio": deep_config.mlp_ratio, "joint_attention_dim": deep_config.joint_attention_dim, "pooled_projection_dim": deep_config.pooled_projection_dim, "in_channels": deep_config.in_channels, "axes_dims_rope": list(deep_config.axes_dims_rope), "guidance_embeds": deep_config.guidance_embeds, } with open("config_deep.json", "w") as f: json.dump(config_dict, f, indent=2) print("✓ Saved config_deep.json") # Upload to hub print("\nUploading to AbstractPhil/tiny-flux-deep...") api = HfApi() try: api.create_repo(repo_id="AbstractPhil/tiny-flux-deep", exist_ok=True, repo_type="model") api.upload_file(path_or_fileobj=save_path, path_in_repo="model.safetensors", repo_id="AbstractPhil/tiny-flux-deep") api.upload_file(path_or_fileobj="config_deep.json", path_in_repo="config.json", repo_id="AbstractPhil/tiny-flux-deep") api.upload_file(path_or_fileobj="frozen_params.json", path_in_repo="frozen_params.json", repo_id="AbstractPhil/tiny-flux-deep") print("✓ Uploaded to hub!") except Exception as e: print(f"⚠ Upload failed: {e}") print("\n" + "="*60) print("Porting complete!") print("="*60) print("\nNext steps:") print("1. Update TinyFlux model definition to accept TinyFluxDeepConfig") print("2. Use the frozen_params.json to freeze layers during training") print("3. Train on AbstractPhil/tiny-flux-deep repo")