| | from PIL import Image |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import torchvision.transforms.functional as TF |
| | from einops import rearrange |
| |
|
| | from .permutations import get_inv_perm |
| | from .view_base import BaseView |
| |
|
| |
|
| | class PatchPermuteView(BaseView): |
| | def __init__(self, num_patches=8): |
| | ''' |
| | Implements random patch permutations, with `num_patches` |
| | patches per side |
| | |
| | num_patches (int) : |
| | Number of patches in one dimension. Total number |
| | of patches will be num_patches**2. Should be a power of 2. |
| | ''' |
| |
|
| | assert 64 % num_patches == 0 and 256 % num_patches == 0, \ |
| | "`num_patches` must divide image side lengths of 64 and 256" |
| |
|
| | self.num_patches = num_patches |
| |
|
| | |
| | self.perm = torch.randperm(self.num_patches**2) |
| | self.perm_inv = get_inv_perm(self.perm) |
| |
|
| | def view(self, im): |
| | im_size = im.shape[-1] |
| |
|
| | |
| | patch_size = int(im_size / self.num_patches) |
| |
|
| | |
| | patches = rearrange(im, |
| | 'c (h p1) (w p2) -> (h w) c p1 p2', |
| | p1=patch_size, |
| | p2=patch_size) |
| |
|
| | |
| | patches = patches[self.perm] |
| |
|
| | |
| | im_rearr = rearrange(patches, |
| | '(h w) c p1 p2 -> c (h p1) (w p2)', |
| | h=self.num_patches, |
| | w=self.num_patches, |
| | p1=patch_size, |
| | p2=patch_size) |
| | return im_rearr |
| |
|
| | def inverse_view(self, noise): |
| | im_size = noise.shape[-1] |
| |
|
| | |
| | patch_size = int(im_size / self.num_patches) |
| |
|
| | |
| | patches = rearrange(noise, |
| | 'c (h p1) (w p2) -> (h w) c p1 p2', |
| | p1=patch_size, |
| | p2=patch_size) |
| |
|
| | |
| | patches = patches[self.perm_inv] |
| |
|
| | |
| | im_rearr = rearrange(patches, |
| | '(h w) c p1 p2 -> c (h p1) (w p2)', |
| | h=self.num_patches, |
| | w=self.num_patches, |
| | p1=patch_size, |
| | p2=patch_size) |
| | return im_rearr |
| |
|
| | def make_frame(self, im, t, canvas_size=384, scale=4, knot_seed=0): |
| | ''' |
| | Scale is a hack, because PIL for some reason doesn't support pasting |
| | at floating point coordinates. So just render at larger scale |
| | and resize by 1/scale |
| | ''' |
| | |
| | im_size = im.size[0] |
| | offset = (canvas_size - im_size) // 2 |
| |
|
| | canvas_size = canvas_size * scale |
| | offset = offset * scale |
| |
|
| | im = TF.to_tensor(im) |
| |
|
| | |
| | im_size = im.shape[-1] |
| | patch_size = int(im_size / self.num_patches) |
| |
|
| | |
| | patches = rearrange(im, |
| | 'c (h p1) (w p2) -> (h w) c p1 p2', |
| | p1=patch_size, |
| | p2=patch_size) |
| |
|
| | |
| | yy, xx = torch.meshgrid( |
| | torch.arange(self.num_patches), |
| | torch.arange(self.num_patches) |
| | ) |
| | xx = xx.flatten() |
| | yy = yy.flatten() |
| | start_locs = torch.stack([xx, yy], dim=1) * patch_size * scale |
| | start_locs = start_locs + offset |
| |
|
| | |
| | end_locs = start_locs[self.perm] |
| |
|
| | |
| | original_state = np.random.get_state() |
| | np.random.seed(knot_seed) |
| | rand_offsets = np.random.rand(self.num_patches**2, 1) * 2 - 1 |
| | rand_offsets = rand_offsets * 2 * scale |
| | eps = np.random.randn(*start_locs.shape) |
| | np.random.set_state(original_state) |
| |
|
| | |
| | |
| | avg_locs = (start_locs + end_locs) / 2. |
| | norm = (end_locs - start_locs) |
| | norm = norm + eps |
| | norm = norm / np.linalg.norm(norm, axis=1, keepdims=True) |
| | rot_mat = np.array([[0,1], [-1,0]]) |
| | norm = norm @ rot_mat |
| | rand_offsets = rand_offsets * (im_size / 4) |
| | knot_locs = avg_locs + norm * rand_offsets |
| |
|
| | |
| | spline_0 = start_locs * (1 - t) + knot_locs * t |
| | spline_1 = knot_locs * (1 - t) + end_locs * t |
| | paste_locs = spline_0 * (1 - t) + spline_1 * t |
| | paste_locs = paste_locs.to(int) |
| |
|
| | |
| | canvas = Image.new("RGBA", (canvas_size, canvas_size), (255,255,255,255)) |
| | for patch, paste_loc in zip(patches, paste_locs): |
| | patch = TF.to_pil_image(patch).convert('RGBA') |
| | patch = patch.resize((patch_size * scale, patch_size * scale)) |
| | paste_loc = (paste_loc[0].item(), paste_loc[1].item()) |
| | canvas.paste(patch, paste_loc, patch) |
| |
|
| | if scale != 1.0: |
| | canvas = canvas.resize((canvas_size // scale, canvas_size // scale)) |
| |
|
| | return canvas |
| |
|