from typing import Dict, List, Tuple, Union import numpy as np import torch from networks import deeplabv3plus_resnet50 from networks import convert_to_separable_conv, set_bn_momentum def get_network() -> torch.nn.Module: network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False) convert_to_separable_conv(network.classifier) set_bn_momentum(network.backbone, momentum=0.01) state_dict = torch.hub.load_state_dict_from_url( "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt", map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu") ) network.load_state_dict(state_dict, strict=True) return network def colourise_mask( mask: np.ndarray, ): assert len(mask.shape) == 2, ValueError(mask.shape) h, w = mask.shape grid = np.zeros((h, w, 3), dtype=np.uint8) unique_labels = set(mask.flatten()) voc2012_palette = { 0: [0, 0, 0], 1: [128, 0, 0], 2: [0, 128, 0], 3: [128, 128, 0], 4: [0, 0, 128], 5: [128, 0, 128], 6: [0, 128, 128], 7: [128, 128, 128], 8: [64, 0, 0], 9: [192, 0, 0], 10: [64, 128, 0], 11: [192, 128, 0], 12: [64, 0, 128], 13: [192, 0, 128], 14: [64, 128, 128], 15: [192, 128, 128], 16: [0, 64, 0], 17: [128, 64, 0], 18: [0, 192, 0], 19: [128, 192, 0], 20: [0, 64, 128], 255: [255, 255, 255] } for l in unique_labels: grid[mask == l] = np.array(voc2012_palette[l]) try: grid[mask == l] = np.array(voc2012_palette[l]) except IndexError: raise IndexError(f"No colour is found for a label id: {l}") return grid