AbstractPhil commited on
Commit
457e2ff
Β·
verified Β·
1 Parent(s): bc95c6d

Update port_tiny_to_deep.py

Browse files
Files changed (1) hide show
  1. port_tiny_to_deep.py +511 -119
port_tiny_to_deep.py CHANGED
@@ -2,7 +2,7 @@
2
  # TinyFlux β†’ TinyFlux-Deep Porting Script
3
  # ============================================================================
4
  # Expands: 3 single + 3 double β†’ 25 single + 15 double
5
- # Heads: 2 β†’ 8 (old heads become first and last)
6
  # Freezes ported layers, trains new ones
7
  # ============================================================================
8
 
@@ -12,6 +12,7 @@ from safetensors.torch import load_file, save_file
12
  from huggingface_hub import hf_hub_download, HfApi
13
  from dataclasses import dataclass
14
  from copy import deepcopy
 
15
 
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
  DTYPE = torch.bfloat16
@@ -21,34 +22,71 @@ DTYPE = torch.bfloat16
21
  # ============================================================================
22
  @dataclass
23
  class TinyFluxConfig:
24
- """Original small config"""
 
25
  hidden_size: int = 768
26
- num_attention_heads: int = 2
27
- attention_head_dim: int = 128
28
- num_single_blocks: int = 3
29
- num_double_blocks: int = 3
30
- mlp_ratio: float = 4.0
31
- t5_embed_dim: int = 768
32
- clip_embed_dim: int = 768
33
  in_channels: int = 16
34
- axes_dims: tuple = (16, 24, 24)
35
- theta: int = 10000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  @dataclass
39
  class TinyFluxDeepConfig:
40
- """Expanded deep config"""
41
- hidden_size: int = 768 # Same
42
- num_attention_heads: int = 8 # 2 β†’ 8 (6 new heads)
43
- attention_head_dim: int = 128 # Same (so attention dim = 8*128 = 1024)
44
- num_single_blocks: int = 25 # 3 β†’ 25 (more singles like original Flux)
45
- num_double_blocks: int = 15 # 3 β†’ 15
46
- mlp_ratio: float = 4.0 # Same
47
- t5_embed_dim: int = 768 # Same
48
- clip_embed_dim: int = 768 # Same
49
- in_channels: int = 16 # Same
50
- axes_dims: tuple = (16, 24, 24) # Same
51
- theta: int = 10000 # Same
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  # ============================================================================
@@ -84,71 +122,118 @@ DOUBLE_FROZEN = {0, 4, 7, 10, 14} # These positions are frozen
84
  # ============================================================================
85
  # WEIGHT EXPANSION UTILITIES
86
  # ============================================================================
87
- def expand_qkv_weights(old_weight, old_heads=2, new_heads=8, head_dim=128):
88
  """
89
- Expand QKV projection weights from 2 heads to 8 heads.
90
- Old heads go to positions 0 and 7, middle heads initialized randomly.
91
 
92
- QKV weight shape: (in_features, 3 * num_heads * head_dim)
 
93
  """
94
- in_features = old_weight.shape[0]
95
- old_qkv_dim = 3 * old_heads * head_dim # 3 * 2 * 128 = 768
96
- new_qkv_dim = 3 * new_heads * head_dim # 3 * 8 * 128 = 3072
 
 
97
 
98
  # Initialize new weights
99
- new_weight = torch.zeros(in_features, new_qkv_dim, dtype=old_weight.dtype, device=old_weight.device)
100
- # Small random init for new heads
101
  nn.init.xavier_uniform_(new_weight)
102
- new_weight *= 0.1 # Scale down random init
103
 
104
- # For each of Q, K, V
105
  for qkv_idx in range(3):
106
- old_start = qkv_idx * old_heads * head_dim
107
- new_start = qkv_idx * new_heads * head_dim
108
 
109
- # Copy old head 0 β†’ new head 0
110
- old_h0_start = old_start
111
- old_h0_end = old_start + head_dim
112
- new_h0_start = new_start
113
- new_h0_end = new_start + head_dim
114
- new_weight[:, new_h0_start:new_h0_end] = old_weight[:, old_h0_start:old_h0_end]
115
-
116
- # Copy old head 1 β†’ new head 7 (last)
117
- old_h1_start = old_start + head_dim
118
- old_h1_end = old_start + 2 * head_dim
119
- new_h7_start = new_start + 7 * head_dim
120
- new_h7_end = new_start + 8 * head_dim
121
- new_weight[:, new_h7_start:new_h7_end] = old_weight[:, old_h1_start:old_h1_end]
122
 
123
  return new_weight
124
 
125
 
126
- def expand_out_proj_weights(old_weight, old_heads=2, new_heads=8, head_dim=128):
127
- """
128
- Expand output projection weights from 2 heads to 8 heads.
 
129
 
130
- Out proj weight shape: (num_heads * head_dim, out_features)
131
- """
132
- out_features = old_weight.shape[1]
133
- old_attn_dim = old_heads * head_dim # 2 * 128 = 256
134
- new_attn_dim = new_heads * head_dim # 8 * 128 = 1024
 
 
135
 
 
 
 
 
 
 
 
 
136
  # Initialize new weights
137
- new_weight = torch.zeros(new_attn_dim, out_features, dtype=old_weight.dtype, device=old_weight.device)
138
  nn.init.xavier_uniform_(new_weight)
139
- new_weight *= 0.1
 
 
 
140
 
141
- # Copy old head 0 β†’ new head 0
142
- new_weight[0:head_dim, :] = old_weight[0:head_dim, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- # Copy old head 1 β†’ new head 7
145
- new_weight[7*head_dim:8*head_dim, :] = old_weight[head_dim:2*head_dim, :]
 
 
146
 
147
  return new_weight
148
 
149
 
150
- def port_single_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True):
151
- """Port weights from old single block to new single block."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  old_prefix = f"single_blocks.{old_idx}"
153
  new_prefix = f"single_blocks.{new_idx}"
154
 
@@ -159,24 +244,119 @@ def port_single_block_weights(old_state, old_idx, new_state, new_idx, expand_hea
159
  new_key = old_key.replace(old_prefix, new_prefix)
160
  old_weight = old_state[old_key]
161
 
162
- # Handle attention head expansion
163
- if expand_heads:
164
- if "attn.qkv.weight" in old_key:
165
- new_state[new_key] = expand_qkv_weights(old_weight)
166
- print(f" Expanded QKV: {old_key} β†’ {new_key}")
167
- continue
168
- elif "attn.out_proj.weight" in old_key:
169
- new_state[new_key] = expand_out_proj_weights(old_weight)
170
- print(f" Expanded out_proj: {old_key} β†’ {new_key}")
171
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- # Direct copy for other weights
174
- new_state[new_key] = old_weight.clone()
175
- print(f" Copied: {old_key} β†’ {new_key}")
 
 
 
 
176
 
177
 
178
- def port_double_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True):
179
- """Port weights from old double block to new double block."""
180
  old_prefix = f"double_blocks.{old_idx}"
181
  new_prefix = f"double_blocks.{new_idx}"
182
 
@@ -187,40 +367,211 @@ def port_double_block_weights(old_state, old_idx, new_state, new_idx, expand_hea
187
  new_key = old_key.replace(old_prefix, new_prefix)
188
  old_weight = old_state[old_key]
189
 
190
- # Handle attention head expansion for joint attention
191
- if expand_heads:
192
- if any(x in old_key for x in ["img_qkv.weight", "txt_qkv.weight"]):
193
- new_state[new_key] = expand_qkv_weights(old_weight)
194
- print(f" Expanded QKV: {old_key} β†’ {new_key}")
195
- continue
196
- elif any(x in old_key for x in ["img_out.weight", "txt_out.weight"]):
197
- new_state[new_key] = expand_out_proj_weights(old_weight)
198
- print(f" Expanded out_proj: {old_key} β†’ {new_key}")
199
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- # Direct copy
202
- new_state[new_key] = old_weight.clone()
203
- print(f" Copied: {old_key} β†’ {new_key}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
 
206
- def port_non_block_weights(old_state, new_state, old_heads=2, new_heads=8):
207
- """Port weights that aren't in single/double blocks."""
208
- head_dim = 128
209
 
210
  for old_key, old_weight in old_state.items():
211
  # Skip block weights (handled separately)
212
  if "single_blocks" in old_key or "double_blocks" in old_key:
213
  continue
214
 
215
- # These can be copied directly (same dimensions)
216
- direct_copy_keys = [
217
- "img_in", "txt_in", "time_in", "vector_in", "guidance_in",
218
- "final_norm", "final_linear", "rope"
219
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
- if any(k in old_key for k in direct_copy_keys):
222
- new_state[old_key] = old_weight.clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  print(f" Direct copy: {old_key}")
 
 
 
 
 
 
 
224
 
225
 
226
  # ============================================================================
@@ -247,14 +598,31 @@ def port_tinyflux_to_deep(old_weights_path, new_model):
247
  print("Stripping _orig_mod prefix...")
248
  old_state = {k.replace("_orig_mod.", ""): v for k, v in old_state.items()}
249
 
250
- # Get new model's state dict as template
251
  new_state = new_model.state_dict()
252
  frozen_params = set()
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  print("\n" + "="*60)
255
  print("Porting non-block weights...")
256
  print("="*60)
257
- port_non_block_weights(old_state, new_state)
258
 
259
  print("\n" + "="*60)
260
  print("Porting single blocks (3 β†’ 25)...")
@@ -262,7 +630,7 @@ def port_tinyflux_to_deep(old_weights_path, new_model):
262
  for old_idx, new_positions in SINGLE_MAPPING.items():
263
  for new_idx in new_positions:
264
  print(f"\nSingle block {old_idx} β†’ {new_idx}:")
265
- port_single_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True)
266
  # Mark as frozen
267
  for key in new_state.keys():
268
  if f"single_blocks.{new_idx}." in key:
@@ -274,7 +642,7 @@ def port_tinyflux_to_deep(old_weights_path, new_model):
274
  for old_idx, new_positions in DOUBLE_MAPPING.items():
275
  for new_idx in new_positions:
276
  print(f"\nDouble block {old_idx} β†’ {new_idx}:")
277
- port_double_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True)
278
  # Mark as frozen
279
  for key in new_state.keys():
280
  if f"double_blocks.{new_idx}." in key:
@@ -325,26 +693,50 @@ if __name__ == "__main__":
325
  print("TinyFlux β†’ TinyFlux-Deep Porting")
326
  print("="*60)
327
 
328
- # Load old weights from hub
329
  print("\nDownloading TinyFlux weights from hub...")
330
  old_weights_path = hf_hub_download(
331
  repo_id="AbstractPhil/tiny-flux",
332
  filename="model.safetensors"
333
  )
334
 
335
- # Create new deep model
336
- print("\nCreating TinyFlux-Deep model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  deep_config = TinyFluxDeepConfig()
 
 
338
 
 
339
  # You need to define TinyFlux class first (run model cell)
340
- # This assumes TinyFlux accepts the config
341
  deep_model = TinyFlux(deep_config).to(DTYPE)
342
 
343
  print(f"\nDeep model config:")
344
  print(f" Hidden size: {deep_config.hidden_size}")
345
  print(f" Attention heads: {deep_config.num_attention_heads}")
346
- print(f" Single blocks: {deep_config.num_single_blocks}")
347
- print(f" Double blocks: {deep_config.num_double_blocks}")
348
 
349
  # Port weights
350
  new_state, frozen_params = port_tinyflux_to_deep(old_weights_path, deep_model)
@@ -384,14 +776,14 @@ if __name__ == "__main__":
384
  "hidden_size": deep_config.hidden_size,
385
  "num_attention_heads": deep_config.num_attention_heads,
386
  "attention_head_dim": deep_config.attention_head_dim,
387
- "num_single_blocks": deep_config.num_single_blocks,
388
- "num_double_blocks": deep_config.num_double_blocks,
389
  "mlp_ratio": deep_config.mlp_ratio,
390
- "t5_embed_dim": deep_config.t5_embed_dim,
391
- "clip_embed_dim": deep_config.clip_embed_dim,
392
  "in_channels": deep_config.in_channels,
393
- "axes_dims": list(deep_config.axes_dims),
394
- "theta": deep_config.theta,
395
  }
396
  with open("config_deep.json", "w") as f:
397
  json.dump(config_dict, f, indent=2)
 
2
  # TinyFlux β†’ TinyFlux-Deep Porting Script
3
  # ============================================================================
4
  # Expands: 3 single + 3 double β†’ 25 single + 15 double
5
+ # Heads: 2 β†’ 4 (doubles heads, hidden 256 β†’ 512)
6
  # Freezes ported layers, trains new ones
7
  # ============================================================================
8
 
 
12
  from huggingface_hub import hf_hub_download, HfApi
13
  from dataclasses import dataclass
14
  from copy import deepcopy
15
+ from typing import Tuple
16
 
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
  DTYPE = torch.bfloat16
 
22
  # ============================================================================
23
  @dataclass
24
  class TinyFluxConfig:
25
+ """Original small config - matches TinyFlux model on hub (hidden=768, 6 heads)"""
26
+ # Core dimensions (detected from hub: 768 hidden, 6 heads)
27
  hidden_size: int = 768
28
+ num_attention_heads: int = 6
29
+ attention_head_dim: int = 128 # 6 * 128 = 768
30
+
31
+ # Input/output
 
 
 
32
  in_channels: int = 16
33
+ patch_size: int = 1
34
+
35
+ # Text encoder interfaces
36
+ joint_attention_dim: int = 768
37
+ pooled_projection_dim: int = 768
38
+
39
+ # Layers
40
+ num_double_layers: int = 3
41
+ num_single_layers: int = 3
42
+
43
+ # MLP
44
+ mlp_ratio: float = 4.0
45
+
46
+ # RoPE
47
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
48
+
49
+ # Misc
50
+ guidance_embeds: bool = True
51
 
52
 
53
  @dataclass
54
  class TinyFluxDeepConfig:
55
+ """
56
+ Expanded deep config - matches TinyFlux model attribute names exactly.
57
+
58
+ Original TinyFlux: hidden_size=256, 2 heads (256/128=2)
59
+ Deep variant: hidden_size=512, 4 heads (4*128=512) - double heads
60
+ """
61
+ # Core dimensions
62
+ hidden_size: int = 512 # 4 heads * 128 head_dim
63
+ num_attention_heads: int = 4 # 2 β†’ 4 (double the heads)
64
+ attention_head_dim: int = 128 # Same (required for RoPE)
65
+
66
+ # Input/output
67
+ in_channels: int = 16
68
+ patch_size: int = 1
69
+
70
+ # Text encoder interfaces
71
+ joint_attention_dim: int = 768 # T5 embed dim
72
+ pooled_projection_dim: int = 768 # CLIP embed dim
73
+
74
+ # Layers (uses _layers not _blocks)
75
+ num_double_layers: int = 15 # 3 β†’ 15
76
+ num_single_layers: int = 25 # 3 β†’ 25 (more singles like original Flux)
77
+
78
+ # MLP
79
+ mlp_ratio: float = 4.0
80
+
81
+ # RoPE (must sum to head_dim=128)
82
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
83
+
84
+ # Misc
85
+ guidance_embeds: bool = True
86
+
87
+ def __post_init__(self):
88
+ assert self.num_attention_heads * self.attention_head_dim == self.hidden_size, \
89
+ f"heads ({self.num_attention_heads}) * head_dim ({self.attention_head_dim}) != hidden ({self.hidden_size})"
90
 
91
 
92
  # ============================================================================
 
122
  # ============================================================================
123
  # WEIGHT EXPANSION UTILITIES
124
  # ============================================================================
125
+ def expand_qkv_weights(old_weight, old_hidden=768, new_hidden=1536, head_dim=128):
126
  """
127
+ Expand QKV projection weights when increasing hidden size / head count.
128
+ QKV weight shape: (3 * num_heads * head_dim, hidden_size) = (3 * hidden_size, hidden_size)
129
 
130
+ Strategy: Copy old weights to corresponding positions, random init new heads.
131
+ Old heads are spread evenly across new head positions.
132
  """
133
+ old_qkv_dim = old_weight.shape[0] # 3 * old_hidden
134
+ new_qkv_dim = 3 * new_hidden
135
+
136
+ old_heads = old_hidden // head_dim
137
+ new_heads = new_hidden // head_dim
138
 
139
  # Initialize new weights
140
+ new_weight = torch.zeros(new_qkv_dim, new_hidden, dtype=old_weight.dtype, device=old_weight.device)
 
141
  nn.init.xavier_uniform_(new_weight)
142
+ new_weight *= 0.02 # Scale down random init
143
 
144
+ # For each of Q, K, V: copy old heads to first N positions
145
  for qkv_idx in range(3):
146
+ old_start = qkv_idx * old_hidden
147
+ new_start = qkv_idx * new_hidden
148
 
149
+ # Copy all old heads to first old_heads positions of new
150
+ for h in range(old_heads):
151
+ old_h_start = old_start + h * head_dim
152
+ old_h_end = old_h_start + head_dim
153
+ new_h_start = new_start + h * head_dim
154
+ new_h_end = new_h_start + head_dim
155
+ # Copy weights, input dim goes to first old_hidden columns
156
+ new_weight[new_h_start:new_h_end, :old_hidden] = old_weight[old_h_start:old_h_end, :]
 
 
 
 
 
157
 
158
  return new_weight
159
 
160
 
161
+ def expand_qkv_bias(old_bias, old_hidden=768, new_hidden=1536, head_dim=128):
162
+ """Expand QKV bias from old_hidden to new_hidden."""
163
+ new_qkv_dim = 3 * new_hidden
164
+ new_bias = torch.zeros(new_qkv_dim, dtype=old_bias.dtype, device=old_bias.device)
165
 
166
+ old_heads = old_hidden // head_dim
167
+
168
+ # Copy old biases to first old_heads positions for each of Q, K, V
169
+ for qkv_idx in range(3):
170
+ old_start = qkv_idx * old_hidden
171
+ new_start = qkv_idx * new_hidden
172
+ new_bias[new_start:new_start + old_hidden] = old_bias[old_start:old_start + old_hidden]
173
 
174
+ return new_bias
175
+
176
+
177
+ def expand_out_proj_weights(old_weight, old_hidden=768, new_hidden=1536, head_dim=128):
178
+ """
179
+ Expand output projection weights.
180
+ Out proj weight shape: (hidden_size, num_heads * head_dim) = (hidden_size, hidden_size)
181
+ """
182
  # Initialize new weights
183
+ new_weight = torch.zeros(new_hidden, new_hidden, dtype=old_weight.dtype, device=old_weight.device)
184
  nn.init.xavier_uniform_(new_weight)
185
+ new_weight *= 0.02
186
+
187
+ # Copy old weights to top-left corner
188
+ new_weight[:old_hidden, :old_hidden] = old_weight
189
 
190
+ return new_weight
191
+
192
+
193
+ def expand_out_proj_bias(old_bias, old_hidden=768, new_hidden=1536):
194
+ """Expand output projection bias."""
195
+ new_bias = torch.zeros(new_hidden, dtype=old_bias.dtype, device=old_bias.device)
196
+ new_bias[:old_hidden] = old_bias
197
+ return new_bias
198
+
199
+
200
+ def expand_linear_hidden(old_weight, old_hidden=768, new_hidden=1536, expand_in=True, expand_out=True):
201
+ """
202
+ Expand a linear layer weight from old_hidden to new_hidden.
203
+ """
204
+ old_out, old_in = old_weight.shape
205
+
206
+ new_out = new_hidden if expand_out else old_out
207
+ new_in = new_hidden if expand_in else old_in
208
+
209
+ new_weight = torch.zeros(new_out, new_in, dtype=old_weight.dtype, device=old_weight.device)
210
+ nn.init.xavier_uniform_(new_weight)
211
+ new_weight *= 0.02
212
 
213
+ # Copy old weights to top-left corner
214
+ copy_out = old_hidden if expand_out else old_out
215
+ copy_in = old_hidden if expand_in else old_in
216
+ new_weight[:copy_out, :copy_in] = old_weight[:copy_out, :copy_in]
217
 
218
  return new_weight
219
 
220
 
221
+ def expand_bias(old_bias, old_hidden=768, new_hidden=1536):
222
+ """Expand bias from old_hidden to new_hidden."""
223
+ new_bias = torch.zeros(new_hidden, dtype=old_bias.dtype, device=old_bias.device)
224
+ new_bias[:old_hidden] = old_bias
225
+ return new_bias
226
+
227
+
228
+ def expand_norm(old_weight, old_hidden=768, new_hidden=1536):
229
+ """Expand RMSNorm weight from old_hidden to new_hidden."""
230
+ new_weight = torch.ones(new_hidden, dtype=old_weight.dtype, device=old_weight.device)
231
+ new_weight[:old_hidden] = old_weight
232
+ return new_weight
233
+
234
+
235
+ def port_single_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=256, new_hidden=1024):
236
+ """Port weights from old single block to new single block with dimension expansion."""
237
  old_prefix = f"single_blocks.{old_idx}"
238
  new_prefix = f"single_blocks.{new_idx}"
239
 
 
244
  new_key = old_key.replace(old_prefix, new_prefix)
245
  old_weight = old_state[old_key]
246
 
247
+ # Attention QKV
248
+ if "attn.qkv.weight" in old_key:
249
+ new_state[new_key] = expand_qkv_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
250
+ print(f" Expanded QKV weight: {old_key}")
251
+ elif "attn.qkv.bias" in old_key:
252
+ new_state[new_key] = expand_qkv_bias(old_weight)
253
+ print(f" Expanded QKV bias: {old_key}")
254
+
255
+ # Attention output projection
256
+ elif "attn.out_proj.weight" in old_key:
257
+ new_state[new_key] = expand_out_proj_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
258
+ print(f" Expanded out_proj weight: {old_key}")
259
+ elif "attn.out_proj.bias" in old_key:
260
+ new_state[new_key] = expand_out_proj_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
261
+ print(f" Expanded out_proj bias: {old_key}")
262
+
263
+ # MLP layers (hidden β†’ 4*hidden β†’ hidden)
264
+ elif "mlp.fc1.weight" in old_key:
265
+ # fc1: hidden β†’ 4*hidden
266
+ old_mlp_hidden = old_hidden * 4
267
+ new_mlp_hidden = new_hidden * 4
268
+ new_weight = torch.zeros(new_mlp_hidden, new_hidden, dtype=old_weight.dtype, device=old_weight.device)
269
+ nn.init.xavier_uniform_(new_weight)
270
+ new_weight *= 0.02
271
+ new_weight[:old_mlp_hidden, :old_hidden] = old_weight
272
+ new_state[new_key] = new_weight
273
+ print(f" Expanded MLP fc1 weight: {old_key}")
274
+ elif "mlp.fc1.bias" in old_key:
275
+ old_mlp_hidden = old_hidden * 4
276
+ new_mlp_hidden = new_hidden * 4
277
+ new_bias = torch.zeros(new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device)
278
+ new_bias[:old_mlp_hidden] = old_weight
279
+ new_state[new_key] = new_bias
280
+ print(f" Expanded MLP fc1 bias: {old_key}")
281
+ elif "mlp.fc2.weight" in old_key:
282
+ # fc2: 4*hidden β†’ hidden
283
+ old_mlp_hidden = old_hidden * 4
284
+ new_mlp_hidden = new_hidden * 4
285
+ new_weight = torch.zeros(new_hidden, new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device)
286
+ nn.init.xavier_uniform_(new_weight)
287
+ new_weight *= 0.02
288
+ new_weight[:old_hidden, :old_mlp_hidden] = old_weight
289
+ new_state[new_key] = new_weight
290
+ print(f" Expanded MLP fc2 weight: {old_key}")
291
+ elif "mlp.fc2.bias" in old_key:
292
+ new_state[new_key] = expand_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
293
+ print(f" Expanded MLP fc2 bias: {old_key}")
294
+
295
+ # AdaLayerNorm modulation linear (norm.linear) - outputs 3*hidden for single blocks
296
+ elif "norm.linear.weight" in old_key:
297
+ # Shape: (3*old_hidden, old_hidden) β†’ (3*new_hidden, new_hidden)
298
+ old_out = old_hidden * 3
299
+ new_out = new_hidden * 3
300
+ new_weight = torch.zeros(new_out, new_hidden, dtype=old_weight.dtype, device=old_weight.device)
301
+ nn.init.xavier_uniform_(new_weight)
302
+ new_weight *= 0.02
303
+ new_weight[:old_out, :old_hidden] = old_weight
304
+ new_state[new_key] = new_weight
305
+ print(f" Expanded AdaLN linear weight: {old_key} ({old_out},{old_hidden})β†’({new_out},{new_hidden})")
306
+ elif "norm.linear.bias" in old_key:
307
+ old_out = old_hidden * 3
308
+ new_out = new_hidden * 3
309
+ new_bias = torch.zeros(new_out, dtype=old_weight.dtype, device=old_weight.device)
310
+ new_bias[:old_out] = old_weight
311
+ new_state[new_key] = new_bias
312
+ print(f" Expanded AdaLN linear bias: {old_key} ({old_out})β†’({new_out})")
313
+
314
+ # RMSNorm inside AdaLN (norm.norm.weight) or standalone norm
315
+ elif "norm.norm.weight" in old_key or "norm2.weight" in old_key:
316
+ new_state[new_key] = expand_norm(old_weight, old_hidden, new_hidden)
317
+ print(f" Expanded RMSNorm weight: {old_key}")
318
+
319
+ # Generic normalization layers - check actual sizes
320
+ elif "norm" in old_key and "weight" in old_key:
321
+ old_size = old_weight.shape[0]
322
+ new_key_shape = new_state.get(new_key, torch.empty(0)).shape
323
+ if len(new_key_shape) > 0:
324
+ new_size = new_key_shape[0]
325
+ if old_size == new_size:
326
+ new_state[new_key] = old_weight.clone()
327
+ print(f" Direct copy norm weight: {old_key} ({old_size})")
328
+ else:
329
+ new_weight = torch.ones(new_size, dtype=old_weight.dtype, device=old_weight.device)
330
+ copy_size = min(old_size, new_size)
331
+ new_weight[:copy_size] = old_weight[:copy_size]
332
+ new_state[new_key] = new_weight
333
+ print(f" Padded norm weight: {old_key} ({old_size}β†’{new_size})")
334
+ elif "norm" in old_key and "bias" in old_key:
335
+ old_size = old_weight.shape[0]
336
+ new_key_shape = new_state.get(new_key, torch.empty(0)).shape
337
+ if len(new_key_shape) > 0:
338
+ new_size = new_key_shape[0]
339
+ if old_size == new_size:
340
+ new_state[new_key] = old_weight.clone()
341
+ print(f" Direct copy norm bias: {old_key} ({old_size})")
342
+ else:
343
+ new_bias = torch.zeros(new_size, dtype=old_weight.dtype, device=old_weight.device)
344
+ copy_size = min(old_size, new_size)
345
+ new_bias[:copy_size] = old_weight[:copy_size]
346
+ new_state[new_key] = new_bias
347
+ print(f" Padded norm bias: {old_key} ({old_size}β†’{new_size})")
348
 
349
+ # Direct copy for anything else (shouldn't be much)
350
+ else:
351
+ if old_weight.shape == new_state.get(new_key, torch.empty(0)).shape:
352
+ new_state[new_key] = old_weight.clone()
353
+ print(f" Direct copy: {old_key}")
354
+ else:
355
+ print(f" SKIP (shape mismatch): {old_key}")
356
 
357
 
358
+ def port_double_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=256, new_hidden=1024):
359
+ """Port weights from old double block to new double block with dimension expansion."""
360
  old_prefix = f"double_blocks.{old_idx}"
361
  new_prefix = f"double_blocks.{new_idx}"
362
 
 
367
  new_key = old_key.replace(old_prefix, new_prefix)
368
  old_weight = old_state[old_key]
369
 
370
+ # Joint attention QKV (img and txt)
371
+ if any(x in old_key for x in ["img_qkv.weight", "txt_qkv.weight"]):
372
+ new_state[new_key] = expand_qkv_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
373
+ print(f" Expanded QKV weight: {old_key}")
374
+ elif any(x in old_key for x in ["img_qkv.bias", "txt_qkv.bias"]):
375
+ new_state[new_key] = expand_qkv_bias(old_weight)
376
+ print(f" Expanded QKV bias: {old_key}")
377
+
378
+ # Joint attention output projections
379
+ elif any(x in old_key for x in ["img_out.weight", "txt_out.weight"]):
380
+ new_state[new_key] = expand_out_proj_weights(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
381
+ print(f" Expanded out_proj weight: {old_key}")
382
+ elif any(x in old_key for x in ["img_out.bias", "txt_out.bias"]):
383
+ new_state[new_key] = expand_out_proj_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
384
+ print(f" Expanded out_proj bias: {old_key}")
385
+
386
+ # MLP layers
387
+ elif "mlp" in old_key and "fc1.weight" in old_key:
388
+ old_mlp_hidden = old_hidden * 4
389
+ new_mlp_hidden = new_hidden * 4
390
+ new_weight = torch.zeros(new_mlp_hidden, new_hidden, dtype=old_weight.dtype, device=old_weight.device)
391
+ nn.init.xavier_uniform_(new_weight)
392
+ new_weight *= 0.02
393
+ new_weight[:old_mlp_hidden, :old_hidden] = old_weight
394
+ new_state[new_key] = new_weight
395
+ print(f" Expanded MLP fc1 weight: {old_key}")
396
+ elif "mlp" in old_key and "fc1.bias" in old_key:
397
+ old_mlp_hidden = old_hidden * 4
398
+ new_mlp_hidden = new_hidden * 4
399
+ new_bias = torch.zeros(new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device)
400
+ new_bias[:old_mlp_hidden] = old_weight
401
+ new_state[new_key] = new_bias
402
+ print(f" Expanded MLP fc1 bias: {old_key}")
403
+ elif "mlp" in old_key and "fc2.weight" in old_key:
404
+ old_mlp_hidden = old_hidden * 4
405
+ new_mlp_hidden = new_hidden * 4
406
+ new_weight = torch.zeros(new_hidden, new_mlp_hidden, dtype=old_weight.dtype, device=old_weight.device)
407
+ nn.init.xavier_uniform_(new_weight)
408
+ new_weight *= 0.02
409
+ new_weight[:old_hidden, :old_mlp_hidden] = old_weight
410
+ new_state[new_key] = new_weight
411
+ print(f" Expanded MLP fc2 weight: {old_key}")
412
+ elif "mlp" in old_key and "fc2.bias" in old_key:
413
+ new_state[new_key] = expand_bias(old_weight, old_hidden=old_hidden, new_hidden=new_hidden)
414
+ print(f" Expanded MLP fc2 bias: {old_key}")
415
+
416
+ # AdaLayerNormZero modulation linear - outputs 6*hidden (img_norm1, txt_norm1)
417
+ elif ("img_norm1.linear" in old_key or "txt_norm1.linear" in old_key) and "weight" in old_key:
418
+ old_out = old_hidden * 6
419
+ new_out = new_hidden * 6
420
+ new_weight = torch.zeros(new_out, new_hidden, dtype=old_weight.dtype, device=old_weight.device)
421
+ nn.init.xavier_uniform_(new_weight)
422
+ new_weight *= 0.02
423
+ new_weight[:old_out, :old_hidden] = old_weight
424
+ new_state[new_key] = new_weight
425
+ print(f" Expanded AdaLN linear weight: {old_key}")
426
+ elif ("img_norm1.linear" in old_key or "txt_norm1.linear" in old_key) and "bias" in old_key:
427
+ old_out = old_hidden * 6
428
+ new_out = new_hidden * 6
429
+ new_bias = torch.zeros(new_out, dtype=old_weight.dtype, device=old_weight.device)
430
+ new_bias[:old_out] = old_weight
431
+ new_state[new_key] = new_bias
432
+ print(f" Expanded AdaLN linear bias: {old_key}")
433
+
434
+ # RMSNorm inside AdaLN (img_norm1.norm, txt_norm1.norm) or standalone (img_norm2, txt_norm2)
435
+ elif any(x in old_key for x in ["_norm1.norm.weight", "_norm2.weight"]):
436
+ new_state[new_key] = expand_norm(old_weight, old_hidden, new_hidden)
437
+ print(f" Expanded RMSNorm weight: {old_key}")
438
 
439
+ # Generic normalization layers - check actual sizes
440
+ elif "norm" in old_key and "weight" in old_key:
441
+ old_size = old_weight.shape[0]
442
+ new_key_shape = new_state.get(new_key, torch.empty(0)).shape
443
+ if len(new_key_shape) > 0:
444
+ new_size = new_key_shape[0]
445
+ if old_size == new_size:
446
+ new_state[new_key] = old_weight.clone()
447
+ print(f" Direct copy norm weight: {old_key} ({old_size})")
448
+ else:
449
+ new_weight = torch.ones(new_size, dtype=old_weight.dtype, device=old_weight.device)
450
+ copy_size = min(old_size, new_size)
451
+ new_weight[:copy_size] = old_weight[:copy_size]
452
+ new_state[new_key] = new_weight
453
+ print(f" Padded norm weight: {old_key} ({old_size}β†’{new_size})")
454
+ elif "norm" in old_key and "bias" in old_key:
455
+ old_size = old_weight.shape[0]
456
+ new_key_shape = new_state.get(new_key, torch.empty(0)).shape
457
+ if len(new_key_shape) > 0:
458
+ new_size = new_key_shape[0]
459
+ if old_size == new_size:
460
+ new_state[new_key] = old_weight.clone()
461
+ print(f" Direct copy norm bias: {old_key} ({old_size})")
462
+ else:
463
+ new_bias = torch.zeros(new_size, dtype=old_weight.dtype, device=old_weight.device)
464
+ copy_size = min(old_size, new_size)
465
+ new_bias[:copy_size] = old_weight[:copy_size]
466
+ new_state[new_key] = new_bias
467
+ print(f" Padded norm bias: {old_key} ({old_size}β†’{new_size})")
468
+
469
+ # Direct copy for matching shapes
470
+ else:
471
+ if old_weight.shape == new_state.get(new_key, torch.empty(0)).shape:
472
+ new_state[new_key] = old_weight.clone()
473
+ print(f" Direct copy: {old_key}")
474
+ else:
475
+ print(f" SKIP (shape mismatch): {old_key}")
476
 
477
 
478
+ def port_non_block_weights(old_state, new_state, old_hidden=256, new_hidden=1024):
479
+ """Port weights that aren't in single/double blocks with dimension expansion."""
 
480
 
481
  for old_key, old_weight in old_state.items():
482
  # Skip block weights (handled separately)
483
  if "single_blocks" in old_key or "double_blocks" in old_key:
484
  continue
485
 
486
+ # Skip buffers that will be recomputed
487
+ if any(x in old_key for x in ["sin_basis", "freqs_"]):
488
+ print(f" Skip buffer: {old_key}")
489
+ continue
490
+
491
+ # img_in: in_channels β†’ hidden
492
+ if "img_in.weight" in old_key:
493
+ new_weight = torch.zeros(new_hidden, old_weight.shape[1], dtype=old_weight.dtype)
494
+ nn.init.xavier_uniform_(new_weight)
495
+ new_weight *= 0.02
496
+ new_weight[:old_hidden, :] = old_weight
497
+ new_state[old_key] = new_weight
498
+ print(f" Expanded: {old_key}")
499
+ elif "img_in.bias" in old_key:
500
+ new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden)
501
+ print(f" Expanded: {old_key}")
502
+
503
+ # txt_in: joint_attention_dim β†’ hidden
504
+ elif "txt_in.weight" in old_key:
505
+ new_weight = torch.zeros(new_hidden, old_weight.shape[1], dtype=old_weight.dtype)
506
+ nn.init.xavier_uniform_(new_weight)
507
+ new_weight *= 0.02
508
+ new_weight[:old_hidden, :] = old_weight
509
+ new_state[old_key] = new_weight
510
+ print(f" Expanded: {old_key}")
511
+ elif "txt_in.bias" in old_key:
512
+ new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden)
513
+ print(f" Expanded: {old_key}")
514
 
515
+ # time_in, guidance_in: MLPEmbedder (hidden β†’ hidden)
516
+ elif any(x in old_key for x in ["time_in", "guidance_in"]):
517
+ if "fc1.weight" in old_key:
518
+ new_weight = torch.zeros(new_hidden, new_hidden, dtype=old_weight.dtype)
519
+ nn.init.xavier_uniform_(new_weight)
520
+ new_weight *= 0.02
521
+ new_weight[:old_hidden, :old_hidden] = old_weight
522
+ new_state[old_key] = new_weight
523
+ print(f" Expanded: {old_key}")
524
+ elif "fc1.bias" in old_key:
525
+ new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden)
526
+ print(f" Expanded: {old_key}")
527
+ elif "fc2.weight" in old_key:
528
+ new_weight = torch.zeros(new_hidden, new_hidden, dtype=old_weight.dtype)
529
+ nn.init.xavier_uniform_(new_weight)
530
+ new_weight *= 0.02
531
+ new_weight[:old_hidden, :old_hidden] = old_weight
532
+ new_state[old_key] = new_weight
533
+ print(f" Expanded: {old_key}")
534
+ elif "fc2.bias" in old_key:
535
+ new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden)
536
+ print(f" Expanded: {old_key}")
537
+
538
+ # vector_in: pooled_projection_dim β†’ hidden
539
+ elif "vector_in" in old_key:
540
+ if "weight" in old_key:
541
+ new_weight = torch.zeros(new_hidden, old_weight.shape[1], dtype=old_weight.dtype)
542
+ nn.init.xavier_uniform_(new_weight)
543
+ new_weight *= 0.02
544
+ new_weight[:old_hidden, :] = old_weight
545
+ new_state[old_key] = new_weight
546
+ print(f" Expanded: {old_key}")
547
+ elif "bias" in old_key:
548
+ new_state[old_key] = expand_bias(old_weight, old_hidden, new_hidden)
549
+ print(f" Expanded: {old_key}")
550
+
551
+ # final_norm: RMSNorm(hidden)
552
+ elif "final_norm" in old_key:
553
+ if "weight" in old_key:
554
+ new_state[old_key] = expand_norm(old_weight, old_hidden, new_hidden)
555
+ print(f" Expanded: {old_key}")
556
+
557
+ # final_linear: hidden β†’ in_channels
558
+ elif "final_linear.weight" in old_key:
559
+ new_weight = torch.zeros(old_weight.shape[0], new_hidden, dtype=old_weight.dtype)
560
+ nn.init.xavier_uniform_(new_weight)
561
+ new_weight *= 0.02
562
+ new_weight[:, :old_hidden] = old_weight
563
+ new_state[old_key] = new_weight
564
+ print(f" Expanded: {old_key}")
565
+ elif "final_linear.bias" in old_key:
566
+ new_state[old_key] = old_weight.clone() # output dim unchanged
567
  print(f" Direct copy: {old_key}")
568
+
569
+ # RoPE - skip, will be recomputed
570
+ elif "rope" in old_key:
571
+ print(f" Skip RoPE: {old_key}")
572
+
573
+ else:
574
+ print(f" Unknown non-block key: {old_key}")
575
 
576
 
577
  # ============================================================================
 
598
  print("Stripping _orig_mod prefix...")
599
  old_state = {k.replace("_orig_mod.", ""): v for k, v in old_state.items()}
600
 
601
+ # Get new model's state dict as template FIRST
602
  new_state = new_model.state_dict()
603
  frozen_params = set()
604
 
605
+ # Auto-detect old hidden size from weights
606
+ if "final_norm.weight" in old_state:
607
+ old_hidden = old_state["final_norm.weight"].shape[0]
608
+ elif "img_in.weight" in old_state:
609
+ old_hidden = old_state["img_in.weight"].shape[0]
610
+ else:
611
+ old_hidden = 256 # Default for TinyFlux
612
+
613
+ # Get new hidden size from new model's state dict
614
+ if "final_norm.weight" in new_state:
615
+ new_hidden = new_state["final_norm.weight"].shape[0]
616
+ else:
617
+ new_hidden = 512 # Default for TinyFlux-Deep
618
+
619
+ print(f"Detected old hidden size: {old_hidden}")
620
+ print(f"New hidden size: {new_hidden}")
621
+
622
  print("\n" + "="*60)
623
  print("Porting non-block weights...")
624
  print("="*60)
625
+ port_non_block_weights(old_state, new_state, old_hidden=old_hidden, new_hidden=new_hidden)
626
 
627
  print("\n" + "="*60)
628
  print("Porting single blocks (3 β†’ 25)...")
 
630
  for old_idx, new_positions in SINGLE_MAPPING.items():
631
  for new_idx in new_positions:
632
  print(f"\nSingle block {old_idx} β†’ {new_idx}:")
633
+ port_single_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=old_hidden, new_hidden=new_hidden)
634
  # Mark as frozen
635
  for key in new_state.keys():
636
  if f"single_blocks.{new_idx}." in key:
 
642
  for old_idx, new_positions in DOUBLE_MAPPING.items():
643
  for new_idx in new_positions:
644
  print(f"\nDouble block {old_idx} β†’ {new_idx}:")
645
+ port_double_block_weights(old_state, old_idx, new_state, new_idx, old_hidden=old_hidden, new_hidden=new_hidden)
646
  # Mark as frozen
647
  for key in new_state.keys():
648
  if f"double_blocks.{new_idx}." in key:
 
693
  print("TinyFlux β†’ TinyFlux-Deep Porting")
694
  print("="*60)
695
 
696
+ # Load old weights from hub FIRST to detect dimensions
697
  print("\nDownloading TinyFlux weights from hub...")
698
  old_weights_path = hf_hub_download(
699
  repo_id="AbstractPhil/tiny-flux",
700
  filename="model.safetensors"
701
  )
702
 
703
+ # Load and detect old dimensions
704
+ print("Detecting old model dimensions...")
705
+ old_state = load_file(old_weights_path)
706
+ if any(k.startswith("_orig_mod.") for k in old_state.keys()):
707
+ old_state = {k.replace("_orig_mod.", ""): v for k, v in old_state.items()}
708
+
709
+ # Detect old hidden size
710
+ old_hidden = old_state["final_norm.weight"].shape[0]
711
+ head_dim = 128 # Fixed for RoPE
712
+ old_heads = old_hidden // head_dim
713
+
714
+ print(f" Old hidden size: {old_hidden}")
715
+ print(f" Old attention heads: {old_heads}")
716
+ print(f" Head dim: {head_dim}")
717
+
718
+ # Calculate new dimensions (double the heads)
719
+ new_heads = old_heads * 2 # 6 β†’ 12
720
+ new_hidden = new_heads * head_dim # 12 * 128 = 1536
721
+
722
+ print(f"\nNew dimensions:")
723
+ print(f" New hidden size: {new_hidden}")
724
+ print(f" New attention heads: {new_heads}")
725
+
726
+ # Create deep config with detected dimensions
727
  deep_config = TinyFluxDeepConfig()
728
+ deep_config.hidden_size = new_hidden
729
+ deep_config.num_attention_heads = new_heads
730
 
731
+ print("\nCreating TinyFlux-Deep model...")
732
  # You need to define TinyFlux class first (run model cell)
 
733
  deep_model = TinyFlux(deep_config).to(DTYPE)
734
 
735
  print(f"\nDeep model config:")
736
  print(f" Hidden size: {deep_config.hidden_size}")
737
  print(f" Attention heads: {deep_config.num_attention_heads}")
738
+ print(f" Single layers: {deep_config.num_single_layers}")
739
+ print(f" Double layers: {deep_config.num_double_layers}")
740
 
741
  # Port weights
742
  new_state, frozen_params = port_tinyflux_to_deep(old_weights_path, deep_model)
 
776
  "hidden_size": deep_config.hidden_size,
777
  "num_attention_heads": deep_config.num_attention_heads,
778
  "attention_head_dim": deep_config.attention_head_dim,
779
+ "num_single_layers": deep_config.num_single_layers,
780
+ "num_double_layers": deep_config.num_double_layers,
781
  "mlp_ratio": deep_config.mlp_ratio,
782
+ "joint_attention_dim": deep_config.joint_attention_dim,
783
+ "pooled_projection_dim": deep_config.pooled_projection_dim,
784
  "in_channels": deep_config.in_channels,
785
+ "axes_dims_rope": list(deep_config.axes_dims_rope),
786
+ "guidance_embeds": deep_config.guidance_embeds,
787
  }
788
  with open("config_deep.json", "w") as f:
789
  json.dump(config_dict, f, indent=2)