AbstractPhil commited on
Commit
b9c4369
·
verified ·
1 Parent(s): 3f7d369

Update scripts/model_v4.py

Browse files
Files changed (1) hide show
  1. scripts/model_v4.py +18 -17
scripts/model_v4.py CHANGED
@@ -41,12 +41,12 @@ from pathlib import Path
41
  class TinyFluxConfig:
42
  """
43
  Configuration for TinyFlux-Deep v4.1 model.
44
-
45
  This config fully defines the model architecture and can be used to:
46
  1. Initialize a new model
47
- 2. Convert checkpoints between versions
48
  3. Validate checkpoint compatibility
49
-
50
  All dimension constraints are validated on creation.
51
  """
52
 
@@ -105,10 +105,10 @@ class TinyFluxConfig:
105
  f"num_attention_heads * attention_head_dim ({expected_hidden})"
106
  )
107
 
108
- # Validate RoPE dimensions
109
  if isinstance(self.axes_dims_rope, list):
110
  self.axes_dims_rope = tuple(self.axes_dims_rope)
111
-
112
  rope_sum = sum(self.axes_dims_rope)
113
  if rope_sum != self.attention_head_dim:
114
  raise ValueError(
@@ -158,11 +158,11 @@ class TinyFluxConfig:
158
  def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]:
159
  """
160
  Validate that a checkpoint matches this config.
161
-
162
  Returns list of warnings (empty if perfect match).
163
  """
164
  warnings = []
165
-
166
  # Check double block count
167
  max_double = 0
168
  for key in state_dict:
@@ -171,7 +171,7 @@ class TinyFluxConfig:
171
  max_double = max(max_double, idx + 1)
172
  if max_double != self.num_double_layers:
173
  warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}")
174
-
175
  # Check single block count
176
  max_single = 0
177
  for key in state_dict:
@@ -180,25 +180,25 @@ class TinyFluxConfig:
180
  max_single = max(max_single, idx + 1)
181
  if max_single != self.num_single_layers:
182
  warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}")
183
-
184
  # Check hidden size from a known weight
185
  if "img_in.weight" in state_dict:
186
  w = state_dict["img_in.weight"]
187
  if w.shape[0] != self.hidden_size:
188
  warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}")
189
-
190
  # Check for v4.1 components
191
  has_sol = any(k.startswith("sol_prior.") for k in state_dict)
192
  has_t5 = any(k.startswith("t5_pool.") for k in state_dict)
193
  has_lune = any(k.startswith("lune_predictor.") for k in state_dict)
194
-
195
  if self.use_sol_prior and not has_sol:
196
  warnings.append("config expects sol_prior but checkpoint missing it")
197
  if self.use_t5_vec and not has_t5:
198
  warnings.append("config expects t5_pool but checkpoint missing it")
199
  if self.use_lune_expert and not has_lune:
200
  warnings.append("config expects lune_predictor but checkpoint missing it")
201
-
202
  return warnings
203
 
204
 
@@ -1024,10 +1024,6 @@ class TinyFluxDeep(nn.Module):
1024
  else:
1025
  self.sol_prior = None
1026
 
1027
- # === Legacy support ===
1028
- # Map old expert_predictor API to lune_predictor
1029
- self.expert_predictor = self.lune_predictor
1030
-
1031
  # Legacy guidance
1032
  if cfg.guidance_embeds:
1033
  self.guidance_in = MLPEmbedder(cfg.hidden_size)
@@ -1060,6 +1056,11 @@ class TinyFluxDeep(nn.Module):
1060
  self.apply(_init)
1061
  nn.init.zeros_(self.final_linear.weight)
1062
 
 
 
 
 
 
1063
  def forward(
1064
  self,
1065
  hidden_states: torch.Tensor,
@@ -1118,7 +1119,7 @@ class TinyFluxDeep(nn.Module):
1118
  encoder_hidden_states = encoder_hidden_states.to(dtype=model_dtype)
1119
  pooled_projections = pooled_projections.to(dtype=model_dtype)
1120
  timestep = timestep.to(dtype=model_dtype)
1121
-
1122
  # Cast optional expert inputs if provided
1123
  if lune_features is not None:
1124
  lune_features = lune_features.to(dtype=model_dtype)
 
41
  class TinyFluxConfig:
42
  """
43
  Configuration for TinyFlux-Deep v4.1 model.
44
+
45
  This config fully defines the model architecture and can be used to:
46
  1. Initialize a new model
47
+ 2. Convert checkpoints between versions
48
  3. Validate checkpoint compatibility
49
+
50
  All dimension constraints are validated on creation.
51
  """
52
 
 
105
  f"num_attention_heads * attention_head_dim ({expected_hidden})"
106
  )
107
 
108
+ # Validate RoPE dimensions
109
  if isinstance(self.axes_dims_rope, list):
110
  self.axes_dims_rope = tuple(self.axes_dims_rope)
111
+
112
  rope_sum = sum(self.axes_dims_rope)
113
  if rope_sum != self.attention_head_dim:
114
  raise ValueError(
 
158
  def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]:
159
  """
160
  Validate that a checkpoint matches this config.
161
+
162
  Returns list of warnings (empty if perfect match).
163
  """
164
  warnings = []
165
+
166
  # Check double block count
167
  max_double = 0
168
  for key in state_dict:
 
171
  max_double = max(max_double, idx + 1)
172
  if max_double != self.num_double_layers:
173
  warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}")
174
+
175
  # Check single block count
176
  max_single = 0
177
  for key in state_dict:
 
180
  max_single = max(max_single, idx + 1)
181
  if max_single != self.num_single_layers:
182
  warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}")
183
+
184
  # Check hidden size from a known weight
185
  if "img_in.weight" in state_dict:
186
  w = state_dict["img_in.weight"]
187
  if w.shape[0] != self.hidden_size:
188
  warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}")
189
+
190
  # Check for v4.1 components
191
  has_sol = any(k.startswith("sol_prior.") for k in state_dict)
192
  has_t5 = any(k.startswith("t5_pool.") for k in state_dict)
193
  has_lune = any(k.startswith("lune_predictor.") for k in state_dict)
194
+
195
  if self.use_sol_prior and not has_sol:
196
  warnings.append("config expects sol_prior but checkpoint missing it")
197
  if self.use_t5_vec and not has_t5:
198
  warnings.append("config expects t5_pool but checkpoint missing it")
199
  if self.use_lune_expert and not has_lune:
200
  warnings.append("config expects lune_predictor but checkpoint missing it")
201
+
202
  return warnings
203
 
204
 
 
1024
  else:
1025
  self.sol_prior = None
1026
 
 
 
 
 
1027
  # Legacy guidance
1028
  if cfg.guidance_embeds:
1029
  self.guidance_in = MLPEmbedder(cfg.hidden_size)
 
1056
  self.apply(_init)
1057
  nn.init.zeros_(self.final_linear.weight)
1058
 
1059
+ @property
1060
+ def expert_predictor(self):
1061
+ """Legacy API: alias for lune_predictor."""
1062
+ return self.lune_predictor
1063
+
1064
  def forward(
1065
  self,
1066
  hidden_states: torch.Tensor,
 
1119
  encoder_hidden_states = encoder_hidden_states.to(dtype=model_dtype)
1120
  pooled_projections = pooled_projections.to(dtype=model_dtype)
1121
  timestep = timestep.to(dtype=model_dtype)
1122
+
1123
  # Cast optional expert inputs if provided
1124
  if lune_features is not None:
1125
  lune_features = lune_features.to(dtype=model_dtype)