BiliSakura commited on
Commit
bd22846
·
verified ·
1 Parent(s): ed2880b

Fix generator determinism: forward generator through scheduler steps and seeded noise

Browse files
LightningDit-XL-1-256/model_index.json CHANGED
@@ -5,8 +5,8 @@
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
- "diffusers",
9
- "FlowMatchHeunDiscreteScheduler"
10
  ],
11
  "vae": [
12
  "diffusers",
 
5
  ],
6
  "_diffusers_version": "0.36.0",
7
  "scheduler": [
8
+ "scheduling_flow_match_lightningdit",
9
+ "LightningDiTFlowMatchScheduler"
10
  ],
11
  "vae": [
12
  "diffusers",
LightningDit-XL-1-256/pipeline.py CHANGED
@@ -1,131 +1,89 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Hub custom pipeline for class-conditional LightningDiT image generation.
16
-
17
- Load with native Hugging Face diffusers via ``DiffusionPipeline.from_pretrained`` and
18
- ``trust_remote_code=True``.
19
  """
20
 
21
  from __future__ import annotations
22
 
23
  import inspect
24
- from typing import Any, Dict, List, Optional, Tuple, Union
25
 
26
- import torch
 
 
 
27
 
28
- from diffusers.image_processor import VaeImageProcessor
29
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
30
- from diffusers.schedulers import KarrasDiffusionSchedulers
31
- from diffusers.utils import replace_example_docstring
32
- from diffusers.utils.torch_utils import randn_tensor
33
-
34
- EXAMPLE_DOC_STRING = """
35
- Examples:
36
- ```py
37
- >>> from pathlib import Path
38
- >>> import torch
39
- >>> from diffusers import DiffusionPipeline
40
-
41
- >>> model_dir = Path("BiliSakura/LightningDiT-diffusers/LightningDit-XL-1-256")
42
- >>> pipe = DiffusionPipeline.from_pretrained(
43
- ... str(model_dir),
44
- ... local_files_only=True,
45
- ... custom_pipeline=str(model_dir / "pipeline.py"),
46
- ... trust_remote_code=True,
47
- ... torch_dtype=torch.bfloat16,
48
- ... ).to("cuda")
49
-
50
- >>> class_id = pipe.get_label_ids("golden retriever")[0]
51
- >>> image = pipe(
52
- ... class_labels=class_id,
53
- ... num_inference_steps=250,
54
- ... guidance_scale=6.7,
55
- ... cfg_interval_start=0.125,
56
- ... generator=torch.Generator(device="cuda").manual_seed(0),
57
- ... ).images[0]
58
- ```
59
- """
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- def _uses_explicit_next_timestep_scheduler(scheduler: KarrasDiffusionSchedulers) -> bool:
63
- """True for LightningDiTFlowMatchScheduler (explicit t, t_next); False for built-in FlowMatch schedulers."""
64
- try:
65
- return "next_timestep" in inspect.signature(scheduler.step).parameters
66
- except (TypeError, ValueError):
67
- return False
68
 
 
 
 
69
 
70
  class LightningDiTPipeline(DiffusionPipeline):
71
  r"""
72
- Pipeline for class-conditional image generation with [LightningDiT](https://github.com/hustvl/LightningDiT).
73
-
74
- Uses VA-VAE latents and flow-matching velocity prediction. The bundled checkpoint defaults to
75
- [`FlowMatchHeunDiscreteScheduler`] with `shift=0.3` (2nd-order Heun). Flow time passed to the
76
- transformer is `1 - sigma` (`t=0` noise, `t=1` data). Latents are denormalized from VAE
77
- `latents_mean` / `latents_std` before decode.
78
-
79
- Recommended settings for `LightningDiT-XL/1` ImageNet-256 (800 epochs), matching official inference:
80
-
81
- - `num_inference_steps=250`
82
- - `guidance_scale=6.7`
83
- - `cfg_interval_start=0.125`
84
- - `cfg_channels=3`
85
- - `timestep_shift=0.3` (only when the scheduler supports `set_shift`; otherwise set `shift` in
86
- `scheduler/scheduler_config.json`)
87
-
88
- Parameters:
89
- transformer ([`LightningDiTTransformer2DModel`]):
90
- LightningDiT transformer predicting flow-matching velocity in latent space.
91
- scheduler ([`FlowMatchHeunDiscreteScheduler`]):
92
- Flow-matching scheduler. Other [`KarrasDiffusionSchedulers`] may be swapped at load time.
93
- vae ([`AutoencoderKL`]):
94
- VA-VAE used to decode latents to pixels.
95
- id2label (`dict[int, str]`, *optional*):
96
- ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
97
- """
98
 
99
- model_cpu_offload_seq = "transformer->vae"
 
 
100
 
101
- def __init__(
102
- self,
103
- transformer,
104
- vae,
105
  scheduler,
106
- id2label=None,
107
- null_class_id=None,
108
  ):
109
- super().__init__()
110
- self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler)
111
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
 
 
 
 
 
112
 
113
- if null_class_id is None:
114
- null_class_id = int(getattr(self.transformer.config, "num_classes", 1000))
115
- self.register_to_config(null_class_id=int(null_class_id))
116
 
 
 
 
 
117
  self._id2label = self._normalize_id2label(id2label)
118
  self.labels = self._build_label2id(self._id2label)
119
-
120
- @property
121
- def vae_scale_factor(self) -> int:
122
- block_out_channels = getattr(self.vae.config, "block_out_channels", None)
123
- if block_out_channels:
124
- return int(2 ** (len(block_out_channels) - 1))
125
- downsample_ratio = getattr(self.vae.config, "downsample_ratio", None)
126
- if downsample_ratio is not None:
127
- return int(downsample_ratio)
128
- return 16
129
 
130
  @staticmethod
131
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
@@ -133,6 +91,19 @@ class LightningDiTPipeline(DiffusionPipeline):
133
  return {}
134
  return {int(key): value for key, value in id2label.items()}
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  @staticmethod
137
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
138
  label2id: Dict[str, int] = {}
@@ -143,81 +114,77 @@ class LightningDiTPipeline(DiffusionPipeline):
143
  label2id[synonym] = int(class_id)
144
  return dict(sorted(label2id.items()))
145
 
 
 
 
 
 
 
 
 
 
146
  @property
147
  def id2label(self) -> Dict[int, str]:
 
148
  return self._id2label
149
 
150
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
151
- r"""Map English ImageNet labels to class ids."""
152
- labels = [label] if isinstance(label, str) else label
153
  if not self.labels:
154
- raise ValueError("No id2label mapping is available in this checkpoint.")
 
155
  missing = [item for item in labels if item not in self.labels]
156
  if missing:
157
  preview = ", ".join(list(self.labels.keys())[:8])
158
- raise ValueError(f"Unknown labels: {missing}. Example valid labels: {preview}, ...")
159
  return [self.labels[item] for item in labels]
160
 
161
  def _normalize_class_labels(
162
  self,
163
- class_labels: Union[int, str, List[Union[int, str]], torch.Tensor],
164
  ) -> torch.LongTensor:
165
  if isinstance(class_labels, torch.Tensor):
166
- return class_labels.to(device=self._execution_device, dtype=torch.long).reshape(-1)
167
  if isinstance(class_labels, int):
168
- class_label_ids = [class_labels]
169
  elif isinstance(class_labels, str):
170
- class_label_ids = self.get_label_ids(class_labels)
171
  elif class_labels and isinstance(class_labels[0], str):
172
- class_label_ids = self.get_label_ids(class_labels) # type: ignore[arg-type]
173
- else:
174
- class_label_ids = [int(class_id) for class_id in class_labels] # type: ignore[union-attr]
175
- return torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
176
 
177
- def _default_image_size(self) -> int:
178
- return int(self.transformer.config.input_size) * self.vae_scale_factor
179
-
180
- def check_inputs(
181
  self,
 
182
  height: int,
183
  width: int,
184
- num_inference_steps: int,
185
- output_type: str,
186
- ) -> None:
187
- if num_inference_steps < 1:
188
- raise ValueError("num_inference_steps must be >= 1.")
189
- if output_type not in {"pil", "np", "pt", "latent"}:
190
- raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
191
- if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
192
- raise ValueError(
193
- f"height and width must be divisible by the VAE downsample factor {self.vae_scale_factor}."
194
- )
195
- latent_height = height // self.vae_scale_factor
196
- latent_width = width // self.vae_scale_factor
197
- expected_size = int(self.transformer.config.input_size)
 
 
 
198
  patch_size = int(self.transformer.config.patch_size)
199
- if latent_height != expected_size or latent_width != expected_size:
200
- raise ValueError(
201
- f"Requested latent size {(latent_height, latent_width)} does not match transformer "
202
- f"input_size={expected_size}. Use height=width={self._default_image_size()}."
203
- )
204
  if latent_height % patch_size != 0 or latent_width % patch_size != 0:
205
- raise ValueError("Latent height and width must be divisible by transformer patch_size.")
206
-
207
- @staticmethod
208
- def prepare_extra_step_kwargs(
209
- scheduler: KarrasDiffusionSchedulers,
210
- generator: Optional[Union[torch.Generator, List[torch.Generator]]],
211
- ) -> Dict[str, Any]:
212
- extra_step_kwargs: Dict[str, Any] = {}
213
- if "generator" in inspect.signature(scheduler.step).parameters:
214
- extra_step_kwargs["generator"] = generator
215
- return extra_step_kwargs
216
 
217
- @staticmethod
218
- def _flow_time_from_sigma_timestep(timestep: torch.Tensor, num_train_timesteps: int) -> torch.Tensor:
219
- """Map FlowMatch scheduler timestep (sigma * N) to LightningDiT flow time in [0, 1]."""
220
- return 1.0 - timestep.to(dtype=torch.float32) / float(num_train_timesteps)
 
 
221
 
222
  @staticmethod
223
  def _apply_cfg(
@@ -230,10 +197,11 @@ class LightningDiTPipeline(DiffusionPipeline):
230
  return model_output
231
  eps, rest = model_output[:, :cfg_channels], model_output[:, cfg_channels:]
232
  cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
233
- guided_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
234
  if rest.numel() == 0:
235
- return guided_eps
236
- return torch.cat([guided_eps, rest[: cond_eps.shape[0]]], dim=1)
 
237
 
238
  def _resolve_latent_stats(
239
  self,
@@ -263,58 +231,30 @@ class LightningDiTPipeline(DiffusionPipeline):
263
  ) -> torch.Tensor:
264
  return (latents * latent_std) / latent_multiplier + latent_mean
265
 
266
- def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
267
- if output_type == "latent":
268
  return latents
 
269
  vae_dtype = next(self.vae.parameters()).dtype
270
  latents = latents.to(dtype=vae_dtype)
271
  scaling_factor = getattr(self.vae.config, "scaling_factor", None)
272
  if scaling_factor not in (None, 0):
273
  latents = latents / scaling_factor
274
- image = self.vae.decode(latents).sample
275
- if output_type == "pt":
276
- return image
277
- return self.image_processor.postprocess(image, output_type=output_type)
278
-
279
- def _configure_scheduler(self, num_inference_steps: int, device: torch.device, timestep_shift: float):
280
- if hasattr(self.scheduler, "set_shift"):
281
- self.scheduler.set_shift(float(timestep_shift))
282
- if _uses_explicit_next_timestep_scheduler(self.scheduler):
283
- return self.scheduler.set_timesteps(
284
- num_inference_steps,
285
- device=device,
286
- timestep_shift=float(timestep_shift),
287
- )
288
- if getattr(self.scheduler.config, "stochastic_sampling", False):
289
- raise ValueError(
290
- "LightningDiT expects deterministic FlowMatch scheduler stepping "
291
- "(scheduler.config.stochastic_sampling=False)."
292
- )
293
- self.scheduler.set_timesteps(num_inference_steps, device=device)
294
- return self.scheduler.timesteps
295
-
296
- def _guidance_active(
297
- self,
298
- flow_time: float,
299
- guidance_interval: Tuple[float, float],
300
- cfg_interval_start: float,
301
- ) -> bool:
302
- if flow_time < float(cfg_interval_start):
303
- return False
304
- return guidance_interval[0] <= flow_time <= guidance_interval[1]
305
 
306
  @torch.no_grad()
307
- @replace_example_docstring(EXAMPLE_DOC_STRING)
308
  def __call__(
309
  self,
310
- class_labels: Union[int, str, List[Union[int, str]], torch.Tensor],
311
- height: Optional[int] = None,
312
- width: Optional[int] = None,
313
  num_inference_steps: int = 250,
314
- guidance_scale: float = 6.7,
315
  guidance_interval: Tuple[float, float] = (0.0, 1.0),
316
  cfg_interval_start: float = 0.125,
317
- timestep_shift: Optional[float] = None,
 
318
  cfg_channels: int = 3,
319
  latent_mean: Optional[torch.Tensor] = None,
320
  latent_std: Optional[torch.Tensor] = None,
@@ -322,138 +262,69 @@ class LightningDiTPipeline(DiffusionPipeline):
322
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
323
  output_type: str = "pil",
324
  return_dict: bool = True,
325
- ) -> Union[ImagePipelineOutput, Tuple]:
326
- r"""
327
- Generate class-conditional images at the transformer's native latent resolution.
328
-
329
- Args:
330
- class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`):
331
- ImageNet class indices or human-readable English label strings (comma-separated synonyms
332
- in `id2label` are supported).
333
- height (`int`, *optional*):
334
- Output image height in pixels. Defaults to `input_size * vae_scale_factor` (256 for XL/1-256).
335
- width (`int`, *optional*):
336
- Output image width in pixels. Defaults to the same value as `height`.
337
- num_inference_steps (`int`, defaults to `250`):
338
- Number of flow-matching steps. With [`FlowMatchHeunDiscreteScheduler`], each step may use
339
- two model evaluations (2nd-order Heun).
340
- guidance_scale (`float`, defaults to `6.7`):
341
- Classifier-free guidance scale on the first `cfg_channels` latent channels. CFG is active when
342
- `guidance_scale > 1.0` and flow time is at least `cfg_interval_start`.
343
- guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
344
- Flow-time interval `[low, high]` where CFG is allowed (in addition to `cfg_interval_start`).
345
- cfg_interval_start (`float`, defaults to `0.125`):
346
- Minimum flow time before CFG is applied (official LightningDiT XL/1 setting).
347
- timestep_shift (`float`, *optional*):
348
- Timestep schedule shift. Defaults to `scheduler.config.shift`. Only applied at runtime if the
349
- scheduler implements `set_shift` (e.g. [`FlowMatchEulerDiscreteScheduler`]); for
350
- [`FlowMatchHeunDiscreteScheduler`], set `shift` in `scheduler_config.json` when loading.
351
- cfg_channels (`int`, defaults to `3`):
352
- Number of latent channels to apply CFG on.
353
- latent_mean (`torch.Tensor`, *optional*):
354
- Per-channel latent mean for denormalization before VAE decode. Read from the VAE config when omitted.
355
- latent_std (`torch.Tensor`, *optional*):
356
- Per-channel latent std for denormalization before VAE decode. Read from the VAE config when omitted.
357
- latent_multiplier (`float`, defaults to `1.0`):
358
- Divisor applied with `latent_std` during denormalization (`latents * std / multiplier + mean`).
359
- generator (`torch.Generator`, *optional*):
360
- RNG for reproducible noise initialization (and scheduler stochastic paths if enabled).
361
- output_type (`str`, defaults to `"pil"`):
362
- `"pil"`, `"np"`, `"pt"`, or `"latent"`.
363
- return_dict (`bool`, defaults to `True`):
364
- Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if `True`, else a `(images,)` tuple.
365
-
366
- Examples:
367
- <!-- this section is replaced by replace_example_docstring -->
368
-
369
- Returns:
370
- [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
371
- Generated images.
372
- """
373
- default_size = self._default_image_size()
374
- height = int(height or default_size)
375
- width = int(width or default_size)
376
- self.check_inputs(height, width, num_inference_steps, output_type)
377
-
378
  device = self._execution_device
379
  model_dtype = next(self.transformer.parameters()).dtype
380
- class_labels_tensor = self._normalize_class_labels(class_labels)
381
- batch_size = class_labels_tensor.numel()
382
- null_labels = torch.full_like(class_labels_tensor, int(self.config.null_class_id))
383
-
384
- if timestep_shift is None:
385
- timestep_shift = float(getattr(self.scheduler.config, "shift", 0.3))
386
-
387
- schedule = self._configure_scheduler(num_inference_steps, device, timestep_shift)
388
- num_train_timesteps = int(self.scheduler.config.num_train_timesteps)
389
- use_builtin_flow_match = not _uses_explicit_next_timestep_scheduler(self.scheduler)
390
- extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator) if use_builtin_flow_match else {}
391
-
392
- latents = randn_tensor(
393
- (
394
- batch_size,
395
- int(self.transformer.config.in_channels),
396
- height // self.vae_scale_factor,
397
- width // self.vae_scale_factor,
398
- ),
399
- generator=generator,
400
- device=device,
401
- dtype=model_dtype,
402
- )
403
-
404
- if use_builtin_flow_match:
405
- for timestep in self.progress_bar(schedule):
406
- flow_time = float(self._flow_time_from_sigma_timestep(timestep, num_train_timesteps))
407
- guidance_active = self._guidance_active(flow_time, guidance_interval, cfg_interval_start)
408
- do_cfg = guidance_scale > 1.0 and guidance_active
409
-
410
- if do_cfg:
411
- model_input = torch.cat([latents, latents], dim=0)
412
- labels = torch.cat([class_labels_tensor, null_labels], dim=0)
413
- else:
414
- model_input = latents
415
- labels = class_labels_tensor
416
-
417
- flow_time_batch = torch.full((labels.shape[0],), flow_time, device=device, dtype=model_dtype)
418
- velocity = self.transformer(
419
- hidden_states=model_input,
420
- timestep=flow_time_batch,
421
- class_labels=labels,
422
- return_dict=True,
423
- ).sample
424
- velocity = self._apply_cfg(velocity, guidance_scale, guidance_active, cfg_channels)
425
 
426
- # FlowMatchEuler/Heun: integrate in sigma space; model expects -velocity
427
- latents = self.scheduler.step(
428
- -velocity,
429
- timestep,
430
- latents,
431
- **extra_step_kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  ).prev_sample
433
- else:
434
- for index, timestep in enumerate(self.progress_bar(schedule[:-1])):
435
- next_timestep = schedule[index + 1]
436
- flow_time = float(timestep)
437
- guidance_active = self._guidance_active(flow_time, guidance_interval, cfg_interval_start)
438
-
439
  if guidance_scale > 1.0 and guidance_active:
440
- model_input = torch.cat([latents, latents], dim=0)
441
- labels = torch.cat([class_labels_tensor, null_labels], dim=0)
442
  else:
443
- model_input = latents
444
- labels = class_labels_tensor
445
-
446
- flow_time_batch = torch.full((labels.shape[0],), flow_time, device=device, dtype=model_dtype)
447
- velocity = self.transformer(
448
- hidden_states=model_input,
449
- timestep=flow_time_batch,
450
- class_labels=labels,
 
451
  return_dict=True,
452
  ).sample
453
- velocity = self._apply_cfg(velocity, guidance_scale, guidance_active, cfg_channels)
454
-
 
 
 
 
 
455
  latents = self.scheduler.step(
456
- velocity, timestep[None], latents, next_timestep[None]
457
  ).prev_sample
458
 
459
  latent_mean, latent_std = self._resolve_latent_stats(
@@ -466,12 +337,12 @@ class LightningDiTPipeline(DiffusionPipeline):
466
  )
467
  latents = self._denormalize_latents(latents, latent_mean, latent_std, latent_multiplier)
468
 
469
- image = self.decode_latents(latents, output_type=output_type)
470
- self.maybe_free_model_hooks()
 
 
471
 
 
472
  if not return_dict:
473
  return (image,)
474
- return ImagePipelineOutput(images=image)
475
-
476
-
477
- __all__ = ["LightningDiTPipeline"]
 
1
+ """Hub custom pipeline: LightningDiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  """
4
 
5
  from __future__ import annotations
6
 
7
  import inspect
 
8
 
9
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
 
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional, Tuple, Union, Any
17
+
18
+ import json
19
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ try:
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
24
+ from diffusers.utils import BaseOutput
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ except Exception: # pragma: no cover
27
+ class BaseOutput(dict):
28
+ def __post_init__(self):
29
+ self.update(self.__dict__)
30
+
31
+ class DiffusionPipeline:
32
+ def register_modules(self, **kwargs):
33
+ for name, module in kwargs.items():
34
+ setattr(self, name, module)
35
+
36
+ @property
37
+ def _execution_device(self):
38
+ return torch.device("cpu")
39
+
40
+ def maybe_free_model_hooks(self):
41
+ pass
42
+
43
+ class VaeImageProcessor:
44
+ def postprocess(self, image, output_type="pil"):
45
+ return image
46
 
47
+ def randn_tensor(shape, generator=None, device=None, dtype=None):
48
+ return torch.randn(shape, generator=generator, device=device, dtype=dtype)
 
 
 
 
49
 
50
+ @dataclass
51
+ class LightningDiTPipelineOutput(BaseOutput):
52
+ images: Union[torch.FloatTensor, List]
53
 
54
  class LightningDiTPipeline(DiffusionPipeline):
55
  r"""
56
+ Class-conditional image generation with LightningDiT and a flow-matching scheduler.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ Components are stored in separate subfolders (`transformer`, `scheduler`, optional `vae`) for
59
+ `DiffusionPipeline.from_pretrained` compatibility.
60
+ """
61
 
62
+ @staticmethod
63
+ def prepare_extra_step_kwargs(
 
 
64
  scheduler,
65
+ generator=None,
66
+ eta: float | None = None,
67
  ):
68
+ kwargs = {}
69
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
70
+ if "generator" in step_params:
71
+ kwargs["generator"] = generator
72
+ if eta is not None and "eta" in step_params:
73
+ kwargs["eta"] = eta
74
+ return kwargs
75
+
76
 
77
+ model_cpu_offload_seq = "transformer->vae"
78
+ _optional_components = ["vae"]
 
79
 
80
+ def __init__(self, transformer, scheduler, vae=None, id2label: Optional[Dict[Union[int, str], str]] = None):
81
+ super().__init__()
82
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
83
+ self.image_processor = VaeImageProcessor()
84
  self._id2label = self._normalize_id2label(id2label)
85
  self.labels = self._build_label2id(self._id2label)
86
+ self._labels_loaded_from_model_index = bool(self._id2label)
 
 
 
 
 
 
 
 
 
87
 
88
  @staticmethod
89
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
 
91
  return {}
92
  return {int(key): value for key, value in id2label.items()}
93
 
94
+ @staticmethod
95
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
96
+ if not variant_path:
97
+ return {}
98
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
99
+ if not model_index_path.exists():
100
+ return {}
101
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
102
+ id2label = raw.get("id2label")
103
+ if not isinstance(id2label, dict):
104
+ return {}
105
+ return {int(key): value for key, value in id2label.items()}
106
+
107
  @staticmethod
108
  def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
109
  label2id: Dict[str, int] = {}
 
114
  label2id[synonym] = int(class_id)
115
  return dict(sorted(label2id.items()))
116
 
117
+ def _ensure_labels_loaded(self) -> None:
118
+ if self._labels_loaded_from_model_index:
119
+ return
120
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
121
+ if loaded:
122
+ self._id2label = loaded
123
+ self.labels = self._build_label2id(self._id2label)
124
+ self._labels_loaded_from_model_index = True
125
+
126
  @property
127
  def id2label(self) -> Dict[int, str]:
128
+ self._ensure_labels_loaded()
129
  return self._id2label
130
 
131
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
132
+ self._ensure_labels_loaded()
 
133
  if not self.labels:
134
+ raise ValueError("No labels loaded. Ensure `id2label` exists in model_index.json.")
135
+ labels = [label] if isinstance(label, str) else label
136
  missing = [item for item in labels if item not in self.labels]
137
  if missing:
138
  preview = ", ".join(list(self.labels.keys())[:8])
139
+ raise ValueError(f"Unknown label(s): {missing}. Example valid labels: {preview}, ...")
140
  return [self.labels[item] for item in labels]
141
 
142
  def _normalize_class_labels(
143
  self,
144
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
145
  ) -> torch.LongTensor:
146
  if isinstance(class_labels, torch.Tensor):
147
+ return class_labels.to(dtype=torch.long).reshape(-1)
148
  if isinstance(class_labels, int):
149
+ class_labels = [class_labels]
150
  elif isinstance(class_labels, str):
151
+ class_labels = self.get_label_ids(class_labels)
152
  elif class_labels and isinstance(class_labels[0], str):
153
+ class_labels = self.get_label_ids(class_labels) # type: ignore[arg-type]
154
+ return torch.tensor(class_labels, dtype=torch.long).reshape(-1)
 
 
155
 
156
+ def _prepare_latents(
 
 
 
157
  self,
158
+ batch_size: int,
159
  height: int,
160
  width: int,
161
+ dtype: torch.dtype,
162
+ device: torch.device,
163
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
164
+ ) -> torch.Tensor:
165
+ downsample = 16
166
+ if self.vae is not None:
167
+ block_out = getattr(self.vae.config, "block_out_channels", None)
168
+ if block_out is not None:
169
+ downsample = 2 ** (len(block_out) - 1)
170
+ elif hasattr(self.vae.config, "downsample_ratio"):
171
+ downsample = int(self.vae.config.downsample_ratio)
172
+
173
+ if height % downsample != 0 or width % downsample != 0:
174
+ raise ValueError(f"height and width must be divisible by the VAE downsample factor {downsample}.")
175
+
176
+ latent_height = height // downsample
177
+ latent_width = width // downsample
178
  patch_size = int(self.transformer.config.patch_size)
 
 
 
 
 
179
  if latent_height % patch_size != 0 or latent_width % patch_size != 0:
180
+ raise ValueError("Latent height and width must be divisible by the transformer patch_size.")
 
 
 
 
 
 
 
 
 
 
181
 
182
+ return randn_tensor(
183
+ (batch_size, self.transformer.config.in_channels, latent_height, latent_width),
184
+ generator=generator,
185
+ device=device,
186
+ dtype=dtype,
187
+ )
188
 
189
  @staticmethod
190
  def _apply_cfg(
 
197
  return model_output
198
  eps, rest = model_output[:, :cfg_channels], model_output[:, cfg_channels:]
199
  cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
200
+ half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
201
  if rest.numel() == 0:
202
+ return half_eps
203
+ cond_rest, _ = torch.chunk(rest, 2, dim=0)
204
+ return torch.cat([half_eps, cond_rest], dim=1)
205
 
206
  def _resolve_latent_stats(
207
  self,
 
231
  ) -> torch.Tensor:
232
  return (latents * latent_std) / latent_multiplier + latent_mean
233
 
234
+ def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
235
+ if self.vae is None:
236
  return latents
237
+
238
  vae_dtype = next(self.vae.parameters()).dtype
239
  latents = latents.to(dtype=vae_dtype)
240
  scaling_factor = getattr(self.vae.config, "scaling_factor", None)
241
  if scaling_factor not in (None, 0):
242
  latents = latents / scaling_factor
243
+ image = self.vae.decode(latents)
244
+ return image.sample if hasattr(image, "sample") else image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  @torch.no_grad()
 
247
  def __call__(
248
  self,
249
+ class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
250
+ height: int = 256,
251
+ width: int = 256,
252
  num_inference_steps: int = 250,
253
+ guidance_scale: float = 1.0,
254
  guidance_interval: Tuple[float, float] = (0.0, 1.0),
255
  cfg_interval_start: float = 0.125,
256
+ timestep_shift: float = 0.3,
257
+ heun: bool = False,
258
  cfg_channels: int = 3,
259
  latent_mean: Optional[torch.Tensor] = None,
260
  latent_std: Optional[torch.Tensor] = None,
 
262
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
263
  output_type: str = "pil",
264
  return_dict: bool = True,
265
+ ) -> Union[LightningDiTPipelineOutput, Tuple]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  device = self._execution_device
267
  model_dtype = next(self.transformer.parameters()).dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
+ class_labels = self._normalize_class_labels(class_labels).to(device=device)
270
+ batch_size = class_labels.numel()
271
+
272
+ latents = self._prepare_latents(batch_size, height, width, model_dtype, device, generator)
273
+ timesteps = self.scheduler.set_timesteps(num_inference_steps, device=device, timestep_shift=timestep_shift)
274
+
275
+ extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
276
+
277
+ null_labels = torch.full_like(class_labels, self.transformer.config.num_classes)
278
+ for index, timestep in enumerate(timesteps[:-1]):
279
+ next_timestep = timesteps[index + 1]
280
+ guidance_active = guidance_interval[0] <= float(timestep) <= guidance_interval[1]
281
+ if cfg_interval_start is not None and float(timestep) < cfg_interval_start:
282
+ guidance_active = False
283
+
284
+ if guidance_scale > 1.0 and guidance_active:
285
+ model_input = torch.cat([latents, latents], dim=0)
286
+ labels = torch.cat([class_labels, null_labels], dim=0)
287
+ else:
288
+ model_input = latents
289
+ labels = class_labels
290
+
291
+ timestep_batch = torch.full((labels.shape[0],), float(timestep), device=device, dtype=model_dtype)
292
+ model_output = self.transformer(
293
+ model_input.to(dtype=model_dtype),
294
+ timestep_batch,
295
+ labels,
296
+ return_dict=True,
297
+ ).sample
298
+ model_output = self._apply_cfg(model_output, guidance_scale, guidance_active, cfg_channels)
299
+
300
+ if heun and index < len(timesteps) - 2:
301
+ provisional = self.scheduler.step(
302
+ model_output, timestep[None], latents, next_timestep[None], **extra_step_kwargs
303
  ).prev_sample
 
 
 
 
 
 
304
  if guidance_scale > 1.0 and guidance_active:
305
+ prime_input = torch.cat([provisional, provisional], dim=0)
306
+ prime_labels = torch.cat([class_labels, null_labels], dim=0)
307
  else:
308
+ prime_input = provisional
309
+ prime_labels = class_labels
310
+ next_timestep_batch = torch.full(
311
+ (prime_labels.shape[0],), float(next_timestep), device=device, dtype=model_dtype
312
+ )
313
+ next_model_output = self.transformer(
314
+ prime_input.to(dtype=model_dtype),
315
+ next_timestep_batch,
316
+ prime_labels,
317
  return_dict=True,
318
  ).sample
319
+ next_model_output = self._apply_cfg(
320
+ next_model_output, guidance_scale, guidance_active, cfg_channels
321
+ )
322
+ latents = self.scheduler.step_heun(
323
+ model_output, next_model_output, timestep[None], latents, next_timestep[None]
324
+ ).prev_sample
325
+ else:
326
  latents = self.scheduler.step(
327
+ model_output, timestep[None], latents, next_timestep[None], **extra_step_kwargs
328
  ).prev_sample
329
 
330
  latent_mean, latent_std = self._resolve_latent_stats(
 
337
  )
338
  latents = self._denormalize_latents(latents, latent_mean, latent_std, latent_multiplier)
339
 
340
+ image = self._decode_latents(latents)
341
+ if self.vae is not None:
342
+ image = (image / 2 + 0.5).clamp(0, 1)
343
+ image = self.image_processor.postprocess(image, output_type=output_type)
344
 
345
+ self.maybe_free_model_hooks()
346
  if not return_dict:
347
  return (image,)
348
+ return LightningDiTPipelineOutput(images=image)
 
 
 
LightningDit-XL-1-256/scheduler/scheduler_config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_class_name": "FlowMatchHeunDiscreteScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
  "shift": 0.3
 
1
  {
2
+ "_class_name": "LightningDiTFlowMatchScheduler",
3
  "_diffusers_version": "0.36.0",
4
  "num_train_timesteps": 1000,
5
  "shift": 0.3
LightningDit-XL-1-256/scheduler/scheduling_flow_match_lightningdit.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+
11
+ try:
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
14
+ from diffusers.utils import BaseOutput
15
+ except Exception: # pragma: no cover
16
+ class BaseOutput(dict):
17
+ def __post_init__(self):
18
+ self.update(self.__dict__)
19
+
20
+ class ConfigMixin:
21
+ config_name = "scheduler_config.json"
22
+
23
+ class SchedulerMixin:
24
+ pass
25
+
26
+ def register_to_config(init):
27
+ return init
28
+
29
+
30
+ @dataclass
31
+ class LightningDiTFlowMatchSchedulerOutput(BaseOutput):
32
+ prev_sample: torch.FloatTensor
33
+
34
+
35
+ class LightningDiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
36
+ """
37
+ Flow-matching ODE scheduler for LightningDiT (linear path, velocity prediction).
38
+
39
+ Integrates from t=0 (noise) to t=1 (data) with optional timestep shifting used in LightningDiT sampling.
40
+ """
41
+
42
+ config_name = "scheduler_config.json"
43
+ order = 1
44
+
45
+ @register_to_config
46
+ def __init__(self, path_type: str = "linear", num_train_timesteps: int = 1000):
47
+ if path_type not in {"linear", "cosine"}:
48
+ raise ValueError("path_type must be either 'linear' or 'cosine'.")
49
+ self.path_type = path_type
50
+ self.num_train_timesteps = num_train_timesteps
51
+ self.timesteps = torch.linspace(0.0, 1.0, num_train_timesteps + 1, dtype=torch.float64)
52
+
53
+ @staticmethod
54
+ def _apply_timestep_shift(timesteps: torch.Tensor, timestep_shift: float) -> torch.Tensor:
55
+ if timestep_shift <= 0:
56
+ return timesteps
57
+ return timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps)
58
+
59
+ def set_timesteps(
60
+ self,
61
+ num_inference_steps: int,
62
+ device: Optional[torch.device] = None,
63
+ timestep_shift: float = 0.0,
64
+ ):
65
+ timesteps = torch.linspace(0.0, 1.0, num_inference_steps + 1, dtype=torch.float64)
66
+ timesteps = self._apply_timestep_shift(timesteps, timestep_shift)
67
+ self.timesteps = timesteps.to(device=device)
68
+ return self.timesteps
69
+
70
+ def step(
71
+ self,
72
+ model_output: torch.Tensor,
73
+ timestep: torch.Tensor,
74
+ sample: torch.Tensor,
75
+ next_timestep: torch.Tensor,
76
+ return_dict: bool = True,
77
+ ) -> LightningDiTFlowMatchSchedulerOutput:
78
+ sample_dtype = sample.dtype
79
+ sample = sample.to(dtype=torch.float64)
80
+ model_output = model_output.to(dtype=torch.float64)
81
+ timestep = timestep.to(device=sample.device, dtype=torch.float64).flatten()
82
+ next_timestep = next_timestep.to(device=sample.device, dtype=torch.float64).flatten()
83
+ prev_sample = sample + (next_timestep[0] - timestep[0]) * model_output
84
+ prev_sample = prev_sample.to(sample_dtype)
85
+ if not return_dict:
86
+ return (prev_sample,)
87
+ return LightningDiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
88
+
89
+ def step_heun(
90
+ self,
91
+ model_output: torch.Tensor,
92
+ next_model_output: torch.Tensor,
93
+ timestep: torch.Tensor,
94
+ sample: torch.Tensor,
95
+ next_timestep: torch.Tensor,
96
+ return_dict: bool = True,
97
+ ) -> LightningDiTFlowMatchSchedulerOutput:
98
+ sample_dtype = sample.dtype
99
+ sample = sample.to(dtype=torch.float64)
100
+ model_output = model_output.to(dtype=torch.float64)
101
+ next_model_output = next_model_output.to(dtype=torch.float64)
102
+ timestep = timestep.to(device=sample.device, dtype=torch.float64).flatten()
103
+ next_timestep = next_timestep.to(device=sample.device, dtype=torch.float64).flatten()
104
+ prev_sample = sample + (next_timestep[0] - timestep[0]) * (0.5 * model_output + 0.5 * next_model_output)
105
+ prev_sample = prev_sample.to(sample_dtype)
106
+ if not return_dict:
107
+ return (prev_sample,)
108
+ return LightningDiTFlowMatchSchedulerOutput(prev_sample=prev_sample)