| | from __future__ import annotations |
| |
|
| | from typing import Sequence, Union |
| |
|
| | import torch |
| |
|
| | from diffusers import DiffusionPipeline |
| | from diffusers.pipelines.pipeline_utils import ImagePipelineOutput |
| |
|
| |
|
| | class BitDanceImageNetPipeline(DiffusionPipeline): |
| | model_cpu_offload_seq = "transformer" |
| |
|
| | def __init__(self, transformer, autoencoder=None): |
| | super().__init__() |
| | self.register_modules(transformer=transformer, autoencoder=autoencoder) |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | class_labels: Union[int, Sequence[int]] = 0, |
| | num_images_per_label: int = 1, |
| | sample_steps: int = 100, |
| | cfg_scale: float = 4.6, |
| | chunk_size: int = 0, |
| | output_type: str = "pil", |
| | return_dict: bool = True, |
| | ): |
| | device = self._execution_device |
| |
|
| | if isinstance(class_labels, int): |
| | labels = [class_labels] |
| | else: |
| | labels = list(class_labels) |
| |
|
| | class_ids = torch.tensor(labels, device=device, dtype=torch.long) |
| | if num_images_per_label > 1: |
| | class_ids = class_ids.repeat_interleave(num_images_per_label) |
| |
|
| | images = self.transformer.sample( |
| | class_ids=class_ids, |
| | sample_steps=sample_steps, |
| | cfg_scale=cfg_scale, |
| | chunk_size=chunk_size, |
| | ) |
| |
|
| | images = (images / 2 + 0.5).clamp(0, 1) |
| | images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
| |
|
| | if output_type == "pil": |
| | images = self.numpy_to_pil(images) |
| |
|
| | if not return_dict: |
| | return (images,) |
| |
|
| | return ImagePipelineOutput(images=images) |
| |
|