| | from typing import List, Tuple |
| | from scipy import interpolate |
| | import numpy as np |
| | import torch |
| | import matplotlib.pyplot as plt |
| | from IPython.display import clear_output |
| | import abc |
| |
|
| |
|
| | class GuideModel(torch.nn.Module, abc.ABC): |
| | def __init__(self) -> None: |
| | super().__init__() |
| |
|
| | @abc.abstractmethod |
| | def preprocess(self, x_img): |
| | pass |
| |
|
| | @abc.abstractmethod |
| | def compute_loss(self, inp): |
| | pass |
| |
|
| |
|
| | class Guider(torch.nn.Module): |
| | def __init__(self, sampler, guide_model, scale=1.0, verbose=False): |
| | """Apply classifier guidance |
| | |
| | Specify a guidance scale as either a scalar |
| | Or a schedule as a list of tuples t = 0->1 and scale, e.g. |
| | [(0, 10), (0.5, 20), (1, 50)] |
| | """ |
| | super().__init__() |
| | self.sampler = sampler |
| | self.index = 0 |
| | self.show = verbose |
| | self.guide_model = guide_model |
| | self.history = [] |
| |
|
| | if isinstance(scale, (Tuple, List)): |
| | times = np.array([x[0] for x in scale]) |
| | values = np.array([x[1] for x in scale]) |
| | self.scale_schedule = {"times": times, "values": values} |
| | else: |
| | self.scale_schedule = float(scale) |
| |
|
| | self.ddim_timesteps = sampler.ddim_timesteps |
| | self.ddpm_num_timesteps = sampler.ddpm_num_timesteps |
| |
|
| |
|
| | def get_scales(self): |
| | if isinstance(self.scale_schedule, float): |
| | return len(self.ddim_timesteps)*[self.scale_schedule] |
| |
|
| | interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"]) |
| | fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps |
| | return interpolater(fractional_steps) |
| |
|
| | def modify_score(self, model, e_t, x, t, c): |
| |
|
| | |
| | scale = self.get_scales()[self.index] |
| |
|
| | if (scale == 0): |
| | return e_t |
| |
|
| | sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device) |
| | with torch.enable_grad(): |
| | x_in = x.detach().requires_grad_(True) |
| | pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t) |
| | x_img = model.first_stage_model.decode((1/0.18215)*pred_x0) |
| |
|
| | inp = self.guide_model.preprocess(x_img) |
| | loss = self.guide_model.compute_loss(inp) |
| | grads = torch.autograd.grad(loss.sum(), x_in)[0] |
| | correction = grads * scale |
| |
|
| | if self.show: |
| | clear_output(wait=True) |
| | print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item()) |
| | self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()]) |
| | plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2) |
| | plt.axis('off') |
| | plt.show() |
| | plt.imshow(correction[0][0].detach().cpu()) |
| | plt.axis('off') |
| | plt.show() |
| |
|
| |
|
| | e_t_mod = e_t - sqrt_1ma*correction |
| | if self.show: |
| | fig, axs = plt.subplots(1, 3) |
| | axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2) |
| | axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2) |
| | axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2) |
| | plt.show() |
| | self.index += 1 |
| | return e_t_mod |