Unconditional Image Generation
Diffusers
Safetensors
sit
image-generation
class-conditional
imagenet
Instructions to use BiliSakura/SiT-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/SiT-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/SiT-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from typing import List, Optional, Union | |
| import torch | |
| from diffusers.image_processor import VaeImageProcessor | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput | |
| from diffusers.utils.torch_utils import randn_tensor | |
| class SiTPipeline(DiffusionPipeline): | |
| model_cpu_offload_seq = "transformer->vae" | |
| def __init__(self, transformer, scheduler, vae): | |
| super().__init__() | |
| self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae) | |
| self.vae_scale_factor = 8 | |
| self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | |
| def __call__( | |
| self, | |
| class_labels: Union[int, List[int]] = 207, | |
| height: int = 256, | |
| width: int = 256, | |
| num_inference_steps: int = 250, | |
| guidance_scale: float = 4.0, | |
| generator: Optional[torch.Generator] = None, | |
| output_type: str = "pil", | |
| return_dict: bool = True, | |
| ): | |
| device = self._execution_device | |
| if isinstance(class_labels, int): | |
| class_labels = [class_labels] | |
| batch_size = len(class_labels) | |
| latent_h = height // self.vae_scale_factor | |
| latent_w = width // self.vae_scale_factor | |
| latents = randn_tensor( | |
| (batch_size, self.transformer.config.in_channels, latent_h, latent_w), | |
| generator=generator, | |
| device=device, | |
| dtype=self.transformer.dtype, | |
| ) | |
| labels = torch.tensor(class_labels, device=device, dtype=torch.long) | |
| do_cfg = guidance_scale is not None and guidance_scale > 1.0 | |
| if do_cfg: | |
| null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long) | |
| labels = torch.cat([labels, null_label], dim=0) | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| for t in self.progress_bar(timesteps): | |
| t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype) | |
| model_input = latents | |
| if do_cfg: | |
| model_input = torch.cat([latents, latents], dim=0) | |
| t_batch = torch.cat([t_batch, t_batch], dim=0) | |
| model_pred = self.transformer( | |
| hidden_states=model_input, | |
| timestep=t_batch, | |
| class_labels=labels, | |
| ).sample | |
| if do_cfg: | |
| cond, uncond = model_pred.chunk(2, dim=0) | |
| model_pred = uncond + guidance_scale * (cond - uncond) | |
| latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample | |
| image = self.vae.decode(latents / 0.18215).sample | |
| # Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts. | |
| if output_type == "pt": | |
| image = image | |
| else: | |
| image = self.image_processor.postprocess(image, output_type=output_type) | |
| if not return_dict: | |
| return (image,) | |
| return ImagePipelineOutput(images=image) | |