Unconditional Image Generation
Diffusers
Safetensors
English
image-generation
class-conditional
dit-moe
Instructions to use BiliSakura/DiT-MoE-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/DiT-MoE-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/DiT-MoE-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Fix generator determinism: forward generator through scheduler steps and seeded noise
Browse files- DiT-MoE-B-8E2A/pipeline.py +21 -27
- DiT-MoE-S-8E2A/pipeline.py +21 -27
- DiT-MoE-XL-8E2A/pipeline.py +21 -32
- README.md +0 -3
DiT-MoE-B-8E2A/pipeline.py
CHANGED
|
@@ -1,23 +1,15 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
|
| 15 |
import importlib
|
| 16 |
import inspect
|
| 17 |
import json
|
| 18 |
import sys
|
| 19 |
from pathlib import Path
|
| 20 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 21 |
|
| 22 |
import torch
|
| 23 |
|
|
@@ -72,6 +64,20 @@ class DiTMoEPipeline(DiffusionPipeline):
|
|
| 72 |
Each checkpoint keeps an English `id2label` map in `model_index.json` (DiT-style).
|
| 73 |
"""
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
model_cpu_offload_seq = "transformer->vae"
|
| 76 |
_optional_components = ["vae"]
|
| 77 |
|
|
@@ -286,19 +292,6 @@ class DiTMoEPipeline(DiffusionPipeline):
|
|
| 286 |
dtype=dtype,
|
| 287 |
)
|
| 288 |
|
| 289 |
-
@staticmethod
|
| 290 |
-
def prepare_extra_step_kwargs(
|
| 291 |
-
scheduler: KarrasDiffusionSchedulers,
|
| 292 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
|
| 293 |
-
eta: float,
|
| 294 |
-
) -> Dict[str, object]:
|
| 295 |
-
kwargs: Dict[str, object] = {}
|
| 296 |
-
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 297 |
-
if "eta" in step_params:
|
| 298 |
-
kwargs["eta"] = eta
|
| 299 |
-
if "generator" in step_params:
|
| 300 |
-
kwargs["generator"] = generator
|
| 301 |
-
return kwargs
|
| 302 |
|
| 303 |
def _apply_cfg(self, model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
|
| 304 |
if guidance_scale <= 1.0:
|
|
@@ -402,6 +395,7 @@ class DiTMoEPipeline(DiffusionPipeline):
|
|
| 402 |
timestep_batch[:batch_size] if do_cfg else timestep_batch,
|
| 403 |
latents_cfg,
|
| 404 |
next_timestep=next_timestep,
|
|
|
|
| 405 |
).prev_sample
|
| 406 |
latents = step_output if not do_cfg else torch.cat([step_output, step_output], dim=0)
|
| 407 |
latents = latents[:batch_size]
|
|
|
|
| 1 |
+
"""Hub custom pipeline: DiTMoEPipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import importlib
|
| 8 |
import inspect
|
| 9 |
import json
|
| 10 |
import sys
|
| 11 |
from pathlib import Path
|
| 12 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
| 13 |
|
| 14 |
import torch
|
| 15 |
|
|
|
|
| 64 |
Each checkpoint keeps an English `id2label` map in `model_index.json` (DiT-style).
|
| 65 |
"""
|
| 66 |
|
| 67 |
+
@staticmethod
|
| 68 |
+
def prepare_extra_step_kwargs(
|
| 69 |
+
scheduler,
|
| 70 |
+
generator=None,
|
| 71 |
+
eta: float | None = None,
|
| 72 |
+
):
|
| 73 |
+
kwargs = {}
|
| 74 |
+
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 75 |
+
if "generator" in step_params:
|
| 76 |
+
kwargs["generator"] = generator
|
| 77 |
+
if eta is not None and "eta" in step_params:
|
| 78 |
+
kwargs["eta"] = eta
|
| 79 |
+
return kwargs
|
| 80 |
+
|
| 81 |
model_cpu_offload_seq = "transformer->vae"
|
| 82 |
_optional_components = ["vae"]
|
| 83 |
|
|
|
|
| 292 |
dtype=dtype,
|
| 293 |
)
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
def _apply_cfg(self, model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
|
| 297 |
if guidance_scale <= 1.0:
|
|
|
|
| 395 |
timestep_batch[:batch_size] if do_cfg else timestep_batch,
|
| 396 |
latents_cfg,
|
| 397 |
next_timestep=next_timestep,
|
| 398 |
+
**extra_step_kwargs,
|
| 399 |
).prev_sample
|
| 400 |
latents = step_output if not do_cfg else torch.cat([step_output, step_output], dim=0)
|
| 401 |
latents = latents[:batch_size]
|
DiT-MoE-S-8E2A/pipeline.py
CHANGED
|
@@ -1,23 +1,15 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
|
| 15 |
import importlib
|
| 16 |
import inspect
|
| 17 |
import json
|
| 18 |
import sys
|
| 19 |
from pathlib import Path
|
| 20 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 21 |
|
| 22 |
import torch
|
| 23 |
|
|
@@ -72,6 +64,20 @@ class DiTMoEPipeline(DiffusionPipeline):
|
|
| 72 |
Each checkpoint keeps an English `id2label` map in `model_index.json` (DiT-style).
|
| 73 |
"""
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
model_cpu_offload_seq = "transformer->vae"
|
| 76 |
_optional_components = ["vae"]
|
| 77 |
|
|
@@ -286,19 +292,6 @@ class DiTMoEPipeline(DiffusionPipeline):
|
|
| 286 |
dtype=dtype,
|
| 287 |
)
|
| 288 |
|
| 289 |
-
@staticmethod
|
| 290 |
-
def prepare_extra_step_kwargs(
|
| 291 |
-
scheduler: KarrasDiffusionSchedulers,
|
| 292 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
|
| 293 |
-
eta: float,
|
| 294 |
-
) -> Dict[str, object]:
|
| 295 |
-
kwargs: Dict[str, object] = {}
|
| 296 |
-
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 297 |
-
if "eta" in step_params:
|
| 298 |
-
kwargs["eta"] = eta
|
| 299 |
-
if "generator" in step_params:
|
| 300 |
-
kwargs["generator"] = generator
|
| 301 |
-
return kwargs
|
| 302 |
|
| 303 |
def _apply_cfg(self, model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
|
| 304 |
if guidance_scale <= 1.0:
|
|
@@ -402,6 +395,7 @@ class DiTMoEPipeline(DiffusionPipeline):
|
|
| 402 |
timestep_batch[:batch_size] if do_cfg else timestep_batch,
|
| 403 |
latents_cfg,
|
| 404 |
next_timestep=next_timestep,
|
|
|
|
| 405 |
).prev_sample
|
| 406 |
latents = step_output if not do_cfg else torch.cat([step_output, step_output], dim=0)
|
| 407 |
latents = latents[:batch_size]
|
|
|
|
| 1 |
+
"""Hub custom pipeline: DiTMoEPipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import importlib
|
| 8 |
import inspect
|
| 9 |
import json
|
| 10 |
import sys
|
| 11 |
from pathlib import Path
|
| 12 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
| 13 |
|
| 14 |
import torch
|
| 15 |
|
|
|
|
| 64 |
Each checkpoint keeps an English `id2label` map in `model_index.json` (DiT-style).
|
| 65 |
"""
|
| 66 |
|
| 67 |
+
@staticmethod
|
| 68 |
+
def prepare_extra_step_kwargs(
|
| 69 |
+
scheduler,
|
| 70 |
+
generator=None,
|
| 71 |
+
eta: float | None = None,
|
| 72 |
+
):
|
| 73 |
+
kwargs = {}
|
| 74 |
+
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 75 |
+
if "generator" in step_params:
|
| 76 |
+
kwargs["generator"] = generator
|
| 77 |
+
if eta is not None and "eta" in step_params:
|
| 78 |
+
kwargs["eta"] = eta
|
| 79 |
+
return kwargs
|
| 80 |
+
|
| 81 |
model_cpu_offload_seq = "transformer->vae"
|
| 82 |
_optional_components = ["vae"]
|
| 83 |
|
|
|
|
| 292 |
dtype=dtype,
|
| 293 |
)
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
def _apply_cfg(self, model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
|
| 297 |
if guidance_scale <= 1.0:
|
|
|
|
| 395 |
timestep_batch[:batch_size] if do_cfg else timestep_batch,
|
| 396 |
latents_cfg,
|
| 397 |
next_timestep=next_timestep,
|
| 398 |
+
**extra_step_kwargs,
|
| 399 |
).prev_sample
|
| 400 |
latents = step_output if not do_cfg else torch.cat([step_output, step_output], dim=0)
|
| 401 |
latents = latents[:batch_size]
|
DiT-MoE-XL-8E2A/pipeline.py
CHANGED
|
@@ -1,23 +1,15 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
|
| 15 |
import importlib
|
| 16 |
import inspect
|
| 17 |
import json
|
| 18 |
import sys
|
| 19 |
from pathlib import Path
|
| 20 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 21 |
|
| 22 |
import torch
|
| 23 |
|
|
@@ -72,6 +64,20 @@ class DiTMoEPipeline(DiffusionPipeline):
|
|
| 72 |
Each checkpoint keeps an English `id2label` map in `model_index.json` (DiT-style).
|
| 73 |
"""
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
model_cpu_offload_seq = "transformer->vae"
|
| 76 |
_optional_components = ["vae"]
|
| 77 |
|
|
@@ -133,12 +139,7 @@ class DiTMoEPipeline(DiffusionPipeline):
|
|
| 133 |
|
| 134 |
id2label_override = kwargs.pop("id2label", None)
|
| 135 |
null_class_id_override = kwargs.pop("null_class_id", None)
|
| 136 |
-
use_flash_attn_override = kwargs.pop("use_flash_attn", None)
|
| 137 |
model_kwargs = dict(kwargs)
|
| 138 |
-
if use_flash_attn_override is not None:
|
| 139 |
-
model_kwargs["use_flash_attn"] = use_flash_attn_override
|
| 140 |
-
elif torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8:
|
| 141 |
-
model_kwargs["use_flash_attn"] = False
|
| 142 |
inserted: List[str] = []
|
| 143 |
|
| 144 |
def _load_component(folder: str, module_name: str, class_name: str):
|
|
@@ -291,19 +292,6 @@ class DiTMoEPipeline(DiffusionPipeline):
|
|
| 291 |
dtype=dtype,
|
| 292 |
)
|
| 293 |
|
| 294 |
-
@staticmethod
|
| 295 |
-
def prepare_extra_step_kwargs(
|
| 296 |
-
scheduler: KarrasDiffusionSchedulers,
|
| 297 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
|
| 298 |
-
eta: float,
|
| 299 |
-
) -> Dict[str, object]:
|
| 300 |
-
kwargs: Dict[str, object] = {}
|
| 301 |
-
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 302 |
-
if "eta" in step_params:
|
| 303 |
-
kwargs["eta"] = eta
|
| 304 |
-
if "generator" in step_params:
|
| 305 |
-
kwargs["generator"] = generator
|
| 306 |
-
return kwargs
|
| 307 |
|
| 308 |
def _apply_cfg(self, model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
|
| 309 |
if guidance_scale <= 1.0:
|
|
@@ -407,6 +395,7 @@ class DiTMoEPipeline(DiffusionPipeline):
|
|
| 407 |
timestep_batch[:batch_size] if do_cfg else timestep_batch,
|
| 408 |
latents_cfg,
|
| 409 |
next_timestep=next_timestep,
|
|
|
|
| 410 |
).prev_sample
|
| 411 |
latents = step_output if not do_cfg else torch.cat([step_output, step_output], dim=0)
|
| 412 |
latents = latents[:batch_size]
|
|
|
|
| 1 |
+
"""Hub custom pipeline: DiTMoEPipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import importlib
|
| 8 |
import inspect
|
| 9 |
import json
|
| 10 |
import sys
|
| 11 |
from pathlib import Path
|
| 12 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
| 13 |
|
| 14 |
import torch
|
| 15 |
|
|
|
|
| 64 |
Each checkpoint keeps an English `id2label` map in `model_index.json` (DiT-style).
|
| 65 |
"""
|
| 66 |
|
| 67 |
+
@staticmethod
|
| 68 |
+
def prepare_extra_step_kwargs(
|
| 69 |
+
scheduler,
|
| 70 |
+
generator=None,
|
| 71 |
+
eta: float | None = None,
|
| 72 |
+
):
|
| 73 |
+
kwargs = {}
|
| 74 |
+
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 75 |
+
if "generator" in step_params:
|
| 76 |
+
kwargs["generator"] = generator
|
| 77 |
+
if eta is not None and "eta" in step_params:
|
| 78 |
+
kwargs["eta"] = eta
|
| 79 |
+
return kwargs
|
| 80 |
+
|
| 81 |
model_cpu_offload_seq = "transformer->vae"
|
| 82 |
_optional_components = ["vae"]
|
| 83 |
|
|
|
|
| 139 |
|
| 140 |
id2label_override = kwargs.pop("id2label", None)
|
| 141 |
null_class_id_override = kwargs.pop("null_class_id", None)
|
|
|
|
| 142 |
model_kwargs = dict(kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
inserted: List[str] = []
|
| 144 |
|
| 145 |
def _load_component(folder: str, module_name: str, class_name: str):
|
|
|
|
| 292 |
dtype=dtype,
|
| 293 |
)
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
def _apply_cfg(self, model_output: torch.Tensor, guidance_scale: float) -> torch.Tensor:
|
| 297 |
if guidance_scale <= 1.0:
|
|
|
|
| 395 |
timestep_batch[:batch_size] if do_cfg else timestep_batch,
|
| 396 |
latents_cfg,
|
| 397 |
next_timestep=next_timestep,
|
| 398 |
+
**extra_step_kwargs,
|
| 399 |
).prev_sample
|
| 400 |
latents = step_output if not do_cfg else torch.cat([step_output, step_output], dim=0)
|
| 401 |
latents = latents[:batch_size]
|
README.md
CHANGED
|
@@ -9,9 +9,6 @@ tags:
|
|
| 9 |
- class-conditional
|
| 10 |
- dit-moe
|
| 11 |
pipeline_tag: unconditional-image-generation
|
| 12 |
-
widget:
|
| 13 |
-
- output:
|
| 14 |
-
url: DiT-MoE-XL-8E2A/demo.png
|
| 15 |
---
|
| 16 |
|
| 17 |
# DiT-MoE-diffusers
|
|
|
|
| 9 |
- class-conditional
|
| 10 |
- dit-moe
|
| 11 |
pipeline_tag: unconditional-image-generation
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
# DiT-MoE-diffusers
|