Spaces:
Running
Running
| import os | |
| import torch | |
| import random | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| def recursive_glob(rootdir='.', suffix=''): | |
| """Performs recursive glob with given suffix and rootdir | |
| :param rootdir is the root directory | |
| :param suffix is the suffix to be searched | |
| """ | |
| return [os.path.join(looproot, filename) | |
| for looproot, _, filenames in os.walk(rootdir) | |
| for filename in filenames if filename.endswith(suffix)] | |
| def get_cityscapes_labels(): | |
| return np.array([ | |
| # [ 0, 0, 0], | |
| [128, 64, 128], | |
| [244, 35, 232], | |
| [70, 70, 70], | |
| [102, 102, 156], | |
| [190, 153, 153], | |
| [153, 153, 153], | |
| [250, 170, 30], | |
| [220, 220, 0], | |
| [107, 142, 35], | |
| [152, 251, 152], | |
| [0, 130, 180], | |
| [220, 20, 60], | |
| [255, 0, 0], | |
| [0, 0, 142], | |
| [0, 0, 70], | |
| [0, 60, 100], | |
| [0, 80, 100], | |
| [0, 0, 230], | |
| [119, 11, 32]]) | |
| def get_pascal_labels(): | |
| """Load the mapping that associates pascal classes with label colors | |
| Returns: | |
| np.ndarray with dimensions (21, 3) | |
| """ | |
| return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], | |
| [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], | |
| [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], | |
| [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], | |
| [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], | |
| [0, 64, 128]]) | |
| def get_mhp_labels(): | |
| """Load the mapping that associates pascal classes with label colors | |
| Returns: | |
| np.ndarray with dimensions (21, 3) | |
| """ | |
| return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], | |
| [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], | |
| [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], | |
| [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], | |
| [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], | |
| [0, 64, 128], # 21 | |
| [96, 0, 0], [0, 96, 0], [96, 96, 0], | |
| [0, 0, 96], [96, 0, 96], [0, 96, 96], [96, 96, 96], | |
| [32, 0, 0], [160, 0, 0], [32, 96, 0], [160, 96, 0], | |
| [32, 0, 96], [160, 0, 96], [32, 96, 96], [160, 96, 96], | |
| [0, 32, 0], [96, 32, 0], [0, 160, 0], [96, 160, 0], | |
| [0, 32, 96], # 41 | |
| [48, 0, 0], [0, 48, 0], [48, 48, 0], | |
| [0, 0, 96], [48, 0, 48], [0, 48, 48], [48, 48, 48], | |
| [16, 0, 0], [80, 0, 0], [16, 48, 0], [80, 48, 0], | |
| [16, 0, 48], [80, 0, 48], [16, 48, 48], [80, 48, 48], | |
| [0, 16, 0], [48, 16, 0], [0, 80, 0], # 59 | |
| ]) | |
| def encode_segmap(mask): | |
| """Encode segmentation label images as pascal classes | |
| Args: | |
| mask (np.ndarray): raw segmentation label image of dimension | |
| (M, N, 3), in which the Pascal classes are encoded as colours. | |
| Returns: | |
| (np.ndarray): class map with dimensions (M,N), where the value at | |
| a given location is the integer denoting the class index. | |
| """ | |
| mask = mask.astype(int) | |
| label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) | |
| for ii, label in enumerate(get_pascal_labels()): | |
| label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii | |
| label_mask = label_mask.astype(int) | |
| return label_mask | |
| def decode_seg_map_sequence(label_masks, dataset='pascal'): | |
| rgb_masks = [] | |
| for label_mask in label_masks: | |
| rgb_mask = decode_segmap(label_mask, dataset) | |
| rgb_masks.append(rgb_mask) | |
| rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) | |
| return rgb_masks | |
| def decode_segmap(label_mask, dataset, plot=False): | |
| """Decode segmentation class labels into a color image | |
| Args: | |
| label_mask (np.ndarray): an (M,N) array of integer values denoting | |
| the class label at each spatial location. | |
| plot (bool, optional): whether to show the resulting color image | |
| in a figure. | |
| Returns: | |
| (np.ndarray, optional): the resulting decoded color image. | |
| """ | |
| if dataset == 'pascal': | |
| n_classes = 21 | |
| label_colours = get_pascal_labels() | |
| elif dataset == 'cityscapes': | |
| n_classes = 19 | |
| label_colours = get_cityscapes_labels() | |
| elif dataset == 'mhp': | |
| n_classes = 59 | |
| label_colours = get_mhp_labels() | |
| else: | |
| raise NotImplementedError | |
| r = label_mask.copy() | |
| g = label_mask.copy() | |
| b = label_mask.copy() | |
| for ll in range(0, n_classes): | |
| r[label_mask == ll] = label_colours[ll, 0] | |
| g[label_mask == ll] = label_colours[ll, 1] | |
| b[label_mask == ll] = label_colours[ll, 2] | |
| rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) | |
| rgb[:, :, 0] = r / 255.0 | |
| rgb[:, :, 1] = g / 255.0 | |
| rgb[:, :, 2] = b / 255.0 | |
| if plot: | |
| plt.imshow(rgb) | |
| plt.show() | |
| else: | |
| return rgb | |
| def generate_param_report(logfile, param): | |
| log_file = open(logfile, 'w') | |
| for key, val in param.items(): | |
| log_file.write(key + ':' + str(val) + '\n') | |
| log_file.close() | |
| def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): | |
| n, c, h, w = logit.size() | |
| # logit = logit.permute(0, 2, 3, 1) | |
| target = target.squeeze(1) | |
| if weight is None: | |
| criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index,size_average=size_average) | |
| else: | |
| criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=size_average) | |
| loss = criterion(logit, target.long()) | |
| return loss | |
| def cross_entropy2d_dataparallel(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): | |
| n, c, h, w = logit.size() | |
| # logit = logit.permute(0, 2, 3, 1) | |
| target = target.squeeze(1) | |
| if weight is None: | |
| criterion = nn.DataParallel(nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index,size_average=size_average)) | |
| else: | |
| criterion = nn.DataParallel(nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=size_average)) | |
| loss = criterion(logit, target.long()) | |
| return loss.sum() | |
| def lr_poly(base_lr, iter_, max_iter=100, power=0.9): | |
| return base_lr * ((1 - float(iter_) / max_iter) ** power) | |
| def get_iou(pred, gt, n_classes=21): | |
| total_iou = 0.0 | |
| for i in range(len(pred)): | |
| pred_tmp = pred[i] | |
| gt_tmp = gt[i] | |
| intersect = [0] * n_classes | |
| union = [0] * n_classes | |
| for j in range(n_classes): | |
| match = (pred_tmp == j) + (gt_tmp == j) | |
| it = torch.sum(match == 2).item() | |
| un = torch.sum(match > 0).item() | |
| intersect[j] += it | |
| union[j] += un | |
| iou = [] | |
| for k in range(n_classes): | |
| if union[k] == 0: | |
| continue | |
| iou.append(intersect[k] / union[k]) | |
| img_iou = (sum(iou) / len(iou)) | |
| total_iou += img_iou | |
| return total_iou | |
| def scale_tensor(input,size=512,mode='bilinear'): | |
| print(input.size()) | |
| # b,h,w = input.size() | |
| _, _, h, w = input.size() | |
| if mode == 'nearest': | |
| if h == 512 and w == 512: | |
| return input | |
| return F.upsample_nearest(input,size=(size,size)) | |
| if h>512 and w > 512: | |
| return F.upsample(input, size=(size,size), mode=mode, align_corners=True) | |
| return F.upsample(input, size=(size,size), mode=mode, align_corners=True) | |
| def scale_tensor_list(input,): | |
| output = [] | |
| for i in range(len(input)-1): | |
| output_item = [] | |
| for j in range(len(input[i])): | |
| _, _, h, w = input[-1][j].size() | |
| output_item.append(F.upsample(input[i][j], size=(h,w), mode='bilinear', align_corners=True)) | |
| output.append(output_item) | |
| output.append(input[-1]) | |
| return output | |
| def scale_tensor_list_0(input,base_input): | |
| output = [] | |
| assert len(input) == len(base_input) | |
| for j in range(len(input)): | |
| _, _, h, w = base_input[j].size() | |
| after_size = F.upsample(input[j], size=(h,w), mode='bilinear', align_corners=True) | |
| base_input[j] = base_input[j] + after_size | |
| # output.append(output_item) | |
| # output.append(input[-1]) | |
| return base_input | |
| if __name__ == '__main__': | |
| print(lr_poly(0.007,iter_=99,max_iter=150)) |