WRU-Net / inference.py
HirraA's picture
Upload 9 files
c44d232 verified
import argparse
import collections
import glob
import shutil
import sys
from datetime import datetime
from pathlib import Path
import PIL.Image as Image
import torch
from torchvision.transforms import transforms
from tqdm import tqdm
import data_loader.data_loaders as module_data
import model.model as module_arch
from parse_config import ConfigParser
def main(config):
logger = config.get_logger('infernece')
# setup data_loader instances
data_loader = getattr(module_data, config['data_loader']['type'])(
config['data_loader']['args']['data_dir'],
patch_size=config['data_loader']['args']['patch_size'],
batch_size=512,
shuffle=False,
validation_split=0.0,
training=False,
num_workers=2
)
# build model architecture
model = config.init_obj('arch', module_arch)
logger.info(model)
# prepare model for testing
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info('Loading checkpoint: {} ...'.format(config.resume))
checkpoint = torch.load(config.resume, map_location=device)
state_dict = checkpoint['state_dict']
if config['n_gpu'] > 1:
model = torch.nn.DataParallel(model)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
with torch.no_grad():
for i, (data, target) in enumerate(tqdm(data_loader)):
data, target = data.to(device), target.to(device)
output = model(data)
batch_size = data_loader.batch_size
patch_idx = torch.arange(
batch_size * i, batch_size * i + data.shape[0])
pred = torch.argmax(output, dim=1)
data_loader.dataset.patches.store_data(
patch_idx, [pred.unsqueeze(1)])
preds = [(data_loader.dataset.patches.combine(idx, data_idx=0).cpu(),
data_loader.dataset.data[idx])
for idx in range(len(data_loader.dataset.data))]
trsfm = transforms.ToPILImage()
out_dir = list(config.save_dir.parts)
out_dir[-3] = 'output'
out_dir = Path(*out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for pred, path in preds:
filename = Path(path).stem + '.png'
pred = trsfm(pred.float())
pred.save(out_dir / filename)
if __name__ == '__main__':
args = argparse.ArgumentParser(description='PyTorch Template')
args.add_argument('-c', '--config', default=None, type=str,
help='config file path (default: None)')
args.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)')
args.add_argument('--data', default=None, type=str,
help='path to data (default: None)')
run_id = datetime.now().strftime(r'%m%d_%H%M%S')
dst_data = Path('.data/', run_id)
data_dir = dst_data / 'test' / 'images'
masks_dir = dst_data / 'test' / 'masks'
data_dir.mkdir(parents=True, exist_ok=True)
masks_dir.mkdir(parents=True, exist_ok=True)
if '--data' in sys.argv:
src_data = Path(sys.argv[sys.argv.index('--data') + 1])
if src_data.is_file():
shutil.copy(src_data, data_dir)
mask = Image.new('1', Image.open(src_data).size)
mask.save(masks_dir / (src_data.stem + '.png'), 'PNG')
else:
for filename in glob.glob('*.jpg', root_dir=src_data):
file_path = src_data / filename
if file_path.is_file():
shutil.copy(file_path, data_dir)
mask = Image.new('1', Image.open(file_path).size)
mask.save(masks_dir / (file_path.stem + '.png'), 'PNG')
sys.argv += ['--data_dir', str(dst_data)]
# custom cli options to modify configuration from default values given in json file.
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
options = [
CustomArgs(['--data_dir'], type=str,
target='data_loader;args;data_dir')
]
config = ConfigParser.from_args(args, options)
main(config)
shutil.rmtree(dst_data)