| from typing import Dict, List, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from networks import deeplabv3plus_resnet50 | |
| from networks import convert_to_separable_conv, set_bn_momentum | |
| def get_network() -> torch.nn.Module: | |
| network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False) | |
| convert_to_separable_conv(network.classifier) | |
| set_bn_momentum(network.backbone, momentum=0.01) | |
| state_dict = torch.hub.load_state_dict_from_url( | |
| "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt", | |
| map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ) | |
| network.load_state_dict(state_dict, strict=True) | |
| return network | |
| def colourise_mask( | |
| mask: np.ndarray, | |
| ): | |
| assert len(mask.shape) == 2, ValueError(mask.shape) | |
| h, w = mask.shape | |
| grid = np.zeros((h, w, 3), dtype=np.uint8) | |
| unique_labels = set(mask.flatten()) | |
| voc2012_palette = { | |
| 0: [0, 0, 0], | |
| 1: [128, 0, 0], | |
| 2: [0, 128, 0], | |
| 3: [128, 128, 0], | |
| 4: [0, 0, 128], | |
| 5: [128, 0, 128], | |
| 6: [0, 128, 128], | |
| 7: [128, 128, 128], | |
| 8: [64, 0, 0], | |
| 9: [192, 0, 0], | |
| 10: [64, 128, 0], | |
| 11: [192, 128, 0], | |
| 12: [64, 0, 128], | |
| 13: [192, 0, 128], | |
| 14: [64, 128, 128], | |
| 15: [192, 128, 128], | |
| 16: [0, 64, 0], | |
| 17: [128, 64, 0], | |
| 18: [0, 192, 0], | |
| 19: [128, 192, 0], | |
| 20: [0, 64, 128], | |
| 255: [255, 255, 255] | |
| } | |
| for l in unique_labels: | |
| grid[mask == l] = np.array(voc2012_palette[l]) | |
| try: | |
| grid[mask == l] = np.array(voc2012_palette[l]) | |
| except IndexError: | |
| raise IndexError(f"No colour is found for a label id: {l}") | |
| return grid |