TheKernel01's picture
Sync from GitHub via hub-sync
0788e19 verified
import argparse
import torch
import torch.nn
import torchvision.transforms as transforms
from networks.resnet import resnet50
from PIL import Image
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-f', '--file', default='examples_realfakedir')
parser.add_argument(
'-m', '--model_path', type=str, default='weights/blur_jpg_prob0.5.pth'
)
parser.add_argument(
'-c',
'--crop',
type=int,
default=None,
help='by default, do not crop. specify crop size',
)
parser.add_argument(
'--use_cpu', action='store_true', help='uses gpu by default, turn on to use cpu'
)
opt = parser.parse_args()
model = resnet50(num_classes=1)
state_dict = torch.load(opt.model_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
if not opt.use_cpu:
model.cuda()
model.eval()
# Transform
trans_init = []
if opt.crop is not None:
trans_init = [
transforms.CenterCrop(opt.crop),
]
print('Cropping to [%i]' % opt.crop)
else:
print('Not cropping')
trans = transforms.Compose(
trans_init
+ [
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img = trans(Image.open(opt.file).convert('RGB'))
with torch.no_grad():
in_tens = img.unsqueeze(0)
if not opt.use_cpu:
in_tens = in_tens.cuda()
prob = model(in_tens).sigmoid().item()
print('probability of being synthetic: {:.2f}%'.format(prob * 100))