Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from sam_diffsr.utils_sr.hparams import hparams | |
| from .diffusion import GaussianDiffusion, noise_like, extract | |
| from .module_util import default | |
| class GaussianDiffusion_sam(GaussianDiffusion): | |
| def __init__(self, denoise_fn, rrdb_net, timesteps=1000, loss_type='l1', sam_config=None): | |
| super().__init__(denoise_fn, rrdb_net, timesteps, loss_type) | |
| self.sam_config = sam_config | |
| def p_losses(self, x_start, t, cond, img_lr_up, noise=None, sam_mask=None): | |
| noise = default(noise, lambda: torch.randn_like(x_start)) | |
| if self.sam_config['p_losses_sam']: | |
| _sam_mask = F.interpolate(sam_mask, noise.shape[2:], mode='bilinear') | |
| if self.sam_config.get('mask_coefficient', False): | |
| _sam_mask *= extract(self.mask_coefficient.to(_sam_mask.device), t, x_start.shape) | |
| noise += _sam_mask | |
| x_tp1_gt = self.q_sample(x_start=x_start, t=t, noise=noise) | |
| x_t_gt = self.q_sample(x_start=x_start, t=t - 1, noise=noise) | |
| noise_pred = self.denoise_fn(x_tp1_gt, t, cond, img_lr_up, sam_mask=sam_mask) | |
| x_t_pred, x0_pred = self.p_sample(x_tp1_gt, t, cond, img_lr_up, noise_pred=noise_pred, sam_mask=sam_mask) | |
| if self.loss_type == 'l1': | |
| loss = (noise - noise_pred).abs().mean() | |
| elif self.loss_type == 'l2': | |
| loss = F.mse_loss(noise, noise_pred) | |
| elif self.loss_type == 'ssim': | |
| loss = (noise - noise_pred).abs().mean() | |
| loss = loss + (1 - self.ssim_loss(noise, noise_pred)) | |
| else: | |
| raise NotImplementedError() | |
| return loss, x_tp1_gt, noise_pred, x_t_pred, x_t_gt, x0_pred | |
| def p_sample(self, x, t, cond, img_lr_up, noise_pred=None, clip_denoised=True, repeat_noise=False, sam_mask=None): | |
| if noise_pred is None: | |
| noise_pred = self.denoise_fn(x, t, cond=cond, img_lr_up=img_lr_up, sam_mask=sam_mask) | |
| b, *_, device = *x.shape, x.device | |
| model_mean, _, model_log_variance, x0_pred = self.p_mean_variance( | |
| x=x, t=t, noise_pred=noise_pred, clip_denoised=clip_denoised) | |
| noise = noise_like(x.shape, device, repeat_noise) | |
| # no noise when t == 0 | |
| nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) | |
| return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0_pred | |
| def sample(self, img_lr, img_lr_up, shape, sam_mask=None, save_intermediate=False): | |
| device = self.betas.device | |
| b = shape[0] | |
| if not hparams['res']: | |
| t = torch.full((b,), self.num_timesteps - 1, device=device, dtype=torch.long) | |
| noise = None | |
| img = self.q_sample(img_lr_up, t, noise=noise) | |
| else: | |
| img = torch.randn(shape, device=device) | |
| if hparams['use_rrdb']: | |
| rrdb_out, cond = self.rrdb(img_lr, True) | |
| else: | |
| rrdb_out = img_lr_up | |
| cond = img_lr | |
| it = reversed(range(0, self.num_timesteps)) | |
| if self.sample_tqdm: | |
| it = tqdm(it, desc='sampling loop time step', total=self.num_timesteps) | |
| images = [] | |
| for i in it: | |
| img, x_recon = self.p_sample( | |
| img, torch.full((b,), i, device=device, dtype=torch.long), cond, img_lr_up, sam_mask=sam_mask) | |
| if save_intermediate: | |
| img_ = self.res2img(img, img_lr_up) | |
| x_recon_ = self.res2img(x_recon, img_lr_up) | |
| images.append((img_.cpu(), x_recon_.cpu())) | |
| img = self.res2img(img, img_lr_up) | |
| if save_intermediate: | |
| return img, rrdb_out, images | |
| else: | |
| return img, rrdb_out | |