Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils import data | |
| from torch import nn, autograd | |
| import os | |
| import matplotlib.pyplot as plt | |
| google_drive_paths = { | |
| "GNR_checkpoint.pt": "https://drive.google.com/uc?id=1IMIVke4WDaGayUa7vk_xVw1uqIHikGtC", | |
| } | |
| def ensure_checkpoint_exists(model_weights_filename): | |
| if not os.path.isfile(model_weights_filename) and ( | |
| model_weights_filename in google_drive_paths | |
| ): | |
| gdrive_url = google_drive_paths[model_weights_filename] | |
| try: | |
| from gdown import download as drive_download | |
| drive_download(gdrive_url, model_weights_filename, quiet=False) | |
| except ModuleNotFoundError: | |
| print( | |
| "gdown module not found.", | |
| "pip3 install gdown or, manually download the checkpoint file:", | |
| gdrive_url | |
| ) | |
| if not os.path.isfile(model_weights_filename) and ( | |
| model_weights_filename not in google_drive_paths | |
| ): | |
| print( | |
| model_weights_filename, | |
| " not found, you may need to manually download the model weights." | |
| ) | |
| def shuffle_batch(x): | |
| return x[torch.randperm(x.size(0))] | |
| def data_sampler(dataset, shuffle, distributed): | |
| if distributed: | |
| return data.distributed.DistributedSampler(dataset, shuffle=shuffle) | |
| if shuffle: | |
| return data.RandomSampler(dataset) | |
| else: | |
| return data.SequentialSampler(dataset) | |
| def accumulate(model1, model2, decay=0.999): | |
| par1 = dict(model1.named_parameters()) | |
| par2 = dict(model2.named_parameters()) | |
| for k in par1.keys(): | |
| par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) | |
| def sample_data(loader): | |
| while True: | |
| for batch in loader: | |
| yield batch | |
| def d_logistic_loss(real_pred, fake_pred): | |
| loss = 0 | |
| for real, fake in zip(real_pred, fake_pred): | |
| real_loss = F.softplus(-real) | |
| fake_loss = F.softplus(fake) | |
| loss += real_loss.mean() + fake_loss.mean() | |
| return loss | |
| def d_r1_loss(real_pred, real_img): | |
| grad_penalty = 0 | |
| for real in real_pred: | |
| grad_real, = autograd.grad( | |
| outputs=real.mean(), inputs=real_img, create_graph=True, only_inputs=True | |
| ) | |
| grad_penalty += grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() | |
| return grad_penalty | |
| def g_nonsaturating_loss(fake_pred, weights): | |
| loss = 0 | |
| for fake, weight in zip(fake_pred, weights): | |
| loss += weight*F.softplus(-fake).mean() | |
| return loss / len(fake_pred) | |
| def display_image(image, size=None, mode='nearest', unnorm=False, title=''): | |
| # image is [3,h,w] or [1,3,h,w] tensor [0,1] | |
| if image.is_cuda: | |
| image = image.cpu() | |
| if size is not None and image.size(-1) != size: | |
| image = F.interpolate(image, size=(size,size), mode=mode) | |
| if image.dim() == 4: | |
| image = image[0] | |
| image = image.permute(1, 2, 0).detach().numpy() | |
| plt.figure() | |
| plt.title(title) | |
| plt.axis('off') | |
| plt.imshow(image) | |
| def normalize(x): | |
| return ((x+1)/2).clamp(0,1) | |
| def get_boundingbox(face, width, height, scale=1.3, minsize=None): | |
| """ | |
| Expects a dlib face to generate a quadratic bounding box. | |
| :param face: dlib face class | |
| :param width: frame width | |
| :param height: frame height | |
| :param scale: bounding box size multiplier to get a bigger face region | |
| :param minsize: set minimum bounding box size | |
| :return: x, y, bounding_box_size in opencv form | |
| """ | |
| x1 = face.left() | |
| y1 = face.top() | |
| x2 = face.right() | |
| y2 = face.bottom() | |
| size_bb = int(max(x2 - x1, y2 - y1) * scale) | |
| if minsize: | |
| if size_bb < minsize: | |
| size_bb = minsize | |
| center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 | |
| # Check for out of bounds, x-y top left corner | |
| x1 = max(int(center_x - size_bb // 2), 0) | |
| y1 = max(int(center_y - size_bb // 2), 0) | |
| # Check for too big bb size for given x, y | |
| size_bb = min(width - x1, size_bb) | |
| size_bb = min(height - y1, size_bb) | |
| return x1, y1, size_bb | |
| def preprocess_image(image, cuda=True): | |
| """ | |
| Preprocesses the image such that it can be fed into our network. | |
| During this process we envoke PIL to cast it into a PIL image. | |
| :param image: numpy image in opencv form (i.e., BGR and of shape | |
| :return: pytorch tensor of shape [1, 3, image_size, image_size], not | |
| necessarily casted to cuda | |
| """ | |
| # Revert from BGR | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Preprocess using the preprocessing function used during training and | |
| # casting it to PIL image | |
| preprocess = xception_default_data_transforms['test'] | |
| preprocessed_image = preprocess(pil_image.fromarray(image)) | |
| # Add first dimension as the network expects a batch | |
| preprocessed_image = preprocessed_image.unsqueeze(0) | |
| if cuda: | |
| preprocessed_image = preprocessed_image.cuda() | |
| return preprocessed_image | |
| def truncate(x, truncation, mean_style): | |
| return truncation*x + (1-truncation)*mean_style | |