Spaces:
Configuration error
Configuration error
File size: 2,354 Bytes
d01f62c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | import numpy as np
import torch
from dataset.util import all_to_onehot
class MaskMapper:
"""
This class is used to convert a indexed-mask to a one-hot representation.
It also takes care of remapping non-continuous indices
It has two modes:
1. Default. Only masks with new indices are supposed to go into the remapper.
This is also the case for YouTubeVOS.
i.e., regions with index 0 are not "background", but "don't care".
2. Exhaustive. Regions with index 0 are considered "background".
Every single pixel is considered to be "labeled".
"""
def __init__(self):
self.labels = []
self.remappings = {}
# if coherent, no mapping is required
self.coherent = True
def convert_mask(self, mask, exhaustive=False):
# mask is in index representation, H*W numpy array
labels = np.unique(mask).astype(np.uint8)
labels = labels[labels!=0].tolist()
new_labels = list(set(labels) - set(self.labels))
# print('new_labels', new_labels) # [255]
if not exhaustive:
assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode'
# add new remappings
for i, l in enumerate(new_labels):
self.remappings[l] = i+len(self.labels)+1
if self.coherent and i+len(self.labels)+1 != l:
self.coherent = False
if exhaustive:
new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1)
else:
if self.coherent:
new_mapped_labels = new_labels
else:
new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1)
# print(list(new_mapped_labels));assert 1==0 # [1]
self.labels.extend(new_labels)
# print(self.labels);assert 1==0 # [255]
mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
# mask num_objects*H*W; new_mapped_labels: [num_objects]
return mask, new_mapped_labels
def remap_index_mask(self, mask):
# mask is in index representation, H*W numpy array
if self.coherent:
return mask
new_mask = np.zeros_like(mask)
for l, i in self.remappings.items():
new_mask[mask==i] = l
return new_mask |