AbstractPhil commited on
Commit
d0d1188
·
verified ·
1 Parent(s): 9f429ac

Update convert_v3_to_v4.py

Browse files

Extended v3 to v4 conversion including configuration flexibility and local/repo direction.

Files changed (1) hide show
  1. convert_v3_to_v4.py +407 -117
convert_v3_to_v4.py CHANGED
@@ -1,70 +1,231 @@
1
  """
2
  TinyFlux-Deep Weight Converter: v3 → v4
3
 
4
- Converts v3 checkpoints to v4 architecture without destroying pretrain.
5
 
6
- Key changes:
7
  - expert_predictor → lune_predictor (rename)
8
- - expert_gate value: 0.5 0.0 (logit space, sigmoid(0)=0.5)
9
- - New modules initialized to zero-effect:
10
- - sol_prior: geometric priors dominate initially
11
- - t5_pool: 50/50 balance with CLIP
12
- - spatial_to_mod: exp(0)=1 identity
13
 
14
- Colab Usage:
 
 
 
15
  from convert_v3_to_v4 import run
16
- run(401434) # Downloads, converts, saves to ./converted/
17
 
18
- API Usage:
19
- from convert_v3_to_v4 import convert_checkpoint, analyze_checkpoint
20
- result = convert_checkpoint(step=401434)
 
21
 
22
- CLI Usage:
23
  python convert_v3_to_v4.py --step 401434
 
24
  """
25
 
 
 
26
  import torch
27
  import torch.nn as nn
28
  import math
29
  import os
30
  import re
31
- from typing import Dict, Tuple, Optional
32
- from dataclasses import dataclass
 
 
33
 
34
 
35
  # =============================================================================
36
- # Quick Entry Point (Colab)
37
  # =============================================================================
38
 
39
- def run(
40
- step: int = 401434,
41
- name: str = "lailah",
42
- output_dir: str = "converted",
43
- ):
44
  """
45
- One-liner for Colab. Downloads, converts, saves.
46
 
47
- Usage:
48
- from convert_v3_to_v4 import run
49
- run(401434)
 
 
 
50
  """
51
- result = convert_checkpoint(
52
- step=step,
53
- model_name=name,
54
- output_dir=output_dir,
55
- verbose=True,
56
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- if result.success:
59
- print(f"\n✅ Done! Files in ./{output_dir}/")
60
- else:
61
- print(f"\n❌ Failed: {result.error}")
 
 
 
 
 
 
 
 
 
 
62
 
63
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  # =============================================================================
67
- # Data Classes
68
  # =============================================================================
69
 
70
  @dataclass
@@ -79,47 +240,7 @@ class CheckpointInfo:
79
  num_double_blocks: int = 0
80
  num_single_blocks: int = 0
81
  total_params: int = 0
82
-
83
-
84
- @dataclass
85
- class ConversionResult:
86
- """Results from a conversion operation."""
87
- success: bool
88
- model_path: Optional[str] = None
89
- ema_path: Optional[str] = None
90
- ema_secondary_path: Optional[str] = None
91
- source_version: str = "unknown"
92
- source_params: int = 0
93
- target_params: int = 0
94
- params_added: int = 0
95
- renamed_keys: int = 0
96
- initialized_keys: int = 0
97
- error: Optional[str] = None
98
-
99
-
100
- @dataclass
101
- class ConversionConfig:
102
- """Configuration for conversion."""
103
- hidden_size: int = 512
104
- time_dim: int = 512
105
- clip_dim: int = 768
106
- joint_attention_dim: int = 768
107
- num_heads: int = 4
108
- sol_hidden_dim: int = 256
109
- sol_spatial_size: int = 8
110
- sol_geometric_weight: float = 0.7
111
- num_double_blocks: int = 15
112
- num_single_blocks: int = 25
113
-
114
-
115
- # =============================================================================
116
- # Core Functions
117
- # =============================================================================
118
-
119
- def to_logit(p: float) -> float:
120
- """Convert probability to logit for sigmoid init."""
121
- p = max(1e-4, min(p, 1 - 1e-4))
122
- return math.log(p / (1 - p))
123
 
124
 
125
  def analyze_checkpoint(state_dict: Dict[str, torch.Tensor]) -> CheckpointInfo:
@@ -135,30 +256,37 @@ def analyze_checkpoint(state_dict: Dict[str, torch.Tensor]) -> CheckpointInfo:
135
  info = CheckpointInfo()
136
  info.total_params = sum(p.numel() for p in state_dict.values())
137
 
 
 
 
 
 
138
  for key in state_dict.keys():
139
- if key.startswith('expert_predictor.'):
140
  info.has_expert_predictor = True
141
- if key.startswith('lune_predictor.'):
142
  info.has_lune_predictor = True
143
- if key.startswith('sol_prior.'):
144
  info.has_sol_prior = True
145
- if key.startswith('t5_pool.'):
146
  info.has_t5_pool = True
147
- if 'spatial_to_mod' in key:
148
  info.has_spatial_to_mod = True
149
- if key.startswith('double_blocks.'):
150
- idx = int(key.split('.')[1])
151
  info.num_double_blocks = max(info.num_double_blocks, idx + 1)
152
- if key.startswith('single_blocks.'):
153
- idx = int(key.split('.')[1])
154
  info.num_single_blocks = max(info.num_single_blocks, idx + 1)
155
 
156
  # Determine version
157
- if info.has_lune_predictor and info.has_sol_prior:
158
- info.version = "v4"
 
 
159
  elif info.has_expert_predictor:
160
  info.version = "v3"
161
- elif info.has_lune_predictor and not info.has_sol_prior:
162
  info.version = "v3.5"
163
  else:
164
  info.version = "v2_or_earlier"
@@ -166,8 +294,162 @@ def analyze_checkpoint(state_dict: Dict[str, torch.Tensor]) -> CheckpointInfo:
166
  return info
167
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  def create_sol_prior_init(
170
- config: ConversionConfig,
171
  dtype: torch.dtype = torch.float32,
172
  ) -> Dict[str, torch.Tensor]:
173
  """Create zero-effect initialization for SolAttentionPrior."""
@@ -232,7 +514,7 @@ def create_sol_prior_init(
232
 
233
 
234
  def create_t5_pool_init(
235
- config: ConversionConfig,
236
  dtype: torch.dtype = torch.float32,
237
  ) -> Dict[str, torch.Tensor]:
238
  """Create initialization for T5 pool pathway."""
@@ -268,23 +550,30 @@ def create_spatial_to_mod_init(
268
 
269
  def convert_state_dict(
270
  v3_state: Dict[str, torch.Tensor],
271
- config: Optional[ConversionConfig] = None,
272
  ) -> Tuple[Dict[str, torch.Tensor], Dict[str, any]]:
273
  """
274
- Convert v3 state dict to v4 format.
275
 
276
  Args:
277
  v3_state: v3 state dictionary
278
- config: Conversion configuration (uses defaults if None)
279
 
280
  Returns:
281
  Tuple of (v4_state_dict, report_dict)
282
  """
283
- cfg = config or ConversionConfig()
284
  v3_info = analyze_checkpoint(v3_state)
285
 
286
- if v3_info.version == "v4":
287
- return v3_state, {'status': 'already_v4', 'source_version': 'v4'}
 
 
 
 
 
 
 
288
 
289
  sample_key = list(v3_state.keys())[0]
290
  dtype = v3_state[sample_key].dtype
@@ -296,6 +585,7 @@ def convert_state_dict(
296
  'renamed': [],
297
  'initialized': [],
298
  'modified': [],
 
299
  }
300
 
301
  v4_state = {}
@@ -309,29 +599,29 @@ def convert_state_dict(
309
  else:
310
  v4_state[key] = value
311
 
312
- # Step 2: Fix expert_gate value
313
  gate_key = 'lune_predictor.expert_gate'
314
  if gate_key in v4_state:
315
  old_val = v4_state[gate_key].item()
316
- if abs(old_val - 0.5) < 0.3:
317
  new_val = to_logit(old_val)
318
  v4_state[gate_key] = torch.tensor(new_val, dtype=dtype)
319
  report['modified'].append((gate_key, f'{old_val:.4f} → {new_val:.4f}'))
320
 
321
- # Step 3: Initialize SolAttentionPrior
322
- if not v3_info.has_sol_prior:
323
  sol_init = create_sol_prior_init(cfg, dtype)
324
  v4_state.update(sol_init)
325
  report['initialized'].extend(list(sol_init.keys()))
326
 
327
- # Step 4: Initialize T5 pool
328
- if not v3_info.has_t5_pool:
329
  t5_init = create_t5_pool_init(cfg, dtype)
330
  v4_state.update(t5_init)
331
  report['initialized'].extend(list(t5_init.keys()))
332
 
333
- # Step 5: Initialize spatial_to_mod
334
- if not v3_info.has_spatial_to_mod:
335
  spatial_init = create_spatial_to_mod_init(cfg.num_heads, dtype)
336
 
337
  for i in range(cfg.num_double_blocks):
@@ -404,17 +694,17 @@ def convert_checkpoint(
404
  step: Optional[int] = None,
405
  input_path: Optional[str] = None,
406
  ema_input_path: Optional[str] = None,
407
- output_dir: str = "converted",
408
  model_name: str = "lailah",
409
  repo_id: str = "AbstractPhil/tiny-flux-deep",
410
  checkpoint_dir: str = "checkpoints",
411
  create_fresh_ema: bool = True,
412
  preserve_secondary_ema: bool = True,
413
- config: Optional[ConversionConfig] = None,
414
  verbose: bool = True,
415
  ) -> ConversionResult:
416
  """
417
- Convert a v3 checkpoint to v4 format.
418
 
419
  Either `step` (to download from HF) or `input_path` (for local file) must be provided.
420
 
@@ -428,7 +718,7 @@ def convert_checkpoint(
428
  checkpoint_dir: Subdirectory in repo (if using step)
429
  create_fresh_ema: Create a fresh EMA from converted weights
430
  preserve_secondary_ema: Convert and preserve old EMA as secondary
431
- config: Conversion configuration
432
  verbose: Print progress messages
433
 
434
  Returns:
@@ -436,6 +726,7 @@ def convert_checkpoint(
436
  """
437
  from safetensors.torch import load_file, save_file
438
 
 
439
  result = ConversionResult(success=False)
440
 
441
  try:
@@ -463,21 +754,20 @@ def convert_checkpoint(
463
 
464
  # Load and convert
465
  if verbose:
466
- print(f"\n🔄 Converting to v4...")
467
 
468
  v3_state = load_file(model_path)
469
- v4_state, report = convert_state_dict(v3_state, config)
470
 
471
  result.source_version = report['source_version']
 
472
  result.source_params = report.get('source_params', 0)
473
  result.target_params = report.get('target_params', 0)
474
  result.params_added = report.get('params_added', 0)
475
- result.renamed_keys = len(report.get('renamed', []))
476
- result.initialized_keys = len(report.get('initialized', []))
477
 
478
  if verbose:
479
  print(f" Source: {result.source_version} ({result.source_params:,} params)")
480
- print(f" Target: v4 ({result.target_params:,} params)")
481
  print(f" Added: {result.params_added:,} params")
482
 
483
  # Save outputs
@@ -504,7 +794,7 @@ def convert_checkpoint(
504
  print(f"\n🔄 Converting old EMA...")
505
  try:
506
  old_ema_state = load_file(ema_path)
507
- old_ema_v4, _ = convert_state_dict(old_ema_state, config)
508
  ema_secondary_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init_ema_secondary.safetensors")
509
  save_file(old_ema_v4, ema_secondary_out)
510
  result.ema_secondary_path = ema_secondary_out
@@ -557,7 +847,7 @@ Examples:
557
 
558
  # Output
559
  output_group = parser.add_argument_group('Output options')
560
- output_group.add_argument('--output-dir', '-o', default='converted', help='Output directory')
561
  output_group.add_argument('--name', default='lailah', help='Model name prefix')
562
 
563
  # Conversion
 
1
  """
2
  TinyFlux-Deep Weight Converter: v3 → v4
3
 
4
+ Converts v3 checkpoints to v4.1 architecture without destroying pretrained weights.
5
 
6
+ Changes from v3 → v4:
7
  - expert_predictor → lune_predictor (rename)
8
+ - expert_gate: raw value → logit space (sigmoid(0)=0.5 preserved)
9
+ - NEW: sol_prior (attention statistics predictor, 70% geometric prior)
10
+ - NEW: t5_pool + text_balance (T5 vec pathway, 50/50 init)
11
+ - NEW: spatial_to_mod per attention layer (zero-init = identity)
 
12
 
13
+ All new modules initialize to zero-effect, so converted model behaves
14
+ identically to v3 on first forward pass.
15
+
16
+ Colab:
17
  from convert_v3_to_v4 import run
18
+ run(401434)
19
 
20
+ API:
21
+ from convert_v3_to_v4 import convert_checkpoint, load_config
22
+ config = load_config("path/to/config.json")
23
+ result = convert_checkpoint(step=401434, config=config)
24
 
25
+ CLI:
26
  python convert_v3_to_v4.py --step 401434
27
+ python convert_v3_to_v4.py --step 401434 --config my_config.json
28
  """
29
 
30
+ __version__ = "4.1.0"
31
+
32
  import torch
33
  import torch.nn as nn
34
  import math
35
  import os
36
  import re
37
+ import json
38
+ from typing import Dict, Tuple, Optional, Union, List
39
+ from dataclasses import dataclass, field, asdict
40
+ from pathlib import Path
41
 
42
 
43
  # =============================================================================
44
+ # Configuration
45
  # =============================================================================
46
 
47
+ @dataclass
48
+ class TinyFluxConfig:
 
 
 
49
  """
50
+ TinyFlux-Deep v4.1 model configuration.
51
 
52
+ This config fully defines the model architecture and can be used to:
53
+ 1. Initialize a new model
54
+ 2. Convert checkpoints between versions
55
+ 3. Validate checkpoint compatibility
56
+
57
+ All dimension constraints are validated on creation.
58
  """
59
+ # Core architecture
60
+ hidden_size: int = 512
61
+ num_attention_heads: int = 4
62
+ attention_head_dim: int = 128
63
+ in_channels: int = 16
64
+ patch_size: int = 1
65
+ joint_attention_dim: int = 768 # T5 sequence dim
66
+ pooled_projection_dim: int = 768 # CLIP pooled dim
67
+ num_double_layers: int = 15
68
+ num_single_layers: int = 25
69
+ mlp_ratio: float = 4.0
70
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
71
+
72
+ # Lune expert predictor (trajectory guidance)
73
+ use_lune_expert: bool = True
74
+ lune_expert_dim: int = 1280 # SD1.5 mid-block dimension
75
+ lune_hidden_dim: int = 512
76
+ lune_dropout: float = 0.1
77
+
78
+ # Sol attention prior (structural guidance)
79
+ use_sol_prior: bool = True
80
+ sol_spatial_size: int = 8
81
+ sol_hidden_dim: int = 256
82
+ sol_geometric_weight: float = 0.7 # 70% geometric, 30% learned
83
+
84
+ # T5 vec enhancement
85
+ use_t5_vec: bool = True
86
+ t5_pool_mode: str = "attention" # "attention", "mean", "cls"
87
+
88
+ # Loss configuration (for training)
89
+ lune_distill_mode: str = "cosine" # "hard", "soft", "cosine", "huber"
90
+ use_huber_loss: bool = True
91
+ huber_delta: float = 0.1
92
+
93
+ # Legacy
94
+ guidance_embeds: bool = False
95
+
96
+ def __post_init__(self):
97
+ """Validate configuration constraints."""
98
+ # Validate attention dimensions
99
+ expected_hidden = self.num_attention_heads * self.attention_head_dim
100
+ if self.hidden_size != expected_hidden:
101
+ raise ValueError(
102
+ f"hidden_size ({self.hidden_size}) must equal "
103
+ f"num_attention_heads * attention_head_dim ({expected_hidden})"
104
+ )
105
+
106
+ # Validate RoPE dimensions
107
+ if isinstance(self.axes_dims_rope, list):
108
+ self.axes_dims_rope = tuple(self.axes_dims_rope)
109
+
110
+ rope_sum = sum(self.axes_dims_rope)
111
+ if rope_sum != self.attention_head_dim:
112
+ raise ValueError(
113
+ f"sum(axes_dims_rope) ({rope_sum}) must equal "
114
+ f"attention_head_dim ({self.attention_head_dim})"
115
+ )
116
+
117
+ # Validate sol_geometric_weight
118
+ if not 0.0 <= self.sol_geometric_weight <= 1.0:
119
+ raise ValueError(f"sol_geometric_weight must be in [0, 1], got {self.sol_geometric_weight}")
120
+
121
+ # Derived properties for converter compatibility
122
+ @property
123
+ def time_dim(self) -> int:
124
+ return self.hidden_size
125
+
126
+ @property
127
+ def clip_dim(self) -> int:
128
+ return self.pooled_projection_dim
129
+
130
+ @property
131
+ def num_heads(self) -> int:
132
+ return self.num_attention_heads
133
+
134
+ @property
135
+ def num_double_blocks(self) -> int:
136
+ return self.num_double_layers
137
+
138
+ @property
139
+ def num_single_blocks(self) -> int:
140
+ return self.num_single_layers
141
+
142
+ def to_dict(self) -> Dict:
143
+ """Convert to JSON-serializable dict."""
144
+ d = asdict(self)
145
+ d["axes_dims_rope"] = list(d["axes_dims_rope"])
146
+ return d
147
+
148
+ @classmethod
149
+ def from_dict(cls, d: Dict) -> "TinyFluxConfig":
150
+ """Create from dict, ignoring unknown keys."""
151
+ # Filter to known fields
152
+ known_fields = {f.name for f in cls.__dataclass_fields__.values()}
153
+ filtered = {k: v for k, v in d.items() if k in known_fields and not k.startswith("_")}
154
+ return cls(**filtered)
155
+
156
+ def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]:
157
+ """
158
+ Validate that a checkpoint matches this config.
159
+
160
+ Returns list of warnings (empty if perfect match).
161
+ """
162
+ warnings = []
163
+
164
+ # Check double block count
165
+ max_double = 0
166
+ for key in state_dict:
167
+ if key.startswith("double_blocks."):
168
+ idx = int(key.split(".")[1])
169
+ max_double = max(max_double, idx + 1)
170
+ if max_double != self.num_double_layers:
171
+ warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}")
172
+
173
+ # Check single block count
174
+ max_single = 0
175
+ for key in state_dict:
176
+ if key.startswith("single_blocks."):
177
+ idx = int(key.split(".")[1])
178
+ max_single = max(max_single, idx + 1)
179
+ if max_single != self.num_single_layers:
180
+ warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}")
181
+
182
+ # Check hidden size from a known weight
183
+ if "img_embed.proj.weight" in state_dict:
184
+ w = state_dict["img_embed.proj.weight"]
185
+ if w.shape[0] != self.hidden_size:
186
+ warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}")
187
+
188
+ return warnings
189
+
190
+
191
+ def load_config(path: Union[str, Path]) -> TinyFluxConfig:
192
+ """
193
+ Load config from JSON file.
194
 
195
+ Args:
196
+ path: Path to config JSON file
197
+
198
+ Returns:
199
+ TinyFluxConfig instance
200
+ """
201
+ with open(path) as f:
202
+ d = json.load(f)
203
+ return TinyFluxConfig.from_dict(d)
204
+
205
+
206
+ def save_config(config: TinyFluxConfig, path: Union[str, Path], conversion_info: Optional[Dict] = None):
207
+ """
208
+ Save config to JSON file.
209
 
210
+ Args:
211
+ config: TinyFluxConfig instance
212
+ path: Output path
213
+ conversion_info: Optional metadata about conversion
214
+ """
215
+ d = config.to_dict()
216
+ if conversion_info:
217
+ d["_conversion_info"] = conversion_info
218
+
219
+ with open(path, "w") as f:
220
+ json.dump(d, f, indent=2)
221
+
222
+
223
+ # Default configuration
224
+ DEFAULT_CONFIG = TinyFluxConfig()
225
 
226
 
227
  # =============================================================================
228
+ # Checkpoint Analysis
229
  # =============================================================================
230
 
231
  @dataclass
 
240
  num_double_blocks: int = 0
241
  num_single_blocks: int = 0
242
  total_params: int = 0
243
+ dtype: str = "float32"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
 
246
  def analyze_checkpoint(state_dict: Dict[str, torch.Tensor]) -> CheckpointInfo:
 
256
  info = CheckpointInfo()
257
  info.total_params = sum(p.numel() for p in state_dict.values())
258
 
259
+ # Detect dtype
260
+ for v in state_dict.values():
261
+ info.dtype = str(v.dtype).replace("torch.", "")
262
+ break
263
+
264
  for key in state_dict.keys():
265
+ if key.startswith("expert_predictor."):
266
  info.has_expert_predictor = True
267
+ if key.startswith("lune_predictor."):
268
  info.has_lune_predictor = True
269
+ if key.startswith("sol_prior."):
270
  info.has_sol_prior = True
271
+ if key.startswith("t5_pool."):
272
  info.has_t5_pool = True
273
+ if "spatial_to_mod" in key:
274
  info.has_spatial_to_mod = True
275
+ if key.startswith("double_blocks."):
276
+ idx = int(key.split(".")[1])
277
  info.num_double_blocks = max(info.num_double_blocks, idx + 1)
278
+ if key.startswith("single_blocks."):
279
+ idx = int(key.split(".")[1])
280
  info.num_single_blocks = max(info.num_single_blocks, idx + 1)
281
 
282
  # Determine version
283
+ if info.has_lune_predictor and info.has_sol_prior and info.has_t5_pool:
284
+ info.version = "v4.1"
285
+ elif info.has_lune_predictor and info.has_sol_prior:
286
+ info.version = "v4.0"
287
  elif info.has_expert_predictor:
288
  info.version = "v3"
289
+ elif info.has_lune_predictor:
290
  info.version = "v3.5"
291
  else:
292
  info.version = "v2_or_earlier"
 
294
  return info
295
 
296
 
297
+ # =============================================================================
298
+ # Conversion Result
299
+ # =============================================================================
300
+
301
+ @dataclass
302
+ class ConversionResult:
303
+ """Results from a conversion operation."""
304
+ success: bool
305
+ model_path: Optional[str] = None
306
+ ema_path: Optional[str] = None
307
+ ema_secondary_path: Optional[str] = None
308
+ config_path: Optional[str] = None
309
+ source_version: str = "unknown"
310
+ target_version: str = "v4.1"
311
+ source_params: int = 0
312
+ target_params: int = 0
313
+ params_added: int = 0
314
+ error: Optional[str] = None
315
+
316
+
317
+ # =============================================================================
318
+ # Colab Entry Point
319
+ # =============================================================================
320
+
321
+ def run(
322
+ step: int = 401434,
323
+ name: str = "lailah",
324
+ output_dir: str = "checkpoint_runs/v4_init",
325
+ repo_id: str = "AbstractPhil/tiny-flux-deep",
326
+ upload_repo: str = "AbstractPhil/tiny-flux-deep",
327
+ upload_subdir: str = "checkpoint_runs/v4_init",
328
+ config: Optional[Union[TinyFluxConfig, Dict, str]] = None,
329
+ ):
330
+ """
331
+ One-liner for Colab. Downloads, converts, saves locally, uploads to HF.
332
+
333
+ Args:
334
+ step: Checkpoint step number to download
335
+ name: Model name prefix for output files
336
+ output_dir: Local output directory
337
+ repo_id: HuggingFace repo to download from
338
+ upload_repo: HuggingFace repo to upload to
339
+ upload_subdir: Subdirectory in upload repo
340
+ config: Model config - can be:
341
+ - None (use default)
342
+ - TinyFluxConfig instance
343
+ - Dict with config values
344
+ - Path to config JSON file
345
+
346
+ Usage:
347
+ from convert_v3_to_v4 import run
348
+ run(401434)
349
+
350
+ # With custom config
351
+ run(401434, config={"hidden_size": 768, ...})
352
+ run(401434, config="path/to/config.json")
353
+ """
354
+ # Resolve config
355
+ if config is None:
356
+ cfg = DEFAULT_CONFIG
357
+ elif isinstance(config, TinyFluxConfig):
358
+ cfg = config
359
+ elif isinstance(config, dict):
360
+ cfg = TinyFluxConfig.from_dict(config)
361
+ elif isinstance(config, (str, Path)):
362
+ cfg = load_config(config)
363
+ else:
364
+ raise TypeError(f"config must be TinyFluxConfig, dict, path, or None, got {type(config)}")
365
+
366
+ print(f"TinyFlux-Deep v3 → v4.1 Converter")
367
+ print(f"=" * 50)
368
+ print(f"Config: hidden_size={cfg.hidden_size}, heads={cfg.num_attention_heads}")
369
+ print(f" double_layers={cfg.num_double_layers}, single_layers={cfg.num_single_layers}")
370
+
371
+ result = convert_checkpoint(
372
+ step=step,
373
+ model_name=name,
374
+ output_dir=output_dir,
375
+ repo_id=repo_id,
376
+ checkpoint_dir="checkpoints",
377
+ config=cfg,
378
+ verbose=True,
379
+ )
380
+
381
+ if not result.success:
382
+ print(f"\n❌ Conversion failed: {result.error}")
383
+ return result
384
+
385
+ print(f"\n✅ Conversion complete!")
386
+ print(f" Source: {result.source_version} ({result.source_params:,} params)")
387
+ print(f" Target: {result.target_version} ({result.target_params:,} params)")
388
+ print(f" Added: {result.params_added:,} params")
389
+
390
+ # Save config
391
+ config_path = os.path.join(output_dir, f"{name}_{step}_v4_config.json")
392
+ conversion_info = {
393
+ "source_step": step,
394
+ "source_repo": repo_id,
395
+ "source_version": result.source_version,
396
+ "target_version": result.target_version,
397
+ "source_params": result.source_params,
398
+ "target_params": result.target_params,
399
+ "params_added": result.params_added,
400
+ "converter_version": __version__,
401
+ "files": {
402
+ "model": os.path.basename(result.model_path) if result.model_path else None,
403
+ "ema": os.path.basename(result.ema_path) if result.ema_path else None,
404
+ "ema_secondary": os.path.basename(result.ema_secondary_path) if result.ema_secondary_path else None,
405
+ },
406
+ }
407
+ save_config(cfg, config_path, conversion_info)
408
+ result.config_path = config_path
409
+ print(f"💾 Config: {config_path}")
410
+
411
+ # Upload to HuggingFace
412
+ from huggingface_hub import HfApi
413
+ api = HfApi()
414
+
415
+ print(f"\n📤 Uploading to {upload_repo}/{upload_subdir}/...")
416
+
417
+ files_to_upload = [
418
+ result.model_path,
419
+ result.ema_path,
420
+ result.ema_secondary_path,
421
+ config_path,
422
+ ]
423
+
424
+ for local_path in files_to_upload:
425
+ if local_path and os.path.exists(local_path):
426
+ filename = os.path.basename(local_path)
427
+ remote_path = f"{upload_subdir}/{filename}"
428
+
429
+ api.upload_file(
430
+ path_or_fileobj=local_path,
431
+ path_in_repo=remote_path,
432
+ repo_id=upload_repo,
433
+ )
434
+ print(f" ✓ {remote_path}")
435
+
436
+ print(f"\n✅ Uploaded to {upload_repo}/{upload_subdir}/")
437
+
438
+ return result
439
+
440
+
441
+ # =============================================================================
442
+ # Weight Initialization Functions
443
+ # =============================================================================
444
+
445
+ def to_logit(p: float) -> float:
446
+ """Convert probability to logit for sigmoid init."""
447
+ p = max(1e-4, min(p, 1 - 1e-4))
448
+ return math.log(p / (1 - p))
449
+
450
+
451
  def create_sol_prior_init(
452
+ config: TinyFluxConfig,
453
  dtype: torch.dtype = torch.float32,
454
  ) -> Dict[str, torch.Tensor]:
455
  """Create zero-effect initialization for SolAttentionPrior."""
 
514
 
515
 
516
  def create_t5_pool_init(
517
+ config: TinyFluxConfig,
518
  dtype: torch.dtype = torch.float32,
519
  ) -> Dict[str, torch.Tensor]:
520
  """Create initialization for T5 pool pathway."""
 
550
 
551
  def convert_state_dict(
552
  v3_state: Dict[str, torch.Tensor],
553
+ config: Optional[TinyFluxConfig] = None,
554
  ) -> Tuple[Dict[str, torch.Tensor], Dict[str, any]]:
555
  """
556
+ Convert v3 state dict to v4.1 format.
557
 
558
  Args:
559
  v3_state: v3 state dictionary
560
+ config: TinyFluxConfig (uses DEFAULT_CONFIG if None)
561
 
562
  Returns:
563
  Tuple of (v4_state_dict, report_dict)
564
  """
565
+ cfg = config or DEFAULT_CONFIG
566
  v3_info = analyze_checkpoint(v3_state)
567
 
568
+ if v3_info.version in ("v4.0", "v4.1"):
569
+ return v3_state, {'status': 'already_v4', 'source_version': v3_info.version}
570
+
571
+ # Validate config matches checkpoint structure
572
+ warnings = cfg.validate_checkpoint(v3_state)
573
+ if warnings:
574
+ print(f"⚠️ Config validation warnings:")
575
+ for w in warnings:
576
+ print(f" - {w}")
577
 
578
  sample_key = list(v3_state.keys())[0]
579
  dtype = v3_state[sample_key].dtype
 
585
  'renamed': [],
586
  'initialized': [],
587
  'modified': [],
588
+ 'warnings': warnings,
589
  }
590
 
591
  v4_state = {}
 
599
  else:
600
  v4_state[key] = value
601
 
602
+ # Step 2: Fix expert_gate value (raw → logit space)
603
  gate_key = 'lune_predictor.expert_gate'
604
  if gate_key in v4_state:
605
  old_val = v4_state[gate_key].item()
606
+ if abs(old_val - 0.5) < 0.3: # Looks like raw probability, not logit
607
  new_val = to_logit(old_val)
608
  v4_state[gate_key] = torch.tensor(new_val, dtype=dtype)
609
  report['modified'].append((gate_key, f'{old_val:.4f} → {new_val:.4f}'))
610
 
611
+ # Step 3: Initialize SolAttentionPrior (if missing)
612
+ if not v3_info.has_sol_prior and cfg.use_sol_prior:
613
  sol_init = create_sol_prior_init(cfg, dtype)
614
  v4_state.update(sol_init)
615
  report['initialized'].extend(list(sol_init.keys()))
616
 
617
+ # Step 4: Initialize T5 pool (if missing)
618
+ if not v3_info.has_t5_pool and cfg.use_t5_vec:
619
  t5_init = create_t5_pool_init(cfg, dtype)
620
  v4_state.update(t5_init)
621
  report['initialized'].extend(list(t5_init.keys()))
622
 
623
+ # Step 5: Initialize spatial_to_mod in attention layers (if missing)
624
+ if not v3_info.has_spatial_to_mod and cfg.use_sol_prior:
625
  spatial_init = create_spatial_to_mod_init(cfg.num_heads, dtype)
626
 
627
  for i in range(cfg.num_double_blocks):
 
694
  step: Optional[int] = None,
695
  input_path: Optional[str] = None,
696
  ema_input_path: Optional[str] = None,
697
+ output_dir: str = "checkpoint_runs/v4_init",
698
  model_name: str = "lailah",
699
  repo_id: str = "AbstractPhil/tiny-flux-deep",
700
  checkpoint_dir: str = "checkpoints",
701
  create_fresh_ema: bool = True,
702
  preserve_secondary_ema: bool = True,
703
+ config: Optional[TinyFluxConfig] = None,
704
  verbose: bool = True,
705
  ) -> ConversionResult:
706
  """
707
+ Convert a v3 checkpoint to v4.1 format.
708
 
709
  Either `step` (to download from HF) or `input_path` (for local file) must be provided.
710
 
 
718
  checkpoint_dir: Subdirectory in repo (if using step)
719
  create_fresh_ema: Create a fresh EMA from converted weights
720
  preserve_secondary_ema: Convert and preserve old EMA as secondary
721
+ config: TinyFluxConfig for model architecture
722
  verbose: Print progress messages
723
 
724
  Returns:
 
726
  """
727
  from safetensors.torch import load_file, save_file
728
 
729
+ cfg = config or DEFAULT_CONFIG
730
  result = ConversionResult(success=False)
731
 
732
  try:
 
754
 
755
  # Load and convert
756
  if verbose:
757
+ print(f"\n🔄 Converting to v4.1...")
758
 
759
  v3_state = load_file(model_path)
760
+ v4_state, report = convert_state_dict(v3_state, cfg)
761
 
762
  result.source_version = report['source_version']
763
+ result.target_version = "v4.1"
764
  result.source_params = report.get('source_params', 0)
765
  result.target_params = report.get('target_params', 0)
766
  result.params_added = report.get('params_added', 0)
 
 
767
 
768
  if verbose:
769
  print(f" Source: {result.source_version} ({result.source_params:,} params)")
770
+ print(f" Target: {result.target_version} ({result.target_params:,} params)")
771
  print(f" Added: {result.params_added:,} params")
772
 
773
  # Save outputs
 
794
  print(f"\n🔄 Converting old EMA...")
795
  try:
796
  old_ema_state = load_file(ema_path)
797
+ old_ema_v4, _ = convert_state_dict(old_ema_state, cfg)
798
  ema_secondary_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init_ema_secondary.safetensors")
799
  save_file(old_ema_v4, ema_secondary_out)
800
  result.ema_secondary_path = ema_secondary_out
 
847
 
848
  # Output
849
  output_group = parser.add_argument_group('Output options')
850
+ output_group.add_argument('--output-dir', '-o', default='checkpoint_runs/v4_init', help='Output directory')
851
  output_group.add_argument('--name', default='lailah', help='Model name prefix')
852
 
853
  # Conversion