s3_net / scripts /decode_demo.py
zzuxzt's picture
Upload folder using huggingface_hub
d9c5371 verified
#!/usr/bin/env python
#
# file: $ISIP_EXP/tuh_dpath/exp_0074/scripts/decode.py
#
# revision history:
# 20190925 (TE): first version
#
# usage:
# python decode.py odir mfile data
#
# arguments:
# odir: the directory where the hypotheses will be stored
# mfile: input model file
# data: the input data list to be decoded
#
# This script decodes data using a simple MLP model.
#------------------------------------------------------------------------------
# import pytorch modules
#
import torch
import torch.nn as nn
from tqdm import tqdm
# visualize:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
matplotlib.style.use('ggplot')
# import the model and all of its variables/functions
#
from model import *
# import modules
#
import sys
import os
#-----------------------------------------------------------------------------
#
# global variables are listed here
#
#-----------------------------------------------------------------------------
# general global values
#
NUM_ARGS = 3
SPACE = " "
# Constants
POINTS = 1081
NUM_CLASSES = 9
NUM_INPUT_CHANNELS = 1
NUM_OUTPUT_CHANNELS = NUM_CLASSES
# Hokuyo UTM-30LX-EW:
POINTS = 1081 # the number of lidar points
AGNLE_MIN = -2.356194496154785
AGNLE_MAX = 2.356194496154785
RANGE_MAX = 60.0
# for reproducibility, we seed the rng
#
set_seed(SEED1)
#------------------------------------------------------------------------------
#
# the main program starts here
#
#------------------------------------------------------------------------------
# function: main
#
# arguments: none
#
# return: none
#
# This method is the main function.
#
def main(argv):
# ensure we have the correct number of arguments:
if(len(argv) != NUM_ARGS):
print("usage: python nedc_decode_mdl.py [ODIR] [MDL_PATH] [EVAL_SET]")
exit(-1)
# define local variables:
odir = argv[0]
mdl_path = argv[1]
fImg = argv[2]
# if the odir doesn't exist, we make it:
if not os.path.exists(odir):
os.makedirs(odir)
# set the device to use GPU if available:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# get array of the data
# data: [[0, 1, ... 26], [27, 28, ...] ...]
# labels: [0, 0, 1, ...]
#
#[ped_pos_e, scan_e, goal_e, vel_e] = get_data(fname)
eval_dataset = VaeTestDataset(fImg,'dev')
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=1, \
shuffle=False, drop_last=True) #, pin_memory=True)
# instantiate a model:
model = S3Net(input_channels=NUM_INPUT_CHANNELS,
output_channels=NUM_OUTPUT_CHANNELS)
# moves the model to device (cpu in our case so no change):
model.to(device)
# set the model to evaluate
#
model.eval()
# set the loss criterion:
criterion = nn.MSELoss(reduction='sum') #, weight=class_weights)
criterion.to(device)
# load the weights
#
checkpoint = torch.load(mdl_path, map_location=device)
model.load_state_dict(checkpoint['model'])
# for each batch in increments of batch size:
counter = 0
num_samples = 32
# get the number of batches (ceiling of train_data/batch_size):
num_batches = int(len(eval_dataset)/eval_dataloader.batch_size)
with torch.no_grad():
for i, batch in tqdm(enumerate(eval_dataloader), total=num_batches):
#for i, batch in enumerate(dataloader, 0):
if(i % 100 == 0):
counter += 1
# collect the samples as a batch:
scans = batch['scan']
scans = scans.to(device)
intensities = batch['intensity']
intensities = intensities.to(device)
angle_incidence = batch['angle_incidence']
angle_incidence = angle_incidence.to(device)
labels = batch['label']
labels = labels.to(device)
# feed the batch to the network:
inputs_samples = scans.repeat(num_samples,1,1)
intensity_samples = intensities.repeat(num_samples,1,1)
angle_incidence_samples = angle_incidence.repeat(num_samples,1,1)
# feed the batch to the network:
semantic_scan, semantic_channels, kl_loss = model(inputs_samples, intensity_samples, angle_incidence_samples)
semantic_scans = semantic_scan.cpu().detach().numpy()
semantic_scans_mx = semantic_scans.argmax(axis=1)
# majority vote:
semantic_scans_mx_mean = semantic_scans_mx.mode(0).values
# plot:
r = scans.cpu().detach().numpy().reshape(POINTS)
theta = np.linspace(AGNLE_MIN, AGNLE_MAX, num=POINTS, endpoint='true')
## plot semantic label:
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(1,1,1, projection='polar', facecolor='seashell')
smap = labels.reshape(POINTS)
# add the background label:
theta = np.insert(theta, -1, np.pi)
r = np.insert(r, -1, 1)
smap = np.insert(smap, -1, 0)
label_val = np.unique(smap).astype(int)
colors = smap
area = 6
scatter = ax.scatter(theta, r, c=colors, s=area, cmap='nipy_spectral', alpha=0.95, linewidth=10)
ax.set_xticks(np.linspace(AGNLE_MIN, AGNLE_MAX, 8, endpoint='true'))
ax.set_thetamin(-135)
ax.set_thetamax(135)
ax.set_yticklabels([])
# produce a legend with the unique colors from the scatter
classes = ['Other', 'Chair', 'Door', 'Elevator', 'Person', 'Pillar', 'Sofa', 'Table', 'Trash bin', 'Wall']
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.legend(handles=scatter.legend_elements(num=[j for j in label_val])[0], labels=[classes[j] for j in label_val], bbox_to_anchor=(0.5, -0.08), loc='lower center', fontsize=18)
ax.grid(False)
ax.set_theta_offset(np.pi/2)
input_img_name = "./output/semantic_ground_truth_" + str(i)+ ".jpg"
plt.savefig(input_img_name, bbox_inches='tight')
#plt.show()
## plot s3-net semantic seg,ementation:
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(1,1,1, projection='polar', facecolor='seashell')
smap = semantic_scans_mx_mean.reshape(POINTS)
# add the background label:
theta = np.insert(theta, -1, np.pi)
r = np.insert(r, -1, 1)
smap = np.insert(smap, -1, 0)
label_val = np.unique(smap).astype(int)
colors = smap
area = 6
scatter = ax.scatter(theta, r, c=colors, s=area, cmap='nipy_spectral', alpha=0.95, linewidth=10)
ax.set_xticks(np.linspace(AGNLE_MIN, AGNLE_MAX, 8, endpoint='true'))
ax.set_thetamin(-135)
ax.set_thetamax(135)
ax.set_yticklabels([])
# produce a legend with the unique colors from the scatter
classes = ['Other', 'Chair', 'Door', 'Elevator', 'Person', 'Pillar', 'Sofa', 'Table', 'Trash bin', 'Wall']
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.legend(handles=scatter.legend_elements(num=[j for j in label_val])[0], labels=[classes[j] for j in label_val], bbox_to_anchor=(0.5, -0.08), loc='lower center', fontsize=18)
ax.grid(False)
ax.set_theta_offset(np.pi/2)
input_img_name = "./output/semantic_s3net_" + str(i)+ ".jpg"
plt.savefig(input_img_name, bbox_inches='tight')
print(i)
# exit gracefully
#
return True
#
# end of function
# begin gracefully
#
if __name__ == '__main__':
main(sys.argv[1:])
#
# end of file