BiliSakura commited on
Commit
343b8d8
·
verified ·
1 Parent(s): f53722e

Fix generator determinism: forward generator through scheduler steps and seeded noise

Browse files
DiT-MoE-B-8E2A/pipeline.py CHANGED
@@ -1,23 +1,15 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
 
15
  import importlib
16
  import inspect
17
  import json
18
  import sys
19
  from pathlib import Path
20
- from typing import Dict, List, Optional, Tuple, Union
21
 
22
  import torch
23
 
@@ -72,6 +64,20 @@ class DiTMoEPipeline(DiffusionPipeline):
72
  Each checkpoint keeps an English `id2label` map in `model_index.json` (DiT-style).
73
  """
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  model_cpu_offload_seq = "transformer->vae"
76
  _optional_components = ["vae"]
77
 
@@ -286,19 +292,6 @@ class DiTMoEPipeline(DiffusionPipeline):
286
  dtype=dtype,
287
  )
288
 
289
- @staticmethod
290
- def prepare_extra_step_kwargs(
291
- scheduler: KarrasDiffusionSchedulers,
292
- generator: Optional[Union[torch.Generator, List[torch.Generator]]],
293
- eta: float,
294
- ) -> Dict[str, object]:
295
- kwargs: Dict[str, object] = {}
296
- step_params = set(inspect.signature(scheduler.step).parameters.keys())
297
- if "eta" in step_params:
298
- kwargs["eta"] = eta
299
- if "generator" in step_params:
300
- kwargs["generator"] = generator
301
- return kwargs
302
 
303
  def _apply_cfg(self, model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
304
  if guidance_scale <= 1.0:
@@ -402,6 +395,7 @@ class DiTMoEPipeline(DiffusionPipeline):
402
  timestep_batch[:batch_size] if do_cfg else timestep_batch,
403
  latents_cfg,
404
  next_timestep=next_timestep,
 
405
  ).prev_sample
406
  latents = step_output if not do_cfg else torch.cat([step_output, step_output], dim=0)
407
  latents = latents[:batch_size]
 
1
+ """Hub custom pipeline: DiTMoEPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
 
 
 
 
 
 
 
 
6
 
7
  import importlib
8
  import inspect
9
  import json
10
  import sys
11
  from pathlib import Path
12
+ from typing import Dict, List, Optional, Tuple, Union, Any
13
 
14
  import torch
15
 
 
64
  Each checkpoint keeps an English `id2label` map in `model_index.json` (DiT-style).
65
  """
66
 
67
+ @staticmethod
68
+ def prepare_extra_step_kwargs(
69
+ scheduler,
70
+ generator=None,
71
+ eta: float | None = None,
72
+ ):
73
+ kwargs = {}
74
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
75
+ if "generator" in step_params:
76
+ kwargs["generator"] = generator
77
+ if eta is not None and "eta" in step_params:
78
+ kwargs["eta"] = eta
79
+ return kwargs
80
+
81
  model_cpu_offload_seq = "transformer->vae"
82
  _optional_components = ["vae"]
83
 
 
292
  dtype=dtype,
293
  )
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  def _apply_cfg(self, model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
297
  if guidance_scale <= 1.0:
 
395
  timestep_batch[:batch_size] if do_cfg else timestep_batch,
396
  latents_cfg,
397
  next_timestep=next_timestep,
398
+ **extra_step_kwargs,
399
  ).prev_sample
400
  latents = step_output if not do_cfg else torch.cat([step_output, step_output], dim=0)
401
  latents = latents[:batch_size]
DiT-MoE-S-8E2A/pipeline.py CHANGED
@@ -1,23 +1,15 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
 
15
  import importlib
16
  import inspect
17
  import json
18
  import sys
19
  from pathlib import Path
20
- from typing import Dict, List, Optional, Tuple, Union
21
 
22
  import torch
23
 
@@ -72,6 +64,20 @@ class DiTMoEPipeline(DiffusionPipeline):
72
  Each checkpoint keeps an English `id2label` map in `model_index.json` (DiT-style).
73
  """
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  model_cpu_offload_seq = "transformer->vae"
76
  _optional_components = ["vae"]
77
 
@@ -286,19 +292,6 @@ class DiTMoEPipeline(DiffusionPipeline):
286
  dtype=dtype,
287
  )
288
 
289
- @staticmethod
290
- def prepare_extra_step_kwargs(
291
- scheduler: KarrasDiffusionSchedulers,
292
- generator: Optional[Union[torch.Generator, List[torch.Generator]]],
293
- eta: float,
294
- ) -> Dict[str, object]:
295
- kwargs: Dict[str, object] = {}
296
- step_params = set(inspect.signature(scheduler.step).parameters.keys())
297
- if "eta" in step_params:
298
- kwargs["eta"] = eta
299
- if "generator" in step_params:
300
- kwargs["generator"] = generator
301
- return kwargs
302
 
303
  def _apply_cfg(self, model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
304
  if guidance_scale <= 1.0:
@@ -402,6 +395,7 @@ class DiTMoEPipeline(DiffusionPipeline):
402
  timestep_batch[:batch_size] if do_cfg else timestep_batch,
403
  latents_cfg,
404
  next_timestep=next_timestep,
 
405
  ).prev_sample
406
  latents = step_output if not do_cfg else torch.cat([step_output, step_output], dim=0)
407
  latents = latents[:batch_size]
 
1
+ """Hub custom pipeline: DiTMoEPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
 
 
 
 
 
 
 
 
6
 
7
  import importlib
8
  import inspect
9
  import json
10
  import sys
11
  from pathlib import Path
12
+ from typing import Dict, List, Optional, Tuple, Union, Any
13
 
14
  import torch
15
 
 
64
  Each checkpoint keeps an English `id2label` map in `model_index.json` (DiT-style).
65
  """
66
 
67
+ @staticmethod
68
+ def prepare_extra_step_kwargs(
69
+ scheduler,
70
+ generator=None,
71
+ eta: float | None = None,
72
+ ):
73
+ kwargs = {}
74
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
75
+ if "generator" in step_params:
76
+ kwargs["generator"] = generator
77
+ if eta is not None and "eta" in step_params:
78
+ kwargs["eta"] = eta
79
+ return kwargs
80
+
81
  model_cpu_offload_seq = "transformer->vae"
82
  _optional_components = ["vae"]
83
 
 
292
  dtype=dtype,
293
  )
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  def _apply_cfg(self, model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
297
  if guidance_scale <= 1.0:
 
395
  timestep_batch[:batch_size] if do_cfg else timestep_batch,
396
  latents_cfg,
397
  next_timestep=next_timestep,
398
+ **extra_step_kwargs,
399
  ).prev_sample
400
  latents = step_output if not do_cfg else torch.cat([step_output, step_output], dim=0)
401
  latents = latents[:batch_size]
DiT-MoE-XL-8E2A/pipeline.py CHANGED
@@ -1,23 +1,15 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
 
15
  import importlib
16
  import inspect
17
  import json
18
  import sys
19
  from pathlib import Path
20
- from typing import Dict, List, Optional, Tuple, Union
21
 
22
  import torch
23
 
@@ -72,6 +64,20 @@ class DiTMoEPipeline(DiffusionPipeline):
72
  Each checkpoint keeps an English `id2label` map in `model_index.json` (DiT-style).
73
  """
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  model_cpu_offload_seq = "transformer->vae"
76
  _optional_components = ["vae"]
77
 
@@ -133,12 +139,7 @@ class DiTMoEPipeline(DiffusionPipeline):
133
 
134
  id2label_override = kwargs.pop("id2label", None)
135
  null_class_id_override = kwargs.pop("null_class_id", None)
136
- use_flash_attn_override = kwargs.pop("use_flash_attn", None)
137
  model_kwargs = dict(kwargs)
138
- if use_flash_attn_override is not None:
139
- model_kwargs["use_flash_attn"] = use_flash_attn_override
140
- elif torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8:
141
- model_kwargs["use_flash_attn"] = False
142
  inserted: List[str] = []
143
 
144
  def _load_component(folder: str, module_name: str, class_name: str):
@@ -291,19 +292,6 @@ class DiTMoEPipeline(DiffusionPipeline):
291
  dtype=dtype,
292
  )
293
 
294
- @staticmethod
295
- def prepare_extra_step_kwargs(
296
- scheduler: KarrasDiffusionSchedulers,
297
- generator: Optional[Union[torch.Generator, List[torch.Generator]]],
298
- eta: float,
299
- ) -> Dict[str, object]:
300
- kwargs: Dict[str, object] = {}
301
- step_params = set(inspect.signature(scheduler.step).parameters.keys())
302
- if "eta" in step_params:
303
- kwargs["eta"] = eta
304
- if "generator" in step_params:
305
- kwargs["generator"] = generator
306
- return kwargs
307
 
308
  def _apply_cfg(self, model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
309
  if guidance_scale <= 1.0:
@@ -407,6 +395,7 @@ class DiTMoEPipeline(DiffusionPipeline):
407
  timestep_batch[:batch_size] if do_cfg else timestep_batch,
408
  latents_cfg,
409
  next_timestep=next_timestep,
 
410
  ).prev_sample
411
  latents = step_output if not do_cfg else torch.cat([step_output, step_output], dim=0)
412
  latents = latents[:batch_size]
 
1
+ """Hub custom pipeline: DiTMoEPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
 
 
 
 
 
 
 
 
6
 
7
  import importlib
8
  import inspect
9
  import json
10
  import sys
11
  from pathlib import Path
12
+ from typing import Dict, List, Optional, Tuple, Union, Any
13
 
14
  import torch
15
 
 
64
  Each checkpoint keeps an English `id2label` map in `model_index.json` (DiT-style).
65
  """
66
 
67
+ @staticmethod
68
+ def prepare_extra_step_kwargs(
69
+ scheduler,
70
+ generator=None,
71
+ eta: float | None = None,
72
+ ):
73
+ kwargs = {}
74
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
75
+ if "generator" in step_params:
76
+ kwargs["generator"] = generator
77
+ if eta is not None and "eta" in step_params:
78
+ kwargs["eta"] = eta
79
+ return kwargs
80
+
81
  model_cpu_offload_seq = "transformer->vae"
82
  _optional_components = ["vae"]
83
 
 
139
 
140
  id2label_override = kwargs.pop("id2label", None)
141
  null_class_id_override = kwargs.pop("null_class_id", None)
 
142
  model_kwargs = dict(kwargs)
 
 
 
 
143
  inserted: List[str] = []
144
 
145
  def _load_component(folder: str, module_name: str, class_name: str):
 
292
  dtype=dtype,
293
  )
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  def _apply_cfg(self, model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
297
  if guidance_scale <= 1.0:
 
395
  timestep_batch[:batch_size] if do_cfg else timestep_batch,
396
  latents_cfg,
397
  next_timestep=next_timestep,
398
+ **extra_step_kwargs,
399
  ).prev_sample
400
  latents = step_output if not do_cfg else torch.cat([step_output, step_output], dim=0)
401
  latents = latents[:batch_size]
README.md CHANGED
@@ -9,9 +9,6 @@ tags:
9
  - class-conditional
10
  - dit-moe
11
  pipeline_tag: unconditional-image-generation
12
- widget:
13
- - output:
14
- url: DiT-MoE-XL-8E2A/demo.png
15
  ---
16
 
17
  # DiT-MoE-diffusers
 
9
  - class-conditional
10
  - dit-moe
11
  pipeline_tag: unconditional-image-generation
 
 
 
12
  ---
13
 
14
  # DiT-MoE-diffusers