| |
| |
| |
|
|
| import torch |
| import random |
| import numpy as np |
| from einops import rearrange |
|
|
| def batch_inpainging_from_grad(im_in, mask, gradx, grady): |
| ''' |
| Recovering from gradient for batch data (torch tensro). |
| Input: |
| im_in: N x c x h x w, torch tensor, masked image |
| mask: N x 1 x h x w, torch tensor |
| gradx, grady: N x c x h x w, torch tensor, image gradient |
| ''' |
| im_out = torch.zeros_like(im_in.data) |
| for ii in range(im_in.shape[0]): |
| im_current, gradx_current, grady_current = [rearrange(x[ii,].cpu().numpy(), 'c h w -> h w c') |
| for x in [im_in, gradx, grady]] |
| mask_current = mask[ii, 0,].cpu().numpy() |
| out_current = inpainting_from_grad(im_current, mask_current, gradx_current, grady_current) |
| im_out[ii,] = torch.from_numpy(rearrange(out_current, 'h w c -> c h w')).to( |
| device=im_in.device, |
| dtype=im_in.dtype |
| ) |
| return im_out |
|
|
| def inpainting_from_grad(im_in, mask, gradx, grady): |
| ''' |
| Input: |
| im_in: h x w x c, masked image, numpy array |
| mask: h x w, image mask, 1 represents missing value |
| gradx: h x w x c, gradient along x-axis, numpy array |
| grady: h x w x c, gradient along y-axis, numpy array |
| Output: |
| im_out: recoverd image |
| ''' |
| h, w = im_in.shape[:2] |
| counts_h = np.sum(1-mask, axis=0, keepdims=False) |
| counts_w = np.sum(1-mask, axis=1, keepdims=False) |
| if np.any(counts_h[1:-1,] == h): |
| idx = find_first_index(counts_h[1:-1,], h) + 1 |
| im_out = fill_image_from_gradx(im_in, mask, gradx, idx) |
| elif np.any(counts_w[1:-1,] == w): |
| idx = find_first_index(counts_w[1:-1,], w) + 1 |
| im_out = inpainting_from_grad(im_in.T, mask.T, gradx.T, idx) |
| else: |
| idx = random.choices(list(range(1,w-1)), k=1, weights=counts_h[1:-1])[0] |
| line = fill_line(im_in[:, idx, ], mask[:, idx,], grady[:, idx,]) |
| im_in[:, idx,] = line |
| im_out = fill_image_from_gradx(im_in, mask, gradx, idx) |
| if im_in.ndim > mask.ndim: |
| mask = mask[:, :, None] |
| im_out = im_in + im_out * mask |
| return im_out |
|
|
| def fill_image_from_gradx(im_in, mask, gradx, idx): |
| init = np.zeros_like(im_in) |
| init[:, idx,] = im_in[:, idx,] |
| right = np.cumsum(init[:, idx:-1, ] + gradx[:, idx+1:, ], axis=1) |
| left = np.cumsum( |
| init[:, idx:0:-1, ] - gradx[:, idx:0:-1, ], |
| axis=1 |
| )[:, ::-1] |
| center = im_in[:, idx, ][:, None] |
| im_out = np.concatenate((left, center, right), axis=1) |
| return im_out |
|
|
| def fill_line(xx, mm, grad): |
| ''' |
| Fill one line from grad. |
| Input: |
| xx: n x c array, masked vector |
| mm: (n,) array, mask, 1 represent missing value |
| grad: (n,) array |
| ''' |
| n = xx.shape[0] |
| assert mm.sum() < n |
| if mm.sum() == 0: |
| return xx |
| else: |
| idx1 = find_first_index(mm, 1) |
| if idx1 == 0: |
| idx2 = find_first_index(mm, 0) |
| subx = xx[idx2::-1,].copy() |
| subgrad = grad[idx2::-1, ].copy() |
| subx -= subgrad |
| xx[:idx2,] = np.cumsum(subx, axis=0)[idx2-1::-1,] |
| mm[idx1:idx2,] = 0 |
| else: |
| idx2 = find_first_index(mm[idx1:,], 0) + idx1 |
| subx = xx[idx1-1:idx2-1,].copy() |
| subgrad = grad[idx1:idx2,].copy() |
| subx += subgrad |
| xx[idx1:idx2,] = np.cumsum(subx, axis=0) |
| mm[idx1:idx2,] = 0 |
| return fill_line(xx, mm, grad) |
|
|
| def find_first_index(mm, value): |
| ''' |
| Input: |
| mm: (n, ) array |
| value: scalar |
| ''' |
| try: |
| out = next((idx for idx, val in np.ndenumerate(mm) if val == value))[0] |
| except StopIteration: |
| out = mm.shape[0] |
| return out |
|
|
| if __name__ == '__main__': |
| import sys |
| from pathlib import Path |
| sys.path.append(str(Path(__file__).resolve().parents[1])) |
| from utils import util_image |
| from datapipe.masks.train import process_mask |
|
|
| |
| mask_file_names = [x for x in Path('./testdata/inpainting/val/places/').glob('*mask*.png')] |
| file_names = [x.parents[0]/(x.stem.rsplit('_mask',1)[0]+'.png') for x in mask_file_names] |
|
|
| for im_path, mask_path in zip(file_names, mask_file_names): |
| im = util_image.imread(im_path, chn='rgb', dtype='float32') |
| mask = process_mask(util_image.imread(mask_path, chn='rgb', dtype='float32')[:, :, 0]) |
| grad_dict = util_image.imgrad(im) |
|
|
| im_masked = im * (1 - mask[:, :, None]) |
| im_recover = inpainting_from_grad(im_masked, mask, grad_dict['gradx'], grad_dict['grady']) |
| error_max = np.abs(im_recover -im).max() |
| print('Error Max: {:.2e}'.format(error_max)) |
|
|