#!/usr/bin/env python # # file: $ISIP_EXP/tuh_dpath/exp_0074/scripts/decode.py # # revision history: # 20190925 (TE): first version # # usage: # python decode.py odir mfile data # # arguments: # odir: the directory where the hypotheses will be stored # mfile: input model file # data: the input data list to be decoded # # This script decodes data using a simple MLP model. #------------------------------------------------------------------------------ # import pytorch modules # import torch import torch.nn as nn from tqdm import tqdm # visualize: import matplotlib.pyplot as plt import numpy as np import matplotlib matplotlib.style.use('ggplot') # import the model and all of its variables/functions # from model import * # import modules # import sys import os #----------------------------------------------------------------------------- # # global variables are listed here # #----------------------------------------------------------------------------- # general global values # NUM_ARGS = 3 SPACE = " " # Constants POINTS = 1081 NUM_CLASSES = 9 NUM_INPUT_CHANNELS = 1 NUM_OUTPUT_CHANNELS = NUM_CLASSES # Hokuyo UTM-30LX-EW: POINTS = 1081 # the number of lidar points AGNLE_MIN = -2.356194496154785 AGNLE_MAX = 2.356194496154785 RANGE_MAX = 60.0 # for reproducibility, we seed the rng # set_seed(SEED1) #------------------------------------------------------------------------------ # # the main program starts here # #------------------------------------------------------------------------------ # function: main # # arguments: none # # return: none # # This method is the main function. # def main(argv): # ensure we have the correct number of arguments: if(len(argv) != NUM_ARGS): print("usage: python nedc_decode_mdl.py [ODIR] [MDL_PATH] [EVAL_SET]") exit(-1) # define local variables: odir = argv[0] mdl_path = argv[1] fImg = argv[2] # if the odir doesn't exist, we make it: if not os.path.exists(odir): os.makedirs(odir) # set the device to use GPU if available: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # get array of the data # data: [[0, 1, ... 26], [27, 28, ...] ...] # labels: [0, 0, 1, ...] # #[ped_pos_e, scan_e, goal_e, vel_e] = get_data(fname) eval_dataset = VaeTestDataset(fImg,'dev') eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=1, \ shuffle=False, drop_last=True) #, pin_memory=True) # instantiate a model: model = S3Net(input_channels=NUM_INPUT_CHANNELS, output_channels=NUM_OUTPUT_CHANNELS) # moves the model to device (cpu in our case so no change): model.to(device) # set the model to evaluate # model.eval() # set the loss criterion: criterion = nn.MSELoss(reduction='sum') #, weight=class_weights) criterion.to(device) # load the weights # checkpoint = torch.load(mdl_path, map_location=device) model.load_state_dict(checkpoint['model']) # for each batch in increments of batch size: counter = 0 num_samples = 32 # get the number of batches (ceiling of train_data/batch_size): num_batches = int(len(eval_dataset)/eval_dataloader.batch_size) with torch.no_grad(): for i, batch in tqdm(enumerate(eval_dataloader), total=num_batches): #for i, batch in enumerate(dataloader, 0): if(i % 100 == 0): counter += 1 # collect the samples as a batch: scans = batch['scan'] scans = scans.to(device) intensities = batch['intensity'] intensities = intensities.to(device) angle_incidence = batch['angle_incidence'] angle_incidence = angle_incidence.to(device) labels = batch['label'] labels = labels.to(device) # feed the batch to the network: inputs_samples = scans.repeat(num_samples,1,1) intensity_samples = intensities.repeat(num_samples,1,1) angle_incidence_samples = angle_incidence.repeat(num_samples,1,1) # feed the batch to the network: semantic_scan, semantic_channels, kl_loss = model(inputs_samples, intensity_samples, angle_incidence_samples) semantic_scans = semantic_scan.cpu().detach().numpy() semantic_scans_mx = semantic_scans.argmax(axis=1) # majority vote: semantic_scans_mx_mean = semantic_scans_mx.mode(0).values # plot: r = scans.cpu().detach().numpy().reshape(POINTS) theta = np.linspace(AGNLE_MIN, AGNLE_MAX, num=POINTS, endpoint='true') ## plot semantic label: fig = plt.figure(figsize=(12, 12)) ax = fig.add_subplot(1,1,1, projection='polar', facecolor='seashell') smap = labels.reshape(POINTS) # add the background label: theta = np.insert(theta, -1, np.pi) r = np.insert(r, -1, 1) smap = np.insert(smap, -1, 0) label_val = np.unique(smap).astype(int) colors = smap area = 6 scatter = ax.scatter(theta, r, c=colors, s=area, cmap='nipy_spectral', alpha=0.95, linewidth=10) ax.set_xticks(np.linspace(AGNLE_MIN, AGNLE_MAX, 8, endpoint='true')) ax.set_thetamin(-135) ax.set_thetamax(135) ax.set_yticklabels([]) # produce a legend with the unique colors from the scatter classes = ['Other', 'Chair', 'Door', 'Elevator', 'Person', 'Pillar', 'Sofa', 'Table', 'Trash bin', 'Wall'] plt.xticks(fontsize=16) plt.yticks(fontsize=16) plt.legend(handles=scatter.legend_elements(num=[j for j in label_val])[0], labels=[classes[j] for j in label_val], bbox_to_anchor=(0.5, -0.08), loc='lower center', fontsize=18) ax.grid(False) ax.set_theta_offset(np.pi/2) input_img_name = "./output/semantic_ground_truth_" + str(i)+ ".jpg" plt.savefig(input_img_name, bbox_inches='tight') #plt.show() ## plot s3-net semantic seg,ementation: fig = plt.figure(figsize=(12, 12)) ax = fig.add_subplot(1,1,1, projection='polar', facecolor='seashell') smap = semantic_scans_mx_mean.reshape(POINTS) # add the background label: theta = np.insert(theta, -1, np.pi) r = np.insert(r, -1, 1) smap = np.insert(smap, -1, 0) label_val = np.unique(smap).astype(int) colors = smap area = 6 scatter = ax.scatter(theta, r, c=colors, s=area, cmap='nipy_spectral', alpha=0.95, linewidth=10) ax.set_xticks(np.linspace(AGNLE_MIN, AGNLE_MAX, 8, endpoint='true')) ax.set_thetamin(-135) ax.set_thetamax(135) ax.set_yticklabels([]) # produce a legend with the unique colors from the scatter classes = ['Other', 'Chair', 'Door', 'Elevator', 'Person', 'Pillar', 'Sofa', 'Table', 'Trash bin', 'Wall'] plt.xticks(fontsize=16) plt.yticks(fontsize=16) plt.legend(handles=scatter.legend_elements(num=[j for j in label_val])[0], labels=[classes[j] for j in label_val], bbox_to_anchor=(0.5, -0.08), loc='lower center', fontsize=18) ax.grid(False) ax.set_theta_offset(np.pi/2) input_img_name = "./output/semantic_s3net_" + str(i)+ ".jpg" plt.savefig(input_img_name, bbox_inches='tight') print(i) # exit gracefully # return True # # end of function # begin gracefully # if __name__ == '__main__': main(sys.argv[1:]) # # end of file