| | """ |
| | TinyFlux-Deep Weight Converter: v3 → v4 |
| | |
| | Converts v3 checkpoints to v4.1 architecture without destroying pretrained weights. |
| | |
| | Changes from v3 → v4: |
| | - expert_predictor → lune_predictor (rename) |
| | - expert_gate: raw value → logit space (sigmoid(0)=0.5 preserved) |
| | - NEW: sol_prior (attention statistics predictor, 70% geometric prior) |
| | - NEW: t5_pool + text_balance (T5 vec pathway, 50/50 init) |
| | - NEW: spatial_to_mod per attention layer (zero-init = identity) |
| | |
| | All new modules initialize to zero-effect, so converted model behaves |
| | identically to v3 on first forward pass. |
| | |
| | Colab: |
| | from convert_v3_to_v4 import run |
| | run(401434) |
| | |
| | API: |
| | from convert_v3_to_v4 import convert_checkpoint, load_config |
| | config = load_config("path/to/config.json") |
| | result = convert_checkpoint(step=401434, config=config) |
| | |
| | CLI: |
| | python convert_v3_to_v4.py --step 401434 |
| | python convert_v3_to_v4.py --step 401434 --config my_config.json |
| | """ |
| |
|
| | __version__ = "4.1.0" |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import math |
| | import os |
| | import re |
| | import json |
| | from typing import Dict, Tuple, Optional, Union, List |
| | from dataclasses import dataclass, field, asdict |
| | from pathlib import Path |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class TinyFluxConfig: |
| | """ |
| | TinyFlux-Deep v4.1 model configuration. |
| | |
| | This config fully defines the model architecture and can be used to: |
| | 1. Initialize a new model |
| | 2. Convert checkpoints between versions |
| | 3. Validate checkpoint compatibility |
| | |
| | All dimension constraints are validated on creation. |
| | """ |
| | |
| | hidden_size: int = 512 |
| | num_attention_heads: int = 4 |
| | attention_head_dim: int = 128 |
| | in_channels: int = 16 |
| | patch_size: int = 1 |
| | joint_attention_dim: int = 768 |
| | pooled_projection_dim: int = 768 |
| | num_double_layers: int = 15 |
| | num_single_layers: int = 25 |
| | mlp_ratio: float = 4.0 |
| | axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) |
| | |
| | |
| | use_lune_expert: bool = True |
| | lune_expert_dim: int = 1280 |
| | lune_hidden_dim: int = 512 |
| | lune_dropout: float = 0.1 |
| | |
| | |
| | use_sol_prior: bool = True |
| | sol_spatial_size: int = 8 |
| | sol_hidden_dim: int = 256 |
| | sol_geometric_weight: float = 0.7 |
| | |
| | |
| | use_t5_vec: bool = True |
| | t5_pool_mode: str = "attention" |
| | |
| | |
| | lune_distill_mode: str = "cosine" |
| | use_huber_loss: bool = True |
| | huber_delta: float = 0.1 |
| | |
| | |
| | guidance_embeds: bool = False |
| | |
| | def __post_init__(self): |
| | """Validate configuration constraints.""" |
| | |
| | expected_hidden = self.num_attention_heads * self.attention_head_dim |
| | if self.hidden_size != expected_hidden: |
| | raise ValueError( |
| | f"hidden_size ({self.hidden_size}) must equal " |
| | f"num_attention_heads * attention_head_dim ({expected_hidden})" |
| | ) |
| | |
| | |
| | if isinstance(self.axes_dims_rope, list): |
| | self.axes_dims_rope = tuple(self.axes_dims_rope) |
| | |
| | rope_sum = sum(self.axes_dims_rope) |
| | if rope_sum != self.attention_head_dim: |
| | raise ValueError( |
| | f"sum(axes_dims_rope) ({rope_sum}) must equal " |
| | f"attention_head_dim ({self.attention_head_dim})" |
| | ) |
| | |
| | |
| | if not 0.0 <= self.sol_geometric_weight <= 1.0: |
| | raise ValueError(f"sol_geometric_weight must be in [0, 1], got {self.sol_geometric_weight}") |
| | |
| | |
| | @property |
| | def time_dim(self) -> int: |
| | return self.hidden_size |
| | |
| | @property |
| | def clip_dim(self) -> int: |
| | return self.pooled_projection_dim |
| | |
| | @property |
| | def num_heads(self) -> int: |
| | return self.num_attention_heads |
| | |
| | @property |
| | def num_double_blocks(self) -> int: |
| | return self.num_double_layers |
| | |
| | @property |
| | def num_single_blocks(self) -> int: |
| | return self.num_single_layers |
| | |
| | def to_dict(self) -> Dict: |
| | """Convert to JSON-serializable dict.""" |
| | d = asdict(self) |
| | d["axes_dims_rope"] = list(d["axes_dims_rope"]) |
| | return d |
| | |
| | @classmethod |
| | def from_dict(cls, d: Dict) -> "TinyFluxConfig": |
| | """Create from dict, ignoring unknown keys.""" |
| | |
| | known_fields = {f.name for f in cls.__dataclass_fields__.values()} |
| | filtered = {k: v for k, v in d.items() if k in known_fields and not k.startswith("_")} |
| | return cls(**filtered) |
| | |
| | def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]: |
| | """ |
| | Validate that a checkpoint matches this config. |
| | |
| | Returns list of warnings (empty if perfect match). |
| | """ |
| | warnings = [] |
| | |
| | |
| | max_double = 0 |
| | for key in state_dict: |
| | if key.startswith("double_blocks."): |
| | idx = int(key.split(".")[1]) |
| | max_double = max(max_double, idx + 1) |
| | if max_double != self.num_double_layers: |
| | warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}") |
| | |
| | |
| | max_single = 0 |
| | for key in state_dict: |
| | if key.startswith("single_blocks."): |
| | idx = int(key.split(".")[1]) |
| | max_single = max(max_single, idx + 1) |
| | if max_single != self.num_single_layers: |
| | warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}") |
| | |
| | |
| | if "img_embed.proj.weight" in state_dict: |
| | w = state_dict["img_embed.proj.weight"] |
| | if w.shape[0] != self.hidden_size: |
| | warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}") |
| | |
| | return warnings |
| |
|
| |
|
| | def load_config(path: Union[str, Path]) -> TinyFluxConfig: |
| | """ |
| | Load config from JSON file. |
| | |
| | Args: |
| | path: Path to config JSON file |
| | |
| | Returns: |
| | TinyFluxConfig instance |
| | """ |
| | with open(path) as f: |
| | d = json.load(f) |
| | return TinyFluxConfig.from_dict(d) |
| |
|
| |
|
| | def save_config(config: TinyFluxConfig, path: Union[str, Path], conversion_info: Optional[Dict] = None): |
| | """ |
| | Save config to JSON file. |
| | |
| | Args: |
| | config: TinyFluxConfig instance |
| | path: Output path |
| | conversion_info: Optional metadata about conversion |
| | """ |
| | d = config.to_dict() |
| | if conversion_info: |
| | d["_conversion_info"] = conversion_info |
| | |
| | with open(path, "w") as f: |
| | json.dump(d, f, indent=2) |
| |
|
| |
|
| | |
| | DEFAULT_CONFIG = TinyFluxConfig() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class CheckpointInfo: |
| | """Analysis results for a checkpoint.""" |
| | version: str = "unknown" |
| | has_expert_predictor: bool = False |
| | has_lune_predictor: bool = False |
| | has_sol_prior: bool = False |
| | has_t5_pool: bool = False |
| | has_spatial_to_mod: bool = False |
| | num_double_blocks: int = 0 |
| | num_single_blocks: int = 0 |
| | total_params: int = 0 |
| | dtype: str = "float32" |
| |
|
| |
|
| | def analyze_checkpoint(state_dict: Dict[str, torch.Tensor]) -> CheckpointInfo: |
| | """ |
| | Analyze a checkpoint to determine version and contents. |
| | |
| | Args: |
| | state_dict: Model state dictionary |
| | |
| | Returns: |
| | CheckpointInfo with analysis results |
| | """ |
| | info = CheckpointInfo() |
| | info.total_params = sum(p.numel() for p in state_dict.values()) |
| | |
| | |
| | for v in state_dict.values(): |
| | info.dtype = str(v.dtype).replace("torch.", "") |
| | break |
| | |
| | for key in state_dict.keys(): |
| | if key.startswith("expert_predictor."): |
| | info.has_expert_predictor = True |
| | if key.startswith("lune_predictor."): |
| | info.has_lune_predictor = True |
| | if key.startswith("sol_prior."): |
| | info.has_sol_prior = True |
| | if key.startswith("t5_pool."): |
| | info.has_t5_pool = True |
| | if "spatial_to_mod" in key: |
| | info.has_spatial_to_mod = True |
| | if key.startswith("double_blocks."): |
| | idx = int(key.split(".")[1]) |
| | info.num_double_blocks = max(info.num_double_blocks, idx + 1) |
| | if key.startswith("single_blocks."): |
| | idx = int(key.split(".")[1]) |
| | info.num_single_blocks = max(info.num_single_blocks, idx + 1) |
| | |
| | |
| | if info.has_lune_predictor and info.has_sol_prior and info.has_t5_pool: |
| | info.version = "v4.1" |
| | elif info.has_lune_predictor and info.has_sol_prior: |
| | info.version = "v4.0" |
| | elif info.has_expert_predictor: |
| | info.version = "v3" |
| | elif info.has_lune_predictor: |
| | info.version = "v3.5" |
| | else: |
| | info.version = "v2_or_earlier" |
| | |
| | return info |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class ConversionResult: |
| | """Results from a conversion operation.""" |
| | success: bool |
| | model_path: Optional[str] = None |
| | ema_path: Optional[str] = None |
| | ema_secondary_path: Optional[str] = None |
| | config_path: Optional[str] = None |
| | source_version: str = "unknown" |
| | target_version: str = "v4.1" |
| | source_params: int = 0 |
| | target_params: int = 0 |
| | params_added: int = 0 |
| | error: Optional[str] = None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def run( |
| | step: int = 401434, |
| | name: str = "lailah", |
| | output_dir: str = "checkpoint_runs/v4_init", |
| | repo_id: str = "AbstractPhil/tiny-flux-deep", |
| | upload_repo: str = "AbstractPhil/tiny-flux-deep", |
| | upload_subdir: str = "checkpoint_runs/v4_init", |
| | config: Optional[Union[TinyFluxConfig, Dict, str]] = None, |
| | ): |
| | """ |
| | One-liner for Colab. Downloads, converts, saves locally, uploads to HF. |
| | |
| | Args: |
| | step: Checkpoint step number to download |
| | name: Model name prefix for output files |
| | output_dir: Local output directory |
| | repo_id: HuggingFace repo to download from |
| | upload_repo: HuggingFace repo to upload to |
| | upload_subdir: Subdirectory in upload repo |
| | config: Model config - can be: |
| | - None (use default) |
| | - TinyFluxConfig instance |
| | - Dict with config values |
| | - Path to config JSON file |
| | |
| | Usage: |
| | from convert_v3_to_v4 import run |
| | run(401434) |
| | |
| | # With custom config |
| | run(401434, config={"hidden_size": 768, ...}) |
| | run(401434, config="path/to/config.json") |
| | """ |
| | |
| | if config is None: |
| | cfg = DEFAULT_CONFIG |
| | elif isinstance(config, TinyFluxConfig): |
| | cfg = config |
| | elif isinstance(config, dict): |
| | cfg = TinyFluxConfig.from_dict(config) |
| | elif isinstance(config, (str, Path)): |
| | cfg = load_config(config) |
| | else: |
| | raise TypeError(f"config must be TinyFluxConfig, dict, path, or None, got {type(config)}") |
| | |
| | print(f"TinyFlux-Deep v3 → v4.1 Converter") |
| | print(f"=" * 50) |
| | print(f"Config: hidden_size={cfg.hidden_size}, heads={cfg.num_attention_heads}") |
| | print(f" double_layers={cfg.num_double_layers}, single_layers={cfg.num_single_layers}") |
| | |
| | result = convert_checkpoint( |
| | step=step, |
| | model_name=name, |
| | output_dir=output_dir, |
| | repo_id=repo_id, |
| | checkpoint_dir="checkpoints", |
| | config=cfg, |
| | verbose=True, |
| | ) |
| | |
| | if not result.success: |
| | print(f"\n❌ Conversion failed: {result.error}") |
| | return result |
| | |
| | print(f"\n✅ Conversion complete!") |
| | print(f" Source: {result.source_version} ({result.source_params:,} params)") |
| | print(f" Target: {result.target_version} ({result.target_params:,} params)") |
| | print(f" Added: {result.params_added:,} params") |
| | |
| | |
| | config_path = os.path.join(output_dir, f"{name}_{step}_v4_config.json") |
| | conversion_info = { |
| | "source_step": step, |
| | "source_repo": repo_id, |
| | "source_version": result.source_version, |
| | "target_version": result.target_version, |
| | "source_params": result.source_params, |
| | "target_params": result.target_params, |
| | "params_added": result.params_added, |
| | "converter_version": __version__, |
| | "files": { |
| | "model": os.path.basename(result.model_path) if result.model_path else None, |
| | "ema": os.path.basename(result.ema_path) if result.ema_path else None, |
| | "ema_secondary": os.path.basename(result.ema_secondary_path) if result.ema_secondary_path else None, |
| | }, |
| | } |
| | save_config(cfg, config_path, conversion_info) |
| | result.config_path = config_path |
| | print(f"💾 Config: {config_path}") |
| | |
| | |
| | from huggingface_hub import HfApi |
| | api = HfApi() |
| | |
| | print(f"\n📤 Uploading to {upload_repo}/{upload_subdir}/...") |
| | |
| | files_to_upload = [ |
| | result.model_path, |
| | result.ema_path, |
| | result.ema_secondary_path, |
| | config_path, |
| | ] |
| | |
| | for local_path in files_to_upload: |
| | if local_path and os.path.exists(local_path): |
| | filename = os.path.basename(local_path) |
| | remote_path = f"{upload_subdir}/{filename}" |
| | |
| | api.upload_file( |
| | path_or_fileobj=local_path, |
| | path_in_repo=remote_path, |
| | repo_id=upload_repo, |
| | ) |
| | print(f" ✓ {remote_path}") |
| | |
| | print(f"\n✅ Uploaded to {upload_repo}/{upload_subdir}/") |
| | |
| | return result |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def to_logit(p: float) -> float: |
| | """Convert probability to logit for sigmoid init.""" |
| | p = max(1e-4, min(p, 1 - 1e-4)) |
| | return math.log(p / (1 - p)) |
| |
|
| |
|
| | def create_sol_prior_init( |
| | config: TinyFluxConfig, |
| | dtype: torch.dtype = torch.float32, |
| | ) -> Dict[str, torch.Tensor]: |
| | """Create zero-effect initialization for SolAttentionPrior.""" |
| | init = {} |
| | hidden_dim = config.sol_hidden_dim |
| | time_dim = config.time_dim |
| | clip_dim = config.clip_dim |
| | num_heads = config.num_heads |
| | spatial_size = config.sol_spatial_size |
| | |
| | |
| | w0 = torch.empty(hidden_dim, time_dim + clip_dim, dtype=dtype) |
| | nn.init.xavier_uniform_(w0, gain=0.1) |
| | init['sol_prior.stat_predictor.0.weight'] = w0 |
| | init['sol_prior.stat_predictor.0.bias'] = torch.zeros(hidden_dim, dtype=dtype) |
| | |
| | w1 = torch.empty(hidden_dim, hidden_dim, dtype=dtype) |
| | nn.init.xavier_uniform_(w1, gain=0.1) |
| | init['sol_prior.stat_predictor.2.weight'] = w1 |
| | init['sol_prior.stat_predictor.2.bias'] = torch.zeros(hidden_dim, dtype=dtype) |
| | |
| | w2 = torch.empty(3, hidden_dim, dtype=dtype) |
| | nn.init.xavier_uniform_(w2, gain=0.1) |
| | init['sol_prior.stat_predictor.4.weight'] = w2 |
| | init['sol_prior.stat_predictor.4.bias'] = torch.zeros(3, dtype=dtype) |
| | |
| | |
| | w0 = torch.empty(hidden_dim, time_dim + clip_dim, dtype=dtype) |
| | nn.init.xavier_uniform_(w0, gain=0.1) |
| | init['sol_prior.spatial_predictor.0.weight'] = w0 |
| | init['sol_prior.spatial_predictor.0.bias'] = torch.zeros(hidden_dim, dtype=dtype) |
| | |
| | w1 = torch.empty(hidden_dim, hidden_dim, dtype=dtype) |
| | nn.init.xavier_uniform_(w1, gain=0.1) |
| | init['sol_prior.spatial_predictor.2.weight'] = w1 |
| | init['sol_prior.spatial_predictor.2.bias'] = torch.zeros(hidden_dim, dtype=dtype) |
| | |
| | w2 = torch.empty(spatial_size * spatial_size, hidden_dim, dtype=dtype) |
| | nn.init.xavier_uniform_(w2, gain=0.1) |
| | init['sol_prior.spatial_predictor.4.weight'] = w2 |
| | init['sol_prior.spatial_predictor.4.bias'] = torch.zeros(spatial_size * spatial_size, dtype=dtype) |
| | |
| | |
| | w0 = torch.empty(hidden_dim // 2, 3, dtype=dtype) |
| | nn.init.xavier_uniform_(w0, gain=0.1) |
| | init['sol_prior.stat_to_temperature.0.weight'] = w0 |
| | init['sol_prior.stat_to_temperature.0.bias'] = torch.zeros(hidden_dim // 2, dtype=dtype) |
| | |
| | w1 = torch.empty(num_heads, hidden_dim // 2, dtype=dtype) |
| | nn.init.xavier_uniform_(w1, gain=0.1) |
| | init['sol_prior.stat_to_temperature.2.weight'] = w1 |
| | init['sol_prior.stat_to_temperature.2.bias'] = torch.full((num_heads,), 0.54, dtype=dtype) |
| | |
| | |
| | init['sol_prior.spatial_to_qk_scale.weight'] = torch.zeros(num_heads, 1, dtype=dtype) |
| | init['sol_prior.spatial_to_qk_scale.bias'] = torch.ones(num_heads, dtype=dtype) |
| | |
| | |
| | init['sol_prior.blend_gate'] = torch.tensor(to_logit(config.sol_geometric_weight), dtype=dtype) |
| | |
| | return init |
| |
|
| |
|
| | def create_t5_pool_init( |
| | config: TinyFluxConfig, |
| | dtype: torch.dtype = torch.float32, |
| | ) -> Dict[str, torch.Tensor]: |
| | """Create initialization for T5 pool pathway.""" |
| | init = {} |
| | hidden_size = config.hidden_size |
| | joint_attention_dim = config.joint_attention_dim |
| | |
| | w1 = torch.empty(hidden_size, joint_attention_dim, dtype=dtype) |
| | nn.init.xavier_uniform_(w1) |
| | init['t5_pool.0.weight'] = w1 |
| | init['t5_pool.0.bias'] = torch.zeros(hidden_size, dtype=dtype) |
| | |
| | w2 = torch.empty(hidden_size, hidden_size, dtype=dtype) |
| | nn.init.xavier_uniform_(w2) |
| | init['t5_pool.2.weight'] = w2 |
| | init['t5_pool.2.bias'] = torch.zeros(hidden_size, dtype=dtype) |
| | |
| | init['text_balance'] = torch.tensor(0.0, dtype=dtype) |
| | |
| | return init |
| |
|
| |
|
| | def create_spatial_to_mod_init( |
| | num_heads: int = 4, |
| | dtype: torch.dtype = torch.float32, |
| | ) -> Dict[str, torch.Tensor]: |
| | """Create zero-init for spatial_to_mod Conv2d layers.""" |
| | return { |
| | 'weight': torch.zeros(num_heads, 1, 1, 1, dtype=dtype), |
| | 'bias': torch.zeros(num_heads, dtype=dtype), |
| | } |
| |
|
| |
|
| | def convert_state_dict( |
| | v3_state: Dict[str, torch.Tensor], |
| | config: Optional[TinyFluxConfig] = None, |
| | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, any]]: |
| | """ |
| | Convert v3 state dict to v4.1 format. |
| | |
| | Args: |
| | v3_state: v3 state dictionary |
| | config: TinyFluxConfig (uses DEFAULT_CONFIG if None) |
| | |
| | Returns: |
| | Tuple of (v4_state_dict, report_dict) |
| | """ |
| | cfg = config or DEFAULT_CONFIG |
| | v3_info = analyze_checkpoint(v3_state) |
| | |
| | if v3_info.version in ("v4.0", "v4.1"): |
| | return v3_state, {'status': 'already_v4', 'source_version': v3_info.version} |
| | |
| | |
| | warnings = cfg.validate_checkpoint(v3_state) |
| | if warnings: |
| | print(f"⚠️ Config validation warnings:") |
| | for w in warnings: |
| | print(f" - {w}") |
| | |
| | sample_key = list(v3_state.keys())[0] |
| | dtype = v3_state[sample_key].dtype |
| | |
| | report = { |
| | 'status': 'converted', |
| | 'source_version': v3_info.version, |
| | 'source_params': v3_info.total_params, |
| | 'renamed': [], |
| | 'initialized': [], |
| | 'modified': [], |
| | 'warnings': warnings, |
| | } |
| | |
| | v4_state = {} |
| | |
| | |
| | for key, value in v3_state.items(): |
| | if key.startswith('expert_predictor.'): |
| | new_key = key.replace('expert_predictor.', 'lune_predictor.') |
| | v4_state[new_key] = value |
| | report['renamed'].append((key, new_key)) |
| | else: |
| | v4_state[key] = value |
| | |
| | |
| | gate_key = 'lune_predictor.expert_gate' |
| | if gate_key in v4_state: |
| | old_val = v4_state[gate_key].item() |
| | if abs(old_val - 0.5) < 0.3: |
| | new_val = to_logit(old_val) |
| | v4_state[gate_key] = torch.tensor(new_val, dtype=dtype) |
| | report['modified'].append((gate_key, f'{old_val:.4f} → {new_val:.4f}')) |
| | |
| | |
| | if not v3_info.has_sol_prior and cfg.use_sol_prior: |
| | sol_init = create_sol_prior_init(cfg, dtype) |
| | v4_state.update(sol_init) |
| | report['initialized'].extend(list(sol_init.keys())) |
| | |
| | |
| | if not v3_info.has_t5_pool and cfg.use_t5_vec: |
| | t5_init = create_t5_pool_init(cfg, dtype) |
| | v4_state.update(t5_init) |
| | report['initialized'].extend(list(t5_init.keys())) |
| | |
| | |
| | if not v3_info.has_spatial_to_mod and cfg.use_sol_prior: |
| | spatial_init = create_spatial_to_mod_init(cfg.num_heads, dtype) |
| | |
| | for i in range(cfg.num_double_blocks): |
| | prefix = f'double_blocks.{i}.attn.spatial_to_mod.' |
| | v4_state[prefix + 'weight'] = spatial_init['weight'].clone() |
| | v4_state[prefix + 'bias'] = spatial_init['bias'].clone() |
| | report['initialized'].extend([prefix + 'weight', prefix + 'bias']) |
| | |
| | for i in range(cfg.num_single_blocks): |
| | prefix = f'single_blocks.{i}.attn.spatial_to_mod.' |
| | v4_state[prefix + 'weight'] = spatial_init['weight'].clone() |
| | v4_state[prefix + 'bias'] = spatial_init['bias'].clone() |
| | report['initialized'].extend([prefix + 'weight', prefix + 'bias']) |
| | |
| | report['target_params'] = sum(p.numel() for p in v4_state.values()) |
| | report['params_added'] = report['target_params'] - report['source_params'] |
| | |
| | return v4_state, report |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def download_from_hf( |
| | step: int, |
| | repo_id: str = "AbstractPhil/tiny-flux-deep", |
| | checkpoint_dir: str = "checkpoints", |
| | local_dir: str = "./downloads", |
| | include_ema: bool = True, |
| | ) -> Tuple[str, Optional[str]]: |
| | """ |
| | Download checkpoint from HuggingFace. |
| | |
| | Args: |
| | step: Step number to download |
| | repo_id: HuggingFace repository ID |
| | checkpoint_dir: Subdirectory in repo containing checkpoints |
| | local_dir: Local directory to download to |
| | include_ema: Whether to also download EMA weights |
| | |
| | Returns: |
| | Tuple of (model_path, ema_path). ema_path may be None. |
| | """ |
| | from huggingface_hub import hf_hub_download |
| | |
| | model_filename = f"{checkpoint_dir}/step_{step}.safetensors" |
| | model_path = hf_hub_download( |
| | repo_id=repo_id, |
| | filename=model_filename, |
| | local_dir=local_dir, |
| | ) |
| | |
| | ema_path = None |
| | if include_ema: |
| | ema_filename = f"{checkpoint_dir}/step_{step}_ema.safetensors" |
| | try: |
| | ema_path = hf_hub_download( |
| | repo_id=repo_id, |
| | filename=ema_filename, |
| | local_dir=local_dir, |
| | ) |
| | except Exception: |
| | pass |
| | |
| | return model_path, ema_path |
| |
|
| |
|
| | def convert_checkpoint( |
| | step: Optional[int] = None, |
| | input_path: Optional[str] = None, |
| | ema_input_path: Optional[str] = None, |
| | output_dir: str = "checkpoint_runs/v4_init", |
| | model_name: str = "lailah", |
| | repo_id: str = "AbstractPhil/tiny-flux-deep", |
| | checkpoint_dir: str = "checkpoints", |
| | create_fresh_ema: bool = True, |
| | preserve_secondary_ema: bool = True, |
| | config: Optional[TinyFluxConfig] = None, |
| | verbose: bool = True, |
| | ) -> ConversionResult: |
| | """ |
| | Convert a v3 checkpoint to v4.1 format. |
| | |
| | Either `step` (to download from HF) or `input_path` (for local file) must be provided. |
| | |
| | Args: |
| | step: Step number to download from HuggingFace |
| | input_path: Path to local v3 checkpoint |
| | ema_input_path: Path to local v3 EMA checkpoint |
| | output_dir: Directory to save converted checkpoints |
| | model_name: Prefix for output filenames |
| | repo_id: HuggingFace repository ID (if using step) |
| | checkpoint_dir: Subdirectory in repo (if using step) |
| | create_fresh_ema: Create a fresh EMA from converted weights |
| | preserve_secondary_ema: Convert and preserve old EMA as secondary |
| | config: TinyFluxConfig for model architecture |
| | verbose: Print progress messages |
| | |
| | Returns: |
| | ConversionResult with paths and statistics |
| | """ |
| | from safetensors.torch import load_file, save_file |
| | |
| | cfg = config or DEFAULT_CONFIG |
| | result = ConversionResult(success=False) |
| | |
| | try: |
| | |
| | if step is not None: |
| | if verbose: |
| | print(f"📥 Downloading step_{step} from {repo_id}...") |
| | model_path, ema_path = download_from_hf( |
| | step=step, |
| | repo_id=repo_id, |
| | checkpoint_dir=checkpoint_dir, |
| | ) |
| | if verbose: |
| | print(f" ✓ Model: {model_path}") |
| | if ema_path: |
| | print(f" ✓ EMA: {ema_path}") |
| | elif input_path is not None: |
| | model_path = input_path |
| | ema_path = ema_input_path |
| | match = re.search(r'step_(\d+)', model_path) |
| | step = int(match.group(1)) if match else 0 |
| | else: |
| | result.error = "Must provide either step or input_path" |
| | return result |
| | |
| | |
| | if verbose: |
| | print(f"\n🔄 Converting to v4.1...") |
| | |
| | v3_state = load_file(model_path) |
| | v4_state, report = convert_state_dict(v3_state, cfg) |
| | |
| | result.source_version = report['source_version'] |
| | result.target_version = "v4.1" |
| | result.source_params = report.get('source_params', 0) |
| | result.target_params = report.get('target_params', 0) |
| | result.params_added = report.get('params_added', 0) |
| | |
| | if verbose: |
| | print(f" Source: {result.source_version} ({result.source_params:,} params)") |
| | print(f" Target: {result.target_version} ({result.target_params:,} params)") |
| | print(f" Added: {result.params_added:,} params") |
| | |
| | |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | model_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init.safetensors") |
| | save_file(v4_state, model_out) |
| | result.model_path = model_out |
| | if verbose: |
| | print(f"\n💾 Model: {model_out}") |
| | |
| | |
| | if create_fresh_ema: |
| | ema_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init_ema.safetensors") |
| | save_file(v4_state, ema_out) |
| | result.ema_path = ema_out |
| | if verbose: |
| | print(f"💾 EMA (fresh): {ema_out}") |
| | |
| | |
| | if preserve_secondary_ema and ema_path and os.path.exists(ema_path): |
| | if verbose: |
| | print(f"\n🔄 Converting old EMA...") |
| | try: |
| | old_ema_state = load_file(ema_path) |
| | old_ema_v4, _ = convert_state_dict(old_ema_state, cfg) |
| | ema_secondary_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init_ema_secondary.safetensors") |
| | save_file(old_ema_v4, ema_secondary_out) |
| | result.ema_secondary_path = ema_secondary_out |
| | if verbose: |
| | print(f"💾 EMA (secondary): {ema_secondary_out}") |
| | except Exception as e: |
| | if verbose: |
| | print(f"⚠ Failed to convert old EMA: {e}") |
| | |
| | result.success = True |
| | |
| | except Exception as e: |
| | result.error = str(e) |
| | if verbose: |
| | print(f"❌ Error: {e}") |
| | |
| | return result |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def create_parser(): |
| | """Create argument parser for CLI.""" |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser( |
| | description='Convert TinyFlux-Deep v3 checkpoints to v4 format', |
| | formatter_class=argparse.RawDescriptionHelpFormatter, |
| | epilog=""" |
| | Examples: |
| | python convert_v3_to_v4.py --step 401434 |
| | python convert_v3_to_v4.py --input model_v3.safetensors |
| | python convert_v3_to_v4.py --step 401434 --analyze-only |
| | python convert_v3_to_v4.py --step 401434 --output-dir my_converted --name mymodel |
| | """ |
| | ) |
| | |
| | |
| | input_group = parser.add_argument_group('Input (one required)') |
| | input_group.add_argument('--step', type=int, help='Step number to download from HuggingFace') |
| | input_group.add_argument('--input', '-i', dest='input_path', help='Path to local v3 checkpoint') |
| | input_group.add_argument('--ema-input', dest='ema_input_path', help='Path to local v3 EMA checkpoint') |
| | |
| | |
| | hf_group = parser.add_argument_group('HuggingFace options') |
| | hf_group.add_argument('--repo', default='AbstractPhil/tiny-flux-deep', help='HuggingFace repo ID') |
| | hf_group.add_argument('--checkpoint-dir', default='checkpoints', help='Subdirectory in repo') |
| | |
| | |
| | output_group = parser.add_argument_group('Output options') |
| | output_group.add_argument('--output-dir', '-o', default='checkpoint_runs/v4_init', help='Output directory') |
| | output_group.add_argument('--name', default='lailah', help='Model name prefix') |
| | |
| | |
| | conv_group = parser.add_argument_group('Conversion options') |
| | conv_group.add_argument('--no-fresh-ema', action='store_true', help='Do not create fresh EMA') |
| | conv_group.add_argument('--no-secondary-ema', action='store_true', help='Do not preserve old EMA') |
| | conv_group.add_argument('--analyze-only', action='store_true', help='Only analyze, do not convert') |
| | conv_group.add_argument('--quiet', '-q', action='store_true', help='Suppress progress messages') |
| | |
| | return parser |
| |
|
| |
|
| | def cli_main(): |
| | """CLI entry point.""" |
| | parser = create_parser() |
| | args = parser.parse_args() |
| | |
| | if not args.step and not args.input_path: |
| | parser.error("Must specify either --step or --input") |
| | |
| | |
| | if args.analyze_only: |
| | from safetensors.torch import load_file |
| | |
| | if args.step: |
| | model_path, _ = download_from_hf( |
| | step=args.step, |
| | repo_id=args.repo, |
| | checkpoint_dir=args.checkpoint_dir, |
| | ) |
| | else: |
| | model_path = args.input_path |
| | |
| | state = load_file(model_path) |
| | info = analyze_checkpoint(state) |
| | |
| | print(f"\nCheckpoint: {model_path}") |
| | print(f" Version: {info.version}") |
| | print(f" Total params: {info.total_params:,}") |
| | print(f" Double blocks: {info.num_double_blocks}") |
| | print(f" Single blocks: {info.num_single_blocks}") |
| | print(f" Has expert_predictor: {info.has_expert_predictor}") |
| | print(f" Has lune_predictor: {info.has_lune_predictor}") |
| | print(f" Has sol_prior: {info.has_sol_prior}") |
| | print(f" Has t5_pool: {info.has_t5_pool}") |
| | print(f" Has spatial_to_mod: {info.has_spatial_to_mod}") |
| | return |
| | |
| | |
| | result = convert_checkpoint( |
| | step=args.step, |
| | input_path=args.input_path, |
| | ema_input_path=args.ema_input_path, |
| | output_dir=args.output_dir, |
| | model_name=args.name, |
| | repo_id=args.repo, |
| | checkpoint_dir=args.checkpoint_dir, |
| | create_fresh_ema=not args.no_fresh_ema, |
| | preserve_secondary_ema=not args.no_secondary_ema, |
| | verbose=not args.quiet, |
| | ) |
| | |
| | if result.success: |
| | if not args.quiet: |
| | print("\n" + "=" * 60) |
| | print("✅ Conversion complete!") |
| | print("=" * 60) |
| | print(f"\nOutput files:") |
| | if result.model_path: |
| | print(f" Model: {result.model_path}") |
| | if result.ema_path: |
| | print(f" EMA: {result.ema_path}") |
| | if result.ema_secondary_path: |
| | print(f" EMA (secondary): {result.ema_secondary_path}") |
| | else: |
| | print(f"\n❌ Conversion failed: {result.error}") |
| | exit(1) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | cli_main() |