|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Sequence, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
from diffusers import DiffusionPipeline |
|
|
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput |
|
|
|
|
|
|
|
|
class BitDanceImageNetPipeline(DiffusionPipeline): |
|
|
model_cpu_offload_seq = "transformer" |
|
|
|
|
|
def __init__(self, transformer, autoencoder=None): |
|
|
super().__init__() |
|
|
self.register_modules(transformer=transformer, autoencoder=autoencoder) |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, |
|
|
class_labels: Union[int, Sequence[int]] = 0, |
|
|
num_images_per_label: int = 1, |
|
|
sample_steps: int = 100, |
|
|
cfg_scale: float = 4.6, |
|
|
chunk_size: int = 0, |
|
|
output_type: str = "pil", |
|
|
return_dict: bool = True, |
|
|
): |
|
|
device = self._execution_device |
|
|
|
|
|
if isinstance(class_labels, int): |
|
|
labels = [class_labels] |
|
|
else: |
|
|
labels = list(class_labels) |
|
|
|
|
|
class_ids = torch.tensor(labels, device=device, dtype=torch.long) |
|
|
if num_images_per_label > 1: |
|
|
class_ids = class_ids.repeat_interleave(num_images_per_label) |
|
|
|
|
|
images = self.transformer.sample( |
|
|
class_ids=class_ids, |
|
|
sample_steps=sample_steps, |
|
|
cfg_scale=cfg_scale, |
|
|
chunk_size=chunk_size, |
|
|
) |
|
|
|
|
|
images = (images / 2 + 0.5).clamp(0, 1) |
|
|
images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
|
|
|
|
|
if output_type == "pil": |
|
|
images = self.numpy_to_pil(images) |
|
|
|
|
|
if not return_dict: |
|
|
return (images,) |
|
|
|
|
|
return ImagePipelineOutput(images=images) |
|
|
|