| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from safetensors.torch import load_file, save_file |
| | from huggingface_hub import hf_hub_download, HfApi |
| | from dataclasses import dataclass |
| | from copy import deepcopy |
| | from typing import Tuple |
| |
|
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| | DTYPE = torch.bfloat16 |
| |
|
| | |
| | |
| | |
| | @dataclass |
| | class TinyFluxConfig: |
| | """Original small config - matches TinyFlux model on hub (hidden=768, 6 heads)""" |
| | |
| | hidden_size: int = 768 |
| | num_attention_heads: int = 6 |
| | attention_head_dim: int = 128 |
| | |
| | |
| | in_channels: int = 16 |
| | patch_size: int = 1 |
| | |
| | |
| | joint_attention_dim: int = 768 |
| | pooled_projection_dim: int = 768 |
| | |
| | |
| | num_double_layers: int = 3 |
| | num_single_layers: int = 3 |
| | |
| | |
| | mlp_ratio: float = 4.0 |
| | |
| | |
| | axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) |
| | |
| | |
| | guidance_embeds: bool = True |
| |
|
| |
|
| | @dataclass |
| | class TinyFluxDeepConfig: |
| | """ |
| | Expanded deep config - matches TinyFlux model attribute names exactly. |
| | |
| | Original TinyFlux: hidden_size=256, 2 heads (256/128=2) |
| | Deep variant: hidden_size=512, 4 heads (4*128=512) - double heads |
| | """ |
| | |
| | hidden_size: int = 512 |
| | num_attention_heads: int = 4 |
| | attention_head_dim: int = 128 |
| | |
| | |
| | in_channels: int = 16 |
| | patch_size: int = 1 |
| | |
| | |
| | joint_attention_dim: int = 768 |
| | pooled_projection_dim: int = 768 |
| | |
| | |
| | num_double_layers: int = 15 |
| | num_single_layers: int = 25 |
| | |
| | |
| | mlp_ratio: float = 4.0 |
| | |
| | |
| | axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) |
| | |
| | |
| | guidance_embeds: bool = True |
| | |
| | def __post_init__(self): |
| | assert self.num_attention_heads * self.attention_head_dim == self.hidden_size, \ |
| | f"heads ({self.num_attention_heads}) * head_dim ({self.attention_head_dim}) != hidden ({self.hidden_size})" |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | SINGLE_MAPPING = { |
| | 0: [0], |
| | 1: [8, 12, 16], |
| | 2: [24], |
| | } |
| | SINGLE_FROZEN = {0, 8, 12, 16, 24} |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | DOUBLE_MAPPING = { |
| | 0: [0], |
| | 1: [4, 7, 10], |
| | 2: [14], |
| | } |
| | DOUBLE_FROZEN = {0, 4, 7, 10, 14} |
| |
|
| |
|
| | |
| | |
| | |
| | def expand_qkv_weights(old_weight, old_hidden=768, new_hidden=1536, head_dim=128): |
| | """ |
| | Expand QKV projection weights when increasing hidden size / head count. |
| | QKV weight shape: (3 * num_heads * head_dim, hidden_size) = (3 * hidden_size, hidden_size) |
| | |
| | Strategy: Copy old weights to corresponding positions, random init new heads. |
| | Old heads are spread evenly across new head positions. |
| | """ |
| | old_qkv_dim = old_weight.shape[0] |
| | new_qkv_dim = 3 * new_hidden |
| | |
| | old_heads = old_hidden // head_dim |
| | new_heads = new_hidden // head_dim |
| | |
| | |
| | new_weight = torch.zeros(new_qkv_dim, new_hidden, dtype=old_weight.dtype, device=old_weight.device) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | |
| | |
| | for qkv_idx in range(3): |
| | old_start = qkv_idx * old_hidden |
| | new_start = qkv_idx * new_hidden |
| | |
| | |
| | for h in range(old_heads): |
| | old_h_start = old_start + h * head_dim |
| | old_h_end = old_h_start + head_dim |
| | new_h_start = new_start + h * head_dim |
| | new_h_end = new_h_start + head_dim |
| | |
| | new_weight[new_h_start:new_h_end, :old_hidden] = old_weight[old_h_start:old_h_end, :] |
| | |
| | return new_weight |
| |
|
| |
|
| | def expand_qkv_bias(old_bias, old_hidden=768, new_hidden=1536, head_dim=128): |
| | """Expand QKV bias from old_hidden to new_hidden.""" |
| | new_qkv_dim = 3 * new_hidden |
| | new_bias = torch.zeros(new_qkv_dim, dtype=old_bias.dtype, device=old_bias.device) |
| | |
| | old_heads = old_hidden // head_dim |
| | |
| | |
| | for qkv_idx in range(3): |
| | old_start = qkv_idx * old_hidden |
| | new_start = qkv_idx * new_hidden |
| | new_bias[new_start:new_start + old_hidden] = old_bias[old_start:old_start + old_hidden] |
| | |
| | return new_bias |
| |
|
| |
|
| | def expand_out_proj_weights(old_weight, old_hidden=768, new_hidden=1536, head_dim=128): |
| | """ |
| | Expand output projection weights. |
| | Out proj weight shape: (hidden_size, num_heads * head_dim) = (hidden_size, hidden_size) |
| | """ |
| | |
| | new_weight = torch.zeros(new_hidden, new_hidden, dtype=old_weight.dtype, device=old_weight.device) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | |
| | |
| | new_weight[:old_hidden, :old_hidden] = old_weight |
| | |
| | return new_weight |
| |
|
| |
|
| | def expand_out_proj_bias(old_bias, old_hidden=768, new_hidden=1536): |
| | """Expand output projection bias.""" |
| | new_bias = torch.zeros(new_hidden, dtype=old_bias.dtype, device=old_bias.device) |
| | new_bias[:old_hidden] = old_bias |
| | return new_bias |
| |
|
| |
|
| | def expand_linear_hidden(old_weight, old_hidden=768, new_hidden=1536, expand_in=True, expand_out=True): |
| | """ |
| | Expand a linear layer weight from old_hidden to new_hidden. |
| | """ |
| | old_out, old_in = old_weight.shape |
| | |
| | new_out = new_hidden if expand_out else old_out |
| | new_in = new_hidden if expand_in else old_in |
| | |
| | new_weight = torch.zeros(new_out, new_in, dtype=old_weight.dtype, device=old_weight.device) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | |
| | |
| | copy_out = old_hidden if expand_out else old_out |
| | copy_in = old_hidden if expand_in else old_in |
| | new_weight[:copy_out, :copy_in] = old_weight[:copy_out, :copy_in] |
| | |
| | return new_weight |
| |
|
| |
|
| | def expand_bias(old_bias, old_hidden=768, new_hidden=1536): |
| | """Expand bias from old_hidden to new_hidden.""" |
| | new_bias = torch.zeros(new_hidden, dtype=old_bias.dtype, device=old_bias.device) |
| | new_bias[:old_hidden] = old_bias |
| | return new_bias |
| |
|
| |
|
| | def expand_norm(old_weight, old_hidden=768, new_hidden=1536): |
| | """Expand RMSNorm weight from old_hidden to new_hidden.""" |
| | new_weight = torch.ones(new_hidden, dtype=old_weight.dtype, device=old_weight.device) |
| | new_weight[:old_hidden] = old_weight |
| | return new_weight |
| |
|
| |
|
| | def port_single_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=256, new_hidden=1024): |
| | """Port weights from old single block to new single block with dimension expansion.""" |
| | old_prefix = f"single_blocks.{old_idx}" |
| | new_prefix = f"single_blocks.{new_idx}" |
| | |
| | for old_key in list(old_state.keys()): |
| | if not old_key.startswith(old_prefix): |
| | continue |
| | |
| | new_key = old_key.replace(old_prefix, new_prefix) |
| | old_weight = old_state[old_key] |
| | |
| | |
| | if "attn.qkv.weight" in old_key: |
| | new_state[new_key] = expand_qkv_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) |
| | print(f" Expanded QKV weight: {old_key}") |
| | elif "attn.qkv.bias" in old_key: |
| | new_state[new_key] = expand_qkv_bias(old_weight) |
| | print(f" Expanded QKV bias: {old_key}") |
| | |
| | |
| | elif "attn.out_proj.weight" in old_key: |
| | new_state[new_key] = expand_out_proj_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) |
| | print(f" Expanded out_proj weight: {old_key}") |
| | elif "attn.out_proj.bias" in old_key: |
| | new_state[new_key] = expand_out_proj_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) |
| | print(f" Expanded out_proj bias: {old_key}") |
| | |
| | |
| | elif "mlp.fc1.weight" in old_key: |
| | |
| | old_mlp_hidden = old_hidden * 4 |
| | new_mlp_hidden = new_hidden * 4 |
| | new_weight = torch.zeros(new_mlp_hidden, new_hidden, dtype=old_weight.dtype, device=old_weight.device) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | new_weight[:old_mlp_hidden, :old_hidden] = old_weight |
| | new_state[new_key] = new_weight |
| | print(f" Expanded MLP fc1 weight: {old_key}") |
| | elif "mlp.fc1.bias" in old_key: |
| | old_mlp_hidden = old_hidden * 4 |
| | new_mlp_hidden = new_hidden * 4 |
| | new_bias = torch.zeros(new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device) |
| | new_bias[:old_mlp_hidden] = old_weight |
| | new_state[new_key] = new_bias |
| | print(f" Expanded MLP fc1 bias: {old_key}") |
| | elif "mlp.fc2.weight" in old_key: |
| | |
| | old_mlp_hidden = old_hidden * 4 |
| | new_mlp_hidden = new_hidden * 4 |
| | new_weight = torch.zeros(new_hidden, new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | new_weight[:old_hidden, :old_mlp_hidden] = old_weight |
| | new_state[new_key] = new_weight |
| | print(f" Expanded MLP fc2 weight: {old_key}") |
| | elif "mlp.fc2.bias" in old_key: |
| | new_state[new_key] = expand_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) |
| | print(f" Expanded MLP fc2 bias: {old_key}") |
| | |
| | |
| | elif "norm.linear.weight" in old_key: |
| | |
| | old_out = old_hidden * 3 |
| | new_out = new_hidden * 3 |
| | new_weight = torch.zeros(new_out, new_hidden, dtype=old_weight.dtype, device=old_weight.device) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | new_weight[:old_out, :old_hidden] = old_weight |
| | new_state[new_key] = new_weight |
| | print(f" Expanded AdaLN linear weight: {old_key} ({old_out},{old_hidden})β({new_out},{new_hidden})") |
| | elif "norm.linear.bias" in old_key: |
| | old_out = old_hidden * 3 |
| | new_out = new_hidden * 3 |
| | new_bias = torch.zeros(new_out, dtype=old_weight.dtype, device=old_weight.device) |
| | new_bias[:old_out] = old_weight |
| | new_state[new_key] = new_bias |
| | print(f" Expanded AdaLN linear bias: {old_key} ({old_out})β({new_out})") |
| | |
| | |
| | elif "norm.norm.weight" in old_key or "norm2.weight" in old_key: |
| | new_state[new_key] = expand_norm(old_weight, old_hidden, new_hidden) |
| | print(f" Expanded RMSNorm weight: {old_key}") |
| | |
| | |
| | elif "norm" in old_key and "weight" in old_key: |
| | old_size = old_weight.shape[0] |
| | new_key_shape = new_state.get(new_key, torch.empty(0)).shape |
| | if len(new_key_shape) > 0: |
| | new_size = new_key_shape[0] |
| | if old_size == new_size: |
| | new_state[new_key] = old_weight.clone() |
| | print(f" Direct copy norm weight: {old_key} ({old_size})") |
| | else: |
| | new_weight = torch.ones(new_size, dtype=old_weight.dtype, device=old_weight.device) |
| | copy_size = min(old_size, new_size) |
| | new_weight[:copy_size] = old_weight[:copy_size] |
| | new_state[new_key] = new_weight |
| | print(f" Padded norm weight: {old_key} ({old_size}β{new_size})") |
| | elif "norm" in old_key and "bias" in old_key: |
| | old_size = old_weight.shape[0] |
| | new_key_shape = new_state.get(new_key, torch.empty(0)).shape |
| | if len(new_key_shape) > 0: |
| | new_size = new_key_shape[0] |
| | if old_size == new_size: |
| | new_state[new_key] = old_weight.clone() |
| | print(f" Direct copy norm bias: {old_key} ({old_size})") |
| | else: |
| | new_bias = torch.zeros(new_size, dtype=old_weight.dtype, device=old_weight.device) |
| | copy_size = min(old_size, new_size) |
| | new_bias[:copy_size] = old_weight[:copy_size] |
| | new_state[new_key] = new_bias |
| | print(f" Padded norm bias: {old_key} ({old_size}β{new_size})") |
| | |
| | |
| | else: |
| | if old_weight.shape == new_state.get(new_key, torch.empty(0)).shape: |
| | new_state[new_key] = old_weight.clone() |
| | print(f" Direct copy: {old_key}") |
| | else: |
| | print(f" SKIP (shape mismatch): {old_key}") |
| |
|
| |
|
| | def port_double_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=256, new_hidden=1024): |
| | """Port weights from old double block to new double block with dimension expansion.""" |
| | old_prefix = f"double_blocks.{old_idx}" |
| | new_prefix = f"double_blocks.{new_idx}" |
| | |
| | for old_key in list(old_state.keys()): |
| | if not old_key.startswith(old_prefix): |
| | continue |
| | |
| | new_key = old_key.replace(old_prefix, new_prefix) |
| | old_weight = old_state[old_key] |
| | |
| | |
| | if any(x in old_key for x in ["img_qkv.weight", "txt_qkv.weight"]): |
| | new_state[new_key] = expand_qkv_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) |
| | print(f" Expanded QKV weight: {old_key}") |
| | elif any(x in old_key for x in ["img_qkv.bias", "txt_qkv.bias"]): |
| | new_state[new_key] = expand_qkv_bias(old_weight) |
| | print(f" Expanded QKV bias: {old_key}") |
| | |
| | |
| | elif any(x in old_key for x in ["img_out.weight", "txt_out.weight"]): |
| | new_state[new_key] = expand_out_proj_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) |
| | print(f" Expanded out_proj weight: {old_key}") |
| | elif any(x in old_key for x in ["img_out.bias", "txt_out.bias"]): |
| | new_state[new_key] = expand_out_proj_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) |
| | print(f" Expanded out_proj bias: {old_key}") |
| | |
| | |
| | elif "mlp" in old_key and "fc1.weight" in old_key: |
| | old_mlp_hidden = old_hidden * 4 |
| | new_mlp_hidden = new_hidden * 4 |
| | new_weight = torch.zeros(new_mlp_hidden, new_hidden, dtype=old_weight.dtype, device=old_weight.device) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | new_weight[:old_mlp_hidden, :old_hidden] = old_weight |
| | new_state[new_key] = new_weight |
| | print(f" Expanded MLP fc1 weight: {old_key}") |
| | elif "mlp" in old_key and "fc1.bias" in old_key: |
| | old_mlp_hidden = old_hidden * 4 |
| | new_mlp_hidden = new_hidden * 4 |
| | new_bias = torch.zeros(new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device) |
| | new_bias[:old_mlp_hidden] = old_weight |
| | new_state[new_key] = new_bias |
| | print(f" Expanded MLP fc1 bias: {old_key}") |
| | elif "mlp" in old_key and "fc2.weight" in old_key: |
| | old_mlp_hidden = old_hidden * 4 |
| | new_mlp_hidden = new_hidden * 4 |
| | new_weight = torch.zeros(new_hidden, new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | new_weight[:old_hidden, :old_mlp_hidden] = old_weight |
| | new_state[new_key] = new_weight |
| | print(f" Expanded MLP fc2 weight: {old_key}") |
| | elif "mlp" in old_key and "fc2.bias" in old_key: |
| | new_state[new_key] = expand_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden) |
| | print(f" Expanded MLP fc2 bias: {old_key}") |
| | |
| | |
| | elif ("img_norm1.linear" in old_key or "txt_norm1.linear" in old_key) and "weight" in old_key: |
| | old_out = old_hidden * 6 |
| | new_out = new_hidden * 6 |
| | new_weight = torch.zeros(new_out, new_hidden, dtype=old_weight.dtype, device=old_weight.device) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | new_weight[:old_out, :old_hidden] = old_weight |
| | new_state[new_key] = new_weight |
| | print(f" Expanded AdaLN linear weight: {old_key}") |
| | elif ("img_norm1.linear" in old_key or "txt_norm1.linear" in old_key) and "bias" in old_key: |
| | old_out = old_hidden * 6 |
| | new_out = new_hidden * 6 |
| | new_bias = torch.zeros(new_out, dtype=old_weight.dtype, device=old_weight.device) |
| | new_bias[:old_out] = old_weight |
| | new_state[new_key] = new_bias |
| | print(f" Expanded AdaLN linear bias: {old_key}") |
| | |
| | |
| | elif any(x in old_key for x in ["_norm1.norm.weight", "_norm2.weight"]): |
| | new_state[new_key] = expand_norm(old_weight, old_hidden, new_hidden) |
| | print(f" Expanded RMSNorm weight: {old_key}") |
| | |
| | |
| | elif "norm" in old_key and "weight" in old_key: |
| | old_size = old_weight.shape[0] |
| | new_key_shape = new_state.get(new_key, torch.empty(0)).shape |
| | if len(new_key_shape) > 0: |
| | new_size = new_key_shape[0] |
| | if old_size == new_size: |
| | new_state[new_key] = old_weight.clone() |
| | print(f" Direct copy norm weight: {old_key} ({old_size})") |
| | else: |
| | new_weight = torch.ones(new_size, dtype=old_weight.dtype, device=old_weight.device) |
| | copy_size = min(old_size, new_size) |
| | new_weight[:copy_size] = old_weight[:copy_size] |
| | new_state[new_key] = new_weight |
| | print(f" Padded norm weight: {old_key} ({old_size}β{new_size})") |
| | elif "norm" in old_key and "bias" in old_key: |
| | old_size = old_weight.shape[0] |
| | new_key_shape = new_state.get(new_key, torch.empty(0)).shape |
| | if len(new_key_shape) > 0: |
| | new_size = new_key_shape[0] |
| | if old_size == new_size: |
| | new_state[new_key] = old_weight.clone() |
| | print(f" Direct copy norm bias: {old_key} ({old_size})") |
| | else: |
| | new_bias = torch.zeros(new_size, dtype=old_weight.dtype, device=old_weight.device) |
| | copy_size = min(old_size, new_size) |
| | new_bias[:copy_size] = old_weight[:copy_size] |
| | new_state[new_key] = new_bias |
| | print(f" Padded norm bias: {old_key} ({old_size}β{new_size})") |
| | |
| | |
| | else: |
| | if old_weight.shape == new_state.get(new_key, torch.empty(0)).shape: |
| | new_state[new_key] = old_weight.clone() |
| | print(f" Direct copy: {old_key}") |
| | else: |
| | print(f" SKIP (shape mismatch): {old_key}") |
| |
|
| |
|
| | def port_non_block_weights(old_state, new_state, old_hidden=256, new_hidden=1024): |
| | """Port weights that aren't in single/double blocks with dimension expansion.""" |
| | |
| | for old_key, old_weight in old_state.items(): |
| | |
| | if "single_blocks" in old_key or "double_blocks" in old_key: |
| | continue |
| | |
| | |
| | if any(x in old_key for x in ["sin_basis", "freqs_"]): |
| | print(f" Skip buffer: {old_key}") |
| | continue |
| | |
| | |
| | if "img_in.weight" in old_key: |
| | new_weight = torch.zeros(new_hidden, old_weight.shape[1], dtype=old_weight.dtype) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | new_weight[:old_hidden, :] = old_weight |
| | new_state[old_key] = new_weight |
| | print(f" Expanded: {old_key}") |
| | elif "img_in.bias" in old_key: |
| | new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden) |
| | print(f" Expanded: {old_key}") |
| | |
| | |
| | elif "txt_in.weight" in old_key: |
| | new_weight = torch.zeros(new_hidden, old_weight.shape[1], dtype=old_weight.dtype) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | new_weight[:old_hidden, :] = old_weight |
| | new_state[old_key] = new_weight |
| | print(f" Expanded: {old_key}") |
| | elif "txt_in.bias" in old_key: |
| | new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden) |
| | print(f" Expanded: {old_key}") |
| | |
| | |
| | elif any(x in old_key for x in ["time_in", "guidance_in"]): |
| | if "fc1.weight" in old_key: |
| | new_weight = torch.zeros(new_hidden, new_hidden, dtype=old_weight.dtype) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | new_weight[:old_hidden, :old_hidden] = old_weight |
| | new_state[old_key] = new_weight |
| | print(f" Expanded: {old_key}") |
| | elif "fc1.bias" in old_key: |
| | new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden) |
| | print(f" Expanded: {old_key}") |
| | elif "fc2.weight" in old_key: |
| | new_weight = torch.zeros(new_hidden, new_hidden, dtype=old_weight.dtype) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | new_weight[:old_hidden, :old_hidden] = old_weight |
| | new_state[old_key] = new_weight |
| | print(f" Expanded: {old_key}") |
| | elif "fc2.bias" in old_key: |
| | new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden) |
| | print(f" Expanded: {old_key}") |
| | |
| | |
| | elif "vector_in" in old_key: |
| | if "weight" in old_key: |
| | new_weight = torch.zeros(new_hidden, old_weight.shape[1], dtype=old_weight.dtype) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | new_weight[:old_hidden, :] = old_weight |
| | new_state[old_key] = new_weight |
| | print(f" Expanded: {old_key}") |
| | elif "bias" in old_key: |
| | new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden) |
| | print(f" Expanded: {old_key}") |
| | |
| | |
| | elif "final_norm" in old_key: |
| | if "weight" in old_key: |
| | new_state[old_key] = expand_norm(old_weight, old_hidden, new_hidden) |
| | print(f" Expanded: {old_key}") |
| | |
| | |
| | elif "final_linear.weight" in old_key: |
| | new_weight = torch.zeros(old_weight.shape[0], new_hidden, dtype=old_weight.dtype) |
| | nn.init.xavier_uniform_(new_weight) |
| | new_weight *= 0.02 |
| | new_weight[:, :old_hidden] = old_weight |
| | new_state[old_key] = new_weight |
| | print(f" Expanded: {old_key}") |
| | elif "final_linear.bias" in old_key: |
| | new_state[old_key] = old_weight.clone() |
| | print(f" Direct copy: {old_key}") |
| | |
| | |
| | elif "rope" in old_key: |
| | print(f" Skip RoPE: {old_key}") |
| | |
| | else: |
| | print(f" Unknown non-block key: {old_key}") |
| |
|
| |
|
| | |
| | |
| | |
| | def port_tinyflux_to_deep(old_weights_path, new_model): |
| | """ |
| | Port TinyFlux weights to TinyFlux-Deep. |
| | |
| | Returns: |
| | new_state_dict: Ported weights |
| | frozen_params: Set of parameter names to freeze |
| | """ |
| | print("Loading old weights...") |
| | if old_weights_path.endswith(".safetensors"): |
| | old_state = load_file(old_weights_path) |
| | else: |
| | old_state = torch.load(old_weights_path, map_location="cpu") |
| | if "model" in old_state: |
| | old_state = old_state["model"] |
| | |
| | |
| | if any(k.startswith("_orig_mod.") for k in old_state.keys()): |
| | print("Stripping _orig_mod prefix...") |
| | old_state = {k.replace("_orig_mod.", ""): v for k, v in old_state.items()} |
| | |
| | |
| | new_state = new_model.state_dict() |
| | frozen_params = set() |
| | |
| | |
| | if "final_norm.weight" in old_state: |
| | old_hidden = old_state["final_norm.weight"].shape[0] |
| | elif "img_in.weight" in old_state: |
| | old_hidden = old_state["img_in.weight"].shape[0] |
| | else: |
| | old_hidden = 256 |
| | |
| | |
| | if "final_norm.weight" in new_state: |
| | new_hidden = new_state["final_norm.weight"].shape[0] |
| | else: |
| | new_hidden = 512 |
| | |
| | print(f"Detected old hidden size: {old_hidden}") |
| | print(f"New hidden size: {new_hidden}") |
| | |
| | print("\n" + "="*60) |
| | print("Porting non-block weights...") |
| | print("="*60) |
| | port_non_block_weights(old_state, new_state, old_hidden=old_hidden, new_hidden=new_hidden) |
| | |
| | print("\n" + "="*60) |
| | print("Porting single blocks (3 β 25)...") |
| | print("="*60) |
| | for old_idx, new_positions in SINGLE_MAPPING.items(): |
| | for new_idx in new_positions: |
| | print(f"\nSingle block {old_idx} β {new_idx}:") |
| | port_single_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=old_hidden, new_hidden=new_hidden) |
| | |
| | for key in new_state.keys(): |
| | if f"single_blocks.{new_idx}." in key: |
| | frozen_params.add(key) |
| | |
| | print("\n" + "="*60) |
| | print("Porting double blocks (3 β 15)...") |
| | print("="*60) |
| | for old_idx, new_positions in DOUBLE_MAPPING.items(): |
| | for new_idx in new_positions: |
| | print(f"\nDouble block {old_idx} β {new_idx}:") |
| | port_double_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=old_hidden, new_hidden=new_hidden) |
| | |
| | for key in new_state.keys(): |
| | if f"double_blocks.{new_idx}." in key: |
| | frozen_params.add(key) |
| | |
| | print("\n" + "="*60) |
| | print("Summary") |
| | print("="*60) |
| | print(f"Total parameters in new model: {len(new_state)}") |
| | print(f"Frozen parameters: {len(frozen_params)}") |
| | print(f"Trainable parameters: {len(new_state) - len(frozen_params)}") |
| | |
| | print(f"\nFrozen single block positions: {sorted(SINGLE_FROZEN)}") |
| | print(f"Frozen double block positions: {sorted(DOUBLE_FROZEN)}") |
| | |
| | return new_state, frozen_params |
| |
|
| |
|
| | |
| | |
| | |
| | def freeze_ported_layers(model, frozen_params): |
| | """Freeze the ported layers, keep new layers trainable.""" |
| | frozen_count = 0 |
| | trainable_count = 0 |
| | |
| | for name, param in model.named_parameters(): |
| | if name in frozen_params: |
| | param.requires_grad = False |
| | frozen_count += param.numel() |
| | else: |
| | param.requires_grad = True |
| | trainable_count += param.numel() |
| | |
| | print(f"\nFrozen params: {frozen_count:,}") |
| | print(f"Trainable params: {trainable_count:,}") |
| | print(f"Total params: {frozen_count + trainable_count:,}") |
| | print(f"Trainable ratio: {trainable_count / (frozen_count + trainable_count) * 100:.1f}%") |
| | |
| | return model |
| |
|
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | print("="*60) |
| | print("TinyFlux β TinyFlux-Deep Porting") |
| | print("="*60) |
| | |
| | |
| | print("\nDownloading TinyFlux weights from hub...") |
| | old_weights_path = hf_hub_download( |
| | repo_id="AbstractPhil/tiny-flux", |
| | filename="model.safetensors" |
| | ) |
| | |
| | |
| | print("Detecting old model dimensions...") |
| | old_state = load_file(old_weights_path) |
| | if any(k.startswith("_orig_mod.") for k in old_state.keys()): |
| | old_state = {k.replace("_orig_mod.", ""): v for k, v in old_state.items()} |
| | |
| | |
| | old_hidden = old_state["final_norm.weight"].shape[0] |
| | head_dim = 128 |
| | old_heads = old_hidden // head_dim |
| | |
| | print(f" Old hidden size: {old_hidden}") |
| | print(f" Old attention heads: {old_heads}") |
| | print(f" Head dim: {head_dim}") |
| | |
| | |
| | new_heads = old_heads * 2 |
| | new_hidden = new_heads * head_dim |
| | |
| | print(f"\nNew dimensions:") |
| | print(f" New hidden size: {new_hidden}") |
| | print(f" New attention heads: {new_heads}") |
| | |
| | |
| | deep_config = TinyFluxDeepConfig() |
| | deep_config.hidden_size = new_hidden |
| | deep_config.num_attention_heads = new_heads |
| | |
| | print("\nCreating TinyFlux-Deep model...") |
| | |
| | deep_model = TinyFlux(deep_config).to(DTYPE) |
| | |
| | print(f"\nDeep model config:") |
| | print(f" Hidden size: {deep_config.hidden_size}") |
| | print(f" Attention heads: {deep_config.num_attention_heads}") |
| | print(f" Single layers: {deep_config.num_single_layers}") |
| | print(f" Double layers: {deep_config.num_double_layers}") |
| | |
| | |
| | new_state, frozen_params = port_tinyflux_to_deep(old_weights_path, deep_model) |
| | |
| | |
| | print("\nLoading ported weights into model...") |
| | missing, unexpected = deep_model.load_state_dict(new_state, strict=False) |
| | if missing: |
| | print(f" Missing keys: {missing[:5]}..." if len(missing) > 5 else f" Missing keys: {missing}") |
| | if unexpected: |
| | print(f" Unexpected keys: {unexpected}") |
| | |
| | |
| | print("\nFreezing ported layers...") |
| | deep_model = freeze_ported_layers(deep_model, frozen_params) |
| | |
| | |
| | print("\nSaving ported model...") |
| | save_path = "tinyflux_deep_ported.safetensors" |
| | |
| | |
| | state_to_save = deep_model.state_dict() |
| | if any(k.startswith("_orig_mod.") for k in state_to_save.keys()): |
| | state_to_save = {k.replace("_orig_mod.", ""): v for k, v in state_to_save.items()} |
| | |
| | save_file(state_to_save, save_path) |
| | print(f"β Saved to {save_path}") |
| | |
| | |
| | import json |
| | with open("frozen_params.json", "w") as f: |
| | json.dump(list(frozen_params), f) |
| | print("β Saved frozen_params.json") |
| | |
| | |
| | config_dict = { |
| | "hidden_size": deep_config.hidden_size, |
| | "num_attention_heads": deep_config.num_attention_heads, |
| | "attention_head_dim": deep_config.attention_head_dim, |
| | "num_single_layers": deep_config.num_single_layers, |
| | "num_double_layers": deep_config.num_double_layers, |
| | "mlp_ratio": deep_config.mlp_ratio, |
| | "joint_attention_dim": deep_config.joint_attention_dim, |
| | "pooled_projection_dim": deep_config.pooled_projection_dim, |
| | "in_channels": deep_config.in_channels, |
| | "axes_dims_rope": list(deep_config.axes_dims_rope), |
| | "guidance_embeds": deep_config.guidance_embeds, |
| | } |
| | with open("config_deep.json", "w") as f: |
| | json.dump(config_dict, f, indent=2) |
| | print("β Saved config_deep.json") |
| | |
| | |
| | print("\nUploading to AbstractPhil/tiny-flux-deep...") |
| | api = HfApi() |
| | try: |
| | api.create_repo(repo_id="AbstractPhil/tiny-flux-deep", exist_ok=True, repo_type="model") |
| | api.upload_file(path_or_fileobj=save_path, path_in_repo="model.safetensors", repo_id="AbstractPhil/tiny-flux-deep") |
| | api.upload_file(path_or_fileobj="config_deep.json", path_in_repo="config.json", repo_id="AbstractPhil/tiny-flux-deep") |
| | api.upload_file(path_or_fileobj="frozen_params.json", path_in_repo="frozen_params.json", repo_id="AbstractPhil/tiny-flux-deep") |
| | print("β Uploaded to hub!") |
| | except Exception as e: |
| | print(f"β Upload failed: {e}") |
| | |
| | print("\n" + "="*60) |
| | print("Porting complete!") |
| | print("="*60) |
| | print("\nNext steps:") |
| | print("1. Update TinyFlux model definition to accept TinyFluxDeepConfig") |
| | print("2. Use the frozen_params.json to freeze layers during training") |
| | print("3. Train on AbstractPhil/tiny-flux-deep repo") |