import cv2 import requests from io import BytesIO from PIL import Image import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms import gradio as gr from utils import * # torchvision transforms normalize = transforms.Normalize( mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225] ) unnormalize = transforms.Compose([ transforms.Normalize( mean = [0., 0., 0.], std = [1/0.229, 1/0.224, 1/0.225] ), transforms.Normalize( mean = [-0.485, -0.456, -0.406], std = [1., 1., 1.] ) ]) # inference script for huggingface space (assumes model is already loaded) # TODO: interface supports drag & drop or url def inference(url=None): # assert url is not None and pil_img is not None if url: response = requests.get(url) original_img = Image.open(BytesIO(response.content)) else: original_img = pil_img img = np.array(original_img) img = cv2.resize(img, (512, 512)) img = img / 255 assert np.min(img) >= 0 assert np.max(img) <= 1 if len(img.shape) < 3: # grayscale coloring x = torch.Tensor(img) x = torch.stack([x, x, x], dim=0) x = normalize(x) x = x.unsqueeze(0) else: # RGB reconstruction x = torch.Tensor(img).permute(2, 0, 1) x = normalize(x) x_gs = cv2.cvtColor(x.permute(1, 2, 0).detach().cpu().numpy(), cv2.COLOR_BGR2GRAY) x_gs = np.dstack([x_gs, x_gs, x_gs]) x = torch.Tensor(x_gs).permute(2, 0, 1).unsqueeze(0) pred = model(x) res = unnormalize(pred.squeeze(0)) res = res.clamp(0, 1) res = res.permute(1, 2, 0).detach().cpu().numpy() colored_img = cv2.resize(res, original_img.size) colored_img = Image.fromarray((colored_img * 255).astype(np.uint8)) if len(img.shape) >= 3: colored_img = postprocess_img(original_img, colored_img) return colored_img, original_img, original_img.convert('L') # load torchscript model model = torch.jit.load('torchscript/generator_torchscript.pt') model = model.eval() # gradio interface : i think it can take either a url or a drag and drop by setting one or the other to None iface = gr.Interface( fn=inference, # inputs=["text", gr.Image(type='pil')], inputs='text', outputs=['image', 'image', 'image'], title='paste a url (right click -> copy image address)' ) iface.launch()