akrao9 commited on
Commit
098686c
Β·
verified Β·
1 Parent(s): 659ddad

Add Boomer FLA checkpoint (step 055000, ema weights, bf16)

Browse files
modeling_boomer_fla.py CHANGED
@@ -1261,7 +1261,8 @@ class BoomerFLADiT(nn.Module):
1261
  sd = load_file(str(path / "diffusion_pytorch_model.safetensors"))
1262
  model.load_state_dict(sd, strict=False)
1263
 
1264
- # Attach inference metadata (latent stats, component repos, etc.)
1265
- # so BoomerPipeline.__init__ can read them without a separate config file.
1266
  model._boomer_cfg = {k: v for k, v in cfg_raw.items() if k.startswith("_")}
 
 
1267
  return model
 
1261
  sd = load_file(str(path / "diffusion_pytorch_model.safetensors"))
1262
  model.load_state_dict(sd, strict=False)
1263
 
1264
+ # Attach inference metadata so BoomerPipeline.__init__ can read without a separate file.
 
1265
  model._boomer_cfg = {k: v for k, v in cfg_raw.items() if k.startswith("_")}
1266
+ # Store snapshot root so the pipeline can add it to sys.path for sibling imports.
1267
+ model._snapshot_dir = str(path.parent if subfolder else path)
1268
  return model
pipeline_boomer.py CHANGED
@@ -112,6 +112,8 @@ class BoomerPipeline(DiffusionPipeline):
112
  self._vae_repo = cfg.get("_vae_repo", "mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
113
  self._te_repo = cfg.get("_te_repo", "google/gemma-4-E2B-it")
114
  self._hf_token = None
 
 
115
 
116
  # ── lazy component loading ─────────────────────────────────────────────────
117
  def _ensure_vae(self):
@@ -164,6 +166,7 @@ class BoomerPipeline(DiffusionPipeline):
164
 
165
  transformer = BoomerFLADiT.from_pretrained(str(local), subfolder="transformer")
166
  transformer = transformer.to(dtype=dtype)
 
167
 
168
  pipe = cls(transformer=transformer)
169
  pipe._hf_token = token
@@ -193,6 +196,10 @@ class BoomerPipeline(DiffusionPipeline):
193
  substeps : STORK-2 internal RK micro-steps
194
  offload_text_encoder : unload text encoder after encoding to free VRAM
195
  """
 
 
 
 
196
  from scheduling_boomer_stork import make_stork_scheduler # noqa: PLC0415
197
 
198
  prompts = [prompt] if isinstance(prompt, str) else list(prompt)
 
112
  self._vae_repo = cfg.get("_vae_repo", "mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
113
  self._te_repo = cfg.get("_te_repo", "google/gemma-4-E2B-it")
114
  self._hf_token = None
115
+ # Snapshot root β€” added to sys.path so sibling .py files (scheduler, etc.) are importable
116
+ self._snapshot_dir = getattr(transformer, "_snapshot_dir", None)
117
 
118
  # ── lazy component loading ─────────────────────────────────────────────────
119
  def _ensure_vae(self):
 
166
 
167
  transformer = BoomerFLADiT.from_pretrained(str(local), subfolder="transformer")
168
  transformer = transformer.to(dtype=dtype)
169
+ transformer._snapshot_dir = str(local) # carry snapshot path for sibling imports
170
 
171
  pipe = cls(transformer=transformer)
172
  pipe._hf_token = token
 
196
  substeps : STORK-2 internal RK micro-steps
197
  offload_text_encoder : unload text encoder after encoding to free VRAM
198
  """
199
+ # Add snapshot dir to sys.path so scheduling_boomer_stork (and STORKScheduler)
200
+ # are findable β€” diffusers only caches pipeline_boomer.py itself, not sibling files.
201
+ if self._snapshot_dir and self._snapshot_dir not in sys.path:
202
+ sys.path.insert(0, self._snapshot_dir)
203
  from scheduling_boomer_stork import make_stork_scheduler # noqa: PLC0415
204
 
205
  prompts = [prompt] if isinstance(prompt, str) else list(prompt)
transformer/modeling_boomer_fla.py CHANGED
@@ -1261,7 +1261,8 @@ class BoomerFLADiT(nn.Module):
1261
  sd = load_file(str(path / "diffusion_pytorch_model.safetensors"))
1262
  model.load_state_dict(sd, strict=False)
1263
 
1264
- # Attach inference metadata (latent stats, component repos, etc.)
1265
- # so BoomerPipeline.__init__ can read them without a separate config file.
1266
  model._boomer_cfg = {k: v for k, v in cfg_raw.items() if k.startswith("_")}
 
 
1267
  return model
 
1261
  sd = load_file(str(path / "diffusion_pytorch_model.safetensors"))
1262
  model.load_state_dict(sd, strict=False)
1263
 
1264
+ # Attach inference metadata so BoomerPipeline.__init__ can read without a separate file.
 
1265
  model._boomer_cfg = {k: v for k, v in cfg_raw.items() if k.startswith("_")}
1266
+ # Store snapshot root so the pipeline can add it to sys.path for sibling imports.
1267
+ model._snapshot_dir = str(path.parent if subfolder else path)
1268
  return model