Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import random | |
| from torchvision import transforms | |
| import torchvision.transforms.functional as TF | |
| def apply_joint_transforms(rgb, mask, img_size, img_aug=True, test=True): | |
| if test: | |
| extra_pad = 16 | |
| else: | |
| extra_pad = random.randint(0, 32) | |
| W_img, H_img = rgb.size[:2] | |
| max_HW = max(H_img, W_img) | |
| top_pad = (max_HW - H_img) // 2 | |
| bottom_pad = max_HW - H_img - top_pad | |
| left_pad = (max_HW - W_img) // 2 | |
| right_pad = max_HW - W_img - left_pad | |
| # 1. padding | |
| rgb = TF.pad(rgb, (left_pad, top_pad, right_pad, bottom_pad), fill=255) | |
| mask = TF.pad(mask, (left_pad, top_pad, right_pad, bottom_pad), fill=0) | |
| if img_aug and (not test): | |
| # 2. random rotate | |
| if random.random() < 0.1: | |
| angle = random.uniform(-10, 10) | |
| rgb = TF.rotate(rgb, angle, fill=255) | |
| mask = TF.rotate(mask, angle, fill=0) | |
| # 3. random crop | |
| if random.random() < 0.1: | |
| crop_ratio = random.uniform(0.9, 1.0) | |
| crop_size = int(max_HW * crop_ratio) | |
| i, j, h, w = transforms.RandomCrop.get_params(rgb, (crop_size, crop_size)) | |
| rgb = TF.crop(rgb, i, j, h, w) | |
| mask = TF.crop(mask, i, j, h, w) | |
| # 4. resize | |
| target_size = (img_size, img_size) | |
| rgb = TF.resize(rgb, target_size, interpolation=TF.InterpolationMode.BILINEAR) | |
| mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST) | |
| # 5. extra padding | |
| rgb = TF.pad(rgb, extra_pad, fill=255) | |
| mask = TF.pad(mask, extra_pad, fill=0) | |
| rgb = TF.resize(rgb, target_size, interpolation=TF.InterpolationMode.BILINEAR) | |
| mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST) | |
| # to tensor | |
| rgb_tensor = TF.to_tensor(rgb) | |
| mask_tensor = TF.to_tensor(mask) | |
| return rgb_tensor, mask_tensor | |
| def crop_recenter(image_no_bg, thereshold=100): | |
| image_no_bg_np = np.array(image_no_bg) | |
| mask = (image_no_bg_np[..., -1]).astype(np.uint8) | |
| mask_bin = mask > thereshold | |
| H, W = image_no_bg_np.shape[:2] | |
| valid_pixels = mask_bin.astype(np.float32).nonzero() # [N, 2] | |
| if np.sum(mask_bin) < (H*W) * 0.001: | |
| min_h =0 | |
| max_h = H - 1 | |
| min_w = 0 | |
| max_w = W -1 | |
| else: | |
| min_h, max_h = valid_pixels[0].min(), valid_pixels[0].max() | |
| min_w, max_w = valid_pixels[1].min(), valid_pixels[1].max() | |
| if min_h < 0: | |
| min_h = 0 | |
| if min_w < 0: | |
| min_w = 0 | |
| if max_h > H: | |
| max_h = H | |
| if max_w > W: | |
| max_w = W | |
| image_no_bg_np = image_no_bg_np[min_h:max_h+1, min_w:max_w+1] | |
| return image_no_bg_np | |
| def preprocess_image(img): | |
| if isinstance(img, str): | |
| img = Image.open(img) | |
| img = np.array(img) | |
| elif isinstance(img, Image.Image): | |
| img = np.array(img) | |
| if img.shape[-1] == 3: | |
| mask = np.ones_like(img[..., 0:1]) | |
| img = np.concatenate([img, mask], axis=-1) | |
| img = crop_recenter(img, thereshold=0) / 255. | |
| mask = img[..., 3] | |
| img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:]) | |
| img = Image.fromarray((img * 255).astype(np.uint8)) | |
| mask = Image.fromarray((mask * 255).astype(np.uint8)) | |
| img, mask = apply_joint_transforms(img, mask, img_size=518, | |
| img_aug=False, test=True) | |
| img = torch.cat([img, mask], dim=0) | |
| return img | |