|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| Created in September 2022
|
| @author: fabrizio.guillaro
|
| """
|
|
|
| import sys, os
|
| import argparse
|
| import numpy as np
|
| from tqdm import tqdm
|
| from glob import glob
|
|
|
| import torch
|
| from torch.nn import functional as F
|
|
|
| path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
| if path not in sys.path:
|
| sys.path.insert(0, path)
|
|
|
| from lib.config import config, update_config
|
| from lib.utils import get_model
|
| from dataset.dataset_test import TestDataset
|
|
|
| parser = argparse.ArgumentParser(description='Test TruFor')
|
| parser.add_argument('-g', '--gpu', type=int, default=0, help='device, use -1 for cpu')
|
| parser.add_argument('-in', '--input', type=str, default='../images', help='can be a single file, a directory or a glob statement')
|
| parser.add_argument('-out', '--output', type=str, default='../output', help='output folder')
|
| parser.add_argument('-exp', '--experiment', type=str, default='trufor_ph3')
|
| parser.add_argument('-save_np', '--save_np', action='store_true', help='whether to save the Noiseprint++ or not')
|
| parser.add_argument('opts', help="other options", default=None, nargs=argparse.REMAINDER)
|
|
|
| args = parser.parse_args()
|
| update_config(config, args)
|
|
|
| input = args.input
|
| output = args.output
|
| gpu = args.gpu
|
| save_np = args.save_np
|
|
|
| device = 'cuda:%d' % gpu if gpu >= 0 else 'cpu'
|
|
|
| if device != 'cpu':
|
|
|
| import torch.backends.cudnn as cudnn
|
| cudnn.benchmark = config.CUDNN.BENCHMARK
|
| cudnn.deterministic = config.CUDNN.DETERMINISTIC
|
| cudnn.enabled = config.CUDNN.ENABLED
|
|
|
| if '*' in input:
|
| list_img = glob(input, recursive=True)
|
| list_img = [img for img in list_img if not os.path.isdir(img)]
|
| elif os.path.isfile(input):
|
| list_img = [input]
|
| elif os.path.isdir(input):
|
| list_img = glob(os.path.join(input, '**/*'), recursive=True)
|
| list_img = [img for img in list_img if not os.path.isdir(img)]
|
| else:
|
| raise ValueError("input is neither a file or a folder")
|
|
|
| test_dataset = TestDataset(list_img=list_img)
|
|
|
| testloader = torch.utils.data.DataLoader(
|
| test_dataset,
|
| batch_size=1)
|
|
|
| if config.TEST.MODEL_FILE:
|
| model_state_file = config.TEST.MODEL_FILE
|
| else:
|
| raise ValueError("Model file is not specified.")
|
|
|
| print('=> loading model from {}'.format(model_state_file))
|
| checkpoint = torch.load(model_state_file, map_location=torch.device(device), weights_only=False)
|
| print("Epoch: {}".format(checkpoint['epoch']))
|
|
|
| model = get_model(config)
|
| model.load_state_dict(checkpoint['state_dict'])
|
| model = model.to(device)
|
|
|
| with torch.no_grad():
|
| for index, (rgb, path) in enumerate(tqdm(testloader)):
|
| if os.path.splitext(os.path.basename(output))[1] == '':
|
| path = path[0]
|
| root = input.split('*')[0]
|
|
|
| if os.path.isfile(input):
|
| sub_path = path.replace(os.path.dirname(root), '').strip()
|
| else:
|
| sub_path = path.replace(root, '').strip()
|
|
|
| if sub_path.startswith('/') or sub_path.startswith('\\'):
|
| sub_path = sub_path[1:]
|
|
|
| filename_out = os.path.join(output, sub_path) + '.npz'
|
| else:
|
| filename_out = output
|
|
|
| if not filename_out.endswith('.npz'):
|
| filename_out = filename_out + '.npz'
|
|
|
|
|
| if not (os.path.isfile(filename_out)):
|
| try:
|
| rgb = rgb.to(device)
|
| model.eval()
|
|
|
| det = None
|
| conf = None
|
|
|
| pred, conf, det, npp = model(rgb, save_np=save_np)
|
|
|
| if conf is not None:
|
| conf = torch.squeeze(conf, 0)
|
| conf = torch.sigmoid(conf)[0]
|
| conf = conf.cpu().numpy()
|
|
|
| if npp is not None:
|
| npp = torch.squeeze(npp, 0)[0]
|
| npp = npp.cpu().numpy()
|
|
|
| if det is not None:
|
| det_sig = torch.sigmoid(det).item()
|
|
|
| pred = torch.squeeze(pred, 0)
|
| pred = F.softmax(pred, dim=0)[1]
|
| pred = pred.cpu().numpy()
|
|
|
| out_dict = dict()
|
| out_dict['map' ] = pred
|
| out_dict['imgsize'] = tuple(rgb.shape[2:])
|
| if det is not None:
|
| out_dict['score'] = det_sig
|
| if conf is not None:
|
| out_dict['conf'] = conf
|
| if save_np:
|
| out_dict['np++'] = npp
|
|
|
| from os import makedirs
|
| makedirs(os.path.dirname(filename_out), exist_ok=True)
|
| np.savez(filename_out, **out_dict)
|
| except:
|
| import traceback
|
| traceback.print_exc()
|
| pass
|
|
|
|
|