from __future__ import annotations from typing import Sequence, Union import torch from diffusers import DiffusionPipeline from diffusers.pipelines.pipeline_utils import ImagePipelineOutput class BitDanceImageNetPipeline(DiffusionPipeline): model_cpu_offload_seq = "transformer" def __init__(self, transformer, autoencoder=None): super().__init__() self.register_modules(transformer=transformer, autoencoder=autoencoder) @torch.no_grad() def __call__( self, class_labels: Union[int, Sequence[int]] = 0, num_images_per_label: int = 1, sample_steps: int = 100, cfg_scale: float = 4.6, chunk_size: int = 0, output_type: str = "pil", return_dict: bool = True, ): device = self._execution_device if isinstance(class_labels, int): labels = [class_labels] else: labels = list(class_labels) class_ids = torch.tensor(labels, device=device, dtype=torch.long) if num_images_per_label > 1: class_ids = class_ids.repeat_interleave(num_images_per_label) images = self.transformer.sample( class_ids=class_ids, sample_steps=sample_steps, cfg_scale=cfg_scale, chunk_size=chunk_size, ) images = (images / 2 + 0.5).clamp(0, 1) images = images.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": images = self.numpy_to_pil(images) if not return_dict: return (images,) return ImagePipelineOutput(images=images)