Instructions to use BiliSakura/ProMoE-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/ProMoE-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/ProMoE-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| """Hub custom pipeline: ProMoEPipeline. | |
| Load with native Hugging Face diffusers and trust_remote_code=True. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| try: | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
| except Exception: # pragma: no cover | |
| class DiffusionPipeline: | |
| def __init__(self): | |
| self._execution_device = torch.device("cpu") | |
| def register_modules(self, **kwargs): | |
| for key, value in kwargs.items(): | |
| setattr(self, key, value) | |
| def to(self, device): | |
| self._execution_device = torch.device(device) | |
| for module in (getattr(self, "transformer", None), getattr(self, "vae", None)): | |
| if module is not None and hasattr(module, "to"): | |
| module.to(device) | |
| return self | |
| def progress_bar(self, iterable): | |
| return iterable | |
| def maybe_free_model_hooks(self): | |
| return None | |
| class ProMoEPipelineOutput: | |
| images: Union[List[Image.Image], np.ndarray, torch.Tensor] | |
| class ProMoEPipeline(DiffusionPipeline): | |
| r""" | |
| Pipeline for class-conditional image generation with ProMoE. | |
| Parameters: | |
| transformer ([`ProMoETransformer2DModel`]): | |
| Class-conditional ProMoE transformer for flow-matching in latent space. | |
| scheduler ([`ProMoEFlowMatchScheduler`]): | |
| Flow-matching scheduler used during denoising. | |
| vae ([`AutoencoderKL`], *optional*): | |
| Variational autoencoder used to decode latents to pixels. | |
| id2label (`dict[int, str]`, *optional*): | |
| ImageNet class id to English label mapping. Values may contain comma-separated synonyms. | |
| """ | |
| model_cpu_offload_seq = "transformer->vae" | |
| _optional_components = ["vae"] | |
| def __init__( | |
| self, | |
| transformer, | |
| scheduler, | |
| vae=None, | |
| id2label: Optional[Dict[Union[int, str], str]] = None, | |
| ): | |
| super().__init__() | |
| self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae) | |
| self._id2label = self._normalize_id2label(id2label) | |
| self.labels = self._build_label2id(self._id2label) | |
| self._labels_loaded_from_model_index = bool(self._id2label) | |
| def _ensure_labels_loaded(self) -> None: | |
| if self._labels_loaded_from_model_index: | |
| return | |
| loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None)) | |
| if loaded: | |
| self._id2label = loaded | |
| self.labels = self._build_label2id(self._id2label) | |
| self._labels_loaded_from_model_index = True | |
| def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]: | |
| if not id2label: | |
| return {} | |
| return {int(key): value for key, value in id2label.items()} | |
| def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]: | |
| if not variant_path: | |
| return {} | |
| variant_dir = Path(variant_path).resolve() | |
| model_index_path = variant_dir / "model_index.json" | |
| if not model_index_path.exists(): | |
| return {} | |
| raw = json.loads(model_index_path.read_text(encoding="utf-8")) | |
| id2label = raw.get("id2label") | |
| if not isinstance(id2label, dict): | |
| return {} | |
| return {int(key): value for key, value in id2label.items()} | |
| def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: | |
| label2id: Dict[str, int] = {} | |
| for class_id, value in id2label.items(): | |
| for synonym in value.split(","): | |
| synonym = synonym.strip() | |
| if synonym: | |
| label2id[synonym] = int(class_id) | |
| return dict(sorted(label2id.items())) | |
| def id2label(self) -> Dict[int, str]: | |
| r"""ImageNet class id to English label string (comma-separated synonyms).""" | |
| self._ensure_labels_loaded() | |
| return self._id2label | |
| def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: | |
| r""" | |
| Map ImageNet label strings to class ids. | |
| Args: | |
| label (`str` or `list[str]`): | |
| One or more English label strings. Each string must match a synonym in `id2label`. | |
| """ | |
| self._ensure_labels_loaded() | |
| label2id = self.labels | |
| if not label2id: | |
| raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.") | |
| if isinstance(label, str): | |
| label = [label] | |
| missing = [item for item in label if item not in label2id] | |
| if missing: | |
| preview = ", ".join(list(label2id.keys())[:8]) | |
| raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...") | |
| return [label2id[item] for item in label] | |
| def _get_vae_spatial_downsample(self) -> int: | |
| if self.vae is None: | |
| return 8 | |
| block_out_channels = getattr(getattr(self.vae, "config", None), "block_out_channels", [0, 0, 0, 0]) | |
| return 2 ** (len(block_out_channels) - 1) | |
| def _normalize_class_labels( | |
| self, | |
| class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor], | |
| device: torch.device, | |
| ) -> torch.LongTensor: | |
| if torch.is_tensor(class_labels): | |
| return class_labels.to(device=device, dtype=torch.long).reshape(-1) | |
| if isinstance(class_labels, int): | |
| class_label_ids = [class_labels] | |
| elif isinstance(class_labels, str): | |
| class_label_ids = self.get_label_ids(class_labels) | |
| elif class_labels and isinstance(class_labels[0], str): | |
| class_label_ids = self.get_label_ids(class_labels) | |
| else: | |
| class_label_ids = list(class_labels) | |
| return torch.tensor(class_label_ids, device=device, dtype=torch.long).reshape(-1) | |
| def _prepare_latents( | |
| self, | |
| batch_size: int, | |
| latent_height: int, | |
| latent_width: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]], | |
| ) -> torch.Tensor: | |
| shape = (batch_size, self.transformer.in_channels, latent_height, latent_width) | |
| if isinstance(generator, list): | |
| latents = [torch.randn((1, *shape[1:]), generator=g, device=device, dtype=dtype) for g in generator] | |
| return torch.cat(latents, dim=0) | |
| return torch.randn(shape, generator=generator, device=device, dtype=dtype) | |
| def _decode_latents(self, latents: torch.Tensor, output_type: str): | |
| if output_type == "latent": | |
| return latents | |
| if self.vae is not None: | |
| scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) | |
| decode_dtype = next(self.vae.parameters()).dtype | |
| latents = (latents / scaling_factor).to(dtype=decode_dtype) | |
| image = self.vae.decode(latents, return_dict=False)[0] | |
| else: | |
| image = latents | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| if output_type == "pt": | |
| return image | |
| image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy() | |
| if output_type == "np": | |
| return image | |
| pil_images = [Image.fromarray((img * 255).round().astype("uint8")) for img in image] | |
| return pil_images | |
| def __call__( | |
| self, | |
| class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor], | |
| height: int = 256, | |
| width: int = 256, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 1.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| output_type: str = "pil", | |
| return_dict: bool = True, | |
| ) -> Union[ProMoEPipelineOutput, Tuple]: | |
| r""" | |
| Generate class-conditional images with ProMoE. | |
| Args: | |
| class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`): | |
| ImageNet class indices or human-readable English label strings. | |
| """ | |
| device = self._execution_device if hasattr(self, "_execution_device") else torch.device("cpu") | |
| model_dtype = next(self.transformer.parameters()).dtype | |
| class_labels = self._normalize_class_labels(class_labels, device) | |
| batch_size = class_labels.shape[0] | |
| vae_scale = self._get_vae_spatial_downsample() | |
| latent_height = height // vae_scale | |
| latent_width = width // vae_scale | |
| latents = self._prepare_latents(batch_size, latent_height, latent_width, model_dtype, device, generator) | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| null_labels = torch.full_like(class_labels, getattr(self.transformer.backbone.y_embedder, "num_classes", 1000)) | |
| for t in self.progress_bar(self.scheduler.timesteps): | |
| if guidance_scale > 1.0: | |
| latent_input = torch.cat([latents, latents], dim=0) | |
| labels = torch.cat([class_labels, null_labels], dim=0) | |
| else: | |
| latent_input = latents | |
| labels = class_labels | |
| timestep = torch.full((labels.shape[0],), t, device=device, dtype=model_dtype) | |
| model_output = self.transformer( | |
| hidden_states=latent_input, | |
| timestep=timestep, | |
| class_labels=labels, | |
| return_dict=True, | |
| ).sample | |
| if model_output.shape[1] != latents.shape[1]: | |
| model_output = model_output.chunk(2, dim=1)[0] | |
| if guidance_scale > 1.0: | |
| model_output_cond, model_output_uncond = model_output.chunk(2) | |
| model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond) | |
| latents = self.scheduler.step(model_output, t, latents, generator=generator).prev_sample | |
| images = self._decode_latents(latents, output_type) | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (images,) | |
| return ProMoEPipelineOutput(images=images) |