| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | import gc |
| | import numpy as np |
| | import numpy as np |
| | from PIL import Image |
| | from scripts.refine_lr_to_sr import run_sr_fast |
| |
|
| | GRADIO_CACHE = "/tmp/gradio/" |
| |
|
| | def clean_up(): |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | def remove_color(arr): |
| | if arr.shape[-1] == 4: |
| | arr = arr[..., :3] |
| | |
| | base = arr[0, 0] |
| | diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1) |
| | alpha = (diffs <= 80) |
| | |
| | arr[alpha] = 255 |
| | alpha = ~alpha |
| | arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1) |
| | return arr |
| |
|
| | def simple_remove(imgs, run_sr=True): |
| | """Only works for normal""" |
| | if not isinstance(imgs, list): |
| | imgs = [imgs] |
| | single_input = True |
| | else: |
| | single_input = False |
| | if run_sr: |
| | imgs = run_sr_fast(imgs) |
| | rets = [] |
| | for img in imgs: |
| | arr = np.array(img) |
| | arr = remove_color(arr) |
| | rets.append(Image.fromarray(arr.astype(np.uint8))) |
| | if single_input: |
| | return rets[0] |
| | return rets |
| |
|
| | def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"): |
| | new_image = Image.new("RGBA", rgba.size, bkgd) |
| | new_image.paste(rgba, (0, 0), rgba) |
| | new_image = new_image.convert('RGB') |
| | return new_image |
| |
|
| | def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"): |
| | rgb_white = rgba_to_rgb(rgba, bkgd) |
| | new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1)) |
| | return new_rgba |
| |
|
| | def split_image(image, rows=None, cols=None): |
| | """ |
| | inverse function of make_image_grid |
| | """ |
| | |
| | if rows is None and cols is None: |
| | |
| | rows = 1 |
| | cols = image.size[0] // image.size[1] |
| | assert cols * image.size[1] == image.size[0] |
| | subimg_size = image.size[1] |
| | elif rows is None: |
| | subimg_size = image.size[0] // cols |
| | rows = image.size[1] // subimg_size |
| | assert rows * subimg_size == image.size[1] |
| | elif cols is None: |
| | subimg_size = image.size[1] // rows |
| | cols = image.size[0] // subimg_size |
| | assert cols * subimg_size == image.size[0] |
| | else: |
| | subimg_size = image.size[1] // rows |
| | assert cols * subimg_size == image.size[0] |
| | subimgs = [] |
| | for i in range(rows): |
| | for j in range(cols): |
| | subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size)) |
| | subimgs.append(subimg) |
| | return subimgs |
| |
|
| | def make_image_grid(images, rows=None, cols=None, resize=None): |
| | if rows is None and cols is None: |
| | rows = 1 |
| | cols = len(images) |
| | if rows is None: |
| | rows = len(images) // cols |
| | if len(images) % cols != 0: |
| | rows += 1 |
| | if cols is None: |
| | cols = len(images) // rows |
| | if len(images) % rows != 0: |
| | cols += 1 |
| | total_imgs = rows * cols |
| | if total_imgs > len(images): |
| | images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))] |
| | |
| | if resize is not None: |
| | images = [img.resize((resize, resize)) for img in images] |
| |
|
| | w, h = images[0].size |
| | grid = Image.new(images[0].mode, size=(cols * w, rows * h)) |
| |
|
| | for i, img in enumerate(images): |
| | grid.paste(img, box=(i % cols * w, i // cols * h)) |
| | return grid |
| |
|
| |
|