File size: 1,501 Bytes
0788e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse

import torch
import torch.nn
import torchvision.transforms as transforms
from networks.resnet import resnet50
from PIL import Image

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-f', '--file', default='examples_realfakedir')
parser.add_argument(
    '-m', '--model_path', type=str, default='weights/blur_jpg_prob0.5.pth'
)
parser.add_argument(
    '-c',
    '--crop',
    type=int,
    default=None,
    help='by default, do not crop. specify crop size',
)
parser.add_argument(
    '--use_cpu', action='store_true', help='uses gpu by default, turn on to use cpu'
)

opt = parser.parse_args()

model = resnet50(num_classes=1)
state_dict = torch.load(opt.model_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
if not opt.use_cpu:
    model.cuda()
model.eval()

# Transform
trans_init = []
if opt.crop is not None:
    trans_init = [
        transforms.CenterCrop(opt.crop),
    ]
    print('Cropping to [%i]' % opt.crop)
else:
    print('Not cropping')
trans = transforms.Compose(
    trans_init
    + [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

img = trans(Image.open(opt.file).convert('RGB'))

with torch.no_grad():
    in_tens = img.unsqueeze(0)
    if not opt.use_cpu:
        in_tens = in_tens.cuda()
    prob = model(in_tens).sigmoid().item()

print('probability of being synthetic: {:.2f}%'.format(prob * 100))