Spaces:
Runtime error
Runtime error
| import os | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| try: | |
| from torchvision.transforms import InterpolationMode | |
| bic = InterpolationMode.BICUBIC | |
| except ImportError: | |
| bic = Image.BICUBIC | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import functools | |
| IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".webp"] | |
| class UnetGenerator(nn.Module): | |
| """Create a Unet-based generator""" | |
| def __init__( | |
| self, | |
| input_nc, | |
| output_nc, | |
| num_downs, | |
| ngf=64, | |
| norm_layer=nn.BatchNorm2d, | |
| use_dropout=False, | |
| ): | |
| """Construct a Unet generator | |
| Parameters: | |
| input_nc (int) -- the number of channels in input images | |
| output_nc (int) -- the number of channels in output images | |
| num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, | |
| image of size 128x128 will become of size 1x1 # at the bottleneck | |
| ngf (int) -- the number of filters in the last conv layer | |
| norm_layer -- normalization layer | |
| We construct the U-Net from the innermost layer to the outermost layer. | |
| It is a recursive process. | |
| """ | |
| super(UnetGenerator, self).__init__() | |
| # construct unet structure | |
| unet_block = UnetSkipConnectionBlock( | |
| ngf * 8, | |
| ngf * 8, | |
| input_nc=None, | |
| submodule=None, | |
| norm_layer=norm_layer, | |
| innermost=True, | |
| ) # add the innermost layer | |
| for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters | |
| unet_block = UnetSkipConnectionBlock( | |
| ngf * 8, | |
| ngf * 8, | |
| input_nc=None, | |
| submodule=unet_block, | |
| norm_layer=norm_layer, | |
| use_dropout=use_dropout, | |
| ) | |
| # gradually reduce the number of filters from ngf * 8 to ngf | |
| unet_block = UnetSkipConnectionBlock( | |
| ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer | |
| ) | |
| unet_block = UnetSkipConnectionBlock( | |
| ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer | |
| ) | |
| unet_block = UnetSkipConnectionBlock( | |
| ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer | |
| ) | |
| self.model = UnetSkipConnectionBlock( | |
| output_nc, | |
| ngf, | |
| input_nc=input_nc, | |
| submodule=unet_block, | |
| outermost=True, | |
| norm_layer=norm_layer, | |
| ) # add the outermost layer | |
| def forward(self, input): | |
| """Standard forward""" | |
| return self.model(input) | |
| class UnetSkipConnectionBlock(nn.Module): | |
| """Defines the Unet submodule with skip connection. | |
| X -------------------identity---------------------- | |
| |-- downsampling -- |submodule| -- upsampling --| | |
| """ | |
| def __init__( | |
| self, | |
| outer_nc, | |
| inner_nc, | |
| input_nc=None, | |
| submodule=None, | |
| outermost=False, | |
| innermost=False, | |
| norm_layer=nn.BatchNorm2d, | |
| use_dropout=False, | |
| ): | |
| """Construct a Unet submodule with skip connections. | |
| Parameters: | |
| outer_nc (int) -- the number of filters in the outer conv layer | |
| inner_nc (int) -- the number of filters in the inner conv layer | |
| input_nc (int) -- the number of channels in input images/features | |
| submodule (UnetSkipConnectionBlock) -- previously defined submodules | |
| outermost (bool) -- if this module is the outermost module | |
| innermost (bool) -- if this module is the innermost module | |
| norm_layer -- normalization layer | |
| use_dropout (bool) -- if use dropout layers. | |
| """ | |
| super(UnetSkipConnectionBlock, self).__init__() | |
| self.outermost = outermost | |
| if type(norm_layer) == functools.partial: | |
| use_bias = norm_layer.func == nn.InstanceNorm2d | |
| else: | |
| use_bias = norm_layer == nn.InstanceNorm2d | |
| if input_nc is None: | |
| input_nc = outer_nc | |
| downconv = nn.Conv2d( | |
| input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias | |
| ) | |
| downrelu = nn.LeakyReLU(0.2, True) | |
| downnorm = norm_layer(inner_nc) | |
| uprelu = nn.ReLU(True) | |
| upnorm = norm_layer(outer_nc) | |
| if outermost: | |
| upconv = nn.ConvTranspose2d( | |
| inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1 | |
| ) | |
| down = [downconv] | |
| up = [uprelu, upconv, nn.Tanh()] | |
| model = down + [submodule] + up | |
| elif innermost: | |
| upconv = nn.ConvTranspose2d( | |
| inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias | |
| ) | |
| down = [downrelu, downconv] | |
| up = [uprelu, upconv, upnorm] | |
| model = down + up | |
| else: | |
| upconv = nn.ConvTranspose2d( | |
| inner_nc * 2, | |
| outer_nc, | |
| kernel_size=4, | |
| stride=2, | |
| padding=1, | |
| bias=use_bias, | |
| ) | |
| down = [downrelu, downconv, downnorm] | |
| up = [uprelu, upconv, upnorm] | |
| if use_dropout: | |
| model = down + [submodule] + up + [nn.Dropout(0.5)] | |
| else: | |
| model = down + [submodule] + up | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| if self.outermost: | |
| return self.model(x) | |
| else: # add skip connections | |
| return torch.cat([x, self.model(x)], 1) | |
| class Anime2Sketch: | |
| def __init__( | |
| self, model_path: str = "./models/netG.pth", device: str = "cpu" | |
| ) -> None: | |
| norm_layer = functools.partial( | |
| nn.InstanceNorm2d, affine=False, track_running_stats=False | |
| ) | |
| net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False) | |
| ckpt = torch.load(model_path) | |
| for key in list(ckpt.keys()): | |
| if "module." in key: | |
| ckpt[key.replace("module.", "")] = ckpt[key].half() | |
| del ckpt[key] | |
| net.load_state_dict(ckpt) | |
| self.model = net | |
| if torch.cuda.is_available() and device == "cuda": | |
| self.device = "cuda" | |
| self.model.to(device) | |
| else: | |
| self.device = "cpu" | |
| self.model.to("cpu") | |
| def predict(self, image: Image.Image, load_size: int = 512) -> Image: | |
| try: | |
| aus_resize = None | |
| if load_size > 0: | |
| aus_resize = image.size | |
| transform = self.get_transform(load_size=load_size) | |
| image = transform(image) | |
| img = image.unsqueeze(0) | |
| except: | |
| raise Exception("Error in reading image {}".format(image.filename)) | |
| aus_tensor = self.model(img.to(self.device)) | |
| aus_img = self.tensor_to_img(aus_tensor) | |
| image_pil = Image.fromarray(aus_img) | |
| if aus_resize: | |
| bic = Image.BICUBIC | |
| image_pil = image_pil.resize(aus_resize, bic) | |
| return image_pil | |
| def get_transform(self, load_size=0, grayscale=False, method=bic, convert=True): | |
| transform_list = [] | |
| if grayscale: | |
| transform_list.append(transforms.Grayscale(1)) | |
| if load_size > 0: | |
| osize = [load_size, load_size] | |
| transform_list.append(transforms.Resize(osize, method)) | |
| if convert: | |
| transform_list += [transforms.ToTensor()] | |
| if grayscale: | |
| transform_list += [transforms.Normalize((0.5,), (0.5,))] | |
| else: | |
| transform_list += [ | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ] | |
| return transforms.Compose(transform_list) | |
| def tensor_to_img(self, input_image, imtype=np.uint8): | |
| """ "Converts a Tensor array into a numpy image array. | |
| Parameters: | |
| input_image (tensor) -- the input image tensor array | |
| imtype (type) -- the desired type of the converted numpy array | |
| """ | |
| if not isinstance(input_image, np.ndarray): | |
| if isinstance(input_image, torch.Tensor): # get the data from a variable | |
| image_tensor = input_image.data | |
| else: | |
| return input_image | |
| image_numpy = ( | |
| image_tensor[0].cpu().float().numpy() | |
| ) # convert it into a numpy array | |
| if image_numpy.shape[0] == 1: # grayscale to RGB | |
| image_numpy = np.tile(image_numpy, (3, 1, 1)) | |
| image_numpy = ( | |
| (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 | |
| ) # post-processing: tranpose and scaling | |
| else: # if it is a numpy array, do nothing | |
| image_numpy = input_image | |
| return image_numpy.astype(imtype) | |