BitDance-ImageNet-diffusers / BitDance_B_16x /pipeline_bitdance_imagenet.py
BiliSakura's picture
Update all files for BitDance-ImageNet-diffusers
42b230a verified
from __future__ import annotations
from typing import Sequence, Union
import torch
from diffusers import DiffusionPipeline
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
class BitDanceImageNetPipeline(DiffusionPipeline):
model_cpu_offload_seq = "transformer"
def __init__(self, transformer, autoencoder=None):
super().__init__()
self.register_modules(transformer=transformer, autoencoder=autoencoder)
@torch.no_grad()
def __call__(
self,
class_labels: Union[int, Sequence[int]] = 0,
num_images_per_label: int = 1,
sample_steps: int = 100,
cfg_scale: float = 4.6,
chunk_size: int = 0,
output_type: str = "pil",
return_dict: bool = True,
):
device = self._execution_device
if isinstance(class_labels, int):
labels = [class_labels]
else:
labels = list(class_labels)
class_ids = torch.tensor(labels, device=device, dtype=torch.long)
if num_images_per_label > 1:
class_ids = class_ids.repeat_interleave(num_images_per_label)
images = self.transformer.sample(
class_ids=class_ids,
sample_steps=sample_steps,
cfg_scale=cfg_scale,
chunk_size=chunk_size,
)
images = (images / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
images = self.numpy_to_pil(images)
if not return_dict:
return (images,)
return ImagePipelineOutput(images=images)