|
|
import numpy as np
|
|
|
import torch
|
|
|
|
|
|
|
|
|
def all_to_onehot(masks, labels):
|
|
|
if len(masks.shape) == 3:
|
|
|
Ms = np.zeros(
|
|
|
(len(labels), masks.shape[0], masks.shape[1], masks.shape[2]),
|
|
|
dtype=np.uint8,
|
|
|
)
|
|
|
else:
|
|
|
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
|
|
|
|
|
|
for ni, l in enumerate(labels):
|
|
|
Ms[ni] = (masks == l).astype(np.uint8)
|
|
|
|
|
|
return Ms
|
|
|
|
|
|
|
|
|
class MaskMapper:
|
|
|
"""
|
|
|
This class is used to convert a indexed-mask to a one-hot representation.
|
|
|
It also takes care of remapping non-continuous indices
|
|
|
It has two modes:
|
|
|
1. Default. Only masks with new indices are supposed to go into the remapper.
|
|
|
This is also the case for YouTubeVOS.
|
|
|
i.e., regions with index 0 are not "background", but "don't care".
|
|
|
|
|
|
2. Exhaustive. Regions with index 0 are considered "background".
|
|
|
Every single pixel is considered to be "labeled".
|
|
|
"""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.labels = []
|
|
|
self.remappings = {}
|
|
|
|
|
|
|
|
|
self.coherent = True
|
|
|
|
|
|
def clear_lables(self):
|
|
|
self.labels = []
|
|
|
self.remappings = {}
|
|
|
|
|
|
|
|
|
self.coherent = True
|
|
|
|
|
|
def convert_mask(self, mask, exhaustive=False):
|
|
|
|
|
|
labels = np.unique(mask).astype(np.uint8)
|
|
|
labels = labels[labels != 0].tolist()
|
|
|
|
|
|
new_labels = list(set(labels) - set(self.labels))
|
|
|
if not exhaustive:
|
|
|
assert len(new_labels) == len(
|
|
|
labels
|
|
|
), 'Old labels found in non-exhaustive mode'
|
|
|
|
|
|
|
|
|
for i, l in enumerate(new_labels):
|
|
|
self.remappings[l] = i + len(self.labels) + 1
|
|
|
if self.coherent and i + len(self.labels) + 1 != l:
|
|
|
self.coherent = False
|
|
|
|
|
|
if exhaustive:
|
|
|
new_mapped_labels = range(1, len(self.labels) + len(new_labels) + 1)
|
|
|
else:
|
|
|
if self.coherent:
|
|
|
new_mapped_labels = new_labels
|
|
|
else:
|
|
|
new_mapped_labels = range(
|
|
|
len(self.labels) + 1, len(self.labels) + len(new_labels) + 1
|
|
|
)
|
|
|
|
|
|
self.labels.extend(new_labels)
|
|
|
mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
|
|
|
|
|
|
|
|
|
return mask, new_mapped_labels
|
|
|
|
|
|
def remap_index_mask(self, mask):
|
|
|
|
|
|
if self.coherent:
|
|
|
return mask
|
|
|
|
|
|
new_mask = np.zeros_like(mask)
|
|
|
for l, i in self.remappings.items():
|
|
|
new_mask[mask == i] = l
|
|
|
return new_mask
|
|
|
|