P-DFD / inference.py
mrneuralnet's picture
Modify inference script args
f4c3cd9
import cv2
import torch
import random
import argparse
from glob import glob
from os.path import join
from model.network import Recce
from model.common import freeze_weights
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch.transforms import ToTensorV2
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
# fix random seed
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
parser = argparse.ArgumentParser(description="This code helps you use a trained model to "
"do inference.")
parser.add_argument("--weight", "-w",
type=str,
default=None,
help="Specify the path to the model weight (the state dict file). "
"Do not use this argument when '--bin' is set.")
parser.add_argument("--bin", "-b",
type=str,
default=None,
help="Specify the path to the model bin which ends up with '.bin' "
"(which is generated by the trainer of this project). "
"Do not use this argument when '--weight' is set.")
parser.add_argument("--image", "-i",
type=str,
default=None,
help="Specify the path to the input image. "
"Do not use this argument when '--image_folder' is set.")
parser.add_argument("--image_folder", "-f",
type=str,
default=None,
help="Specify the directory to evaluate all the images. "
"Do not use this argument when '--image' is set.")
parser.add_argument('--device', '-d', type=str,
default="cpu",
help="Specify the device to load the model. Default: 'cpu'.")
parser.add_argument('--image_size', '-s', type=int,
default=299,
help="Specify the spatial size of the input image(s). Default: 299.")
parser.add_argument('--visualize', '-v', action="store_true",
default=False, help='Visualize images.')
def preprocess(file_path):
img = cv2.imread(file_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
compose = Compose([Resize(height=args.image_size, width=args.image_size),
Normalize(mean=[0.5] * 3, std=[0.5] * 3),
ToTensorV2()])
img = compose(image=img)['image'].unsqueeze(0)
return img
def prepare_data():
paths = list()
images = list()
# check the console arguments
if args.image and args.image_folder:
raise ValueError("Only one of '--image' or '--image_folder' can be set.")
elif args.image:
images.append(preprocess(args.image))
paths.append(args.image)
elif args.image_folder:
image_paths = glob(args.image_folder + "/*.jpg")
image_paths.extend(glob(args.image_folder + "/*.png"))
for _ in image_paths:
images.append(preprocess(_))
paths.append(_)
else:
raise ValueError("Neither of '--image' nor '--image_folder' is set. Please specify either "
"one of these two arguments to load input image(s) properly.")
return paths, images
def inference(model, images, paths, device):
mean_pred = 0
for img, pt in zip(images, paths):
img = img.to(device)
prediction = model(img)
prediction = torch.sigmoid(prediction).cpu()
fake = True if prediction >= 0.5 else False
mean_pred += prediction.item()
print(f"path: {pt} \t\t| fake probability: {prediction.item():.4f} \t| "
f"prediction: {'fake' if fake else 'real'}")
if args.visualize:
cvimg = cv2.imread(pt)
cvimg = cv2.putText(cvimg, f'p: {prediction.item():.2f}, ' + f"{'fake' if fake else 'real'}",
(5, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(0, 0, 255) if fake else (255, 0, 0), 2)
cv2.imshow("image", cvimg)
cv2.waitKey(0)
cv2.destroyWindow("image")
mean_pred = mean_pred / len(images)
return mean_pred
def main():
print("Arguments:\n", args, end="\n\n")
# set device
device = torch.device(args.device)
# load model
model = eval("Recce")(num_classes=1)
# check the console arguments
if args.weight and args.bin:
raise ValueError("Only one of '--weight' or '--bin' can be set.")
elif args.weight:
weights = torch.load(args.weight, map_location=device)
elif args.bin:
weights = torch.load(args.bin, map_location=device)["model"]
else:
raise ValueError("Neither of '--weight' nor '--bin' is set. Please specify either "
"one of these two arguments to load model's weight properly.")
model.load_state_dict(weights)
model = model.to(device)
freeze_weights(model)
model.eval()
paths, images = prepare_data()
print("Inference:")
mean_pred = inference(model, images=images, paths=paths, device=device)
print("Mean prediction:", mean_pred)
if __name__ == '__main__':
args = parser.parse_args()
main()