Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| """Utils for the pix2pixHD model.""" | |
| import numpy as np | |
| import torch | |
| from imaginaire.utils.data import get_paired_input_label_channel_number | |
| from imaginaire.utils.distributed import dist_all_gather_tensor, is_master | |
| from imaginaire.utils.distributed import master_only_print as print | |
| from imaginaire.utils.trainer import (get_optimizer, get_optimizer_for_params, | |
| wrap_model_and_optimizer) | |
| from sklearn.cluster import KMeans | |
| def cluster_features(cfg, train_data_loader, net_E, | |
| preprocess=None, small_ratio=0.0625, is_cityscapes=True): | |
| r"""Use clustering to compute the features. | |
| Args: | |
| cfg (obj): Global configuration file. | |
| train_data_loader (obj): Dataloader for iterate through the training | |
| set. | |
| net_E (nn.Module): Pytorch network. | |
| preprocess (function): Pre-processing function. | |
| small_ratio (float): We only consider instance that at least occupy | |
| $(small_ratio) amount of image space. | |
| is_cityscapes (bool): Is this is the cityscape dataset? In the | |
| Cityscapes dataset, the instance labels for car start with 26001, | |
| 26002, ... | |
| Returns: | |
| ( num_labels x num_cluster_centers x feature_dims): cluster centers. | |
| """ | |
| # Encode features. | |
| label_nc = get_paired_input_label_channel_number(cfg.data) | |
| feat_nc = cfg.gen.enc.num_feat_channels | |
| n_clusters = getattr(cfg.gen.enc, 'num_clusters', 10) | |
| # Compute features. | |
| features = {} | |
| for label in range(label_nc): | |
| features[label] = np.zeros((0, feat_nc + 1)) | |
| for data in train_data_loader: | |
| if preprocess is not None: | |
| data = preprocess(data) | |
| feat = encode_features(net_E, feat_nc, label_nc, | |
| data['images'], data['instance_maps'], | |
| is_cityscapes) | |
| # We only collect the feature vectors for the master GPU. | |
| if is_master(): | |
| for label in range(label_nc): | |
| features[label] = np.append( | |
| features[label], feat[label], axis=0) | |
| # Clustering. | |
| # We only perform clustering for the master GPU. | |
| if is_master(): | |
| for label in range(label_nc): | |
| feat = features[label] | |
| # We only consider segments that are greater than a pre-set | |
| # threshold. | |
| feat = feat[feat[:, -1] > small_ratio, :-1] | |
| if feat.shape[0]: | |
| n_clusters = min(feat.shape[0], n_clusters) | |
| kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat) | |
| n, d = kmeans.cluster_centers_.shape | |
| this_cluster = getattr(net_E, 'cluster_%d' % label) | |
| this_cluster[0:n, :] = torch.Tensor( | |
| kmeans.cluster_centers_).float() | |
| def encode_features(net_E, feat_nc, label_nc, image, inst, | |
| is_cityscapes=True): | |
| r"""Compute feature embeddings for an image image. | |
| TODO(Ting-Chun): To make this funciton dataset independent. | |
| Args: | |
| net_E (nn.Module): The encoder network. | |
| feat_nc (int): Feature dimensions | |
| label_nc (int): Number of segmentation labels. | |
| image (tensor): Input image tensor. | |
| inst (tensor): Input instance map. | |
| is_cityscapes (bool): Is this is the cityscape dataset? In the | |
| Cityscapes dataset, the instance labels for car start with 26001, | |
| 26002, ... | |
| Returns: | |
| (list of list of numpy vectors): We will have $(label_nc) | |
| list. For each list, it will record a list of feature vectors of | |
| dimension $(feat_nc+1) where the first $(feat_nc) dimensions is | |
| the representative feature of an instance and the last dimension | |
| is the proportion. | |
| """ | |
| # h, w = inst.size()[2:] | |
| feat_map = net_E(image, inst) | |
| feature_map_gather = dist_all_gather_tensor(feat_map) | |
| inst_gathered = dist_all_gather_tensor(inst) | |
| # Initialize the cluster centers. | |
| # For each feature vector, | |
| # 0:feat_nc will be the feature vector. | |
| # The feat_nc dimension record the percentage of the instance. | |
| feature = {} | |
| for i in range(label_nc): | |
| feature[i] = np.zeros((0, feat_nc + 1)) | |
| if is_master(): | |
| all_feat_map = torch.cat(feature_map_gather, 0) | |
| all_inst_map = torch.cat(inst_gathered, 0) | |
| # Scan through the batches. | |
| for n in range(all_feat_map.size()[0]): | |
| feat_map = all_feat_map[n:(n + 1), :, :, :] | |
| inst = all_inst_map[n:(n + 1), :, :, :] | |
| fh, fw = feat_map.size()[2:] | |
| inst_np = inst.cpu().numpy().astype(int) | |
| for i in np.unique(inst_np): | |
| if is_cityscapes: | |
| label = i if i < 1000 else i // 1000 | |
| else: | |
| label = i | |
| idx = (inst == int(i)).nonzero() | |
| num = idx.size()[0] | |
| # We will just pick the middle pixel as its representative | |
| # feature. | |
| idx = idx[num // 2, :] | |
| val = np.zeros((1, feat_nc + 1)) | |
| for k in range(feat_nc): | |
| # We expect idx[0]=0 and idx[1]=0 as the number of sample | |
| # per processing is 1 (idx[0]=0) and the channel number of | |
| # the instance map is 1. | |
| val[0, k] = feat_map[ | |
| idx[0], idx[1] + k, idx[2], idx[3]].item() | |
| val[0, feat_nc] = float(num) / (fh * fw) | |
| feature[label] = np.append(feature[label], val, axis=0) | |
| return feature | |
| else: | |
| return feature | |
| def get_edges(t): | |
| r""" Compute edge maps for a given input instance map. | |
| Args: | |
| t (4D tensor): Input instance map. | |
| Returns: | |
| (4D tensor): Output edge map. | |
| """ | |
| edge = torch.cuda.ByteTensor(t.size()).zero_() | |
| edge[:, :, :, 1:] = edge[:, :, :, 1:] | ( | |
| t[:, :, :, 1:] != t[:, :, :, :-1]).byte() | |
| edge[:, :, :, :-1] = edge[:, :, :, :-1] | ( | |
| t[:, :, :, 1:] != t[:, :, :, :-1]).byte() | |
| edge[:, :, 1:, :] = edge[:, :, 1:, :] | ( | |
| t[:, :, 1:, :] != t[:, :, :-1, :]).byte() | |
| edge[:, :, :-1, :] = edge[:, :, :-1, :] | ( | |
| t[:, :, 1:, :] != t[:, :, :-1, :]).byte() | |
| return edge.float() | |
| def get_train_params(net, param_names_start_with=[], param_names_include=[]): | |
| r"""Get train parameters. | |
| Args: | |
| net (obj): Network object. | |
| param_names_start_with (list of strings): Params whose names | |
| start with any of the strings will be trained. | |
| param_names_include (list of strings): Params whose names include | |
| any of the strings will be trained. | |
| """ | |
| params_to_train = [] | |
| params_dict = net.state_dict() | |
| list_of_param_names_to_train = set() | |
| # Iterate through all params in the network and check if we need to | |
| # train it. | |
| for key, value in params_dict.items(): | |
| do_train = False | |
| # If the param name starts with the target string (excluding | |
| # the 'module' part etc), we will train this param. | |
| key_s = key.replace('module.', '').replace('averaged_model.', '') | |
| for param_name in param_names_start_with: | |
| if key_s.startswith(param_name): | |
| do_train = True | |
| list_of_param_names_to_train.add(param_name) | |
| # Otherwise, if the param name includes the target string, | |
| # we will also train it. | |
| if not do_train: | |
| for param_name in param_names_include: | |
| if param_name in key_s: | |
| do_train = True | |
| full_param_name = \ | |
| key_s[:(key_s.find(param_name) + len(param_name))] | |
| list_of_param_names_to_train.add(full_param_name) | |
| # If we decide to train the param, add it to the list to train. | |
| if do_train: | |
| module = net | |
| key_list = key.split('.') | |
| for k in key_list: | |
| module = getattr(module, k) | |
| params_to_train += [module] | |
| print('Training layers: ', sorted(list_of_param_names_to_train)) | |
| return params_to_train | |
| def get_optimizer_with_params(cfg, net_G, net_D, param_names_start_with=[], | |
| param_names_include=[]): | |
| r"""Return the optimizer object. | |
| Args: | |
| cfg (obj): Global config. | |
| net_G (obj): Generator network. | |
| net_D (obj): Discriminator network. | |
| param_names_start_with (list of strings): Params whose names | |
| start with any of the strings will be trained. | |
| param_names_include (list of strings): Params whose names include | |
| any of the strings will be trained. | |
| """ | |
| # If any of the param name lists is not empty, will only train | |
| # these params. Otherwise will train the entire network (all params). | |
| if param_names_start_with or param_names_include: | |
| params = get_train_params(net_G, param_names_start_with, | |
| param_names_include) | |
| else: | |
| params = net_G.parameters() | |
| opt_G = get_optimizer_for_params(cfg.gen_opt, params) | |
| opt_D = get_optimizer(cfg.dis_opt, net_D) | |
| return wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D) | |