| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from math import pi |
| from typing import Callable, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
| from diffusers import DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DModel |
| from diffusers.utils.torch_utils import randn_tensor |
|
|
|
|
| class DPSPipeline(DiffusionPipeline): |
| r""" |
| Pipeline for Diffusion Posterior Sampling. |
| |
| This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods |
| implemented for all pipelines (downloading, saving, running on a particular device, etc.). |
| |
| Parameters: |
| unet ([`UNet2DModel`]): |
| A `UNet2DModel` to denoise the encoded image latents. |
| scheduler ([`SchedulerMixin`]): |
| A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of |
| [`DDPMScheduler`], or [`DDIMScheduler`]. |
| """ |
|
|
| model_cpu_offload_seq = "unet" |
|
|
| def __init__(self, unet, scheduler): |
| super().__init__() |
| self.register_modules(unet=unet, scheduler=scheduler) |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| measurement: torch.Tensor, |
| operator: torch.nn.Module, |
| loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], |
| batch_size: int = 1, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| num_inference_steps: int = 1000, |
| output_type: Optional[str] = "pil", |
| return_dict: bool = True, |
| zeta: float = 0.3, |
| ) -> Union[ImagePipelineOutput, Tuple]: |
| r""" |
| The call function to the pipeline for generation. |
| |
| Args: |
| measurement (`torch.Tensor`, *required*): |
| A 'torch.Tensor', the corrupted image |
| operator (`torch.nn.Module`, *required*): |
| A 'torch.nn.Module', the operator generating the corrupted image |
| loss_fn (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *required*): |
| A 'Callable[[torch.Tensor, torch.Tensor], torch.Tensor]', the loss function used |
| between the measurements, for most of the cases using RMSE is fine. |
| batch_size (`int`, *optional*, defaults to 1): |
| The number of images to generate. |
| generator (`torch.Generator`, *optional*): |
| A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
| generation deterministic. |
| num_inference_steps (`int`, *optional*, defaults to 1000): |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| expense of slower inference. |
| output_type (`str`, *optional*, defaults to `"pil"`): |
| The output format of the generated image. Choose between `PIL.Image` or `np.array`. |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. |
| |
| Example: |
| |
| ```py |
| >>> from diffusers import DDPMPipeline |
| |
| >>> # load model and scheduler |
| >>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256") |
| |
| >>> # run pipeline in inference (sample random noise and denoise) |
| >>> image = pipe().images[0] |
| |
| >>> # save image |
| >>> image.save("ddpm_generated_image.png") |
| ``` |
| |
| Returns: |
| [`~pipelines.ImagePipelineOutput`] or `tuple`: |
| If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is |
| returned where the first element is a list with the generated images |
| """ |
| |
| if isinstance(self.unet.config.sample_size, int): |
| image_shape = ( |
| batch_size, |
| self.unet.config.in_channels, |
| self.unet.config.sample_size, |
| self.unet.config.sample_size, |
| ) |
| else: |
| image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) |
|
|
| if self.device.type == "mps": |
| |
| image = randn_tensor(image_shape, generator=generator) |
| image = image.to(self.device) |
| else: |
| image = randn_tensor(image_shape, generator=generator, device=self.device) |
|
|
| |
| self.scheduler.set_timesteps(num_inference_steps) |
|
|
| for t in self.progress_bar(self.scheduler.timesteps): |
| with torch.enable_grad(): |
| |
| image = image.requires_grad_() |
| model_output = self.unet(image, t).sample |
|
|
| |
| scheduler_out = self.scheduler.step(model_output, t, image, generator=generator) |
| image_pred, origi_pred = scheduler_out.prev_sample, scheduler_out.pred_original_sample |
|
|
| |
| measurement_pred = operator(origi_pred) |
|
|
| |
| loss = loss_fn(measurement, measurement_pred) |
| loss.backward() |
|
|
| print("distance: {0:.4f}".format(loss.item())) |
|
|
| with torch.no_grad(): |
| image_pred = image_pred - zeta * image.grad |
| image = image_pred.detach() |
|
|
| image = (image / 2 + 0.5).clamp(0, 1) |
| image = image.cpu().permute(0, 2, 3, 1).numpy() |
| if output_type == "pil": |
| image = self.numpy_to_pil(image) |
|
|
| if not return_dict: |
| return (image,) |
|
|
| return ImagePipelineOutput(images=image) |
|
|
|
|
| if __name__ == "__main__": |
| import scipy |
| from torch import nn |
| from torchvision.utils import save_image |
|
|
| |
| |
| class SuperResolutionOperator(nn.Module): |
| def __init__(self, in_shape, scale_factor): |
| super().__init__() |
|
|
| |
| class Resizer(nn.Module): |
| def __init__(self, in_shape, scale_factor=None, output_shape=None, kernel=None, antialiasing=True): |
| super(Resizer, self).__init__() |
|
|
| |
| scale_factor, output_shape = self.fix_scale_and_size(in_shape, output_shape, scale_factor) |
|
|
| |
| def cubic(x): |
| absx = np.abs(x) |
| absx2 = absx**2 |
| absx3 = absx**3 |
| return (1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) + ( |
| -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 |
| ) * ((1 < absx) & (absx <= 2)) |
|
|
| def lanczos2(x): |
| return ( |
| (np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps) |
| / ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps) |
| ) * (abs(x) < 2) |
|
|
| def box(x): |
| return ((-0.5 <= x) & (x < 0.5)) * 1.0 |
|
|
| def lanczos3(x): |
| return ( |
| (np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps) |
| / ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps) |
| ) * (abs(x) < 3) |
|
|
| def linear(x): |
| return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) |
|
|
| method, kernel_width = { |
| "cubic": (cubic, 4.0), |
| "lanczos2": (lanczos2, 4.0), |
| "lanczos3": (lanczos3, 6.0), |
| "box": (box, 1.0), |
| "linear": (linear, 2.0), |
| None: (cubic, 4.0), |
| }.get(kernel) |
|
|
| |
| antialiasing *= np.any(np.array(scale_factor) < 1) |
|
|
| |
| sorted_dims = np.argsort(np.array(scale_factor)) |
| self.sorted_dims = [int(dim) for dim in sorted_dims if scale_factor[dim] != 1] |
|
|
| |
| field_of_view_list = [] |
| weights_list = [] |
| for dim in self.sorted_dims: |
| |
| |
| weights, field_of_view = self.contributions( |
| in_shape[dim], output_shape[dim], scale_factor[dim], method, kernel_width, antialiasing |
| ) |
|
|
| |
| weights = torch.tensor(weights.T, dtype=torch.float32) |
|
|
| |
| |
| weights_list.append( |
| nn.Parameter( |
| torch.reshape(weights, list(weights.shape) + (len(scale_factor) - 1) * [1]), |
| requires_grad=False, |
| ) |
| ) |
| field_of_view_list.append( |
| nn.Parameter( |
| torch.tensor(field_of_view.T.astype(np.int32), dtype=torch.long), requires_grad=False |
| ) |
| ) |
|
|
| self.field_of_view = nn.ParameterList(field_of_view_list) |
| self.weights = nn.ParameterList(weights_list) |
|
|
| def forward(self, in_tensor): |
| x = in_tensor |
|
|
| |
| for dim, fov, w in zip(self.sorted_dims, self.field_of_view, self.weights): |
| |
| x = torch.transpose(x, dim, 0) |
|
|
| |
| |
| |
| |
| |
| |
| x = torch.sum(x[fov] * w, dim=0) |
|
|
| |
| x = torch.transpose(x, dim, 0) |
|
|
| return x |
|
|
| def fix_scale_and_size(self, input_shape, output_shape, scale_factor): |
| |
| |
| if scale_factor is not None: |
| |
| if np.isscalar(scale_factor) and len(input_shape) > 1: |
| scale_factor = [scale_factor, scale_factor] |
|
|
| |
| scale_factor = list(scale_factor) |
| scale_factor = [1] * (len(input_shape) - len(scale_factor)) + scale_factor |
|
|
| |
| |
| if output_shape is not None: |
| output_shape = list(input_shape[len(output_shape) :]) + list(np.uint(np.array(output_shape))) |
|
|
| |
| |
| if scale_factor is None: |
| scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape) |
|
|
| |
| if output_shape is None: |
| output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor))) |
|
|
| return scale_factor, output_shape |
|
|
| def contributions(self, in_length, out_length, scale, kernel, kernel_width, antialiasing): |
| |
| |
| |
| |
|
|
| |
| |
| fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel |
| kernel_width *= 1.0 / scale if antialiasing else 1.0 |
|
|
| |
| out_coordinates = np.arange(1, out_length + 1) |
|
|
| |
| |
| |
| shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| match_coordinates = shifted_out_coordinates / scale + 0.5 * (1 - 1 / scale) |
|
|
| |
| left_boundary = np.floor(match_coordinates - kernel_width / 2) |
|
|
| |
| |
| expanded_kernel_width = np.ceil(kernel_width) + 2 |
|
|
| |
| |
| |
| field_of_view = np.squeeze( |
| np.int16(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1) |
| ) |
|
|
| |
| |
| |
| weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1) |
|
|
| |
| sum_weights = np.sum(weights, axis=1) |
| sum_weights[sum_weights == 0] = 1.0 |
| weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1) |
|
|
| |
| mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1)))) |
| field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])] |
|
|
| |
| non_zero_out_pixels = np.nonzero(np.any(weights, axis=0)) |
| weights = np.squeeze(weights[:, non_zero_out_pixels]) |
| field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels]) |
|
|
| |
| return weights, field_of_view |
|
|
| self.down_sample = Resizer(in_shape, 1 / scale_factor) |
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, data, **kwargs): |
| return self.down_sample(data) |
|
|
| |
| class GaussialBlurOperator(nn.Module): |
| def __init__(self, kernel_size, intensity): |
| super().__init__() |
|
|
| class Blurkernel(nn.Module): |
| def __init__(self, blur_type="gaussian", kernel_size=31, std=3.0): |
| super().__init__() |
| self.blur_type = blur_type |
| self.kernel_size = kernel_size |
| self.std = std |
| self.seq = nn.Sequential( |
| nn.ReflectionPad2d(self.kernel_size // 2), |
| nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3), |
| ) |
| self.weights_init() |
|
|
| def forward(self, x): |
| return self.seq(x) |
|
|
| def weights_init(self): |
| if self.blur_type == "gaussian": |
| n = np.zeros((self.kernel_size, self.kernel_size)) |
| n[self.kernel_size // 2, self.kernel_size // 2] = 1 |
| k = scipy.ndimage.gaussian_filter(n, sigma=self.std) |
| k = torch.from_numpy(k) |
| self.k = k |
| for name, f in self.named_parameters(): |
| f.data.copy_(k) |
|
|
| def update_weights(self, k): |
| if not torch.is_tensor(k): |
| k = torch.from_numpy(k) |
| for name, f in self.named_parameters(): |
| f.data.copy_(k) |
|
|
| def get_kernel(self): |
| return self.k |
|
|
| self.kernel_size = kernel_size |
| self.conv = Blurkernel(blur_type="gaussian", kernel_size=kernel_size, std=intensity) |
| self.kernel = self.conv.get_kernel() |
| self.conv.update_weights(self.kernel.type(torch.float32)) |
|
|
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, data, **kwargs): |
| return self.conv(data) |
|
|
| def transpose(self, data, **kwargs): |
| return data |
|
|
| def get_kernel(self): |
| return self.kernel.view(1, 1, self.kernel_size, self.kernel_size) |
|
|
| |
| def RMSELoss(yhat, y): |
| return torch.sqrt(torch.sum((yhat - y) ** 2)) |
|
|
| |
| src = Image.open("sample.png") |
| |
| src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2, 0, 1)[None] |
| |
| src = (src / 127.5) - 1.0 |
| src = src.to("cuda") |
|
|
| |
| |
| operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda") |
| measurement = operator(src) |
|
|
| |
| scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256") |
| scheduler.set_timesteps(1000) |
|
|
| |
| model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda") |
|
|
| save_image((src + 1.0) / 2.0, "dps_src.png") |
| save_image((measurement + 1.0) / 2.0, "dps_mea.png") |
|
|
| |
| dpspipe = DPSPipeline(model, scheduler) |
| image = dpspipe( |
| measurement=measurement, |
| operator=operator, |
| loss_fn=RMSELoss, |
| zeta=1.0, |
| ).images[0] |
|
|
| image.save("dps_generated_image.png") |
|
|