|
|
import argparse |
|
|
import collections |
|
|
import glob |
|
|
import shutil |
|
|
import sys |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
import PIL.Image as Image |
|
|
import torch |
|
|
from torchvision.transforms import transforms |
|
|
from tqdm import tqdm |
|
|
import data_loader.data_loaders as module_data |
|
|
import model.model as module_arch |
|
|
from parse_config import ConfigParser |
|
|
|
|
|
|
|
|
def main(config): |
|
|
logger = config.get_logger('infernece') |
|
|
|
|
|
|
|
|
data_loader = getattr(module_data, config['data_loader']['type'])( |
|
|
config['data_loader']['args']['data_dir'], |
|
|
patch_size=config['data_loader']['args']['patch_size'], |
|
|
batch_size=512, |
|
|
shuffle=False, |
|
|
validation_split=0.0, |
|
|
training=False, |
|
|
num_workers=2 |
|
|
) |
|
|
|
|
|
|
|
|
model = config.init_obj('arch', module_arch) |
|
|
logger.info(model) |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
logger.info('Loading checkpoint: {} ...'.format(config.resume)) |
|
|
checkpoint = torch.load(config.resume, map_location=device) |
|
|
state_dict = checkpoint['state_dict'] |
|
|
if config['n_gpu'] > 1: |
|
|
model = torch.nn.DataParallel(model) |
|
|
model.load_state_dict(state_dict) |
|
|
model = model.to(device) |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
for i, (data, target) in enumerate(tqdm(data_loader)): |
|
|
data, target = data.to(device), target.to(device) |
|
|
output = model(data) |
|
|
|
|
|
batch_size = data_loader.batch_size |
|
|
patch_idx = torch.arange( |
|
|
batch_size * i, batch_size * i + data.shape[0]) |
|
|
pred = torch.argmax(output, dim=1) |
|
|
data_loader.dataset.patches.store_data( |
|
|
patch_idx, [pred.unsqueeze(1)]) |
|
|
|
|
|
preds = [(data_loader.dataset.patches.combine(idx, data_idx=0).cpu(), |
|
|
data_loader.dataset.data[idx]) |
|
|
for idx in range(len(data_loader.dataset.data))] |
|
|
trsfm = transforms.ToPILImage() |
|
|
|
|
|
out_dir = list(config.save_dir.parts) |
|
|
out_dir[-3] = 'output' |
|
|
out_dir = Path(*out_dir) |
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
|
for pred, path in preds: |
|
|
filename = Path(path).stem + '.png' |
|
|
pred = trsfm(pred.float()) |
|
|
pred.save(out_dir / filename) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = argparse.ArgumentParser(description='PyTorch Template') |
|
|
args.add_argument('-c', '--config', default=None, type=str, |
|
|
help='config file path (default: None)') |
|
|
args.add_argument('-r', '--resume', default=None, type=str, |
|
|
help='path to latest checkpoint (default: None)') |
|
|
args.add_argument('-d', '--device', default=None, type=str, |
|
|
help='indices of GPUs to enable (default: all)') |
|
|
args.add_argument('--data', default=None, type=str, |
|
|
help='path to data (default: None)') |
|
|
|
|
|
run_id = datetime.now().strftime(r'%m%d_%H%M%S') |
|
|
dst_data = Path('.data/', run_id) |
|
|
|
|
|
data_dir = dst_data / 'test' / 'images' |
|
|
masks_dir = dst_data / 'test' / 'masks' |
|
|
|
|
|
data_dir.mkdir(parents=True, exist_ok=True) |
|
|
masks_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if '--data' in sys.argv: |
|
|
src_data = Path(sys.argv[sys.argv.index('--data') + 1]) |
|
|
if src_data.is_file(): |
|
|
shutil.copy(src_data, data_dir) |
|
|
mask = Image.new('1', Image.open(src_data).size) |
|
|
mask.save(masks_dir / (src_data.stem + '.png'), 'PNG') |
|
|
else: |
|
|
for filename in glob.glob('*.jpg', root_dir=src_data): |
|
|
file_path = src_data / filename |
|
|
if file_path.is_file(): |
|
|
shutil.copy(file_path, data_dir) |
|
|
mask = Image.new('1', Image.open(file_path).size) |
|
|
mask.save(masks_dir / (file_path.stem + '.png'), 'PNG') |
|
|
sys.argv += ['--data_dir', str(dst_data)] |
|
|
|
|
|
|
|
|
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') |
|
|
options = [ |
|
|
CustomArgs(['--data_dir'], type=str, |
|
|
target='data_loader;args;data_dir') |
|
|
] |
|
|
config = ConfigParser.from_args(args, options) |
|
|
main(config) |
|
|
shutil.rmtree(dst_data) |
|
|
|