AbstractPhil commited on
Commit
67cc511
·
verified ·
1 Parent(s): a624b61

Create convert_v3_to_v4.py

Browse files
Files changed (1) hide show
  1. convert_v3_to_v4.py +641 -0
convert_v3_to_v4.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TinyFlux-Deep Weight Converter: v3 → v4
3
+
4
+ Converts v3 checkpoints to v4 architecture without destroying pretrain.
5
+
6
+ Key changes:
7
+ - expert_predictor → lune_predictor (rename)
8
+ - expert_gate value: 0.5 → 0.0 (logit space, sigmoid(0)=0.5)
9
+ - New modules initialized to zero-effect:
10
+ - sol_prior: geometric priors dominate initially
11
+ - t5_pool: 50/50 balance with CLIP
12
+ - spatial_to_mod: exp(0)=1 identity
13
+
14
+ Colab Usage:
15
+ from convert_v3_to_v4 import run
16
+ run(401434) # Downloads, converts, saves to ./converted/
17
+
18
+ API Usage:
19
+ from convert_v3_to_v4 import convert_checkpoint, analyze_checkpoint
20
+ result = convert_checkpoint(step=401434)
21
+
22
+ CLI Usage:
23
+ python convert_v3_to_v4.py --step 401434
24
+ """
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import math
29
+ import os
30
+ import re
31
+ from typing import Dict, Tuple, Optional
32
+ from dataclasses import dataclass
33
+
34
+
35
+ # =============================================================================
36
+ # Quick Entry Point (Colab)
37
+ # =============================================================================
38
+
39
+ def run(
40
+ step: int = 401434,
41
+ name: str = "lailah",
42
+ output_dir: str = "converted",
43
+ ):
44
+ """
45
+ One-liner for Colab. Downloads, converts, saves.
46
+
47
+ Usage:
48
+ from convert_v3_to_v4 import run
49
+ run(401434)
50
+ """
51
+ result = convert_checkpoint(
52
+ step=step,
53
+ model_name=name,
54
+ output_dir=output_dir,
55
+ verbose=True,
56
+ )
57
+
58
+ if result.success:
59
+ print(f"\n✅ Done! Files in ./{output_dir}/")
60
+ else:
61
+ print(f"\n❌ Failed: {result.error}")
62
+
63
+ return result
64
+
65
+
66
+ # =============================================================================
67
+ # Data Classes
68
+ # =============================================================================
69
+
70
+ @dataclass
71
+ class CheckpointInfo:
72
+ """Analysis results for a checkpoint."""
73
+ version: str = "unknown"
74
+ has_expert_predictor: bool = False
75
+ has_lune_predictor: bool = False
76
+ has_sol_prior: bool = False
77
+ has_t5_pool: bool = False
78
+ has_spatial_to_mod: bool = False
79
+ num_double_blocks: int = 0
80
+ num_single_blocks: int = 0
81
+ total_params: int = 0
82
+
83
+
84
+ @dataclass
85
+ class ConversionResult:
86
+ """Results from a conversion operation."""
87
+ success: bool
88
+ model_path: Optional[str] = None
89
+ ema_path: Optional[str] = None
90
+ ema_secondary_path: Optional[str] = None
91
+ source_version: str = "unknown"
92
+ source_params: int = 0
93
+ target_params: int = 0
94
+ params_added: int = 0
95
+ renamed_keys: int = 0
96
+ initialized_keys: int = 0
97
+ error: Optional[str] = None
98
+
99
+
100
+ @dataclass
101
+ class ConversionConfig:
102
+ """Configuration for conversion."""
103
+ hidden_size: int = 512
104
+ time_dim: int = 512
105
+ clip_dim: int = 768
106
+ joint_attention_dim: int = 768
107
+ num_heads: int = 4
108
+ sol_hidden_dim: int = 256
109
+ sol_spatial_size: int = 8
110
+ sol_geometric_weight: float = 0.7
111
+ num_double_blocks: int = 15
112
+ num_single_blocks: int = 25
113
+
114
+
115
+ # =============================================================================
116
+ # Core Functions
117
+ # =============================================================================
118
+
119
+ def to_logit(p: float) -> float:
120
+ """Convert probability to logit for sigmoid init."""
121
+ p = max(1e-4, min(p, 1 - 1e-4))
122
+ return math.log(p / (1 - p))
123
+
124
+
125
+ def analyze_checkpoint(state_dict: Dict[str, torch.Tensor]) -> CheckpointInfo:
126
+ """
127
+ Analyze a checkpoint to determine version and contents.
128
+
129
+ Args:
130
+ state_dict: Model state dictionary
131
+
132
+ Returns:
133
+ CheckpointInfo with analysis results
134
+ """
135
+ info = CheckpointInfo()
136
+ info.total_params = sum(p.numel() for p in state_dict.values())
137
+
138
+ for key in state_dict.keys():
139
+ if key.startswith('expert_predictor.'):
140
+ info.has_expert_predictor = True
141
+ if key.startswith('lune_predictor.'):
142
+ info.has_lune_predictor = True
143
+ if key.startswith('sol_prior.'):
144
+ info.has_sol_prior = True
145
+ if key.startswith('t5_pool.'):
146
+ info.has_t5_pool = True
147
+ if 'spatial_to_mod' in key:
148
+ info.has_spatial_to_mod = True
149
+ if key.startswith('double_blocks.'):
150
+ idx = int(key.split('.')[1])
151
+ info.num_double_blocks = max(info.num_double_blocks, idx + 1)
152
+ if key.startswith('single_blocks.'):
153
+ idx = int(key.split('.')[1])
154
+ info.num_single_blocks = max(info.num_single_blocks, idx + 1)
155
+
156
+ # Determine version
157
+ if info.has_lune_predictor and info.has_sol_prior:
158
+ info.version = "v4"
159
+ elif info.has_expert_predictor:
160
+ info.version = "v3"
161
+ elif info.has_lune_predictor and not info.has_sol_prior:
162
+ info.version = "v3.5"
163
+ else:
164
+ info.version = "v2_or_earlier"
165
+
166
+ return info
167
+
168
+
169
+ def create_sol_prior_init(
170
+ config: ConversionConfig,
171
+ dtype: torch.dtype = torch.float32,
172
+ ) -> Dict[str, torch.Tensor]:
173
+ """Create zero-effect initialization for SolAttentionPrior."""
174
+ init = {}
175
+ hidden_dim = config.sol_hidden_dim
176
+ time_dim = config.time_dim
177
+ clip_dim = config.clip_dim
178
+ num_heads = config.num_heads
179
+ spatial_size = config.sol_spatial_size
180
+
181
+ # stat_predictor
182
+ w0 = torch.empty(hidden_dim, time_dim + clip_dim, dtype=dtype)
183
+ nn.init.xavier_uniform_(w0, gain=0.1)
184
+ init['sol_prior.stat_predictor.0.weight'] = w0
185
+ init['sol_prior.stat_predictor.0.bias'] = torch.zeros(hidden_dim, dtype=dtype)
186
+
187
+ w1 = torch.empty(hidden_dim, hidden_dim, dtype=dtype)
188
+ nn.init.xavier_uniform_(w1, gain=0.1)
189
+ init['sol_prior.stat_predictor.2.weight'] = w1
190
+ init['sol_prior.stat_predictor.2.bias'] = torch.zeros(hidden_dim, dtype=dtype)
191
+
192
+ w2 = torch.empty(3, hidden_dim, dtype=dtype)
193
+ nn.init.xavier_uniform_(w2, gain=0.1)
194
+ init['sol_prior.stat_predictor.4.weight'] = w2
195
+ init['sol_prior.stat_predictor.4.bias'] = torch.zeros(3, dtype=dtype)
196
+
197
+ # spatial_predictor
198
+ w0 = torch.empty(hidden_dim, time_dim + clip_dim, dtype=dtype)
199
+ nn.init.xavier_uniform_(w0, gain=0.1)
200
+ init['sol_prior.spatial_predictor.0.weight'] = w0
201
+ init['sol_prior.spatial_predictor.0.bias'] = torch.zeros(hidden_dim, dtype=dtype)
202
+
203
+ w1 = torch.empty(hidden_dim, hidden_dim, dtype=dtype)
204
+ nn.init.xavier_uniform_(w1, gain=0.1)
205
+ init['sol_prior.spatial_predictor.2.weight'] = w1
206
+ init['sol_prior.spatial_predictor.2.bias'] = torch.zeros(hidden_dim, dtype=dtype)
207
+
208
+ w2 = torch.empty(spatial_size * spatial_size, hidden_dim, dtype=dtype)
209
+ nn.init.xavier_uniform_(w2, gain=0.1)
210
+ init['sol_prior.spatial_predictor.4.weight'] = w2
211
+ init['sol_prior.spatial_predictor.4.bias'] = torch.zeros(spatial_size * spatial_size, dtype=dtype)
212
+
213
+ # stat_to_temperature
214
+ w0 = torch.empty(hidden_dim // 2, 3, dtype=dtype)
215
+ nn.init.xavier_uniform_(w0, gain=0.1)
216
+ init['sol_prior.stat_to_temperature.0.weight'] = w0
217
+ init['sol_prior.stat_to_temperature.0.bias'] = torch.zeros(hidden_dim // 2, dtype=dtype)
218
+
219
+ w1 = torch.empty(num_heads, hidden_dim // 2, dtype=dtype)
220
+ nn.init.xavier_uniform_(w1, gain=0.1)
221
+ init['sol_prior.stat_to_temperature.2.weight'] = w1
222
+ init['sol_prior.stat_to_temperature.2.bias'] = torch.full((num_heads,), 0.54, dtype=dtype)
223
+
224
+ # spatial_to_qk_scale
225
+ init['sol_prior.spatial_to_qk_scale.weight'] = torch.zeros(num_heads, 1, dtype=dtype)
226
+ init['sol_prior.spatial_to_qk_scale.bias'] = torch.ones(num_heads, dtype=dtype)
227
+
228
+ # blend_gate
229
+ init['sol_prior.blend_gate'] = torch.tensor(to_logit(config.sol_geometric_weight), dtype=dtype)
230
+
231
+ return init
232
+
233
+
234
+ def create_t5_pool_init(
235
+ config: ConversionConfig,
236
+ dtype: torch.dtype = torch.float32,
237
+ ) -> Dict[str, torch.Tensor]:
238
+ """Create initialization for T5 pool pathway."""
239
+ init = {}
240
+ hidden_size = config.hidden_size
241
+ joint_attention_dim = config.joint_attention_dim
242
+
243
+ w1 = torch.empty(hidden_size, joint_attention_dim, dtype=dtype)
244
+ nn.init.xavier_uniform_(w1)
245
+ init['t5_pool.0.weight'] = w1
246
+ init['t5_pool.0.bias'] = torch.zeros(hidden_size, dtype=dtype)
247
+
248
+ w2 = torch.empty(hidden_size, hidden_size, dtype=dtype)
249
+ nn.init.xavier_uniform_(w2)
250
+ init['t5_pool.2.weight'] = w2
251
+ init['t5_pool.2.bias'] = torch.zeros(hidden_size, dtype=dtype)
252
+
253
+ init['text_balance'] = torch.tensor(0.0, dtype=dtype)
254
+
255
+ return init
256
+
257
+
258
+ def create_spatial_to_mod_init(
259
+ num_heads: int = 4,
260
+ dtype: torch.dtype = torch.float32,
261
+ ) -> Dict[str, torch.Tensor]:
262
+ """Create zero-init for spatial_to_mod Conv2d layers."""
263
+ return {
264
+ 'weight': torch.zeros(num_heads, 1, 1, 1, dtype=dtype),
265
+ 'bias': torch.zeros(num_heads, dtype=dtype),
266
+ }
267
+
268
+
269
+ def convert_state_dict(
270
+ v3_state: Dict[str, torch.Tensor],
271
+ config: Optional[ConversionConfig] = None,
272
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, any]]:
273
+ """
274
+ Convert v3 state dict to v4 format.
275
+
276
+ Args:
277
+ v3_state: v3 state dictionary
278
+ config: Conversion configuration (uses defaults if None)
279
+
280
+ Returns:
281
+ Tuple of (v4_state_dict, report_dict)
282
+ """
283
+ cfg = config or ConversionConfig()
284
+ v3_info = analyze_checkpoint(v3_state)
285
+
286
+ if v3_info.version == "v4":
287
+ return v3_state, {'status': 'already_v4', 'source_version': 'v4'}
288
+
289
+ sample_key = list(v3_state.keys())[0]
290
+ dtype = v3_state[sample_key].dtype
291
+
292
+ report = {
293
+ 'status': 'converted',
294
+ 'source_version': v3_info.version,
295
+ 'source_params': v3_info.total_params,
296
+ 'renamed': [],
297
+ 'initialized': [],
298
+ 'modified': [],
299
+ }
300
+
301
+ v4_state = {}
302
+
303
+ # Step 1: Rename expert_predictor → lune_predictor
304
+ for key, value in v3_state.items():
305
+ if key.startswith('expert_predictor.'):
306
+ new_key = key.replace('expert_predictor.', 'lune_predictor.')
307
+ v4_state[new_key] = value
308
+ report['renamed'].append((key, new_key))
309
+ else:
310
+ v4_state[key] = value
311
+
312
+ # Step 2: Fix expert_gate value
313
+ gate_key = 'lune_predictor.expert_gate'
314
+ if gate_key in v4_state:
315
+ old_val = v4_state[gate_key].item()
316
+ if abs(old_val - 0.5) < 0.3:
317
+ new_val = to_logit(old_val)
318
+ v4_state[gate_key] = torch.tensor(new_val, dtype=dtype)
319
+ report['modified'].append((gate_key, f'{old_val:.4f} → {new_val:.4f}'))
320
+
321
+ # Step 3: Initialize SolAttentionPrior
322
+ if not v3_info.has_sol_prior:
323
+ sol_init = create_sol_prior_init(cfg, dtype)
324
+ v4_state.update(sol_init)
325
+ report['initialized'].extend(list(sol_init.keys()))
326
+
327
+ # Step 4: Initialize T5 pool
328
+ if not v3_info.has_t5_pool:
329
+ t5_init = create_t5_pool_init(cfg, dtype)
330
+ v4_state.update(t5_init)
331
+ report['initialized'].extend(list(t5_init.keys()))
332
+
333
+ # Step 5: Initialize spatial_to_mod
334
+ if not v3_info.has_spatial_to_mod:
335
+ spatial_init = create_spatial_to_mod_init(cfg.num_heads, dtype)
336
+
337
+ for i in range(cfg.num_double_blocks):
338
+ prefix = f'double_blocks.{i}.attn.spatial_to_mod.'
339
+ v4_state[prefix + 'weight'] = spatial_init['weight'].clone()
340
+ v4_state[prefix + 'bias'] = spatial_init['bias'].clone()
341
+ report['initialized'].extend([prefix + 'weight', prefix + 'bias'])
342
+
343
+ for i in range(cfg.num_single_blocks):
344
+ prefix = f'single_blocks.{i}.attn.spatial_to_mod.'
345
+ v4_state[prefix + 'weight'] = spatial_init['weight'].clone()
346
+ v4_state[prefix + 'bias'] = spatial_init['bias'].clone()
347
+ report['initialized'].extend([prefix + 'weight', prefix + 'bias'])
348
+
349
+ report['target_params'] = sum(p.numel() for p in v4_state.values())
350
+ report['params_added'] = report['target_params'] - report['source_params']
351
+
352
+ return v4_state, report
353
+
354
+
355
+ # =============================================================================
356
+ # High-Level API
357
+ # =============================================================================
358
+
359
+ def download_from_hf(
360
+ step: int,
361
+ repo_id: str = "AbstractPhil/tiny-flux-deep",
362
+ checkpoint_dir: str = "checkpoints",
363
+ local_dir: str = "./downloads",
364
+ include_ema: bool = True,
365
+ ) -> Tuple[str, Optional[str]]:
366
+ """
367
+ Download checkpoint from HuggingFace.
368
+
369
+ Args:
370
+ step: Step number to download
371
+ repo_id: HuggingFace repository ID
372
+ checkpoint_dir: Subdirectory in repo containing checkpoints
373
+ local_dir: Local directory to download to
374
+ include_ema: Whether to also download EMA weights
375
+
376
+ Returns:
377
+ Tuple of (model_path, ema_path). ema_path may be None.
378
+ """
379
+ from huggingface_hub import hf_hub_download
380
+
381
+ model_filename = f"{checkpoint_dir}/step_{step}.safetensors"
382
+ model_path = hf_hub_download(
383
+ repo_id=repo_id,
384
+ filename=model_filename,
385
+ local_dir=local_dir,
386
+ )
387
+
388
+ ema_path = None
389
+ if include_ema:
390
+ ema_filename = f"{checkpoint_dir}/step_{step}_ema.safetensors"
391
+ try:
392
+ ema_path = hf_hub_download(
393
+ repo_id=repo_id,
394
+ filename=ema_filename,
395
+ local_dir=local_dir,
396
+ )
397
+ except Exception:
398
+ pass
399
+
400
+ return model_path, ema_path
401
+
402
+
403
+ def convert_checkpoint(
404
+ step: Optional[int] = None,
405
+ input_path: Optional[str] = None,
406
+ ema_input_path: Optional[str] = None,
407
+ output_dir: str = "converted",
408
+ model_name: str = "lailah",
409
+ repo_id: str = "AbstractPhil/tiny-flux-deep",
410
+ checkpoint_dir: str = "checkpoints",
411
+ create_fresh_ema: bool = True,
412
+ preserve_secondary_ema: bool = True,
413
+ config: Optional[ConversionConfig] = None,
414
+ verbose: bool = True,
415
+ ) -> ConversionResult:
416
+ """
417
+ Convert a v3 checkpoint to v4 format.
418
+
419
+ Either `step` (to download from HF) or `input_path` (for local file) must be provided.
420
+
421
+ Args:
422
+ step: Step number to download from HuggingFace
423
+ input_path: Path to local v3 checkpoint
424
+ ema_input_path: Path to local v3 EMA checkpoint
425
+ output_dir: Directory to save converted checkpoints
426
+ model_name: Prefix for output filenames
427
+ repo_id: HuggingFace repository ID (if using step)
428
+ checkpoint_dir: Subdirectory in repo (if using step)
429
+ create_fresh_ema: Create a fresh EMA from converted weights
430
+ preserve_secondary_ema: Convert and preserve old EMA as secondary
431
+ config: Conversion configuration
432
+ verbose: Print progress messages
433
+
434
+ Returns:
435
+ ConversionResult with paths and statistics
436
+ """
437
+ from safetensors.torch import load_file, save_file
438
+
439
+ result = ConversionResult(success=False)
440
+
441
+ try:
442
+ # Get checkpoint paths
443
+ if step is not None:
444
+ if verbose:
445
+ print(f"📥 Downloading step_{step} from {repo_id}...")
446
+ model_path, ema_path = download_from_hf(
447
+ step=step,
448
+ repo_id=repo_id,
449
+ checkpoint_dir=checkpoint_dir,
450
+ )
451
+ if verbose:
452
+ print(f" ✓ Model: {model_path}")
453
+ if ema_path:
454
+ print(f" ✓ EMA: {ema_path}")
455
+ elif input_path is not None:
456
+ model_path = input_path
457
+ ema_path = ema_input_path
458
+ match = re.search(r'step_(\d+)', model_path)
459
+ step = int(match.group(1)) if match else 0
460
+ else:
461
+ result.error = "Must provide either step or input_path"
462
+ return result
463
+
464
+ # Load and convert
465
+ if verbose:
466
+ print(f"\n🔄 Converting to v4...")
467
+
468
+ v3_state = load_file(model_path)
469
+ v4_state, report = convert_state_dict(v3_state, config)
470
+
471
+ result.source_version = report['source_version']
472
+ result.source_params = report.get('source_params', 0)
473
+ result.target_params = report.get('target_params', 0)
474
+ result.params_added = report.get('params_added', 0)
475
+ result.renamed_keys = len(report.get('renamed', []))
476
+ result.initialized_keys = len(report.get('initialized', []))
477
+
478
+ if verbose:
479
+ print(f" Source: {result.source_version} ({result.source_params:,} params)")
480
+ print(f" Target: v4 ({result.target_params:,} params)")
481
+ print(f" Added: {result.params_added:,} params")
482
+
483
+ # Save outputs
484
+ os.makedirs(output_dir, exist_ok=True)
485
+
486
+ # Main model
487
+ model_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init.safetensors")
488
+ save_file(v4_state, model_out)
489
+ result.model_path = model_out
490
+ if verbose:
491
+ print(f"\n💾 Model: {model_out}")
492
+
493
+ # Fresh EMA
494
+ if create_fresh_ema:
495
+ ema_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init_ema.safetensors")
496
+ save_file(v4_state, ema_out)
497
+ result.ema_path = ema_out
498
+ if verbose:
499
+ print(f"💾 EMA (fresh): {ema_out}")
500
+
501
+ # Secondary EMA
502
+ if preserve_secondary_ema and ema_path and os.path.exists(ema_path):
503
+ if verbose:
504
+ print(f"\n🔄 Converting old EMA...")
505
+ try:
506
+ old_ema_state = load_file(ema_path)
507
+ old_ema_v4, _ = convert_state_dict(old_ema_state, config)
508
+ ema_secondary_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init_ema_secondary.safetensors")
509
+ save_file(old_ema_v4, ema_secondary_out)
510
+ result.ema_secondary_path = ema_secondary_out
511
+ if verbose:
512
+ print(f"💾 EMA (secondary): {ema_secondary_out}")
513
+ except Exception as e:
514
+ if verbose:
515
+ print(f"⚠ Failed to convert old EMA: {e}")
516
+
517
+ result.success = True
518
+
519
+ except Exception as e:
520
+ result.error = str(e)
521
+ if verbose:
522
+ print(f"❌ Error: {e}")
523
+
524
+ return result
525
+
526
+
527
+ # =============================================================================
528
+ # CLI Interface
529
+ # =============================================================================
530
+
531
+ def create_parser():
532
+ """Create argument parser for CLI."""
533
+ import argparse
534
+
535
+ parser = argparse.ArgumentParser(
536
+ description='Convert TinyFlux-Deep v3 checkpoints to v4 format',
537
+ formatter_class=argparse.RawDescriptionHelpFormatter,
538
+ epilog="""
539
+ Examples:
540
+ python convert_v3_to_v4.py --step 401434
541
+ python convert_v3_to_v4.py --input model_v3.safetensors
542
+ python convert_v3_to_v4.py --step 401434 --analyze-only
543
+ python convert_v3_to_v4.py --step 401434 --output-dir my_converted --name mymodel
544
+ """
545
+ )
546
+
547
+ # Input
548
+ input_group = parser.add_argument_group('Input (one required)')
549
+ input_group.add_argument('--step', type=int, help='Step number to download from HuggingFace')
550
+ input_group.add_argument('--input', '-i', dest='input_path', help='Path to local v3 checkpoint')
551
+ input_group.add_argument('--ema-input', dest='ema_input_path', help='Path to local v3 EMA checkpoint')
552
+
553
+ # HuggingFace
554
+ hf_group = parser.add_argument_group('HuggingFace options')
555
+ hf_group.add_argument('--repo', default='AbstractPhil/tiny-flux-deep', help='HuggingFace repo ID')
556
+ hf_group.add_argument('--checkpoint-dir', default='checkpoints', help='Subdirectory in repo')
557
+
558
+ # Output
559
+ output_group = parser.add_argument_group('Output options')
560
+ output_group.add_argument('--output-dir', '-o', default='converted', help='Output directory')
561
+ output_group.add_argument('--name', default='lailah', help='Model name prefix')
562
+
563
+ # Conversion
564
+ conv_group = parser.add_argument_group('Conversion options')
565
+ conv_group.add_argument('--no-fresh-ema', action='store_true', help='Do not create fresh EMA')
566
+ conv_group.add_argument('--no-secondary-ema', action='store_true', help='Do not preserve old EMA')
567
+ conv_group.add_argument('--analyze-only', action='store_true', help='Only analyze, do not convert')
568
+ conv_group.add_argument('--quiet', '-q', action='store_true', help='Suppress progress messages')
569
+
570
+ return parser
571
+
572
+
573
+ def cli_main():
574
+ """CLI entry point."""
575
+ parser = create_parser()
576
+ args = parser.parse_args()
577
+
578
+ if not args.step and not args.input_path:
579
+ parser.error("Must specify either --step or --input")
580
+
581
+ # Analyze only
582
+ if args.analyze_only:
583
+ from safetensors.torch import load_file
584
+
585
+ if args.step:
586
+ model_path, _ = download_from_hf(
587
+ step=args.step,
588
+ repo_id=args.repo,
589
+ checkpoint_dir=args.checkpoint_dir,
590
+ )
591
+ else:
592
+ model_path = args.input_path
593
+
594
+ state = load_file(model_path)
595
+ info = analyze_checkpoint(state)
596
+
597
+ print(f"\nCheckpoint: {model_path}")
598
+ print(f" Version: {info.version}")
599
+ print(f" Total params: {info.total_params:,}")
600
+ print(f" Double blocks: {info.num_double_blocks}")
601
+ print(f" Single blocks: {info.num_single_blocks}")
602
+ print(f" Has expert_predictor: {info.has_expert_predictor}")
603
+ print(f" Has lune_predictor: {info.has_lune_predictor}")
604
+ print(f" Has sol_prior: {info.has_sol_prior}")
605
+ print(f" Has t5_pool: {info.has_t5_pool}")
606
+ print(f" Has spatial_to_mod: {info.has_spatial_to_mod}")
607
+ return
608
+
609
+ # Convert
610
+ result = convert_checkpoint(
611
+ step=args.step,
612
+ input_path=args.input_path,
613
+ ema_input_path=args.ema_input_path,
614
+ output_dir=args.output_dir,
615
+ model_name=args.name,
616
+ repo_id=args.repo,
617
+ checkpoint_dir=args.checkpoint_dir,
618
+ create_fresh_ema=not args.no_fresh_ema,
619
+ preserve_secondary_ema=not args.no_secondary_ema,
620
+ verbose=not args.quiet,
621
+ )
622
+
623
+ if result.success:
624
+ if not args.quiet:
625
+ print("\n" + "=" * 60)
626
+ print("✅ Conversion complete!")
627
+ print("=" * 60)
628
+ print(f"\nOutput files:")
629
+ if result.model_path:
630
+ print(f" Model: {result.model_path}")
631
+ if result.ema_path:
632
+ print(f" EMA: {result.ema_path}")
633
+ if result.ema_secondary_path:
634
+ print(f" EMA (secondary): {result.ema_secondary_path}")
635
+ else:
636
+ print(f"\n❌ Conversion failed: {result.error}")
637
+ exit(1)
638
+
639
+
640
+ if __name__ == '__main__':
641
+ cli_main()