tiny-flux-deep / port_tiny_to_deep.py
AbstractPhil's picture
Create port_tiny_to_deep.py
10a0fd5 verified
raw
history blame
16.2 kB
# ============================================================================
# TinyFlux → TinyFlux-Deep Porting Script
# ============================================================================
# Expands: 3 single + 3 double → 25 single + 15 double
# Heads: 2 → 8 (old heads become first and last)
# Freezes ported layers, trains new ones
# ============================================================================
import torch
import torch.nn as nn
from safetensors.torch import load_file, save_file
from huggingface_hub import hf_hub_download, HfApi
from dataclasses import dataclass
from copy import deepcopy
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16
# ============================================================================
# CONFIGS
# ============================================================================
@dataclass
class TinyFluxConfig:
"""Original small config"""
hidden_size: int = 768
num_attention_heads: int = 2
attention_head_dim: int = 128
num_single_blocks: int = 3
num_double_blocks: int = 3
mlp_ratio: float = 4.0
t5_embed_dim: int = 768
clip_embed_dim: int = 768
in_channels: int = 16
axes_dims: tuple = (16, 24, 24)
theta: int = 10000
@dataclass
class TinyFluxDeepConfig:
"""Expanded deep config"""
hidden_size: int = 768 # Same
num_attention_heads: int = 8 # 2 → 8 (6 new heads)
attention_head_dim: int = 128 # Same (so attention dim = 8*128 = 1024)
num_single_blocks: int = 25 # 3 → 25 (more singles like original Flux)
num_double_blocks: int = 15 # 3 → 15
mlp_ratio: float = 4.0 # Same
t5_embed_dim: int = 768 # Same
clip_embed_dim: int = 768 # Same
in_channels: int = 16 # Same
axes_dims: tuple = (16, 24, 24) # Same
theta: int = 10000 # Same
# ============================================================================
# LAYER MAPPING
# ============================================================================
# Single blocks: 3 → 25
# - Layer 0 → position 0 (frozen)
# - Layer 1 → positions 8, 12, 16 (center, spaced, frozen)
# - Layer 2 → position 24 (frozen)
# - Rest → new (trainable)
SINGLE_MAPPING = {
0: [0], # Old layer 0 → new position 0
1: [8, 12, 16], # Old layer 1 → new positions 8, 12, 16
2: [24], # Old layer 2 → new position 24
}
SINGLE_FROZEN = {0, 8, 12, 16, 24} # These positions are frozen
# Double blocks: 3 → 15
# - Layer 0 → position 0 (frozen)
# - Layer 1 → positions 4, 7, 10 (3 copies, spaced, frozen)
# - Layer 2 → position 14 (frozen)
# - Rest → new (trainable)
DOUBLE_MAPPING = {
0: [0], # Old layer 0 → new position 0
1: [4, 7, 10], # Old layer 1 → 3 positions
2: [14], # Old layer 2 → new position 14
}
DOUBLE_FROZEN = {0, 4, 7, 10, 14} # These positions are frozen
# ============================================================================
# WEIGHT EXPANSION UTILITIES
# ============================================================================
def expand_qkv_weights(old_weight, old_heads=2, new_heads=8, head_dim=128):
"""
Expand QKV projection weights from 2 heads to 8 heads.
Old heads go to positions 0 and 7, middle heads initialized randomly.
QKV weight shape: (in_features, 3 * num_heads * head_dim)
"""
in_features = old_weight.shape[0]
old_qkv_dim = 3 * old_heads * head_dim # 3 * 2 * 128 = 768
new_qkv_dim = 3 * new_heads * head_dim # 3 * 8 * 128 = 3072
# Initialize new weights
new_weight = torch.zeros(in_features, new_qkv_dim, dtype=old_weight.dtype, device=old_weight.device)
# Small random init for new heads
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.1 # Scale down random init
# For each of Q, K, V
for qkv_idx in range(3):
old_start = qkv_idx * old_heads * head_dim
new_start = qkv_idx * new_heads * head_dim
# Copy old head 0 → new head 0
old_h0_start = old_start
old_h0_end = old_start + head_dim
new_h0_start = new_start
new_h0_end = new_start + head_dim
new_weight[:, new_h0_start:new_h0_end] = old_weight[:, old_h0_start:old_h0_end]
# Copy old head 1 → new head 7 (last)
old_h1_start = old_start + head_dim
old_h1_end = old_start + 2 * head_dim
new_h7_start = new_start + 7 * head_dim
new_h7_end = new_start + 8 * head_dim
new_weight[:, new_h7_start:new_h7_end] = old_weight[:, old_h1_start:old_h1_end]
return new_weight
def expand_out_proj_weights(old_weight, old_heads=2, new_heads=8, head_dim=128):
"""
Expand output projection weights from 2 heads to 8 heads.
Out proj weight shape: (num_heads * head_dim, out_features)
"""
out_features = old_weight.shape[1]
old_attn_dim = old_heads * head_dim # 2 * 128 = 256
new_attn_dim = new_heads * head_dim # 8 * 128 = 1024
# Initialize new weights
new_weight = torch.zeros(new_attn_dim, out_features, dtype=old_weight.dtype, device=old_weight.device)
nn.init.xavier_uniform_(new_weight)
new_weight *= 0.1
# Copy old head 0 → new head 0
new_weight[0:head_dim, :] = old_weight[0:head_dim, :]
# Copy old head 1 → new head 7
new_weight[7*head_dim:8*head_dim, :] = old_weight[head_dim:2*head_dim, :]
return new_weight
def port_single_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True):
"""Port weights from old single block to new single block."""
old_prefix = f"single_blocks.{old_idx}"
new_prefix = f"single_blocks.{new_idx}"
for old_key in list(old_state.keys()):
if not old_key.startswith(old_prefix):
continue
new_key = old_key.replace(old_prefix, new_prefix)
old_weight = old_state[old_key]
# Handle attention head expansion
if expand_heads:
if "attn.qkv.weight" in old_key:
new_state[new_key] = expand_qkv_weights(old_weight)
print(f" Expanded QKV: {old_key}{new_key}")
continue
elif "attn.out_proj.weight" in old_key:
new_state[new_key] = expand_out_proj_weights(old_weight)
print(f" Expanded out_proj: {old_key}{new_key}")
continue
# Direct copy for other weights
new_state[new_key] = old_weight.clone()
print(f" Copied: {old_key}{new_key}")
def port_double_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True):
"""Port weights from old double block to new double block."""
old_prefix = f"double_blocks.{old_idx}"
new_prefix = f"double_blocks.{new_idx}"
for old_key in list(old_state.keys()):
if not old_key.startswith(old_prefix):
continue
new_key = old_key.replace(old_prefix, new_prefix)
old_weight = old_state[old_key]
# Handle attention head expansion for joint attention
if expand_heads:
if any(x in old_key for x in ["img_qkv.weight", "txt_qkv.weight"]):
new_state[new_key] = expand_qkv_weights(old_weight)
print(f" Expanded QKV: {old_key}{new_key}")
continue
elif any(x in old_key for x in ["img_out.weight", "txt_out.weight"]):
new_state[new_key] = expand_out_proj_weights(old_weight)
print(f" Expanded out_proj: {old_key}{new_key}")
continue
# Direct copy
new_state[new_key] = old_weight.clone()
print(f" Copied: {old_key}{new_key}")
def port_non_block_weights(old_state, new_state, old_heads=2, new_heads=8):
"""Port weights that aren't in single/double blocks."""
head_dim = 128
for old_key, old_weight in old_state.items():
# Skip block weights (handled separately)
if "single_blocks" in old_key or "double_blocks" in old_key:
continue
# These can be copied directly (same dimensions)
direct_copy_keys = [
"img_in", "txt_in", "time_in", "vector_in", "guidance_in",
"final_norm", "final_linear", "rope"
]
if any(k in old_key for k in direct_copy_keys):
new_state[old_key] = old_weight.clone()
print(f" Direct copy: {old_key}")
# ============================================================================
# MAIN PORTING FUNCTION
# ============================================================================
def port_tinyflux_to_deep(old_weights_path, new_model):
"""
Port TinyFlux weights to TinyFlux-Deep.
Returns:
new_state_dict: Ported weights
frozen_params: Set of parameter names to freeze
"""
print("Loading old weights...")
if old_weights_path.endswith(".safetensors"):
old_state = load_file(old_weights_path)
else:
old_state = torch.load(old_weights_path, map_location="cpu")
if "model" in old_state:
old_state = old_state["model"]
# Strip _orig_mod prefix if present
if any(k.startswith("_orig_mod.") for k in old_state.keys()):
print("Stripping _orig_mod prefix...")
old_state = {k.replace("_orig_mod.", ""): v for k, v in old_state.items()}
# Get new model's state dict as template
new_state = new_model.state_dict()
frozen_params = set()
print("\n" + "="*60)
print("Porting non-block weights...")
print("="*60)
port_non_block_weights(old_state, new_state)
print("\n" + "="*60)
print("Porting single blocks (3 → 25)...")
print("="*60)
for old_idx, new_positions in SINGLE_MAPPING.items():
for new_idx in new_positions:
print(f"\nSingle block {old_idx}{new_idx}:")
port_single_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True)
# Mark as frozen
for key in new_state.keys():
if f"single_blocks.{new_idx}." in key:
frozen_params.add(key)
print("\n" + "="*60)
print("Porting double blocks (3 → 15)...")
print("="*60)
for old_idx, new_positions in DOUBLE_MAPPING.items():
for new_idx in new_positions:
print(f"\nDouble block {old_idx}{new_idx}:")
port_double_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True)
# Mark as frozen
for key in new_state.keys():
if f"double_blocks.{new_idx}." in key:
frozen_params.add(key)
print("\n" + "="*60)
print("Summary")
print("="*60)
print(f"Total parameters in new model: {len(new_state)}")
print(f"Frozen parameters: {len(frozen_params)}")
print(f"Trainable parameters: {len(new_state) - len(frozen_params)}")
print(f"\nFrozen single block positions: {sorted(SINGLE_FROZEN)}")
print(f"Frozen double block positions: {sorted(DOUBLE_FROZEN)}")
return new_state, frozen_params
# ============================================================================
# FREEZE HELPER
# ============================================================================
def freeze_ported_layers(model, frozen_params):
"""Freeze the ported layers, keep new layers trainable."""
frozen_count = 0
trainable_count = 0
for name, param in model.named_parameters():
if name in frozen_params:
param.requires_grad = False
frozen_count += param.numel()
else:
param.requires_grad = True
trainable_count += param.numel()
print(f"\nFrozen params: {frozen_count:,}")
print(f"Trainable params: {trainable_count:,}")
print(f"Total params: {frozen_count + trainable_count:,}")
print(f"Trainable ratio: {trainable_count / (frozen_count + trainable_count) * 100:.1f}%")
return model
# ============================================================================
# MAIN SCRIPT
# ============================================================================
if __name__ == "__main__":
print("="*60)
print("TinyFlux → TinyFlux-Deep Porting")
print("="*60)
# Load old weights from hub
print("\nDownloading TinyFlux weights from hub...")
old_weights_path = hf_hub_download(
repo_id="AbstractPhil/tiny-flux",
filename="model.safetensors"
)
# Create new deep model
print("\nCreating TinyFlux-Deep model...")
deep_config = TinyFluxDeepConfig()
# You need to define TinyFlux class first (run model cell)
# This assumes TinyFlux accepts the config
deep_model = TinyFlux(deep_config).to(DTYPE)
print(f"\nDeep model config:")
print(f" Hidden size: {deep_config.hidden_size}")
print(f" Attention heads: {deep_config.num_attention_heads}")
print(f" Single blocks: {deep_config.num_single_blocks}")
print(f" Double blocks: {deep_config.num_double_blocks}")
# Port weights
new_state, frozen_params = port_tinyflux_to_deep(old_weights_path, deep_model)
# Load ported weights
print("\nLoading ported weights into model...")
missing, unexpected = deep_model.load_state_dict(new_state, strict=False)
if missing:
print(f" Missing keys: {missing[:5]}..." if len(missing) > 5 else f" Missing keys: {missing}")
if unexpected:
print(f" Unexpected keys: {unexpected}")
# Freeze ported layers
print("\nFreezing ported layers...")
deep_model = freeze_ported_layers(deep_model, frozen_params)
# Save
print("\nSaving ported model...")
save_path = "tinyflux_deep_ported.safetensors"
# Strip any _orig_mod prefix before saving
state_to_save = deep_model.state_dict()
if any(k.startswith("_orig_mod.") for k in state_to_save.keys()):
state_to_save = {k.replace("_orig_mod.", ""): v for k, v in state_to_save.items()}
save_file(state_to_save, save_path)
print(f"✓ Saved to {save_path}")
# Save frozen params list
import json
with open("frozen_params.json", "w") as f:
json.dump(list(frozen_params), f)
print("✓ Saved frozen_params.json")
# Save config
config_dict = {
"hidden_size": deep_config.hidden_size,
"num_attention_heads": deep_config.num_attention_heads,
"attention_head_dim": deep_config.attention_head_dim,
"num_single_blocks": deep_config.num_single_blocks,
"num_double_blocks": deep_config.num_double_blocks,
"mlp_ratio": deep_config.mlp_ratio,
"t5_embed_dim": deep_config.t5_embed_dim,
"clip_embed_dim": deep_config.clip_embed_dim,
"in_channels": deep_config.in_channels,
"axes_dims": list(deep_config.axes_dims),
"theta": deep_config.theta,
}
with open("config_deep.json", "w") as f:
json.dump(config_dict, f, indent=2)
print("✓ Saved config_deep.json")
# Upload to hub
print("\nUploading to AbstractPhil/tiny-flux-deep...")
api = HfApi()
try:
api.create_repo(repo_id="AbstractPhil/tiny-flux-deep", exist_ok=True, repo_type="model")
api.upload_file(path_or_fileobj=save_path, path_in_repo="model.safetensors", repo_id="AbstractPhil/tiny-flux-deep")
api.upload_file(path_or_fileobj="config_deep.json", path_in_repo="config.json", repo_id="AbstractPhil/tiny-flux-deep")
api.upload_file(path_or_fileobj="frozen_params.json", path_in_repo="frozen_params.json", repo_id="AbstractPhil/tiny-flux-deep")
print("✓ Uploaded to hub!")
except Exception as e:
print(f"⚠ Upload failed: {e}")
print("\n" + "="*60)
print("Porting complete!")
print("="*60)
print("\nNext steps:")
print("1. Update TinyFlux model definition to accept TinyFluxDeepConfig")
print("2. Use the frozen_params.json to freeze layers during training")
print("3. Train on AbstractPhil/tiny-flux-deep repo")