Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from networks.drn_seg import DRNSeg, DRNSub | |
| from utils.tools import * | |
| from utils.visualize import * | |
| def load_classifier(model_path, gpu_id): | |
| if torch.cuda.is_available() and gpu_id != -1: | |
| device = 'cuda:{}'.format(gpu_id) | |
| else: | |
| device = 'cpu' | |
| model = DRNSub(1) | |
| state_dict = torch.load(model_path, map_location='cpu') | |
| model.load_state_dict(state_dict['model']) | |
| model.to(device) | |
| model.device = device | |
| model.eval() | |
| return model | |
| local_model_path = 'weights/local.pth' | |
| global_model_path = 'weights/global.pth' | |
| gpu_id = 0 | |
| # Loading the model | |
| if torch.cuda.is_available(): | |
| device = 'cuda:{}'.format(gpu_id) | |
| else: | |
| device = 'cpu' | |
| local_model = DRNSeg(2) | |
| state_dict = torch.load(local_model_path, map_location=device) | |
| local_model.load_state_dict(state_dict['model']) | |
| local_model.to(device) | |
| local_model.eval() | |
| global_model = load_classifier(global_model_path, gpu_id) | |
| # prob = classify_fake(model, args.input_path, args.no_crop) | |
| # Data preprocessing | |
| tf = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def classify_fake(img, no_crop=False, global_model=global_model, | |
| model_file='utils/dlib_face_detector/mmod_human_face_detector.dat'): | |
| # Data preprocessing | |
| im_w, im_h = img.size | |
| if no_crop: | |
| face = img | |
| else: | |
| faces = face_detection(img, verbose=False, model_file=model_file) | |
| if len(faces) == 0: | |
| print("no face detected by dlib, exiting") | |
| sys.exit() | |
| face, box = faces[0] | |
| face = resize_shorter_side(face, 400)[0] | |
| face_tens = tf(face).to(global_model.device) | |
| # Prediction | |
| with torch.no_grad(): | |
| prob = global_model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item() | |
| return prob | |
| def heatmap_analysis(img, no_crop=False): | |
| im_w, im_h = img.size | |
| if no_crop: | |
| face = imgs | |
| else: | |
| faces = face_detection(img, verbose=False) | |
| if len(faces) == 0: | |
| print("no face detected by dlib, exiting") | |
| sys.exit() | |
| face, box = faces[0] | |
| face = resize_shorter_side(face, 400)[0] | |
| face_tens = tf(face).to(device) | |
| # Warping field prediction | |
| with torch.no_grad(): | |
| flow = local_model(face_tens.unsqueeze(0))[0].cpu().numpy() | |
| flow = np.transpose(flow, (1, 2, 0)) | |
| h, w, _ = flow.shape | |
| # Undoing the warps | |
| modified = face.resize((w, h), Image.BICUBIC) | |
| modified_np = np.asarray(modified) | |
| reverse_np = warp(modified_np, flow) | |
| reverse = Image.fromarray(reverse_np) | |
| flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2) | |
| cv_out = get_heatmap_cv(modified_np, flow_magn, 7) | |
| heatmap = Image.fromarray(cv_out) | |
| return modified, reverse, heatmap | |
| # Saving the results | |
| # modified.save( | |
| # os.path.join(dest_folder, 'cropped_input.jpg'), | |
| # quality=90) | |
| # reverse.save( | |
| # os.path.join(dest_folder, 'warped.jpg'), | |
| # quality=90) | |
| # flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2) | |
| # save_heatmap_cv( | |
| # modified_np, flow_magn, | |
| # os.path.join(dest_folder, 'heatmap.jpg')) | |