""" TinyFlux-Deep Weight Converter: v3 → v4 Converts v3 checkpoints to v4.1 architecture without destroying pretrained weights. Changes from v3 → v4: - expert_predictor → lune_predictor (rename) - expert_gate: raw value → logit space (sigmoid(0)=0.5 preserved) - NEW: sol_prior (attention statistics predictor, 70% geometric prior) - NEW: t5_pool + text_balance (T5 vec pathway, 50/50 init) - NEW: spatial_to_mod per attention layer (zero-init = identity) All new modules initialize to zero-effect, so converted model behaves identically to v3 on first forward pass. Colab: from convert_v3_to_v4 import run run(401434) API: from convert_v3_to_v4 import convert_checkpoint, load_config config = load_config("path/to/config.json") result = convert_checkpoint(step=401434, config=config) CLI: python convert_v3_to_v4.py --step 401434 python convert_v3_to_v4.py --step 401434 --config my_config.json """ __version__ = "4.1.0" import torch import torch.nn as nn import math import os import re import json from typing import Dict, Tuple, Optional, Union, List from dataclasses import dataclass, field, asdict from pathlib import Path # ============================================================================= # Configuration # ============================================================================= @dataclass class TinyFluxConfig: """ TinyFlux-Deep v4.1 model configuration. This config fully defines the model architecture and can be used to: 1. Initialize a new model 2. Convert checkpoints between versions 3. Validate checkpoint compatibility All dimension constraints are validated on creation. """ # Core architecture hidden_size: int = 512 num_attention_heads: int = 4 attention_head_dim: int = 128 in_channels: int = 16 patch_size: int = 1 joint_attention_dim: int = 768 # T5 sequence dim pooled_projection_dim: int = 768 # CLIP pooled dim num_double_layers: int = 15 num_single_layers: int = 25 mlp_ratio: float = 4.0 axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) # Lune expert predictor (trajectory guidance) use_lune_expert: bool = True lune_expert_dim: int = 1280 # SD1.5 mid-block dimension lune_hidden_dim: int = 512 lune_dropout: float = 0.1 # Sol attention prior (structural guidance) use_sol_prior: bool = True sol_spatial_size: int = 8 sol_hidden_dim: int = 256 sol_geometric_weight: float = 0.7 # 70% geometric, 30% learned # T5 vec enhancement use_t5_vec: bool = True t5_pool_mode: str = "attention" # "attention", "mean", "cls" # Loss configuration (for training) lune_distill_mode: str = "cosine" # "hard", "soft", "cosine", "huber" use_huber_loss: bool = True huber_delta: float = 0.1 # Legacy guidance_embeds: bool = False def __post_init__(self): """Validate configuration constraints.""" # Validate attention dimensions expected_hidden = self.num_attention_heads * self.attention_head_dim if self.hidden_size != expected_hidden: raise ValueError( f"hidden_size ({self.hidden_size}) must equal " f"num_attention_heads * attention_head_dim ({expected_hidden})" ) # Validate RoPE dimensions if isinstance(self.axes_dims_rope, list): self.axes_dims_rope = tuple(self.axes_dims_rope) rope_sum = sum(self.axes_dims_rope) if rope_sum != self.attention_head_dim: raise ValueError( f"sum(axes_dims_rope) ({rope_sum}) must equal " f"attention_head_dim ({self.attention_head_dim})" ) # Validate sol_geometric_weight if not 0.0 <= self.sol_geometric_weight <= 1.0: raise ValueError(f"sol_geometric_weight must be in [0, 1], got {self.sol_geometric_weight}") # Derived properties for converter compatibility @property def time_dim(self) -> int: return self.hidden_size @property def clip_dim(self) -> int: return self.pooled_projection_dim @property def num_heads(self) -> int: return self.num_attention_heads @property def num_double_blocks(self) -> int: return self.num_double_layers @property def num_single_blocks(self) -> int: return self.num_single_layers def to_dict(self) -> Dict: """Convert to JSON-serializable dict.""" d = asdict(self) d["axes_dims_rope"] = list(d["axes_dims_rope"]) return d @classmethod def from_dict(cls, d: Dict) -> "TinyFluxConfig": """Create from dict, ignoring unknown keys.""" # Filter to known fields known_fields = {f.name for f in cls.__dataclass_fields__.values()} filtered = {k: v for k, v in d.items() if k in known_fields and not k.startswith("_")} return cls(**filtered) def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]: """ Validate that a checkpoint matches this config. Returns list of warnings (empty if perfect match). """ warnings = [] # Check double block count max_double = 0 for key in state_dict: if key.startswith("double_blocks."): idx = int(key.split(".")[1]) max_double = max(max_double, idx + 1) if max_double != self.num_double_layers: warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}") # Check single block count max_single = 0 for key in state_dict: if key.startswith("single_blocks."): idx = int(key.split(".")[1]) max_single = max(max_single, idx + 1) if max_single != self.num_single_layers: warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}") # Check hidden size from a known weight if "img_embed.proj.weight" in state_dict: w = state_dict["img_embed.proj.weight"] if w.shape[0] != self.hidden_size: warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}") return warnings def load_config(path: Union[str, Path]) -> TinyFluxConfig: """ Load config from JSON file. Args: path: Path to config JSON file Returns: TinyFluxConfig instance """ with open(path) as f: d = json.load(f) return TinyFluxConfig.from_dict(d) def save_config(config: TinyFluxConfig, path: Union[str, Path], conversion_info: Optional[Dict] = None): """ Save config to JSON file. Args: config: TinyFluxConfig instance path: Output path conversion_info: Optional metadata about conversion """ d = config.to_dict() if conversion_info: d["_conversion_info"] = conversion_info with open(path, "w") as f: json.dump(d, f, indent=2) # Default configuration DEFAULT_CONFIG = TinyFluxConfig() # ============================================================================= # Checkpoint Analysis # ============================================================================= @dataclass class CheckpointInfo: """Analysis results for a checkpoint.""" version: str = "unknown" has_expert_predictor: bool = False has_lune_predictor: bool = False has_sol_prior: bool = False has_t5_pool: bool = False has_spatial_to_mod: bool = False num_double_blocks: int = 0 num_single_blocks: int = 0 total_params: int = 0 dtype: str = "float32" def analyze_checkpoint(state_dict: Dict[str, torch.Tensor]) -> CheckpointInfo: """ Analyze a checkpoint to determine version and contents. Args: state_dict: Model state dictionary Returns: CheckpointInfo with analysis results """ info = CheckpointInfo() info.total_params = sum(p.numel() for p in state_dict.values()) # Detect dtype for v in state_dict.values(): info.dtype = str(v.dtype).replace("torch.", "") break for key in state_dict.keys(): if key.startswith("expert_predictor."): info.has_expert_predictor = True if key.startswith("lune_predictor."): info.has_lune_predictor = True if key.startswith("sol_prior."): info.has_sol_prior = True if key.startswith("t5_pool."): info.has_t5_pool = True if "spatial_to_mod" in key: info.has_spatial_to_mod = True if key.startswith("double_blocks."): idx = int(key.split(".")[1]) info.num_double_blocks = max(info.num_double_blocks, idx + 1) if key.startswith("single_blocks."): idx = int(key.split(".")[1]) info.num_single_blocks = max(info.num_single_blocks, idx + 1) # Determine version if info.has_lune_predictor and info.has_sol_prior and info.has_t5_pool: info.version = "v4.1" elif info.has_lune_predictor and info.has_sol_prior: info.version = "v4.0" elif info.has_expert_predictor: info.version = "v3" elif info.has_lune_predictor: info.version = "v3.5" else: info.version = "v2_or_earlier" return info # ============================================================================= # Conversion Result # ============================================================================= @dataclass class ConversionResult: """Results from a conversion operation.""" success: bool model_path: Optional[str] = None ema_path: Optional[str] = None ema_secondary_path: Optional[str] = None config_path: Optional[str] = None source_version: str = "unknown" target_version: str = "v4.1" source_params: int = 0 target_params: int = 0 params_added: int = 0 error: Optional[str] = None # ============================================================================= # Colab Entry Point # ============================================================================= def run( step: int = 401434, name: str = "lailah", output_dir: str = "checkpoint_runs/v4_init", repo_id: str = "AbstractPhil/tiny-flux-deep", upload_repo: str = "AbstractPhil/tiny-flux-deep", upload_subdir: str = "checkpoint_runs/v4_init", config: Optional[Union[TinyFluxConfig, Dict, str]] = None, ): """ One-liner for Colab. Downloads, converts, saves locally, uploads to HF. Args: step: Checkpoint step number to download name: Model name prefix for output files output_dir: Local output directory repo_id: HuggingFace repo to download from upload_repo: HuggingFace repo to upload to upload_subdir: Subdirectory in upload repo config: Model config - can be: - None (use default) - TinyFluxConfig instance - Dict with config values - Path to config JSON file Usage: from convert_v3_to_v4 import run run(401434) # With custom config run(401434, config={"hidden_size": 768, ...}) run(401434, config="path/to/config.json") """ # Resolve config if config is None: cfg = DEFAULT_CONFIG elif isinstance(config, TinyFluxConfig): cfg = config elif isinstance(config, dict): cfg = TinyFluxConfig.from_dict(config) elif isinstance(config, (str, Path)): cfg = load_config(config) else: raise TypeError(f"config must be TinyFluxConfig, dict, path, or None, got {type(config)}") print(f"TinyFlux-Deep v3 → v4.1 Converter") print(f"=" * 50) print(f"Config: hidden_size={cfg.hidden_size}, heads={cfg.num_attention_heads}") print(f" double_layers={cfg.num_double_layers}, single_layers={cfg.num_single_layers}") result = convert_checkpoint( step=step, model_name=name, output_dir=output_dir, repo_id=repo_id, checkpoint_dir="checkpoints", config=cfg, verbose=True, ) if not result.success: print(f"\n❌ Conversion failed: {result.error}") return result print(f"\n✅ Conversion complete!") print(f" Source: {result.source_version} ({result.source_params:,} params)") print(f" Target: {result.target_version} ({result.target_params:,} params)") print(f" Added: {result.params_added:,} params") # Save config config_path = os.path.join(output_dir, f"{name}_{step}_v4_config.json") conversion_info = { "source_step": step, "source_repo": repo_id, "source_version": result.source_version, "target_version": result.target_version, "source_params": result.source_params, "target_params": result.target_params, "params_added": result.params_added, "converter_version": __version__, "files": { "model": os.path.basename(result.model_path) if result.model_path else None, "ema": os.path.basename(result.ema_path) if result.ema_path else None, "ema_secondary": os.path.basename(result.ema_secondary_path) if result.ema_secondary_path else None, }, } save_config(cfg, config_path, conversion_info) result.config_path = config_path print(f"💾 Config: {config_path}") # Upload to HuggingFace from huggingface_hub import HfApi api = HfApi() print(f"\n📤 Uploading to {upload_repo}/{upload_subdir}/...") files_to_upload = [ result.model_path, result.ema_path, result.ema_secondary_path, config_path, ] for local_path in files_to_upload: if local_path and os.path.exists(local_path): filename = os.path.basename(local_path) remote_path = f"{upload_subdir}/{filename}" api.upload_file( path_or_fileobj=local_path, path_in_repo=remote_path, repo_id=upload_repo, ) print(f" ✓ {remote_path}") print(f"\n✅ Uploaded to {upload_repo}/{upload_subdir}/") return result # ============================================================================= # Weight Initialization Functions # ============================================================================= def to_logit(p: float) -> float: """Convert probability to logit for sigmoid init.""" p = max(1e-4, min(p, 1 - 1e-4)) return math.log(p / (1 - p)) def create_sol_prior_init( config: TinyFluxConfig, dtype: torch.dtype = torch.float32, ) -> Dict[str, torch.Tensor]: """Create zero-effect initialization for SolAttentionPrior.""" init = {} hidden_dim = config.sol_hidden_dim time_dim = config.time_dim clip_dim = config.clip_dim num_heads = config.num_heads spatial_size = config.sol_spatial_size # stat_predictor w0 = torch.empty(hidden_dim, time_dim + clip_dim, dtype=dtype) nn.init.xavier_uniform_(w0, gain=0.1) init['sol_prior.stat_predictor.0.weight'] = w0 init['sol_prior.stat_predictor.0.bias'] = torch.zeros(hidden_dim, dtype=dtype) w1 = torch.empty(hidden_dim, hidden_dim, dtype=dtype) nn.init.xavier_uniform_(w1, gain=0.1) init['sol_prior.stat_predictor.2.weight'] = w1 init['sol_prior.stat_predictor.2.bias'] = torch.zeros(hidden_dim, dtype=dtype) w2 = torch.empty(3, hidden_dim, dtype=dtype) nn.init.xavier_uniform_(w2, gain=0.1) init['sol_prior.stat_predictor.4.weight'] = w2 init['sol_prior.stat_predictor.4.bias'] = torch.zeros(3, dtype=dtype) # spatial_predictor w0 = torch.empty(hidden_dim, time_dim + clip_dim, dtype=dtype) nn.init.xavier_uniform_(w0, gain=0.1) init['sol_prior.spatial_predictor.0.weight'] = w0 init['sol_prior.spatial_predictor.0.bias'] = torch.zeros(hidden_dim, dtype=dtype) w1 = torch.empty(hidden_dim, hidden_dim, dtype=dtype) nn.init.xavier_uniform_(w1, gain=0.1) init['sol_prior.spatial_predictor.2.weight'] = w1 init['sol_prior.spatial_predictor.2.bias'] = torch.zeros(hidden_dim, dtype=dtype) w2 = torch.empty(spatial_size * spatial_size, hidden_dim, dtype=dtype) nn.init.xavier_uniform_(w2, gain=0.1) init['sol_prior.spatial_predictor.4.weight'] = w2 init['sol_prior.spatial_predictor.4.bias'] = torch.zeros(spatial_size * spatial_size, dtype=dtype) # stat_to_temperature w0 = torch.empty(hidden_dim // 2, 3, dtype=dtype) nn.init.xavier_uniform_(w0, gain=0.1) init['sol_prior.stat_to_temperature.0.weight'] = w0 init['sol_prior.stat_to_temperature.0.bias'] = torch.zeros(hidden_dim // 2, dtype=dtype) w1 = torch.empty(num_heads, hidden_dim // 2, dtype=dtype) nn.init.xavier_uniform_(w1, gain=0.1) init['sol_prior.stat_to_temperature.2.weight'] = w1 init['sol_prior.stat_to_temperature.2.bias'] = torch.full((num_heads,), 0.54, dtype=dtype) # spatial_to_qk_scale init['sol_prior.spatial_to_qk_scale.weight'] = torch.zeros(num_heads, 1, dtype=dtype) init['sol_prior.spatial_to_qk_scale.bias'] = torch.ones(num_heads, dtype=dtype) # blend_gate init['sol_prior.blend_gate'] = torch.tensor(to_logit(config.sol_geometric_weight), dtype=dtype) return init def create_t5_pool_init( config: TinyFluxConfig, dtype: torch.dtype = torch.float32, ) -> Dict[str, torch.Tensor]: """Create initialization for T5 pool pathway.""" init = {} hidden_size = config.hidden_size joint_attention_dim = config.joint_attention_dim w1 = torch.empty(hidden_size, joint_attention_dim, dtype=dtype) nn.init.xavier_uniform_(w1) init['t5_pool.0.weight'] = w1 init['t5_pool.0.bias'] = torch.zeros(hidden_size, dtype=dtype) w2 = torch.empty(hidden_size, hidden_size, dtype=dtype) nn.init.xavier_uniform_(w2) init['t5_pool.2.weight'] = w2 init['t5_pool.2.bias'] = torch.zeros(hidden_size, dtype=dtype) init['text_balance'] = torch.tensor(0.0, dtype=dtype) return init def create_spatial_to_mod_init( num_heads: int = 4, dtype: torch.dtype = torch.float32, ) -> Dict[str, torch.Tensor]: """Create zero-init for spatial_to_mod Conv2d layers.""" return { 'weight': torch.zeros(num_heads, 1, 1, 1, dtype=dtype), 'bias': torch.zeros(num_heads, dtype=dtype), } def convert_state_dict( v3_state: Dict[str, torch.Tensor], config: Optional[TinyFluxConfig] = None, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, any]]: """ Convert v3 state dict to v4.1 format. Args: v3_state: v3 state dictionary config: TinyFluxConfig (uses DEFAULT_CONFIG if None) Returns: Tuple of (v4_state_dict, report_dict) """ cfg = config or DEFAULT_CONFIG v3_info = analyze_checkpoint(v3_state) if v3_info.version in ("v4.0", "v4.1"): return v3_state, {'status': 'already_v4', 'source_version': v3_info.version} # Validate config matches checkpoint structure warnings = cfg.validate_checkpoint(v3_state) if warnings: print(f"⚠️ Config validation warnings:") for w in warnings: print(f" - {w}") sample_key = list(v3_state.keys())[0] dtype = v3_state[sample_key].dtype report = { 'status': 'converted', 'source_version': v3_info.version, 'source_params': v3_info.total_params, 'renamed': [], 'initialized': [], 'modified': [], 'warnings': warnings, } v4_state = {} # Step 1: Rename expert_predictor → lune_predictor for key, value in v3_state.items(): if key.startswith('expert_predictor.'): new_key = key.replace('expert_predictor.', 'lune_predictor.') v4_state[new_key] = value report['renamed'].append((key, new_key)) else: v4_state[key] = value # Step 2: Fix expert_gate value (raw → logit space) gate_key = 'lune_predictor.expert_gate' if gate_key in v4_state: old_val = v4_state[gate_key].item() if abs(old_val - 0.5) < 0.3: # Looks like raw probability, not logit new_val = to_logit(old_val) v4_state[gate_key] = torch.tensor(new_val, dtype=dtype) report['modified'].append((gate_key, f'{old_val:.4f} → {new_val:.4f}')) # Step 3: Initialize SolAttentionPrior (if missing) if not v3_info.has_sol_prior and cfg.use_sol_prior: sol_init = create_sol_prior_init(cfg, dtype) v4_state.update(sol_init) report['initialized'].extend(list(sol_init.keys())) # Step 4: Initialize T5 pool (if missing) if not v3_info.has_t5_pool and cfg.use_t5_vec: t5_init = create_t5_pool_init(cfg, dtype) v4_state.update(t5_init) report['initialized'].extend(list(t5_init.keys())) # Step 5: Initialize spatial_to_mod in attention layers (if missing) if not v3_info.has_spatial_to_mod and cfg.use_sol_prior: spatial_init = create_spatial_to_mod_init(cfg.num_heads, dtype) for i in range(cfg.num_double_blocks): prefix = f'double_blocks.{i}.attn.spatial_to_mod.' v4_state[prefix + 'weight'] = spatial_init['weight'].clone() v4_state[prefix + 'bias'] = spatial_init['bias'].clone() report['initialized'].extend([prefix + 'weight', prefix + 'bias']) for i in range(cfg.num_single_blocks): prefix = f'single_blocks.{i}.attn.spatial_to_mod.' v4_state[prefix + 'weight'] = spatial_init['weight'].clone() v4_state[prefix + 'bias'] = spatial_init['bias'].clone() report['initialized'].extend([prefix + 'weight', prefix + 'bias']) report['target_params'] = sum(p.numel() for p in v4_state.values()) report['params_added'] = report['target_params'] - report['source_params'] return v4_state, report # ============================================================================= # High-Level API # ============================================================================= def download_from_hf( step: int, repo_id: str = "AbstractPhil/tiny-flux-deep", checkpoint_dir: str = "checkpoints", local_dir: str = "./downloads", include_ema: bool = True, ) -> Tuple[str, Optional[str]]: """ Download checkpoint from HuggingFace. Args: step: Step number to download repo_id: HuggingFace repository ID checkpoint_dir: Subdirectory in repo containing checkpoints local_dir: Local directory to download to include_ema: Whether to also download EMA weights Returns: Tuple of (model_path, ema_path). ema_path may be None. """ from huggingface_hub import hf_hub_download model_filename = f"{checkpoint_dir}/step_{step}.safetensors" model_path = hf_hub_download( repo_id=repo_id, filename=model_filename, local_dir=local_dir, ) ema_path = None if include_ema: ema_filename = f"{checkpoint_dir}/step_{step}_ema.safetensors" try: ema_path = hf_hub_download( repo_id=repo_id, filename=ema_filename, local_dir=local_dir, ) except Exception: pass return model_path, ema_path def convert_checkpoint( step: Optional[int] = None, input_path: Optional[str] = None, ema_input_path: Optional[str] = None, output_dir: str = "checkpoint_runs/v4_init", model_name: str = "lailah", repo_id: str = "AbstractPhil/tiny-flux-deep", checkpoint_dir: str = "checkpoints", create_fresh_ema: bool = True, preserve_secondary_ema: bool = True, config: Optional[TinyFluxConfig] = None, verbose: bool = True, ) -> ConversionResult: """ Convert a v3 checkpoint to v4.1 format. Either `step` (to download from HF) or `input_path` (for local file) must be provided. Args: step: Step number to download from HuggingFace input_path: Path to local v3 checkpoint ema_input_path: Path to local v3 EMA checkpoint output_dir: Directory to save converted checkpoints model_name: Prefix for output filenames repo_id: HuggingFace repository ID (if using step) checkpoint_dir: Subdirectory in repo (if using step) create_fresh_ema: Create a fresh EMA from converted weights preserve_secondary_ema: Convert and preserve old EMA as secondary config: TinyFluxConfig for model architecture verbose: Print progress messages Returns: ConversionResult with paths and statistics """ from safetensors.torch import load_file, save_file cfg = config or DEFAULT_CONFIG result = ConversionResult(success=False) try: # Get checkpoint paths if step is not None: if verbose: print(f"📥 Downloading step_{step} from {repo_id}...") model_path, ema_path = download_from_hf( step=step, repo_id=repo_id, checkpoint_dir=checkpoint_dir, ) if verbose: print(f" ✓ Model: {model_path}") if ema_path: print(f" ✓ EMA: {ema_path}") elif input_path is not None: model_path = input_path ema_path = ema_input_path match = re.search(r'step_(\d+)', model_path) step = int(match.group(1)) if match else 0 else: result.error = "Must provide either step or input_path" return result # Load and convert if verbose: print(f"\n🔄 Converting to v4.1...") v3_state = load_file(model_path) v4_state, report = convert_state_dict(v3_state, cfg) result.source_version = report['source_version'] result.target_version = "v4.1" result.source_params = report.get('source_params', 0) result.target_params = report.get('target_params', 0) result.params_added = report.get('params_added', 0) if verbose: print(f" Source: {result.source_version} ({result.source_params:,} params)") print(f" Target: {result.target_version} ({result.target_params:,} params)") print(f" Added: {result.params_added:,} params") # Save outputs os.makedirs(output_dir, exist_ok=True) # Main model model_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init.safetensors") save_file(v4_state, model_out) result.model_path = model_out if verbose: print(f"\n💾 Model: {model_out}") # Fresh EMA if create_fresh_ema: ema_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init_ema.safetensors") save_file(v4_state, ema_out) result.ema_path = ema_out if verbose: print(f"💾 EMA (fresh): {ema_out}") # Secondary EMA if preserve_secondary_ema and ema_path and os.path.exists(ema_path): if verbose: print(f"\n🔄 Converting old EMA...") try: old_ema_state = load_file(ema_path) old_ema_v4, _ = convert_state_dict(old_ema_state, cfg) ema_secondary_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init_ema_secondary.safetensors") save_file(old_ema_v4, ema_secondary_out) result.ema_secondary_path = ema_secondary_out if verbose: print(f"💾 EMA (secondary): {ema_secondary_out}") except Exception as e: if verbose: print(f"⚠ Failed to convert old EMA: {e}") result.success = True except Exception as e: result.error = str(e) if verbose: print(f"❌ Error: {e}") return result # ============================================================================= # CLI Interface # ============================================================================= def create_parser(): """Create argument parser for CLI.""" import argparse parser = argparse.ArgumentParser( description='Convert TinyFlux-Deep v3 checkpoints to v4 format', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python convert_v3_to_v4.py --step 401434 python convert_v3_to_v4.py --input model_v3.safetensors python convert_v3_to_v4.py --step 401434 --analyze-only python convert_v3_to_v4.py --step 401434 --output-dir my_converted --name mymodel """ ) # Input input_group = parser.add_argument_group('Input (one required)') input_group.add_argument('--step', type=int, help='Step number to download from HuggingFace') input_group.add_argument('--input', '-i', dest='input_path', help='Path to local v3 checkpoint') input_group.add_argument('--ema-input', dest='ema_input_path', help='Path to local v3 EMA checkpoint') # HuggingFace hf_group = parser.add_argument_group('HuggingFace options') hf_group.add_argument('--repo', default='AbstractPhil/tiny-flux-deep', help='HuggingFace repo ID') hf_group.add_argument('--checkpoint-dir', default='checkpoints', help='Subdirectory in repo') # Output output_group = parser.add_argument_group('Output options') output_group.add_argument('--output-dir', '-o', default='checkpoint_runs/v4_init', help='Output directory') output_group.add_argument('--name', default='lailah', help='Model name prefix') # Conversion conv_group = parser.add_argument_group('Conversion options') conv_group.add_argument('--no-fresh-ema', action='store_true', help='Do not create fresh EMA') conv_group.add_argument('--no-secondary-ema', action='store_true', help='Do not preserve old EMA') conv_group.add_argument('--analyze-only', action='store_true', help='Only analyze, do not convert') conv_group.add_argument('--quiet', '-q', action='store_true', help='Suppress progress messages') return parser def cli_main(): """CLI entry point.""" parser = create_parser() args = parser.parse_args() if not args.step and not args.input_path: parser.error("Must specify either --step or --input") # Analyze only if args.analyze_only: from safetensors.torch import load_file if args.step: model_path, _ = download_from_hf( step=args.step, repo_id=args.repo, checkpoint_dir=args.checkpoint_dir, ) else: model_path = args.input_path state = load_file(model_path) info = analyze_checkpoint(state) print(f"\nCheckpoint: {model_path}") print(f" Version: {info.version}") print(f" Total params: {info.total_params:,}") print(f" Double blocks: {info.num_double_blocks}") print(f" Single blocks: {info.num_single_blocks}") print(f" Has expert_predictor: {info.has_expert_predictor}") print(f" Has lune_predictor: {info.has_lune_predictor}") print(f" Has sol_prior: {info.has_sol_prior}") print(f" Has t5_pool: {info.has_t5_pool}") print(f" Has spatial_to_mod: {info.has_spatial_to_mod}") return # Convert result = convert_checkpoint( step=args.step, input_path=args.input_path, ema_input_path=args.ema_input_path, output_dir=args.output_dir, model_name=args.name, repo_id=args.repo, checkpoint_dir=args.checkpoint_dir, create_fresh_ema=not args.no_fresh_ema, preserve_secondary_ema=not args.no_secondary_ema, verbose=not args.quiet, ) if result.success: if not args.quiet: print("\n" + "=" * 60) print("✅ Conversion complete!") print("=" * 60) print(f"\nOutput files:") if result.model_path: print(f" Model: {result.model_path}") if result.ema_path: print(f" EMA: {result.ema_path}") if result.ema_secondary_path: print(f" EMA (secondary): {result.ema_secondary_path}") else: print(f"\n❌ Conversion failed: {result.error}") exit(1) if __name__ == '__main__': cli_main()