| import torch |
| from diffusers import DiffusionPipeline |
|
|
|
|
| class MyPipeline(DiffusionPipeline): |
| def __init__(self, unet, scheduler): |
| super().__init__() |
|
|
| self.register_modules(unet=unet, scheduler=scheduler) |
|
|
| @torch.no_grad() |
| def __call__(self, batch_size: int = 1, num_inference_steps: int = 50): |
| |
| image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)) |
|
|
| image = image.to(self.device) |
|
|
| |
| self.scheduler.set_timesteps(num_inference_steps) |
|
|
| for t in self.progress_bar(self.scheduler.timesteps): |
| |
| model_output = self.unet(image, t).sample |
|
|
| |
| |
| |
| image = self.scheduler.step(model_output, t, image).prev_sample |
|
|
| image = (image / 2 + 0.5).clamp(0, 1) |
| image = image.cpu().permute(0, 2, 3, 1).numpy() |
|
|
| return image |
|
|