AbstractPhil commited on
Commit
10a0fd5
·
verified ·
1 Parent(s): 99265b6

Create port_tiny_to_deep.py

Browse files
Files changed (1) hide show
  1. port_tiny_to_deep.py +418 -0
port_tiny_to_deep.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # TinyFlux → TinyFlux-Deep Porting Script
3
+ # ============================================================================
4
+ # Expands: 3 single + 3 double → 25 single + 15 double
5
+ # Heads: 2 → 8 (old heads become first and last)
6
+ # Freezes ported layers, trains new ones
7
+ # ============================================================================
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from safetensors.torch import load_file, save_file
12
+ from huggingface_hub import hf_hub_download, HfApi
13
+ from dataclasses import dataclass
14
+ from copy import deepcopy
15
+
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ DTYPE = torch.bfloat16
18
+
19
+ # ============================================================================
20
+ # CONFIGS
21
+ # ============================================================================
22
+ @dataclass
23
+ class TinyFluxConfig:
24
+ """Original small config"""
25
+ hidden_size: int = 768
26
+ num_attention_heads: int = 2
27
+ attention_head_dim: int = 128
28
+ num_single_blocks: int = 3
29
+ num_double_blocks: int = 3
30
+ mlp_ratio: float = 4.0
31
+ t5_embed_dim: int = 768
32
+ clip_embed_dim: int = 768
33
+ in_channels: int = 16
34
+ axes_dims: tuple = (16, 24, 24)
35
+ theta: int = 10000
36
+
37
+
38
+ @dataclass
39
+ class TinyFluxDeepConfig:
40
+ """Expanded deep config"""
41
+ hidden_size: int = 768 # Same
42
+ num_attention_heads: int = 8 # 2 → 8 (6 new heads)
43
+ attention_head_dim: int = 128 # Same (so attention dim = 8*128 = 1024)
44
+ num_single_blocks: int = 25 # 3 → 25 (more singles like original Flux)
45
+ num_double_blocks: int = 15 # 3 → 15
46
+ mlp_ratio: float = 4.0 # Same
47
+ t5_embed_dim: int = 768 # Same
48
+ clip_embed_dim: int = 768 # Same
49
+ in_channels: int = 16 # Same
50
+ axes_dims: tuple = (16, 24, 24) # Same
51
+ theta: int = 10000 # Same
52
+
53
+
54
+ # ============================================================================
55
+ # LAYER MAPPING
56
+ # ============================================================================
57
+ # Single blocks: 3 → 25
58
+ # - Layer 0 → position 0 (frozen)
59
+ # - Layer 1 → positions 8, 12, 16 (center, spaced, frozen)
60
+ # - Layer 2 → position 24 (frozen)
61
+ # - Rest → new (trainable)
62
+
63
+ SINGLE_MAPPING = {
64
+ 0: [0], # Old layer 0 → new position 0
65
+ 1: [8, 12, 16], # Old layer 1 → new positions 8, 12, 16
66
+ 2: [24], # Old layer 2 → new position 24
67
+ }
68
+ SINGLE_FROZEN = {0, 8, 12, 16, 24} # These positions are frozen
69
+
70
+ # Double blocks: 3 → 15
71
+ # - Layer 0 → position 0 (frozen)
72
+ # - Layer 1 → positions 4, 7, 10 (3 copies, spaced, frozen)
73
+ # - Layer 2 → position 14 (frozen)
74
+ # - Rest → new (trainable)
75
+
76
+ DOUBLE_MAPPING = {
77
+ 0: [0], # Old layer 0 → new position 0
78
+ 1: [4, 7, 10], # Old layer 1 → 3 positions
79
+ 2: [14], # Old layer 2 → new position 14
80
+ }
81
+ DOUBLE_FROZEN = {0, 4, 7, 10, 14} # These positions are frozen
82
+
83
+
84
+ # ============================================================================
85
+ # WEIGHT EXPANSION UTILITIES
86
+ # ============================================================================
87
+ def expand_qkv_weights(old_weight, old_heads=2, new_heads=8, head_dim=128):
88
+ """
89
+ Expand QKV projection weights from 2 heads to 8 heads.
90
+ Old heads go to positions 0 and 7, middle heads initialized randomly.
91
+
92
+ QKV weight shape: (in_features, 3 * num_heads * head_dim)
93
+ """
94
+ in_features = old_weight.shape[0]
95
+ old_qkv_dim = 3 * old_heads * head_dim # 3 * 2 * 128 = 768
96
+ new_qkv_dim = 3 * new_heads * head_dim # 3 * 8 * 128 = 3072
97
+
98
+ # Initialize new weights
99
+ new_weight = torch.zeros(in_features, new_qkv_dim, dtype=old_weight.dtype, device=old_weight.device)
100
+ # Small random init for new heads
101
+ nn.init.xavier_uniform_(new_weight)
102
+ new_weight *= 0.1 # Scale down random init
103
+
104
+ # For each of Q, K, V
105
+ for qkv_idx in range(3):
106
+ old_start = qkv_idx * old_heads * head_dim
107
+ new_start = qkv_idx * new_heads * head_dim
108
+
109
+ # Copy old head 0 → new head 0
110
+ old_h0_start = old_start
111
+ old_h0_end = old_start + head_dim
112
+ new_h0_start = new_start
113
+ new_h0_end = new_start + head_dim
114
+ new_weight[:, new_h0_start:new_h0_end] = old_weight[:, old_h0_start:old_h0_end]
115
+
116
+ # Copy old head 1 → new head 7 (last)
117
+ old_h1_start = old_start + head_dim
118
+ old_h1_end = old_start + 2 * head_dim
119
+ new_h7_start = new_start + 7 * head_dim
120
+ new_h7_end = new_start + 8 * head_dim
121
+ new_weight[:, new_h7_start:new_h7_end] = old_weight[:, old_h1_start:old_h1_end]
122
+
123
+ return new_weight
124
+
125
+
126
+ def expand_out_proj_weights(old_weight, old_heads=2, new_heads=8, head_dim=128):
127
+ """
128
+ Expand output projection weights from 2 heads to 8 heads.
129
+
130
+ Out proj weight shape: (num_heads * head_dim, out_features)
131
+ """
132
+ out_features = old_weight.shape[1]
133
+ old_attn_dim = old_heads * head_dim # 2 * 128 = 256
134
+ new_attn_dim = new_heads * head_dim # 8 * 128 = 1024
135
+
136
+ # Initialize new weights
137
+ new_weight = torch.zeros(new_attn_dim, out_features, dtype=old_weight.dtype, device=old_weight.device)
138
+ nn.init.xavier_uniform_(new_weight)
139
+ new_weight *= 0.1
140
+
141
+ # Copy old head 0 → new head 0
142
+ new_weight[0:head_dim, :] = old_weight[0:head_dim, :]
143
+
144
+ # Copy old head 1 → new head 7
145
+ new_weight[7*head_dim:8*head_dim, :] = old_weight[head_dim:2*head_dim, :]
146
+
147
+ return new_weight
148
+
149
+
150
+ def port_single_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True):
151
+ """Port weights from old single block to new single block."""
152
+ old_prefix = f"single_blocks.{old_idx}"
153
+ new_prefix = f"single_blocks.{new_idx}"
154
+
155
+ for old_key in list(old_state.keys()):
156
+ if not old_key.startswith(old_prefix):
157
+ continue
158
+
159
+ new_key = old_key.replace(old_prefix, new_prefix)
160
+ old_weight = old_state[old_key]
161
+
162
+ # Handle attention head expansion
163
+ if expand_heads:
164
+ if "attn.qkv.weight" in old_key:
165
+ new_state[new_key] = expand_qkv_weights(old_weight)
166
+ print(f" Expanded QKV: {old_key} → {new_key}")
167
+ continue
168
+ elif "attn.out_proj.weight" in old_key:
169
+ new_state[new_key] = expand_out_proj_weights(old_weight)
170
+ print(f" Expanded out_proj: {old_key} → {new_key}")
171
+ continue
172
+
173
+ # Direct copy for other weights
174
+ new_state[new_key] = old_weight.clone()
175
+ print(f" Copied: {old_key} → {new_key}")
176
+
177
+
178
+ def port_double_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True):
179
+ """Port weights from old double block to new double block."""
180
+ old_prefix = f"double_blocks.{old_idx}"
181
+ new_prefix = f"double_blocks.{new_idx}"
182
+
183
+ for old_key in list(old_state.keys()):
184
+ if not old_key.startswith(old_prefix):
185
+ continue
186
+
187
+ new_key = old_key.replace(old_prefix, new_prefix)
188
+ old_weight = old_state[old_key]
189
+
190
+ # Handle attention head expansion for joint attention
191
+ if expand_heads:
192
+ if any(x in old_key for x in ["img_qkv.weight", "txt_qkv.weight"]):
193
+ new_state[new_key] = expand_qkv_weights(old_weight)
194
+ print(f" Expanded QKV: {old_key} → {new_key}")
195
+ continue
196
+ elif any(x in old_key for x in ["img_out.weight", "txt_out.weight"]):
197
+ new_state[new_key] = expand_out_proj_weights(old_weight)
198
+ print(f" Expanded out_proj: {old_key} → {new_key}")
199
+ continue
200
+
201
+ # Direct copy
202
+ new_state[new_key] = old_weight.clone()
203
+ print(f" Copied: {old_key} → {new_key}")
204
+
205
+
206
+ def port_non_block_weights(old_state, new_state, old_heads=2, new_heads=8):
207
+ """Port weights that aren't in single/double blocks."""
208
+ head_dim = 128
209
+
210
+ for old_key, old_weight in old_state.items():
211
+ # Skip block weights (handled separately)
212
+ if "single_blocks" in old_key or "double_blocks" in old_key:
213
+ continue
214
+
215
+ # These can be copied directly (same dimensions)
216
+ direct_copy_keys = [
217
+ "img_in", "txt_in", "time_in", "vector_in", "guidance_in",
218
+ "final_norm", "final_linear", "rope"
219
+ ]
220
+
221
+ if any(k in old_key for k in direct_copy_keys):
222
+ new_state[old_key] = old_weight.clone()
223
+ print(f" Direct copy: {old_key}")
224
+
225
+
226
+ # ============================================================================
227
+ # MAIN PORTING FUNCTION
228
+ # ============================================================================
229
+ def port_tinyflux_to_deep(old_weights_path, new_model):
230
+ """
231
+ Port TinyFlux weights to TinyFlux-Deep.
232
+
233
+ Returns:
234
+ new_state_dict: Ported weights
235
+ frozen_params: Set of parameter names to freeze
236
+ """
237
+ print("Loading old weights...")
238
+ if old_weights_path.endswith(".safetensors"):
239
+ old_state = load_file(old_weights_path)
240
+ else:
241
+ old_state = torch.load(old_weights_path, map_location="cpu")
242
+ if "model" in old_state:
243
+ old_state = old_state["model"]
244
+
245
+ # Strip _orig_mod prefix if present
246
+ if any(k.startswith("_orig_mod.") for k in old_state.keys()):
247
+ print("Stripping _orig_mod prefix...")
248
+ old_state = {k.replace("_orig_mod.", ""): v for k, v in old_state.items()}
249
+
250
+ # Get new model's state dict as template
251
+ new_state = new_model.state_dict()
252
+ frozen_params = set()
253
+
254
+ print("\n" + "="*60)
255
+ print("Porting non-block weights...")
256
+ print("="*60)
257
+ port_non_block_weights(old_state, new_state)
258
+
259
+ print("\n" + "="*60)
260
+ print("Porting single blocks (3 → 25)...")
261
+ print("="*60)
262
+ for old_idx, new_positions in SINGLE_MAPPING.items():
263
+ for new_idx in new_positions:
264
+ print(f"\nSingle block {old_idx} → {new_idx}:")
265
+ port_single_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True)
266
+ # Mark as frozen
267
+ for key in new_state.keys():
268
+ if f"single_blocks.{new_idx}." in key:
269
+ frozen_params.add(key)
270
+
271
+ print("\n" + "="*60)
272
+ print("Porting double blocks (3 → 15)...")
273
+ print("="*60)
274
+ for old_idx, new_positions in DOUBLE_MAPPING.items():
275
+ for new_idx in new_positions:
276
+ print(f"\nDouble block {old_idx} → {new_idx}:")
277
+ port_double_block_weights(old_state, old_idx, new_state, new_idx, expand_heads=True)
278
+ # Mark as frozen
279
+ for key in new_state.keys():
280
+ if f"double_blocks.{new_idx}." in key:
281
+ frozen_params.add(key)
282
+
283
+ print("\n" + "="*60)
284
+ print("Summary")
285
+ print("="*60)
286
+ print(f"Total parameters in new model: {len(new_state)}")
287
+ print(f"Frozen parameters: {len(frozen_params)}")
288
+ print(f"Trainable parameters: {len(new_state) - len(frozen_params)}")
289
+
290
+ print(f"\nFrozen single block positions: {sorted(SINGLE_FROZEN)}")
291
+ print(f"Frozen double block positions: {sorted(DOUBLE_FROZEN)}")
292
+
293
+ return new_state, frozen_params
294
+
295
+
296
+ # ============================================================================
297
+ # FREEZE HELPER
298
+ # ============================================================================
299
+ def freeze_ported_layers(model, frozen_params):
300
+ """Freeze the ported layers, keep new layers trainable."""
301
+ frozen_count = 0
302
+ trainable_count = 0
303
+
304
+ for name, param in model.named_parameters():
305
+ if name in frozen_params:
306
+ param.requires_grad = False
307
+ frozen_count += param.numel()
308
+ else:
309
+ param.requires_grad = True
310
+ trainable_count += param.numel()
311
+
312
+ print(f"\nFrozen params: {frozen_count:,}")
313
+ print(f"Trainable params: {trainable_count:,}")
314
+ print(f"Total params: {frozen_count + trainable_count:,}")
315
+ print(f"Trainable ratio: {trainable_count / (frozen_count + trainable_count) * 100:.1f}%")
316
+
317
+ return model
318
+
319
+
320
+ # ============================================================================
321
+ # MAIN SCRIPT
322
+ # ============================================================================
323
+ if __name__ == "__main__":
324
+ print("="*60)
325
+ print("TinyFlux → TinyFlux-Deep Porting")
326
+ print("="*60)
327
+
328
+ # Load old weights from hub
329
+ print("\nDownloading TinyFlux weights from hub...")
330
+ old_weights_path = hf_hub_download(
331
+ repo_id="AbstractPhil/tiny-flux",
332
+ filename="model.safetensors"
333
+ )
334
+
335
+ # Create new deep model
336
+ print("\nCreating TinyFlux-Deep model...")
337
+ deep_config = TinyFluxDeepConfig()
338
+
339
+ # You need to define TinyFlux class first (run model cell)
340
+ # This assumes TinyFlux accepts the config
341
+ deep_model = TinyFlux(deep_config).to(DTYPE)
342
+
343
+ print(f"\nDeep model config:")
344
+ print(f" Hidden size: {deep_config.hidden_size}")
345
+ print(f" Attention heads: {deep_config.num_attention_heads}")
346
+ print(f" Single blocks: {deep_config.num_single_blocks}")
347
+ print(f" Double blocks: {deep_config.num_double_blocks}")
348
+
349
+ # Port weights
350
+ new_state, frozen_params = port_tinyflux_to_deep(old_weights_path, deep_model)
351
+
352
+ # Load ported weights
353
+ print("\nLoading ported weights into model...")
354
+ missing, unexpected = deep_model.load_state_dict(new_state, strict=False)
355
+ if missing:
356
+ print(f" Missing keys: {missing[:5]}..." if len(missing) > 5 else f" Missing keys: {missing}")
357
+ if unexpected:
358
+ print(f" Unexpected keys: {unexpected}")
359
+
360
+ # Freeze ported layers
361
+ print("\nFreezing ported layers...")
362
+ deep_model = freeze_ported_layers(deep_model, frozen_params)
363
+
364
+ # Save
365
+ print("\nSaving ported model...")
366
+ save_path = "tinyflux_deep_ported.safetensors"
367
+
368
+ # Strip any _orig_mod prefix before saving
369
+ state_to_save = deep_model.state_dict()
370
+ if any(k.startswith("_orig_mod.") for k in state_to_save.keys()):
371
+ state_to_save = {k.replace("_orig_mod.", ""): v for k, v in state_to_save.items()}
372
+
373
+ save_file(state_to_save, save_path)
374
+ print(f"✓ Saved to {save_path}")
375
+
376
+ # Save frozen params list
377
+ import json
378
+ with open("frozen_params.json", "w") as f:
379
+ json.dump(list(frozen_params), f)
380
+ print("✓ Saved frozen_params.json")
381
+
382
+ # Save config
383
+ config_dict = {
384
+ "hidden_size": deep_config.hidden_size,
385
+ "num_attention_heads": deep_config.num_attention_heads,
386
+ "attention_head_dim": deep_config.attention_head_dim,
387
+ "num_single_blocks": deep_config.num_single_blocks,
388
+ "num_double_blocks": deep_config.num_double_blocks,
389
+ "mlp_ratio": deep_config.mlp_ratio,
390
+ "t5_embed_dim": deep_config.t5_embed_dim,
391
+ "clip_embed_dim": deep_config.clip_embed_dim,
392
+ "in_channels": deep_config.in_channels,
393
+ "axes_dims": list(deep_config.axes_dims),
394
+ "theta": deep_config.theta,
395
+ }
396
+ with open("config_deep.json", "w") as f:
397
+ json.dump(config_dict, f, indent=2)
398
+ print("✓ Saved config_deep.json")
399
+
400
+ # Upload to hub
401
+ print("\nUploading to AbstractPhil/tiny-flux-deep...")
402
+ api = HfApi()
403
+ try:
404
+ api.create_repo(repo_id="AbstractPhil/tiny-flux-deep", exist_ok=True, repo_type="model")
405
+ api.upload_file(path_or_fileobj=save_path, path_in_repo="model.safetensors", repo_id="AbstractPhil/tiny-flux-deep")
406
+ api.upload_file(path_or_fileobj="config_deep.json", path_in_repo="config.json", repo_id="AbstractPhil/tiny-flux-deep")
407
+ api.upload_file(path_or_fileobj="frozen_params.json", path_in_repo="frozen_params.json", repo_id="AbstractPhil/tiny-flux-deep")
408
+ print("✓ Uploaded to hub!")
409
+ except Exception as e:
410
+ print(f"âš  Upload failed: {e}")
411
+
412
+ print("\n" + "="*60)
413
+ print("Porting complete!")
414
+ print("="*60)
415
+ print("\nNext steps:")
416
+ print("1. Update TinyFlux model definition to accept TinyFluxDeepConfig")
417
+ print("2. Use the frozen_params.json to freeze layers during training")
418
+ print("3. Train on AbstractPhil/tiny-flux-deep repo")