# from PIL import Image from dataset import get_loader import torch from torchvision import transforms # from util import save_tensor_img, Logger from tqdm import tqdm from torch import nn import os from models.main import * import argparse # import numpy as np # import cv2 # from skimage import img_as_ubyte def main(args): # Init model device = torch.device("cuda") model = DCFM() model = model.to(device) try: # modelname = os.path.join(args.param_root, 'best_ep198_Smeasure0.7019.pth') modelname = "/scratch/wej36how/codes/DCFM-master/best_ep12_Smeasure0.7256.pth" dcfmnet_dict = torch.load(modelname) print('loaded', modelname) except: dcfmnet_dict = torch.load(os.path.join(args.param_root, 'dcfm.pth')) model.to(device) model.dcfmnet.load_state_dict(dcfmnet_dict) model.eval() model.set_mode('test') tensor2pil = transforms.ToPILImage() for testset in ['NWRD']: if testset == 'CoCA': test_img_path = './data/images/CoCA/' test_gt_path = './data/gts/CoCA/' saved_root = os.path.join(args.save_root, 'CoCA') elif testset == 'CoSOD3k': test_img_path = './data/images/CoSOD3k/' test_gt_path = './data/gts/CoSOD3k/' saved_root = os.path.join(args.save_root, 'CoSOD3k') elif testset == 'CoSal2015': test_img_path = './data/images/CoSal2015/' test_gt_path = './data/gts/CoSal2015/' saved_root = os.path.join(args.save_root, 'CoSal2015') elif testset == 'NWRD': test_img_path = '/home/wej36how/codes/crossvit/results/nwrd22/images/' test_gt_path = '/home/wej36how/codes/crossvit/results/nwrd22/images/' saved_root = os.path.join(args.save_root, 'NWRD') else: print('Unkonwn test dataset') print(args.dataset) test_loader = get_loader( test_img_path, test_gt_path, args.size, 1, istrain=False, shuffle=False, num_workers=8, pin=True) for batch in tqdm(test_loader): inputs = batch[0].to(device).squeeze(0) gts = batch[1].to(device).squeeze(0) subpaths = batch[2] ori_sizes = batch[3] scaled_preds= model(inputs, gts) scaled_preds = torch.sigmoid(scaled_preds[-1]) os.makedirs(os.path.join(saved_root, subpaths[0][0].split('/')[0]), exist_ok=True) num = gts.shape[0] for inum in range(num): subpath = subpaths[inum][0] ori_size = (ori_sizes[inum][0].item(), ori_sizes[inum][1].item()) res = nn.functional.interpolate(scaled_preds[inum].unsqueeze(0), size=ori_size, mode='bilinear', align_corners=True) save_tensor_img(res, os.path.join(saved_root, subpath)) if __name__ == '__main__': # Parameter from command line parser = argparse.ArgumentParser(description='') parser.add_argument('--size', default=224, type=int, help='input size') parser.add_argument('--param_root', default='/data1/dcfm/temp', type=str, help='model folder') parser.add_argument('--save_root', default='./CoSODmaps/pred', type=str, help='Output folder') args = parser.parse_args() main(args)