Update scripts/model_v4.py
Browse files- scripts/model_v4.py +18 -17
scripts/model_v4.py
CHANGED
|
@@ -41,12 +41,12 @@ from pathlib import Path
|
|
| 41 |
class TinyFluxConfig:
|
| 42 |
"""
|
| 43 |
Configuration for TinyFlux-Deep v4.1 model.
|
| 44 |
-
|
| 45 |
This config fully defines the model architecture and can be used to:
|
| 46 |
1. Initialize a new model
|
| 47 |
-
2. Convert checkpoints between versions
|
| 48 |
3. Validate checkpoint compatibility
|
| 49 |
-
|
| 50 |
All dimension constraints are validated on creation.
|
| 51 |
"""
|
| 52 |
|
|
@@ -105,10 +105,10 @@ class TinyFluxConfig:
|
|
| 105 |
f"num_attention_heads * attention_head_dim ({expected_hidden})"
|
| 106 |
)
|
| 107 |
|
| 108 |
-
# Validate RoPE dimensions
|
| 109 |
if isinstance(self.axes_dims_rope, list):
|
| 110 |
self.axes_dims_rope = tuple(self.axes_dims_rope)
|
| 111 |
-
|
| 112 |
rope_sum = sum(self.axes_dims_rope)
|
| 113 |
if rope_sum != self.attention_head_dim:
|
| 114 |
raise ValueError(
|
|
@@ -158,11 +158,11 @@ class TinyFluxConfig:
|
|
| 158 |
def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]:
|
| 159 |
"""
|
| 160 |
Validate that a checkpoint matches this config.
|
| 161 |
-
|
| 162 |
Returns list of warnings (empty if perfect match).
|
| 163 |
"""
|
| 164 |
warnings = []
|
| 165 |
-
|
| 166 |
# Check double block count
|
| 167 |
max_double = 0
|
| 168 |
for key in state_dict:
|
|
@@ -171,7 +171,7 @@ class TinyFluxConfig:
|
|
| 171 |
max_double = max(max_double, idx + 1)
|
| 172 |
if max_double != self.num_double_layers:
|
| 173 |
warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}")
|
| 174 |
-
|
| 175 |
# Check single block count
|
| 176 |
max_single = 0
|
| 177 |
for key in state_dict:
|
|
@@ -180,25 +180,25 @@ class TinyFluxConfig:
|
|
| 180 |
max_single = max(max_single, idx + 1)
|
| 181 |
if max_single != self.num_single_layers:
|
| 182 |
warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}")
|
| 183 |
-
|
| 184 |
# Check hidden size from a known weight
|
| 185 |
if "img_in.weight" in state_dict:
|
| 186 |
w = state_dict["img_in.weight"]
|
| 187 |
if w.shape[0] != self.hidden_size:
|
| 188 |
warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}")
|
| 189 |
-
|
| 190 |
# Check for v4.1 components
|
| 191 |
has_sol = any(k.startswith("sol_prior.") for k in state_dict)
|
| 192 |
has_t5 = any(k.startswith("t5_pool.") for k in state_dict)
|
| 193 |
has_lune = any(k.startswith("lune_predictor.") for k in state_dict)
|
| 194 |
-
|
| 195 |
if self.use_sol_prior and not has_sol:
|
| 196 |
warnings.append("config expects sol_prior but checkpoint missing it")
|
| 197 |
if self.use_t5_vec and not has_t5:
|
| 198 |
warnings.append("config expects t5_pool but checkpoint missing it")
|
| 199 |
if self.use_lune_expert and not has_lune:
|
| 200 |
warnings.append("config expects lune_predictor but checkpoint missing it")
|
| 201 |
-
|
| 202 |
return warnings
|
| 203 |
|
| 204 |
|
|
@@ -1024,10 +1024,6 @@ class TinyFluxDeep(nn.Module):
|
|
| 1024 |
else:
|
| 1025 |
self.sol_prior = None
|
| 1026 |
|
| 1027 |
-
# === Legacy support ===
|
| 1028 |
-
# Map old expert_predictor API to lune_predictor
|
| 1029 |
-
self.expert_predictor = self.lune_predictor
|
| 1030 |
-
|
| 1031 |
# Legacy guidance
|
| 1032 |
if cfg.guidance_embeds:
|
| 1033 |
self.guidance_in = MLPEmbedder(cfg.hidden_size)
|
|
@@ -1060,6 +1056,11 @@ class TinyFluxDeep(nn.Module):
|
|
| 1060 |
self.apply(_init)
|
| 1061 |
nn.init.zeros_(self.final_linear.weight)
|
| 1062 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1063 |
def forward(
|
| 1064 |
self,
|
| 1065 |
hidden_states: torch.Tensor,
|
|
@@ -1118,7 +1119,7 @@ class TinyFluxDeep(nn.Module):
|
|
| 1118 |
encoder_hidden_states = encoder_hidden_states.to(dtype=model_dtype)
|
| 1119 |
pooled_projections = pooled_projections.to(dtype=model_dtype)
|
| 1120 |
timestep = timestep.to(dtype=model_dtype)
|
| 1121 |
-
|
| 1122 |
# Cast optional expert inputs if provided
|
| 1123 |
if lune_features is not None:
|
| 1124 |
lune_features = lune_features.to(dtype=model_dtype)
|
|
|
|
| 41 |
class TinyFluxConfig:
|
| 42 |
"""
|
| 43 |
Configuration for TinyFlux-Deep v4.1 model.
|
| 44 |
+
|
| 45 |
This config fully defines the model architecture and can be used to:
|
| 46 |
1. Initialize a new model
|
| 47 |
+
2. Convert checkpoints between versions
|
| 48 |
3. Validate checkpoint compatibility
|
| 49 |
+
|
| 50 |
All dimension constraints are validated on creation.
|
| 51 |
"""
|
| 52 |
|
|
|
|
| 105 |
f"num_attention_heads * attention_head_dim ({expected_hidden})"
|
| 106 |
)
|
| 107 |
|
| 108 |
+
# Validate RoPE dimensions
|
| 109 |
if isinstance(self.axes_dims_rope, list):
|
| 110 |
self.axes_dims_rope = tuple(self.axes_dims_rope)
|
| 111 |
+
|
| 112 |
rope_sum = sum(self.axes_dims_rope)
|
| 113 |
if rope_sum != self.attention_head_dim:
|
| 114 |
raise ValueError(
|
|
|
|
| 158 |
def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]:
|
| 159 |
"""
|
| 160 |
Validate that a checkpoint matches this config.
|
| 161 |
+
|
| 162 |
Returns list of warnings (empty if perfect match).
|
| 163 |
"""
|
| 164 |
warnings = []
|
| 165 |
+
|
| 166 |
# Check double block count
|
| 167 |
max_double = 0
|
| 168 |
for key in state_dict:
|
|
|
|
| 171 |
max_double = max(max_double, idx + 1)
|
| 172 |
if max_double != self.num_double_layers:
|
| 173 |
warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}")
|
| 174 |
+
|
| 175 |
# Check single block count
|
| 176 |
max_single = 0
|
| 177 |
for key in state_dict:
|
|
|
|
| 180 |
max_single = max(max_single, idx + 1)
|
| 181 |
if max_single != self.num_single_layers:
|
| 182 |
warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}")
|
| 183 |
+
|
| 184 |
# Check hidden size from a known weight
|
| 185 |
if "img_in.weight" in state_dict:
|
| 186 |
w = state_dict["img_in.weight"]
|
| 187 |
if w.shape[0] != self.hidden_size:
|
| 188 |
warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}")
|
| 189 |
+
|
| 190 |
# Check for v4.1 components
|
| 191 |
has_sol = any(k.startswith("sol_prior.") for k in state_dict)
|
| 192 |
has_t5 = any(k.startswith("t5_pool.") for k in state_dict)
|
| 193 |
has_lune = any(k.startswith("lune_predictor.") for k in state_dict)
|
| 194 |
+
|
| 195 |
if self.use_sol_prior and not has_sol:
|
| 196 |
warnings.append("config expects sol_prior but checkpoint missing it")
|
| 197 |
if self.use_t5_vec and not has_t5:
|
| 198 |
warnings.append("config expects t5_pool but checkpoint missing it")
|
| 199 |
if self.use_lune_expert and not has_lune:
|
| 200 |
warnings.append("config expects lune_predictor but checkpoint missing it")
|
| 201 |
+
|
| 202 |
return warnings
|
| 203 |
|
| 204 |
|
|
|
|
| 1024 |
else:
|
| 1025 |
self.sol_prior = None
|
| 1026 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1027 |
# Legacy guidance
|
| 1028 |
if cfg.guidance_embeds:
|
| 1029 |
self.guidance_in = MLPEmbedder(cfg.hidden_size)
|
|
|
|
| 1056 |
self.apply(_init)
|
| 1057 |
nn.init.zeros_(self.final_linear.weight)
|
| 1058 |
|
| 1059 |
+
@property
|
| 1060 |
+
def expert_predictor(self):
|
| 1061 |
+
"""Legacy API: alias for lune_predictor."""
|
| 1062 |
+
return self.lune_predictor
|
| 1063 |
+
|
| 1064 |
def forward(
|
| 1065 |
self,
|
| 1066 |
hidden_states: torch.Tensor,
|
|
|
|
| 1119 |
encoder_hidden_states = encoder_hidden_states.to(dtype=model_dtype)
|
| 1120 |
pooled_projections = pooled_projections.to(dtype=model_dtype)
|
| 1121 |
timestep = timestep.to(dtype=model_dtype)
|
| 1122 |
+
|
| 1123 |
# Cast optional expert inputs if provided
|
| 1124 |
if lune_features is not None:
|
| 1125 |
lune_features = lune_features.to(dtype=model_dtype)
|