| |
| import torch |
|
|
| from diffusers import DiffusionPipeline |
|
|
|
|
| class UnetSchedulerOneForwardPipeline(DiffusionPipeline): |
| def __init__(self, unet, scheduler): |
| super().__init__() |
|
|
| self.register_modules(unet=unet, scheduler=scheduler) |
|
|
| def __call__(self): |
| image = torch.randn( |
| (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), |
| ) |
| timestep = 1 |
|
|
| model_output = self.unet(image, timestep).sample |
| scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample |
|
|
| result = scheduler_output - scheduler_output + torch.ones_like(scheduler_output) |
|
|
| return result |
|
|