Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from scipy.io import loadmat | |
| def init_spixel_grid(args, b_train=True, ratio = 1, downsize = 16): | |
| curr_img_height = args.crop_size | |
| curr_img_width = args.crop_size | |
| # pixel coord | |
| all_h_coords = np.arange(0, curr_img_height, 1) | |
| all_w_coords = np.arange(0, curr_img_width, 1) | |
| curr_pxl_coord = np.array(np.meshgrid(all_h_coords, all_w_coords, indexing='ij')) | |
| coord_tensor = np.concatenate([curr_pxl_coord[1:2, :, :], curr_pxl_coord[:1, :, :]]) | |
| all_XY_feat = (torch.from_numpy( | |
| np.tile(coord_tensor, (1, 1, 1, 1)).astype(np.float32)).cuda()) | |
| return all_XY_feat | |
| def label2one_hot_torch(labels, C=14): | |
| """ Converts an integer label torch.autograd.Variable to a one-hot Variable. | |
| Args: | |
| labels(tensor) : segmentation label | |
| C (integer) : number of classes in labels | |
| Returns: | |
| target (tensor) : one-hot vector of the input label | |
| Shape: | |
| labels: (B, 1, H, W) | |
| target: (B, N, H, W) | |
| """ | |
| b,_, h, w = labels.shape | |
| one_hot = torch.zeros(b, C, h, w, dtype=torch.long).to(labels) | |
| target = one_hot.scatter_(1, labels.type(torch.long).data, 1) #require long type | |
| return target.type(torch.float32) | |
| colors = loadmat('data/color150.mat')['colors'] | |
| colors = np.concatenate((colors, colors, colors, colors)) | |
| def unique(ar, return_index=False, return_inverse=False, return_counts=False): | |
| ar = np.asanyarray(ar).flatten() | |
| optional_indices = return_index or return_inverse | |
| optional_returns = optional_indices or return_counts | |
| if ar.size == 0: | |
| if not optional_returns: | |
| ret = ar | |
| else: | |
| ret = (ar,) | |
| if return_index: | |
| ret += (np.empty(0, np.bool),) | |
| if return_inverse: | |
| ret += (np.empty(0, np.bool),) | |
| if return_counts: | |
| ret += (np.empty(0, np.intp),) | |
| return ret | |
| if optional_indices: | |
| perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') | |
| aux = ar[perm] | |
| else: | |
| ar.sort() | |
| aux = ar | |
| flag = np.concatenate(([True], aux[1:] != aux[:-1])) | |
| if not optional_returns: | |
| ret = aux[flag] | |
| else: | |
| ret = (aux[flag],) | |
| if return_index: | |
| ret += (perm[flag],) | |
| if return_inverse: | |
| iflag = np.cumsum(flag) - 1 | |
| inv_idx = np.empty(ar.shape, dtype=np.intp) | |
| inv_idx[perm] = iflag | |
| ret += (inv_idx,) | |
| if return_counts: | |
| idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) | |
| ret += (np.diff(idx),) | |
| return ret | |
| def colorEncode(labelmap, mode='RGB'): | |
| labelmap = labelmap.astype('int') | |
| labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), | |
| dtype=np.uint8) | |
| for label in unique(labelmap): | |
| if label < 0: | |
| continue | |
| labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ | |
| np.tile(colors[label], | |
| (labelmap.shape[0], labelmap.shape[1], 1)) | |
| if mode == 'BGR': | |
| return labelmap_rgb[:, :, ::-1] | |
| else: | |
| return labelmap_rgb | |
| def get_edges(sp_label, sp_num): | |
| # This function returns a (hw) * (hw) matrix N. | |
| # If Nij = 1, then superpixel i and j are neighbors | |
| # Otherwise Nij = 0. | |
| top = sp_label[:, :, :-1, :] - sp_label[:, :, 1:, :] | |
| left = sp_label[:, :, :, :-1] - sp_label[:, :, :, 1:] | |
| top_left = sp_label[:, :, :-1, :-1] - sp_label[:, :, 1:, 1:] | |
| top_right = sp_label[:, :, :-1, 1:] - sp_label[:, :, 1:, :-1] | |
| n_affs = [] | |
| edge_indices = [] | |
| for i in range(sp_label.shape[0]): | |
| # change to torch.ones below to include self-loop in graph | |
| n_aff = torch.zeros(sp_num, sp_num).unsqueeze(0).cuda() | |
| # top/bottom | |
| top_i = top[i].squeeze() | |
| x, y = torch.nonzero(top_i, as_tuple = True) | |
| sp1 = sp_label[i, :, x, y].squeeze().long() | |
| sp2 = sp_label[i, :, x+1, y].squeeze().long() | |
| n_aff[:, sp1, sp2] = 1 | |
| n_aff[:, sp2, sp1] = 1 | |
| # left/right | |
| left_i = left[i].squeeze() | |
| try: | |
| x, y = torch.nonzero(left_i, as_tuple = True) | |
| except: | |
| import pdb; pdb.set_trace() | |
| sp1 = sp_label[i, :, x, y].squeeze().long() | |
| sp2 = sp_label[i, :, x, y+1].squeeze().long() | |
| n_aff[:, sp1, sp2] = 1 | |
| n_aff[:, sp2, sp1] = 1 | |
| # top left | |
| top_left_i = top_left[i].squeeze() | |
| x, y = torch.nonzero(top_left_i, as_tuple = True) | |
| sp1 = sp_label[i, :, x, y].squeeze().long() | |
| sp2 = sp_label[i, :, x+1, y+1].squeeze().long() | |
| n_aff[:, sp1, sp2] = 1 | |
| n_aff[:, sp2, sp1] = 1 | |
| # top right | |
| top_right_i = top_right[i].squeeze() | |
| x, y = torch.nonzero(top_right_i, as_tuple = True) | |
| sp1 = sp_label[i, :, x, y+1].squeeze().long() | |
| sp2 = sp_label[i, :, x+1, y].squeeze().long() | |
| n_aff[:, sp1, sp2] = 1 | |
| n_aff[:, sp2, sp1] = 1 | |
| n_affs.append(n_aff) | |
| edge_index = torch.stack(torch.nonzero(n_aff.squeeze(), as_tuple=True)) | |
| edge_indices.append(edge_index.cuda()) | |
| return edge_indices | |
| def draw_color_seg(seg): | |
| seg = seg.detach().cpu().numpy() | |
| color_ = [] | |
| for i in range(seg.shape[0]): | |
| colori = colorEncode(seg[i].squeeze()) | |
| colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1) | |
| color_.append(colori) | |
| color_ = torch.stack(color_) | |
| return color_ | |