moebius / infer /utils_dataset.py
Mike0021's picture
Implement Moebius Gradio Space
166ab04 verified
Raw
History Blame Contribute Delete
7.11 kB
# ---------------- Dataset Utils -----------------------
import warnings
from pathlib import Path
from typing import Tuple, Optional
import math
import os
import numpy as np
import torch
from PIL import Image, ImageDraw
from torch.utils.data import Dataset, DataLoader
warnings.filterwarnings("ignore")
def RandomBrush(
max_tries,
s,
min_num_vertex=4,
max_num_vertex=18,
mean_angle=2*math.pi / 5,
angle_range=2*math.pi / 15,
min_width=12,
max_width=48
):
H, W = s, s
average_radius = math.sqrt(H*H+W*W) / 8
mask = Image.new('L', (W, H), 0)
for _ in range(np.random.randint(max_tries)):
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
angle_min = mean_angle - np.random.uniform(0, angle_range)
angle_max = mean_angle + np.random.uniform(0, angle_range)
angles = []
vertex = []
for i in range(num_vertex):
if i % 2 == 0:
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
else:
angles.append(np.random.uniform(angle_min, angle_max))
h, w = mask.size
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
for i in range(num_vertex):
r = np.clip(
np.random.normal(loc=average_radius, scale=average_radius//2),
0, 2*average_radius)
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
vertex.append((int(new_x), int(new_y)))
draw = ImageDraw.Draw(mask)
width = int(np.random.uniform(min_width, max_width))
draw.line(vertex, fill=1, width=width)
for v in vertex:
draw.ellipse((v[0] - width//2,
v[1] - width//2,
v[0] + width//2,
v[1] + width//2),
fill=1)
if np.random.random() > 0.5:
mask.transpose(Image.FLIP_LEFT_RIGHT)
if np.random.random() > 0.5:
mask.transpose(Image.FLIP_TOP_BOTTOM)
mask = np.asarray(mask, np.uint8)
if np.random.random() > 0.5:
mask = np.flip(mask, 0)
if np.random.random() > 0.5:
mask = np.flip(mask, 1)
return mask
def RandomMask(s, hole_range=[0,1]):
coef = min(hole_range[0] + hole_range[1], 1.0)
while True:
mask = np.ones((s, s), np.uint8)
def Fill(max_size):
w, h = np.random.randint(max_size), np.random.randint(max_size)
ww, hh = w // 2, h // 2
x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh)
mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0
def MultiFill(max_tries, max_size):
for _ in range(np.random.randint(max_tries)):
Fill(max_size)
MultiFill(int(10 * coef), s // 2)
MultiFill(int(5 * coef), s)
mask = np.logical_and(mask, 1 - RandomBrush(int(20 * coef), s))
hole_ratio = 1 - np.mean(mask)
if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]):
continue
return (mask * 255).astype(np.uint8)
class InferDataset(Dataset): # ABC
img_ext = {".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"}
def __init__(
self,
real_dir: Path,
mask_dir: Optional[Path] = None,
resolution: int = None
):
super(InferDataset, self).__init__()
self.img_paths = sorted([i for i in Path(real_dir).iterdir() if i.suffix in self.img_ext])
self.mask_dir = mask_dir
self.resolution = resolution
def __len__(self):
return len(self.img_paths)
def __getitem__(self, index) -> Tuple[torch.Tensor, np.array, np.array, str]:
img_path = Path(self.img_paths[index])
img_name = img_path.stem
img = Image.open(img_path).convert("RGB")
if img.size[0] != self.resolution or img.size[1] != self.resolution:
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
assert img.size[0] == self.resolution
if self.mask_dir is not None:
# mask_path = self.mask_dir / f"{img_name}.png"
mask_path = self.mask_dir / f"img000{img_name}.png"
mask = Image.open(mask_path).convert("L")
mask = mask.resize((self.resolution, self.resolution), Image.NEAREST)
assert mask.size[0] == self.resolution
else:
mask = RandomMask(img.size[0])
mask = Image.fromarray(mask).convert("L")
img = np.array(img)
mask = np.array(mask)[:, :, np.newaxis] // 255
img = torch.Tensor(img).float() * 2 / 255 - 1
mask = torch.Tensor(mask).float()
img = img.permute(2, 0, 1)
mask = mask.permute(2, 0, 1)
x = torch.cat([mask - 0.5, img * mask], dim=0)
return x, np.array(img), mask, img_name
class SimpleInferDataset(torch.utils.data.Dataset):
def __init__(
self,
real_dir: Path,
mask_dir: Path = None,
resolution: int = 512
):
super(SimpleInferDataset, self).__init__()
img_extensions = {".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"}
self.img_paths = sorted([i for i in Path(real_dir).iterdir() if i.suffix in img_extensions])
self.img_dir = real_dir
if mask_dir:
self.mask_paths = sorted([i for i in Path(mask_dir).iterdir() if i.suffix in img_extensions])
self.mask_dir = mask_dir
self.resolution = resolution
def __getitem__(self, index):
img_path = Path(self.img_paths[index])
img_name = os.path.basename(img_path)
img = Image.open(img_path).convert("RGB")
if self.mask_dir:
mask_path = Path(self.mask_paths[index])
mask = Image.open(mask_path).convert("L")
else:
mask = RandomMask(img.size[0])
mask = Image.fromarray(mask).convert("L")
mask = mask.resize((self.resolution, self.resolution), Image.NEAREST)
if img.size[0] != self.resolution or img.size[1] != self.resolution:
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
return img, mask, img_name
def __len__(self):
return len(self.img_paths)
def collate_fn(inputs):
image_list = [i[0] for i in inputs]
mask_list = [i[1] for i in inputs]
iname_list = [i[2] for i in inputs]
return image_list, mask_list, iname_list
def build_dataloader(args, dataset_class=InferDataset):
dataset = dataset_class(
real_dir=args.real_dir,
mask_dir=args.mask_dir,
resolution=args.resolution)
dataloader = DataLoader(
dataset,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers,
drop_last=False,
collate_fn = collate_fn,
pin_memory=True,
# persistent_workers=True
)
return dataloader