Spaces:
Sleeping
Sleeping
| import cv2 | |
| import torch | |
| import random | |
| import argparse | |
| from glob import glob | |
| from os.path import join | |
| from model.network import Recce | |
| from model.common import freeze_weights | |
| from albumentations import Compose, Normalize, Resize | |
| from albumentations.pytorch.transforms import ToTensorV2 | |
| import os | |
| os.environ['KMP_DUPLICATE_LIB_OK']='True' | |
| # fix random seed | |
| seed = 0 | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| parser = argparse.ArgumentParser(description="This code helps you use a trained model to " | |
| "do inference.") | |
| parser.add_argument("--weight", "-w", | |
| type=str, | |
| default=None, | |
| help="Specify the path to the model weight (the state dict file). " | |
| "Do not use this argument when '--bin' is set.") | |
| parser.add_argument("--bin", "-b", | |
| type=str, | |
| default=None, | |
| help="Specify the path to the model bin which ends up with '.bin' " | |
| "(which is generated by the trainer of this project). " | |
| "Do not use this argument when '--weight' is set.") | |
| parser.add_argument("--image", "-i", | |
| type=str, | |
| default=None, | |
| help="Specify the path to the input image. " | |
| "Do not use this argument when '--image_folder' is set.") | |
| parser.add_argument("--image_folder", "-f", | |
| type=str, | |
| default=None, | |
| help="Specify the directory to evaluate all the images. " | |
| "Do not use this argument when '--image' is set.") | |
| parser.add_argument('--device', '-d', type=str, | |
| default="cpu", | |
| help="Specify the device to load the model. Default: 'cpu'.") | |
| parser.add_argument('--image_size', '-s', type=int, | |
| default=299, | |
| help="Specify the spatial size of the input image(s). Default: 299.") | |
| parser.add_argument('--visualize', '-v', action="store_true", | |
| default=False, help='Visualize images.') | |
| def preprocess(file_path): | |
| img = cv2.imread(file_path) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| compose = Compose([Resize(height=args.image_size, width=args.image_size), | |
| Normalize(mean=[0.5] * 3, std=[0.5] * 3), | |
| ToTensorV2()]) | |
| img = compose(image=img)['image'].unsqueeze(0) | |
| return img | |
| def prepare_data(): | |
| paths = list() | |
| images = list() | |
| # check the console arguments | |
| if args.image and args.image_folder: | |
| raise ValueError("Only one of '--image' or '--image_folder' can be set.") | |
| elif args.image: | |
| images.append(preprocess(args.image)) | |
| paths.append(args.image) | |
| elif args.image_folder: | |
| image_paths = glob(args.image_folder + "/*.jpg") | |
| image_paths.extend(glob(args.image_folder + "/*.png")) | |
| for _ in image_paths: | |
| images.append(preprocess(_)) | |
| paths.append(_) | |
| else: | |
| raise ValueError("Neither of '--image' nor '--image_folder' is set. Please specify either " | |
| "one of these two arguments to load input image(s) properly.") | |
| return paths, images | |
| def inference(model, images, paths, device): | |
| mean_pred = 0 | |
| for img, pt in zip(images, paths): | |
| img = img.to(device) | |
| prediction = model(img) | |
| prediction = torch.sigmoid(prediction).cpu() | |
| fake = True if prediction >= 0.5 else False | |
| mean_pred += prediction.item() | |
| print(f"path: {pt} \t\t| fake probability: {prediction.item():.4f} \t| " | |
| f"prediction: {'fake' if fake else 'real'}") | |
| if args.visualize: | |
| cvimg = cv2.imread(pt) | |
| cvimg = cv2.putText(cvimg, f'p: {prediction.item():.2f}, ' + f"{'fake' if fake else 'real'}", | |
| (5, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, | |
| (0, 0, 255) if fake else (255, 0, 0), 2) | |
| cv2.imshow("image", cvimg) | |
| cv2.waitKey(0) | |
| cv2.destroyWindow("image") | |
| mean_pred = mean_pred / len(images) | |
| return mean_pred | |
| def main(): | |
| print("Arguments:\n", args, end="\n\n") | |
| # set device | |
| device = torch.device(args.device) | |
| # load model | |
| model = eval("Recce")(num_classes=1) | |
| # check the console arguments | |
| if args.weight and args.bin: | |
| raise ValueError("Only one of '--weight' or '--bin' can be set.") | |
| elif args.weight: | |
| weights = torch.load(args.weight, map_location=device) | |
| elif args.bin: | |
| weights = torch.load(args.bin, map_location=device)["model"] | |
| else: | |
| raise ValueError("Neither of '--weight' nor '--bin' is set. Please specify either " | |
| "one of these two arguments to load model's weight properly.") | |
| model.load_state_dict(weights) | |
| model = model.to(device) | |
| freeze_weights(model) | |
| model.eval() | |
| paths, images = prepare_data() | |
| print("Inference:") | |
| mean_pred = inference(model, images=images, paths=paths, device=device) | |
| print("Mean prediction:", mean_pred) | |
| if __name__ == '__main__': | |
| args = parser.parse_args() | |
| main() | |