File size: 3,380 Bytes
168ec29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# from PIL import Image
from dataset import get_loader
import torch
from torchvision import transforms
# from util import save_tensor_img, Logger
from tqdm import tqdm
from torch import nn
import os
from models.main import *
import argparse
# import numpy as np
# import cv2
# from skimage import img_as_ubyte


def main(args):
    # Init model

    device = torch.device("cuda")
    model = DCFM()
    model = model.to(device)
    try:
        # modelname = os.path.join(args.param_root, 'best_ep198_Smeasure0.7019.pth')
        modelname = "/scratch/wej36how/codes/DCFM-master/best_ep12_Smeasure0.7256.pth"
        dcfmnet_dict = torch.load(modelname)
        print('loaded', modelname)
    except:
        dcfmnet_dict = torch.load(os.path.join(args.param_root, 'dcfm.pth'))

    model.to(device)
    model.dcfmnet.load_state_dict(dcfmnet_dict)
    model.eval()
    model.set_mode('test')

    tensor2pil = transforms.ToPILImage()
    for testset in ['NWRD']:
        if testset == 'CoCA':
            test_img_path = './data/images/CoCA/'
            test_gt_path = './data/gts/CoCA/'
            saved_root = os.path.join(args.save_root, 'CoCA')
        elif testset == 'CoSOD3k':
            test_img_path = './data/images/CoSOD3k/'
            test_gt_path = './data/gts/CoSOD3k/'
            saved_root = os.path.join(args.save_root, 'CoSOD3k')
        elif testset == 'CoSal2015':
            test_img_path = './data/images/CoSal2015/'
            test_gt_path = './data/gts/CoSal2015/'
            saved_root = os.path.join(args.save_root, 'CoSal2015')
        elif testset == 'NWRD':
            test_img_path = '/home/wej36how/codes/crossvit/results/nwrd22/images/'
            test_gt_path = '/home/wej36how/codes/crossvit/results/nwrd22/images/'
            saved_root = os.path.join(args.save_root, 'NWRD')
        else:
            print('Unkonwn test dataset')
            print(args.dataset)
        
        test_loader = get_loader(
            test_img_path, test_gt_path, args.size, 1, istrain=False, shuffle=False, num_workers=8, pin=True)

        for batch in tqdm(test_loader):
            inputs = batch[0].to(device).squeeze(0)
            gts = batch[1].to(device).squeeze(0)
            subpaths = batch[2]
            ori_sizes = batch[3]
            scaled_preds= model(inputs, gts)
            scaled_preds = torch.sigmoid(scaled_preds[-1])
            os.makedirs(os.path.join(saved_root, subpaths[0][0].split('/')[0]), exist_ok=True)
            num = gts.shape[0]
            for inum in range(num):
                subpath = subpaths[inum][0]
                ori_size = (ori_sizes[inum][0].item(), ori_sizes[inum][1].item())
                res = nn.functional.interpolate(scaled_preds[inum].unsqueeze(0), size=ori_size, mode='bilinear', align_corners=True)
                save_tensor_img(res, os.path.join(saved_root, subpath))


if __name__ == '__main__':
    # Parameter from command line
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--size',
                        default=224,
                        type=int,
                        help='input size')
    parser.add_argument('--param_root', default='/data1/dcfm/temp', type=str, help='model folder')
    parser.add_argument('--save_root', default='./CoSODmaps/pred', type=str, help='Output folder')

    args = parser.parse_args()

    main(args)