BiliSakura's picture
Upload folder using huggingface_hub
24196fc verified
Raw
History Blame Contribute Delete
10.4 kB
"""Hub custom pipeline: ProMoEPipeline.
Load with native Hugging Face diffusers and trust_remote_code=True.
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
try:
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
except Exception: # pragma: no cover
class DiffusionPipeline:
def __init__(self):
self._execution_device = torch.device("cpu")
def register_modules(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def to(self, device):
self._execution_device = torch.device(device)
for module in (getattr(self, "transformer", None), getattr(self, "vae", None)):
if module is not None and hasattr(module, "to"):
module.to(device)
return self
def progress_bar(self, iterable):
return iterable
def maybe_free_model_hooks(self):
return None
@dataclass
class ProMoEPipelineOutput:
images: Union[List[Image.Image], np.ndarray, torch.Tensor]
class ProMoEPipeline(DiffusionPipeline):
r"""
Pipeline for class-conditional image generation with ProMoE.
Parameters:
transformer ([`ProMoETransformer2DModel`]):
Class-conditional ProMoE transformer for flow-matching in latent space.
scheduler ([`ProMoEFlowMatchScheduler`]):
Flow-matching scheduler used during denoising.
vae ([`AutoencoderKL`], *optional*):
Variational autoencoder used to decode latents to pixels.
id2label (`dict[int, str]`, *optional*):
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
"""
model_cpu_offload_seq = "transformer->vae"
_optional_components = ["vae"]
def __init__(
self,
transformer,
scheduler,
vae=None,
id2label: Optional[Dict[Union[int, str], str]] = None,
):
super().__init__()
self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
self._id2label = self._normalize_id2label(id2label)
self.labels = self._build_label2id(self._id2label)
self._labels_loaded_from_model_index = bool(self._id2label)
def _ensure_labels_loaded(self) -> None:
if self._labels_loaded_from_model_index:
return
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
if loaded:
self._id2label = loaded
self.labels = self._build_label2id(self._id2label)
self._labels_loaded_from_model_index = True
@staticmethod
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
if not id2label:
return {}
return {int(key): value for key, value in id2label.items()}
@staticmethod
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
if not variant_path:
return {}
variant_dir = Path(variant_path).resolve()
model_index_path = variant_dir / "model_index.json"
if not model_index_path.exists():
return {}
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
id2label = raw.get("id2label")
if not isinstance(id2label, dict):
return {}
return {int(key): value for key, value in id2label.items()}
@staticmethod
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
label2id: Dict[str, int] = {}
for class_id, value in id2label.items():
for synonym in value.split(","):
synonym = synonym.strip()
if synonym:
label2id[synonym] = int(class_id)
return dict(sorted(label2id.items()))
@property
def id2label(self) -> Dict[int, str]:
r"""ImageNet class id to English label string (comma-separated synonyms)."""
self._ensure_labels_loaded()
return self._id2label
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
r"""
Map ImageNet label strings to class ids.
Args:
label (`str` or `list[str]`):
One or more English label strings. Each string must match a synonym in `id2label`.
"""
self._ensure_labels_loaded()
label2id = self.labels
if not label2id:
raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
if isinstance(label, str):
label = [label]
missing = [item for item in label if item not in label2id]
if missing:
preview = ", ".join(list(label2id.keys())[:8])
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
return [label2id[item] for item in label]
def _get_vae_spatial_downsample(self) -> int:
if self.vae is None:
return 8
block_out_channels = getattr(getattr(self.vae, "config", None), "block_out_channels", [0, 0, 0, 0])
return 2 ** (len(block_out_channels) - 1)
def _normalize_class_labels(
self,
class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
device: torch.device,
) -> torch.LongTensor:
if torch.is_tensor(class_labels):
return class_labels.to(device=device, dtype=torch.long).reshape(-1)
if isinstance(class_labels, int):
class_label_ids = [class_labels]
elif isinstance(class_labels, str):
class_label_ids = self.get_label_ids(class_labels)
elif class_labels and isinstance(class_labels[0], str):
class_label_ids = self.get_label_ids(class_labels)
else:
class_label_ids = list(class_labels)
return torch.tensor(class_label_ids, device=device, dtype=torch.long).reshape(-1)
def _prepare_latents(
self,
batch_size: int,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
) -> torch.Tensor:
shape = (batch_size, self.transformer.in_channels, latent_height, latent_width)
if isinstance(generator, list):
latents = [torch.randn((1, *shape[1:]), generator=g, device=device, dtype=dtype) for g in generator]
return torch.cat(latents, dim=0)
return torch.randn(shape, generator=generator, device=device, dtype=dtype)
def _decode_latents(self, latents: torch.Tensor, output_type: str):
if output_type == "latent":
return latents
if self.vae is not None:
scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
decode_dtype = next(self.vae.parameters()).dtype
latents = (latents / scaling_factor).to(dtype=decode_dtype)
image = self.vae.decode(latents, return_dict=False)[0]
else:
image = latents
image = (image / 2 + 0.5).clamp(0, 1)
if output_type == "pt":
return image
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "np":
return image
pil_images = [Image.fromarray((img * 255).round().astype("uint8")) for img in image]
return pil_images
@torch.no_grad()
def __call__(
self,
class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
height: int = 256,
width: int = 256,
num_inference_steps: int = 50,
guidance_scale: float = 1.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: str = "pil",
return_dict: bool = True,
) -> Union[ProMoEPipelineOutput, Tuple]:
r"""
Generate class-conditional images with ProMoE.
Args:
class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
ImageNet class indices or human-readable English label strings.
"""
device = self._execution_device if hasattr(self, "_execution_device") else torch.device("cpu")
model_dtype = next(self.transformer.parameters()).dtype
class_labels = self._normalize_class_labels(class_labels, device)
batch_size = class_labels.shape[0]
vae_scale = self._get_vae_spatial_downsample()
latent_height = height // vae_scale
latent_width = width // vae_scale
latents = self._prepare_latents(batch_size, latent_height, latent_width, model_dtype, device, generator)
self.scheduler.set_timesteps(num_inference_steps, device=device)
null_labels = torch.full_like(class_labels, getattr(self.transformer.backbone.y_embedder, "num_classes", 1000))
for t in self.progress_bar(self.scheduler.timesteps):
if guidance_scale > 1.0:
latent_input = torch.cat([latents, latents], dim=0)
labels = torch.cat([class_labels, null_labels], dim=0)
else:
latent_input = latents
labels = class_labels
timestep = torch.full((labels.shape[0],), t, device=device, dtype=model_dtype)
model_output = self.transformer(
hidden_states=latent_input,
timestep=timestep,
class_labels=labels,
return_dict=True,
).sample
if model_output.shape[1] != latents.shape[1]:
model_output = model_output.chunk(2, dim=1)[0]
if guidance_scale > 1.0:
model_output_cond, model_output_uncond = model_output.chunk(2)
model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
latents = self.scheduler.step(model_output, t, latents, generator=generator).prev_sample
images = self._decode_latents(latents, output_type)
self.maybe_free_model_hooks()
if not return_dict:
return (images,)
return ProMoEPipelineOutput(images=images)