DRgaddam's picture
adding files
362a3b9 verified
raw
history blame
2.05 kB
import os
import sys
import tqdm
import torch
import numpy as np
from PIL import Image
from torch.utils.data.dataloader import DataLoader
filepath = os.path.split(os.path.abspath(__file__))[0]
repopath = os.path.split(filepath)[0]
sys.path.append(repopath)
from lib import *
from utils.misc import *
from data.dataloader import *
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def test(opt, args):
model = eval(opt.Model.name)(**opt.Model)
model.load_state_dict(torch.load(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'latest.pth')), strict=True)
model.cuda()
model.eval()
if args.verbose is True:
sets = tqdm.tqdm(opt.Test.Dataset.sets, desc='Total TestSet', total=len(
opt.Test.Dataset.sets), position=0, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
else:
sets = opt.Test.Dataset.sets
for set in sets:
save_path = os.path.join(opt.Test.Checkpoint.checkpoint_dir, set)
os.makedirs(save_path, exist_ok=True)
test_dataset = eval(opt.Test.Dataset.type)(opt.Test.Dataset.root, [set], opt.Test.Dataset.transforms)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, num_workers=opt.Test.Dataloader.num_workers, pin_memory=opt.Test.Dataloader.pin_memory)
if args.verbose is True:
samples = tqdm.tqdm(test_loader, desc=set + ' - Test', total=len(test_loader),
position=1, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
else:
samples = test_loader
for sample in samples:
sample = to_cuda(sample)
with torch.no_grad():
out = model(sample)
pred = to_numpy(out['pred'], sample['shape'])
Image.fromarray((pred * 255).astype(np.uint8)).save(os.path.join(save_path, sample['name'][0] + '.png'))
if __name__ == "__main__":
args = parse_args()
opt = load_config(args.config)
test(opt, args)