Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| import cv2 | |
| def dt(a): | |
| return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0) | |
| def trimap_transform(trimap, L=320): | |
| clicks = [] | |
| for k in range(2): | |
| dt_mask = -dt(1 - trimap[:, :, k]) ** 2 | |
| clicks.append(np.exp(dt_mask / (2 * ((0.02 * L) ** 2)))) | |
| clicks.append(np.exp(dt_mask / (2 * ((0.08 * L) ** 2)))) | |
| clicks.append(np.exp(dt_mask / (2 * ((0.16 * L) ** 2)))) | |
| clicks = np.array(clicks) | |
| return clicks | |
| # For RGB ! | |
| imagenet_norm_std = torch.from_numpy(np.array([0.229, 0.224, 0.225])).float().cpu()[None, :, None, None] | |
| imagenet_norm_mean = torch.from_numpy(np.array([0.485, 0.456, 0.406])).float().cpu()[None, :, None, None] | |
| def normalise_image(image, mean=imagenet_norm_mean, std=imagenet_norm_std): | |
| return (image - mean) / std | |