| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from math import pi |
| | from typing import Callable, List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| |
|
| | from diffusers import DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DModel |
| | from diffusers.utils.torch_utils import randn_tensor |
| |
|
| |
|
| | class DPSPipeline(DiffusionPipeline): |
| | r""" |
| | Pipeline for Diffusion Posterior Sampling. |
| | |
| | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods |
| | implemented for all pipelines (downloading, saving, running on a particular device, etc.). |
| | |
| | Parameters: |
| | unet ([`UNet2DModel`]): |
| | A `UNet2DModel` to denoise the encoded image latents. |
| | scheduler ([`SchedulerMixin`]): |
| | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of |
| | [`DDPMScheduler`], or [`DDIMScheduler`]. |
| | """ |
| |
|
| | model_cpu_offload_seq = "unet" |
| |
|
| | def __init__(self, unet, scheduler): |
| | super().__init__() |
| | self.register_modules(unet=unet, scheduler=scheduler) |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | measurement: torch.Tensor, |
| | operator: torch.nn.Module, |
| | loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], |
| | batch_size: int = 1, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | num_inference_steps: int = 1000, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | zeta: float = 0.3, |
| | ) -> Union[ImagePipelineOutput, Tuple]: |
| | r""" |
| | The call function to the pipeline for generation. |
| | |
| | Args: |
| | measurement (`torch.Tensor`, *required*): |
| | A 'torch.Tensor', the corrupted image |
| | operator (`torch.nn.Module`, *required*): |
| | A 'torch.nn.Module', the operator generating the corrupted image |
| | loss_fn (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *required*): |
| | A 'Callable[[torch.Tensor, torch.Tensor], torch.Tensor]', the loss function used |
| | between the measurements, for most of the cases using RMSE is fine. |
| | batch_size (`int`, *optional*, defaults to 1): |
| | The number of images to generate. |
| | generator (`torch.Generator`, *optional*): |
| | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
| | generation deterministic. |
| | num_inference_steps (`int`, *optional*, defaults to 1000): |
| | The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| | expense of slower inference. |
| | output_type (`str`, *optional*, defaults to `"pil"`): |
| | The output format of the generated image. Choose between `PIL.Image` or `np.array`. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. |
| | |
| | Example: |
| | |
| | ```py |
| | >>> from diffusers import DDPMPipeline |
| | |
| | >>> # load model and scheduler |
| | >>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256") |
| | |
| | >>> # run pipeline in inference (sample random noise and denoise) |
| | >>> image = pipe().images[0] |
| | |
| | >>> # save image |
| | >>> image.save("ddpm_generated_image.png") |
| | ``` |
| | |
| | Returns: |
| | [`~pipelines.ImagePipelineOutput`] or `tuple`: |
| | If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is |
| | returned where the first element is a list with the generated images |
| | """ |
| | |
| | if isinstance(self.unet.config.sample_size, int): |
| | image_shape = ( |
| | batch_size, |
| | self.unet.config.in_channels, |
| | self.unet.config.sample_size, |
| | self.unet.config.sample_size, |
| | ) |
| | else: |
| | image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) |
| |
|
| | if self.device.type == "mps": |
| | |
| | image = randn_tensor(image_shape, generator=generator) |
| | image = image.to(self.device) |
| | else: |
| | image = randn_tensor(image_shape, generator=generator, device=self.device) |
| |
|
| | |
| | self.scheduler.set_timesteps(num_inference_steps) |
| |
|
| | for t in self.progress_bar(self.scheduler.timesteps): |
| | with torch.enable_grad(): |
| | |
| | image = image.requires_grad_() |
| | model_output = self.unet(image, t).sample |
| |
|
| | |
| | scheduler_out = self.scheduler.step(model_output, t, image, generator=generator) |
| | image_pred, origi_pred = scheduler_out.prev_sample, scheduler_out.pred_original_sample |
| |
|
| | |
| | measurement_pred = operator(origi_pred) |
| |
|
| | |
| | loss = loss_fn(measurement, measurement_pred) |
| | loss.backward() |
| |
|
| | print("distance: {0:.4f}".format(loss.item())) |
| |
|
| | with torch.no_grad(): |
| | image_pred = image_pred - zeta * image.grad |
| | image = image_pred.detach() |
| |
|
| | image = (image / 2 + 0.5).clamp(0, 1) |
| | image = image.cpu().permute(0, 2, 3, 1).numpy() |
| | if output_type == "pil": |
| | image = self.numpy_to_pil(image) |
| |
|
| | if not return_dict: |
| | return (image,) |
| |
|
| | return ImagePipelineOutput(images=image) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import scipy |
| | from torch import nn |
| | from torchvision.utils import save_image |
| |
|
| | |
| | |
| | class SuperResolutionOperator(nn.Module): |
| | def __init__(self, in_shape, scale_factor): |
| | super().__init__() |
| |
|
| | |
| | class Resizer(nn.Module): |
| | def __init__(self, in_shape, scale_factor=None, output_shape=None, kernel=None, antialiasing=True): |
| | super(Resizer, self).__init__() |
| |
|
| | |
| | scale_factor, output_shape = self.fix_scale_and_size(in_shape, output_shape, scale_factor) |
| |
|
| | |
| | def cubic(x): |
| | absx = np.abs(x) |
| | absx2 = absx**2 |
| | absx3 = absx**3 |
| | return (1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) + ( |
| | -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 |
| | ) * ((1 < absx) & (absx <= 2)) |
| |
|
| | def lanczos2(x): |
| | return ( |
| | (np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps) |
| | / ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps) |
| | ) * (abs(x) < 2) |
| |
|
| | def box(x): |
| | return ((-0.5 <= x) & (x < 0.5)) * 1.0 |
| |
|
| | def lanczos3(x): |
| | return ( |
| | (np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps) |
| | / ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps) |
| | ) * (abs(x) < 3) |
| |
|
| | def linear(x): |
| | return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) |
| |
|
| | method, kernel_width = { |
| | "cubic": (cubic, 4.0), |
| | "lanczos2": (lanczos2, 4.0), |
| | "lanczos3": (lanczos3, 6.0), |
| | "box": (box, 1.0), |
| | "linear": (linear, 2.0), |
| | None: (cubic, 4.0), |
| | }.get(kernel) |
| |
|
| | |
| | antialiasing *= np.any(np.array(scale_factor) < 1) |
| |
|
| | |
| | sorted_dims = np.argsort(np.array(scale_factor)) |
| | self.sorted_dims = [int(dim) for dim in sorted_dims if scale_factor[dim] != 1] |
| |
|
| | |
| | field_of_view_list = [] |
| | weights_list = [] |
| | for dim in self.sorted_dims: |
| | |
| | |
| | weights, field_of_view = self.contributions( |
| | in_shape[dim], output_shape[dim], scale_factor[dim], method, kernel_width, antialiasing |
| | ) |
| |
|
| | |
| | weights = torch.tensor(weights.T, dtype=torch.float32) |
| |
|
| | |
| | |
| | weights_list.append( |
| | nn.Parameter( |
| | torch.reshape(weights, list(weights.shape) + (len(scale_factor) - 1) * [1]), |
| | requires_grad=False, |
| | ) |
| | ) |
| | field_of_view_list.append( |
| | nn.Parameter( |
| | torch.tensor(field_of_view.T.astype(np.int32), dtype=torch.long), requires_grad=False |
| | ) |
| | ) |
| |
|
| | self.field_of_view = nn.ParameterList(field_of_view_list) |
| | self.weights = nn.ParameterList(weights_list) |
| |
|
| | def forward(self, in_tensor): |
| | x = in_tensor |
| |
|
| | |
| | for dim, fov, w in zip(self.sorted_dims, self.field_of_view, self.weights): |
| | |
| | x = torch.transpose(x, dim, 0) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | x = torch.sum(x[fov] * w, dim=0) |
| |
|
| | |
| | x = torch.transpose(x, dim, 0) |
| |
|
| | return x |
| |
|
| | def fix_scale_and_size(self, input_shape, output_shape, scale_factor): |
| | |
| | |
| | if scale_factor is not None: |
| | |
| | if np.isscalar(scale_factor) and len(input_shape) > 1: |
| | scale_factor = [scale_factor, scale_factor] |
| |
|
| | |
| | scale_factor = list(scale_factor) |
| | scale_factor = [1] * (len(input_shape) - len(scale_factor)) + scale_factor |
| |
|
| | |
| | |
| | if output_shape is not None: |
| | output_shape = list(input_shape[len(output_shape) :]) + list(np.uint(np.array(output_shape))) |
| |
|
| | |
| | |
| | if scale_factor is None: |
| | scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape) |
| |
|
| | |
| | if output_shape is None: |
| | output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor))) |
| |
|
| | return scale_factor, output_shape |
| |
|
| | def contributions(self, in_length, out_length, scale, kernel, kernel_width, antialiasing): |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel |
| | kernel_width *= 1.0 / scale if antialiasing else 1.0 |
| |
|
| | |
| | out_coordinates = np.arange(1, out_length + 1) |
| |
|
| | |
| | |
| | |
| | shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | match_coordinates = shifted_out_coordinates / scale + 0.5 * (1 - 1 / scale) |
| |
|
| | |
| | left_boundary = np.floor(match_coordinates - kernel_width / 2) |
| |
|
| | |
| | |
| | expanded_kernel_width = np.ceil(kernel_width) + 2 |
| |
|
| | |
| | |
| | |
| | field_of_view = np.squeeze( |
| | np.int16(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1) |
| | ) |
| |
|
| | |
| | |
| | |
| | weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1) |
| |
|
| | |
| | sum_weights = np.sum(weights, axis=1) |
| | sum_weights[sum_weights == 0] = 1.0 |
| | weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1) |
| |
|
| | |
| | mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1)))) |
| | field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])] |
| |
|
| | |
| | non_zero_out_pixels = np.nonzero(np.any(weights, axis=0)) |
| | weights = np.squeeze(weights[:, non_zero_out_pixels]) |
| | field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels]) |
| |
|
| | |
| | return weights, field_of_view |
| |
|
| | self.down_sample = Resizer(in_shape, 1 / scale_factor) |
| | for param in self.parameters(): |
| | param.requires_grad = False |
| |
|
| | def forward(self, data, **kwargs): |
| | return self.down_sample(data) |
| |
|
| | |
| | class GaussialBlurOperator(nn.Module): |
| | def __init__(self, kernel_size, intensity): |
| | super().__init__() |
| |
|
| | class Blurkernel(nn.Module): |
| | def __init__(self, blur_type="gaussian", kernel_size=31, std=3.0): |
| | super().__init__() |
| | self.blur_type = blur_type |
| | self.kernel_size = kernel_size |
| | self.std = std |
| | self.seq = nn.Sequential( |
| | nn.ReflectionPad2d(self.kernel_size // 2), |
| | nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3), |
| | ) |
| | self.weights_init() |
| |
|
| | def forward(self, x): |
| | return self.seq(x) |
| |
|
| | def weights_init(self): |
| | if self.blur_type == "gaussian": |
| | n = np.zeros((self.kernel_size, self.kernel_size)) |
| | n[self.kernel_size // 2, self.kernel_size // 2] = 1 |
| | k = scipy.ndimage.gaussian_filter(n, sigma=self.std) |
| | k = torch.from_numpy(k) |
| | self.k = k |
| | for name, f in self.named_parameters(): |
| | f.data.copy_(k) |
| |
|
| | def update_weights(self, k): |
| | if not torch.is_tensor(k): |
| | k = torch.from_numpy(k) |
| | for name, f in self.named_parameters(): |
| | f.data.copy_(k) |
| |
|
| | def get_kernel(self): |
| | return self.k |
| |
|
| | self.kernel_size = kernel_size |
| | self.conv = Blurkernel(blur_type="gaussian", kernel_size=kernel_size, std=intensity) |
| | self.kernel = self.conv.get_kernel() |
| | self.conv.update_weights(self.kernel.type(torch.float32)) |
| |
|
| | for param in self.parameters(): |
| | param.requires_grad = False |
| |
|
| | def forward(self, data, **kwargs): |
| | return self.conv(data) |
| |
|
| | def transpose(self, data, **kwargs): |
| | return data |
| |
|
| | def get_kernel(self): |
| | return self.kernel.view(1, 1, self.kernel_size, self.kernel_size) |
| |
|
| | |
| | def RMSELoss(yhat, y): |
| | return torch.sqrt(torch.sum((yhat - y) ** 2)) |
| |
|
| | |
| | src = Image.open("sample.png") |
| | |
| | src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2, 0, 1)[None] |
| | |
| | src = (src / 127.5) - 1.0 |
| | src = src.to("cuda") |
| |
|
| | |
| | |
| | operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda") |
| | measurement = operator(src) |
| |
|
| | |
| | scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256") |
| | scheduler.set_timesteps(1000) |
| |
|
| | |
| | model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda") |
| |
|
| | save_image((src + 1.0) / 2.0, "dps_src.png") |
| | save_image((measurement + 1.0) / 2.0, "dps_mea.png") |
| |
|
| | |
| | dpspipe = DPSPipeline(model, scheduler) |
| | image = dpspipe( |
| | measurement=measurement, |
| | operator=operator, |
| | loss_fn=RMSELoss, |
| | zeta=1.0, |
| | ).images[0] |
| |
|
| | image.save("dps_generated_image.png") |
| |
|