File size: 2,380 Bytes
0e83290 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import numpy as np
import torch.nn.functional as F
import torch
def compute_tensor_iu(seg, gt):
intersection = (seg & gt).float().sum()
union = (seg | gt).float().sum()
return intersection, union
def compute_tensor_iou(seg, gt):
intersection, union = compute_tensor_iu(seg, gt)
iou = (intersection + 1e-6) / (union + 1e-6)
return iou
def compute_array_iou(seg, gt):
# grayscale 2D masks, each gray shade - unique object
seg = seg.squeeze()
gt = gt.squeeze()
ious = []
for color in np.unique(seg):
if color == 0:
continue # skipping background
curr_object_iou = compute_tensor_iou(
torch.tensor(seg == color),
torch.tensor(gt == color),
)
ious.append(curr_object_iou)
if not len(ious):
# GT is pure black, let's check if the mask also doesn't have any junk
curr_object_iou = compute_tensor_iou(
torch.tensor(seg == 0),
torch.tensor(gt == 0),
)
ious.append(curr_object_iou)
return sum(ious) / len(ious)
# STM
def pad_divide_by(in_img, d):
h, w = in_img.shape[-2:]
if h % d > 0:
new_h = h + d - h % d
else:
new_h = h
if w % d > 0:
new_w = w + d - w % d
else:
new_w = w
lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
pad_array = (int(lw), int(uw), int(lh), int(uh))
out = F.pad(in_img, pad_array)
return out, pad_array
def unpad(img, pad):
if len(img.shape) == 4:
if pad[2]+pad[3] > 0:
img = img[:,:,pad[2]:-pad[3],:]
if pad[0]+pad[1] > 0:
img = img[:,:,:,pad[0]:-pad[1]]
elif len(img.shape) == 3:
if pad[2]+pad[3] > 0:
img = img[:,pad[2]:-pad[3],:]
if pad[0]+pad[1] > 0:
img = img[:,:,pad[0]:-pad[1]]
else:
raise NotImplementedError
return img
def get_bbox_from_mask(mask):
mask = torch.squeeze(mask)
assert mask.ndim == 2
nonzero = torch.nonzero(mask)
min_y, min_x = nonzero.min(dim=0).values
max_y, max_x = nonzero.max(dim=0).values
return int(min_y), int(min_x), int(max_y), int(max_x) |