| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| from safetensors.torch import load_file, save_file |
| from huggingface_hub import hf_hub_download, HfApi |
| from dataclasses import dataclass |
| from copy import deepcopy |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| DTYPE = torch.bfloat16 |
|
|
| |
| |
| |
| @dataclass |
| class TinyFluxConfig: |
| """Original small config""" |
| hidden_size: int = 768 |
| num_attention_heads: int = 2 |
| attention_head_dim: int = 128 |
| num_single_blocks: int = 3 |
| num_double_blocks: int = 3 |
| mlp_ratio: float = 4.0 |
| t5_embed_dim: int = 768 |
| clip_embed_dim: int = 768 |
| in_channels: int = 16 |
| axes_dims: tuple = (16, 24, 24) |
| theta: int = 10000 |
|
|
|
|
| @dataclass |
| class TinyFluxDeepConfig: |
| """Expanded deep config""" |
| hidden_size: int = 768 |
| num_attention_heads: int = 8 |
| attention_head_dim: int = 128 |
| num_single_blocks: int = 25 |
| num_double_blocks: int = 15 |
| mlp_ratio: float = 4.0 |
| t5_embed_dim: int = 768 |
| clip_embed_dim: int = 768 |
| in_channels: int = 16 |
| axes_dims: tuple = (16, 24, 24) |
| theta: int = 10000 |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| SINGLE_MAPPING = { |
| 0: [0], |
| 1: [8, 12, 16], |
| 2: [24], |
| } |
| SINGLE_FROZEN = {0, 8, 12, 16, 24} |
|
|
| |
| |
| |
| |
| |
|
|
| DOUBLE_MAPPING = { |
| 0: [0], |
| 1: [4, 7, 10], |
| 2: [14], |
| } |
| DOUBLE_FROZEN = {0, 4, 7, 10, 14} |
|
|
|
|
| |
| |
| |
| def expand_qkv_weights(old_weight, old_heads=2, new_heads=8, head_dim=128): |
| """ |
| Expand QKV projection weights from 2 heads to 8 heads. |
| Old heads go to positions 0 and 7, middle heads initialized randomly. |
| |
| QKV weight shape: (in_features, 3 * num_heads * head_dim) |
| """ |
| in_features = old_weight.shape[0] |
| old_qkv_dim = 3 * old_heads * head_dim |
| new_qkv_dim = 3 * new_heads * head_dim |
| |
| |
| new_weight = torch.zeros(in_features, new_qkv_dim, dtype=old_weight.dtype, device=old_weight.device) |
| |
| nn.init.xavier_uniform_(new_weight) |
| new_weight *= 0.1 |
| |
| |
| for qkv_idx in range(3): |
| old_start = qkv_idx * old_heads * head_dim |
| new_start = qkv_idx * new_heads * head_dim |
| |
| |
| old_h0_start = old_start |
| old_h0_end = old_start + head_dim |
| new_h0_start = new_start |
| new_h0_end = new_start + head_dim |
| new_weight[:, new_h0_start:new_h0_end] = old_weight[:, old_h0_start:old_h0_end] |
| |
| |
| old_h1_start = old_start + head_dim |
| old_h1_end = old_start + 2 * head_dim |
| new_h7_start = new_start + 7 * head_dim |
| new_h7_end = new_start + 8 * head_dim |
| new_weight[:, new_h7_start:new_h7_end] = old_weight[:, old_h1_start:old_h1_end] |
| |
| return new_weight |
|
|
|
|
| def expand_out_proj_weights(old_weight, old_heads=2, new_heads=8, head_dim=128): |
| """ |
| Expand output projection weights from 2 heads to 8 heads. |
| |
| Out proj weight shape: (num_heads * head_dim, out_features) |
| """ |
| out_features = old_weight.shape[1] |
| old_attn_dim = old_heads * head_dim |
| new_attn_dim = new_heads * head_dim |
| |
| |
| new_weight = torch.zeros(new_attn_dim, out_features, dtype=old_weight.dtype, device=old_weight.device) |
| nn.init.xavier_uniform_(new_weight) |
| new_weight *= 0.1 |
| |
| |
| new_weight[0:head_dim, :] = old_weight[0:head_dim, :] |
| |
| |
| new_weight[7*head_dim:8*head_dim, :] = old_weight[head_dim:2*head_dim, :] |
| |
| return new_weight |
|
|
|
|
| def port_single_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True): |
| """Port weights from old single block to new single block.""" |
| old_prefix = f"single_blocks.{old_idx}" |
| new_prefix = f"single_blocks.{new_idx}" |
| |
| for old_key in list(old_state.keys()): |
| if not old_key.startswith(old_prefix): |
| continue |
| |
| new_key = old_key.replace(old_prefix, new_prefix) |
| old_weight = old_state[old_key] |
| |
| |
| if expand_heads: |
| if "attn.qkv.weight" in old_key: |
| new_state[new_key] = expand_qkv_weights(old_weight) |
| print(f" Expanded QKV: {old_key} → {new_key}") |
| continue |
| elif "attn.out_proj.weight" in old_key: |
| new_state[new_key] = expand_out_proj_weights(old_weight) |
| print(f" Expanded out_proj: {old_key} → {new_key}") |
| continue |
| |
| |
| new_state[new_key] = old_weight.clone() |
| print(f" Copied: {old_key} → {new_key}") |
|
|
|
|
| def port_double_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True): |
| """Port weights from old double block to new double block.""" |
| old_prefix = f"double_blocks.{old_idx}" |
| new_prefix = f"double_blocks.{new_idx}" |
| |
| for old_key in list(old_state.keys()): |
| if not old_key.startswith(old_prefix): |
| continue |
| |
| new_key = old_key.replace(old_prefix, new_prefix) |
| old_weight = old_state[old_key] |
| |
| |
| if expand_heads: |
| if any(x in old_key for x in ["img_qkv.weight", "txt_qkv.weight"]): |
| new_state[new_key] = expand_qkv_weights(old_weight) |
| print(f" Expanded QKV: {old_key} → {new_key}") |
| continue |
| elif any(x in old_key for x in ["img_out.weight", "txt_out.weight"]): |
| new_state[new_key] = expand_out_proj_weights(old_weight) |
| print(f" Expanded out_proj: {old_key} → {new_key}") |
| continue |
| |
| |
| new_state[new_key] = old_weight.clone() |
| print(f" Copied: {old_key} → {new_key}") |
|
|
|
|
| def port_non_block_weights(old_state, new_state, old_heads=2, new_heads=8): |
| """Port weights that aren't in single/double blocks.""" |
| head_dim = 128 |
| |
| for old_key, old_weight in old_state.items(): |
| |
| if "single_blocks" in old_key or "double_blocks" in old_key: |
| continue |
| |
| |
| direct_copy_keys = [ |
| "img_in", "txt_in", "time_in", "vector_in", "guidance_in", |
| "final_norm", "final_linear", "rope" |
| ] |
| |
| if any(k in old_key for k in direct_copy_keys): |
| new_state[old_key] = old_weight.clone() |
| print(f" Direct copy: {old_key}") |
|
|
|
|
| |
| |
| |
| def port_tinyflux_to_deep(old_weights_path, new_model): |
| """ |
| Port TinyFlux weights to TinyFlux-Deep. |
| |
| Returns: |
| new_state_dict: Ported weights |
| frozen_params: Set of parameter names to freeze |
| """ |
| print("Loading old weights...") |
| if old_weights_path.endswith(".safetensors"): |
| old_state = load_file(old_weights_path) |
| else: |
| old_state = torch.load(old_weights_path, map_location="cpu") |
| if "model" in old_state: |
| old_state = old_state["model"] |
| |
| |
| if any(k.startswith("_orig_mod.") for k in old_state.keys()): |
| print("Stripping _orig_mod prefix...") |
| old_state = {k.replace("_orig_mod.", ""): v for k, v in old_state.items()} |
| |
| |
| new_state = new_model.state_dict() |
| frozen_params = set() |
| |
| print("\n" + "="*60) |
| print("Porting non-block weights...") |
| print("="*60) |
| port_non_block_weights(old_state, new_state) |
| |
| print("\n" + "="*60) |
| print("Porting single blocks (3 → 25)...") |
| print("="*60) |
| for old_idx, new_positions in SINGLE_MAPPING.items(): |
| for new_idx in new_positions: |
| print(f"\nSingle block {old_idx} → {new_idx}:") |
| port_single_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True) |
| |
| for key in new_state.keys(): |
| if f"single_blocks.{new_idx}." in key: |
| frozen_params.add(key) |
| |
| print("\n" + "="*60) |
| print("Porting double blocks (3 → 15)...") |
| print("="*60) |
| for old_idx, new_positions in DOUBLE_MAPPING.items(): |
| for new_idx in new_positions: |
| print(f"\nDouble block {old_idx} → {new_idx}:") |
| port_double_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True) |
| |
| for key in new_state.keys(): |
| if f"double_blocks.{new_idx}." in key: |
| frozen_params.add(key) |
| |
| print("\n" + "="*60) |
| print("Summary") |
| print("="*60) |
| print(f"Total parameters in new model: {len(new_state)}") |
| print(f"Frozen parameters: {len(frozen_params)}") |
| print(f"Trainable parameters: {len(new_state) - len(frozen_params)}") |
| |
| print(f"\nFrozen single block positions: {sorted(SINGLE_FROZEN)}") |
| print(f"Frozen double block positions: {sorted(DOUBLE_FROZEN)}") |
| |
| return new_state, frozen_params |
|
|
|
|
| |
| |
| |
| def freeze_ported_layers(model, frozen_params): |
| """Freeze the ported layers, keep new layers trainable.""" |
| frozen_count = 0 |
| trainable_count = 0 |
| |
| for name, param in model.named_parameters(): |
| if name in frozen_params: |
| param.requires_grad = False |
| frozen_count += param.numel() |
| else: |
| param.requires_grad = True |
| trainable_count += param.numel() |
| |
| print(f"\nFrozen params: {frozen_count:,}") |
| print(f"Trainable params: {trainable_count:,}") |
| print(f"Total params: {frozen_count + trainable_count:,}") |
| print(f"Trainable ratio: {trainable_count / (frozen_count + trainable_count) * 100:.1f}%") |
| |
| return model |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| print("="*60) |
| print("TinyFlux → TinyFlux-Deep Porting") |
| print("="*60) |
| |
| |
| print("\nDownloading TinyFlux weights from hub...") |
| old_weights_path = hf_hub_download( |
| repo_id="AbstractPhil/tiny-flux", |
| filename="model.safetensors" |
| ) |
| |
| |
| print("\nCreating TinyFlux-Deep model...") |
| deep_config = TinyFluxDeepConfig() |
| |
| |
| |
| deep_model = TinyFlux(deep_config).to(DTYPE) |
| |
| print(f"\nDeep model config:") |
| print(f" Hidden size: {deep_config.hidden_size}") |
| print(f" Attention heads: {deep_config.num_attention_heads}") |
| print(f" Single blocks: {deep_config.num_single_blocks}") |
| print(f" Double blocks: {deep_config.num_double_blocks}") |
| |
| |
| new_state, frozen_params = port_tinyflux_to_deep(old_weights_path, deep_model) |
| |
| |
| print("\nLoading ported weights into model...") |
| missing, unexpected = deep_model.load_state_dict(new_state, strict=False) |
| if missing: |
| print(f" Missing keys: {missing[:5]}..." if len(missing) > 5 else f" Missing keys: {missing}") |
| if unexpected: |
| print(f" Unexpected keys: {unexpected}") |
| |
| |
| print("\nFreezing ported layers...") |
| deep_model = freeze_ported_layers(deep_model, frozen_params) |
| |
| |
| print("\nSaving ported model...") |
| save_path = "tinyflux_deep_ported.safetensors" |
| |
| |
| state_to_save = deep_model.state_dict() |
| if any(k.startswith("_orig_mod.") for k in state_to_save.keys()): |
| state_to_save = {k.replace("_orig_mod.", ""): v for k, v in state_to_save.items()} |
| |
| save_file(state_to_save, save_path) |
| print(f"✓ Saved to {save_path}") |
| |
| |
| import json |
| with open("frozen_params.json", "w") as f: |
| json.dump(list(frozen_params), f) |
| print("✓ Saved frozen_params.json") |
| |
| |
| config_dict = { |
| "hidden_size": deep_config.hidden_size, |
| "num_attention_heads": deep_config.num_attention_heads, |
| "attention_head_dim": deep_config.attention_head_dim, |
| "num_single_blocks": deep_config.num_single_blocks, |
| "num_double_blocks": deep_config.num_double_blocks, |
| "mlp_ratio": deep_config.mlp_ratio, |
| "t5_embed_dim": deep_config.t5_embed_dim, |
| "clip_embed_dim": deep_config.clip_embed_dim, |
| "in_channels": deep_config.in_channels, |
| "axes_dims": list(deep_config.axes_dims), |
| "theta": deep_config.theta, |
| } |
| with open("config_deep.json", "w") as f: |
| json.dump(config_dict, f, indent=2) |
| print("✓ Saved config_deep.json") |
| |
| |
| print("\nUploading to AbstractPhil/tiny-flux-deep...") |
| api = HfApi() |
| try: |
| api.create_repo(repo_id="AbstractPhil/tiny-flux-deep", exist_ok=True, repo_type="model") |
| api.upload_file(path_or_fileobj=save_path, path_in_repo="model.safetensors", repo_id="AbstractPhil/tiny-flux-deep") |
| api.upload_file(path_or_fileobj="config_deep.json", path_in_repo="config.json", repo_id="AbstractPhil/tiny-flux-deep") |
| api.upload_file(path_or_fileobj="frozen_params.json", path_in_repo="frozen_params.json", repo_id="AbstractPhil/tiny-flux-deep") |
| print("✓ Uploaded to hub!") |
| except Exception as e: |
| print(f"⚠ Upload failed: {e}") |
| |
| print("\n" + "="*60) |
| print("Porting complete!") |
| print("="*60) |
| print("\nNext steps:") |
| print("1. Update TinyFlux model definition to accept TinyFluxDeepConfig") |
| print("2. Use the frozen_params.json to freeze layers during training") |
| print("3. Train on AbstractPhil/tiny-flux-deep repo") |