Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import PIL | |
| from PIL import Image | |
| import torch | |
| from diffusion_arch import ILVRUNetModel, ConditionalUNetModel | |
| from guided_diffusion.script_util import create_gaussian_diffusion | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| from torchvision.utils import make_grid | |
| def preprocess_image(image): | |
| w, h = image.size | |
| w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | |
| image = image.resize((w, h), resample=PIL.Image.LANCZOS) | |
| image = np.array(image).astype(np.float32) / 255.0 | |
| image = torch.from_numpy(image.transpose(2,0,1)).unsqueeze(0) | |
| return 2.0 * image - 1.0 | |
| def preprocess_mask(mask): | |
| mask = mask.convert("L") | |
| w, h = mask.size | |
| w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | |
| mask = mask.resize((w, h), resample=PIL.Image.NEAREST) | |
| mask = np.array(mask).astype(np.float32) / 255.0 | |
| mask = torch.from_numpy(np.repeat(mask[None, ...], 3, axis=0)).unsqueeze(0) | |
| mask[mask > 0] = 1 | |
| return mask | |
| class DiffusionPipeline(): | |
| def __init__(self, device): | |
| super().__init__() | |
| self.device = device | |
| diffusion_model = ILVRUNetModel( | |
| in_channels=3, | |
| model_channels=128, | |
| out_channels=6, | |
| num_res_blocks=1, | |
| attention_resolutions=[16], | |
| channel_mult=(1, 1, 2, 2, 4, 4), | |
| num_classes=None, | |
| use_checkpoint=False, | |
| use_fp16=False, | |
| num_heads=4, | |
| num_head_channels=64, | |
| num_heads_upsample=-1, | |
| use_scale_shift_norm=True, | |
| resblock_updown=True, | |
| use_new_attention_order=False | |
| ) | |
| diffusion_model = diffusion_model.to(device) | |
| diffusion_model = diffusion_model.eval() | |
| ilvr_pretraining = torch.load('./ffhq_10m.pt', map_location='cpu') | |
| diffusion_model.load_state_dict(ilvr_pretraining) | |
| self.diffusion_model = diffusion_model | |
| diffusion_restoration_model = ConditionalUNetModel( | |
| in_channels=3, | |
| model_channels=128, | |
| out_channels=6, | |
| num_res_blocks=1, | |
| attention_resolutions=[16], | |
| dropout=0.0, | |
| channel_mult=(1, 1, 2, 2, 4, 4), | |
| num_classes=None, | |
| use_checkpoint=False, | |
| use_fp16=False, | |
| num_heads=4, | |
| num_head_channels=64, | |
| num_heads_upsample=-1, | |
| use_scale_shift_norm=True, | |
| resblock_updown=True, | |
| use_new_attention_order=False | |
| ) | |
| diffusion_restoration_model = diffusion_restoration_model.to(device) | |
| diffusion_restoration_model = diffusion_restoration_model.eval() | |
| state_dict = torch.load('./net_g_250000.pth', map_location='cpu') | |
| diffusion_restoration_model.load_state_dict(state_dict['params']) | |
| self.diffusion_restoration_model = diffusion_restoration_model | |
| def __call__(self, lq, diffusion_step, binoising_step, grid_size): | |
| lq = lq.convert("RGB").resize((256, 256), resample=Image.LANCZOS) | |
| eval_gaussian_diffusion = create_gaussian_diffusion( | |
| steps=1000, | |
| learn_sigma=True, | |
| noise_schedule='linear', | |
| use_kl=False, | |
| timestep_respacing=str(int(diffusion_step)), | |
| predict_xstart=False, | |
| rescale_timesteps=False, | |
| rescale_learned_sigmas=False, | |
| ) | |
| ow, oh = lq.size | |
| # preprocess image | |
| lq_img_th = preprocess_image(lq).to(self.device) | |
| lq_img_th = lq_img_th.repeat([grid_size, 1, 1, 1]) | |
| img = torch.randn_like(lq_img_th, device=self.device) | |
| s_img = torch.randn_like(lq_img_th, device=self.device) | |
| indices = list(range(eval_gaussian_diffusion.num_timesteps))[::-1] | |
| for i in indices: | |
| t = torch.tensor([i] * lq_img_th.size(0), device=self.device) | |
| out = eval_gaussian_diffusion.p_mean_variance(self.diffusion_restoration_model, s_img, t, model_kwargs={'lq': lq_img_th}) | |
| nonzero_mask = ( | |
| (t != 0).float().view(-1, *([1] * (len(img.shape) - 1))) | |
| ) # no noise when t == 0 | |
| s_img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device) | |
| s_img_pred = out["pred_xstart"] | |
| if i < binoising_step: | |
| model_output = eval_gaussian_diffusion._wrap_model(self.diffusion_restoration_model)(img, t, lq=lq_img_th) | |
| B, C = img.shape[:2] | |
| model_output, model_var_values = torch.split(model_output, C, dim=1) | |
| pred_xstart = eval_gaussian_diffusion._predict_xstart_from_eps(img, t, model_output).clamp(-1, 1) | |
| img = eval_gaussian_diffusion.q_sample(pred_xstart, t) | |
| out = eval_gaussian_diffusion.p_mean_variance(self.diffusion_model, img, t) | |
| nonzero_mask = ( | |
| (t != 0).float().view(-1, *([1] * (len(img.shape) - 1))) | |
| ) # no noise when t == 0 | |
| img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device) | |
| img_pred = out["pred_xstart"] | |
| if i % 2 == 0: | |
| yield [Image.fromarray(np.uint8((make_grid(s_img_pred) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.)), Image.fromarray(np.uint8((make_grid(img_pred) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.))] | |
| yield [Image.fromarray(np.uint8((make_grid(s_img) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.)), Image.fromarray(np.uint8((make_grid(img) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.))] | |