DaniilOr's picture
Upload folder using huggingface_hub
5f0437a verified
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
"""
Created in September 2022
@author: fabrizio.guillaro
"""
import sys, os
import argparse
import numpy as np
from tqdm import tqdm
from glob import glob
import torch
from torch.nn import functional as F
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
if path not in sys.path:
sys.path.insert(0, path)
from lib.config import config, update_config
from lib.utils import get_model
from dataset.dataset_test import TestDataset
parser = argparse.ArgumentParser(description='Test TruFor')
parser.add_argument('-g', '--gpu', type=int, default=0, help='device, use -1 for cpu')
parser.add_argument('-in', '--input', type=str, default='../images', help='can be a single file, a directory or a glob statement')
parser.add_argument('-out', '--output', type=str, default='../output', help='output folder')
parser.add_argument('-exp', '--experiment', type=str, default='trufor_ph3')
parser.add_argument('-save_np', '--save_np', action='store_true', help='whether to save the Noiseprint++ or not')
parser.add_argument('opts', help="other options", default=None, nargs=argparse.REMAINDER)
args = parser.parse_args()
update_config(config, args)
input = args.input
output = args.output
gpu = args.gpu
save_np = args.save_np
device = 'cuda:%d' % gpu if gpu >= 0 else 'cpu'
if device != 'cpu':
# cudnn setting
import torch.backends.cudnn as cudnn
cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED
if '*' in input:
list_img = glob(input, recursive=True)
list_img = [img for img in list_img if not os.path.isdir(img)]
elif os.path.isfile(input):
list_img = [input]
elif os.path.isdir(input):
list_img = glob(os.path.join(input, '**/*'), recursive=True)
list_img = [img for img in list_img if not os.path.isdir(img)]
else:
raise ValueError("input is neither a file or a folder")
test_dataset = TestDataset(list_img=list_img)
testloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1) # 1 to allow arbitrary input sizes
if config.TEST.MODEL_FILE:
model_state_file = config.TEST.MODEL_FILE
else:
raise ValueError("Model file is not specified.")
print('=> loading model from {}'.format(model_state_file))
checkpoint = torch.load(model_state_file, map_location=torch.device(device))
print("Epoch: {}".format(checkpoint['epoch']))
model = get_model(config)
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device)
with torch.no_grad():
for index, (rgb, path) in enumerate(tqdm(testloader)):
if os.path.splitext(os.path.basename(output))[1] == '': # output is a directory
path = path[0]
root = input.split('*')[0]
if os.path.isfile(input):
sub_path = path.replace(os.path.dirname(root), '').strip()
else:
sub_path = path.replace(root, '').strip()
if sub_path.startswith('/'):
sub_path = sub_path[1:]
filename_out = os.path.join(output, sub_path) + '.npz'
else: # output is a filename
filename_out = output
if not filename_out.endswith('.npz'):
filename_out = filename_out + '.npz'
# by default it does not overwrite
if not (os.path.isfile(filename_out)):
try:
rgb = rgb.to(device)
model.eval()
det = None
conf = None
pred, conf, det, npp = model(rgb, save_np=save_np)
if conf is not None:
conf = torch.squeeze(conf, 0)
conf = torch.sigmoid(conf)[0]
conf = conf.cpu().numpy()
if npp is not None:
npp = torch.squeeze(npp, 0)[0]
npp = npp.cpu().numpy()
if det is not None:
det_sig = torch.sigmoid(det).item()
pred = torch.squeeze(pred, 0)
pred = F.softmax(pred, dim=0)[1]
pred = pred.cpu().numpy()
out_dict = dict()
out_dict['map' ] = pred
out_dict['imgsize'] = tuple(rgb.shape[2:])
if det is not None:
out_dict['score'] = det_sig
if conf is not None:
out_dict['conf'] = conf
if save_np:
out_dict['np++'] = npp
from os import makedirs
makedirs(os.path.dirname(filename_out), exist_ok=True)
np.savez(filename_out, **out_dict)
except:
import traceback
traceback.print_exc()
pass