Spaces:
Build error
Build error
| # 外部から簡単にupscalerを呼ぶためのスクリプト | |
| # 単体で動くようにモデル定義も含めている | |
| import argparse | |
| import glob | |
| import os | |
| import cv2 | |
| from diffusers import AutoencoderKL | |
| from typing import Dict, List | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from tqdm import tqdm | |
| from PIL import Image | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): | |
| super(ResidualBlock, self).__init__() | |
| if out_channels is None: | |
| out_channels = in_channels | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) | |
| self.bn1 = nn.BatchNorm2d(out_channels) | |
| self.relu1 = nn.ReLU(inplace=True) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False) | |
| self.bn2 = nn.BatchNorm2d(out_channels) | |
| self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも | |
| # initialize weights | |
| self._initialize_weights() | |
| def _initialize_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, 0, 0.01) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu1(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out += residual | |
| out = self.relu2(out) | |
| return out | |
| class Upscaler(nn.Module): | |
| def __init__(self): | |
| super(Upscaler, self).__init__() | |
| # define layers | |
| # latent has 4 channels | |
| self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
| self.bn1 = nn.BatchNorm2d(128) | |
| self.relu1 = nn.ReLU(inplace=True) | |
| # resblocks | |
| # 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ | |
| self.resblock1 = ResidualBlock(128) | |
| self.resblock2 = ResidualBlock(128) | |
| self.resblock3 = ResidualBlock(128) | |
| self.resblock4 = ResidualBlock(128) | |
| self.resblock5 = ResidualBlock(128) | |
| self.resblock6 = ResidualBlock(128) | |
| self.resblock7 = ResidualBlock(128) | |
| self.resblock8 = ResidualBlock(128) | |
| self.resblock9 = ResidualBlock(128) | |
| self.resblock10 = ResidualBlock(128) | |
| self.resblock11 = ResidualBlock(128) | |
| self.resblock12 = ResidualBlock(128) | |
| self.resblock13 = ResidualBlock(128) | |
| self.resblock14 = ResidualBlock(128) | |
| self.resblock15 = ResidualBlock(128) | |
| self.resblock16 = ResidualBlock(128) | |
| self.resblock17 = ResidualBlock(128) | |
| self.resblock18 = ResidualBlock(128) | |
| self.resblock19 = ResidualBlock(128) | |
| self.resblock20 = ResidualBlock(128) | |
| # last convs | |
| self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
| self.bn2 = nn.BatchNorm2d(64) | |
| self.relu2 = nn.ReLU(inplace=True) | |
| self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | |
| self.bn3 = nn.BatchNorm2d(64) | |
| self.relu3 = nn.ReLU(inplace=True) | |
| # final conv: output 4 channels | |
| self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)) | |
| # initialize weights | |
| self._initialize_weights() | |
| def _initialize_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, 0, 0.01) | |
| nn.init.constant_(m.bias, 0) | |
| # initialize final conv weights to 0: 流行りのzero conv | |
| nn.init.constant_(self.conv_final.weight, 0) | |
| def forward(self, x): | |
| inp = x | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu1(x) | |
| # いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず | |
| residual = x | |
| x = self.resblock1(x) | |
| x = self.resblock2(x) | |
| x = self.resblock3(x) | |
| x = self.resblock4(x) | |
| x = x + residual | |
| residual = x | |
| x = self.resblock5(x) | |
| x = self.resblock6(x) | |
| x = self.resblock7(x) | |
| x = self.resblock8(x) | |
| x = x + residual | |
| residual = x | |
| x = self.resblock9(x) | |
| x = self.resblock10(x) | |
| x = self.resblock11(x) | |
| x = self.resblock12(x) | |
| x = x + residual | |
| residual = x | |
| x = self.resblock13(x) | |
| x = self.resblock14(x) | |
| x = self.resblock15(x) | |
| x = self.resblock16(x) | |
| x = x + residual | |
| residual = x | |
| x = self.resblock17(x) | |
| x = self.resblock18(x) | |
| x = self.resblock19(x) | |
| x = self.resblock20(x) | |
| x = x + residual | |
| x = self.conv2(x) | |
| x = self.bn2(x) | |
| x = self.relu2(x) | |
| x = self.conv3(x) | |
| x = self.bn3(x) | |
| # ここにreluを入れないほうがいい気がする | |
| x = self.conv_final(x) | |
| # network estimates the difference between the input and the output | |
| x = x + inp | |
| return x | |
| def support_latents(self) -> bool: | |
| return False | |
| def upscale( | |
| self, | |
| vae: AutoencoderKL, | |
| lowreso_images: List[Image.Image], | |
| lowreso_latents: torch.Tensor, | |
| dtype: torch.dtype, | |
| width: int, | |
| height: int, | |
| batch_size: int = 1, | |
| vae_batch_size: int = 1, | |
| ): | |
| # assertion | |
| assert lowreso_images is not None, "Upscaler requires lowreso image" | |
| # make upsampled image with lanczos4 | |
| upsampled_images = [] | |
| for lowreso_image in lowreso_images: | |
| upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS)) | |
| upsampled_images.append(upsampled_image) | |
| # convert to tensor: this tensor is too large to be converted to cuda | |
| upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images] | |
| upsampled_images = torch.stack(upsampled_images, dim=0) | |
| upsampled_images = upsampled_images.to(dtype) | |
| # normalize to [-1, 1] | |
| upsampled_images = upsampled_images / 127.5 - 1.0 | |
| # convert upsample images to latents with batch size | |
| # print("Encoding upsampled (LANCZOS4) images...") | |
| upsampled_latents = [] | |
| for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)): | |
| batch = upsampled_images[i : i + vae_batch_size].to(vae.device) | |
| with torch.no_grad(): | |
| batch = vae.encode(batch).latent_dist.sample() | |
| upsampled_latents.append(batch) | |
| upsampled_latents = torch.cat(upsampled_latents, dim=0) | |
| # upscale (refine) latents with this model with batch size | |
| print("Upscaling latents...") | |
| upscaled_latents = [] | |
| for i in range(0, upsampled_latents.shape[0], batch_size): | |
| with torch.no_grad(): | |
| upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size])) | |
| upscaled_latents = torch.cat(upscaled_latents, dim=0) | |
| return upscaled_latents * 0.18215 | |
| # external interface: returns a model | |
| def create_upscaler(**kwargs): | |
| weights = kwargs["weights"] | |
| model = Upscaler() | |
| print(f"Loading weights from {weights}...") | |
| if os.path.splitext(weights)[1] == ".safetensors": | |
| from safetensors.torch import load_file | |
| sd = load_file(weights) | |
| else: | |
| sd = torch.load(weights, map_location=torch.device("cpu")) | |
| model.load_state_dict(sd) | |
| return model | |
| # another interface: upscale images with a model for given images from command line | |
| def upscale_images(args: argparse.Namespace): | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| us_dtype = torch.float16 # TODO: support fp32/bf16 | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # load VAE with Diffusers | |
| assert args.vae_path is not None, "VAE path is required" | |
| print(f"Loading VAE from {args.vae_path}...") | |
| vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") | |
| vae.to(DEVICE, dtype=us_dtype) | |
| # prepare model | |
| print("Preparing model...") | |
| upscaler: Upscaler = create_upscaler(weights=args.weights) | |
| # print("Loading weights from", args.weights) | |
| # upscaler.load_state_dict(torch.load(args.weights)) | |
| upscaler.eval() | |
| upscaler.to(DEVICE, dtype=us_dtype) | |
| # load images | |
| image_paths = glob.glob(args.image_pattern) | |
| images = [] | |
| for image_path in image_paths: | |
| image = Image.open(image_path) | |
| image = image.convert("RGB") | |
| # make divisible by 8 | |
| width = image.width | |
| height = image.height | |
| if width % 8 != 0: | |
| width = width - (width % 8) | |
| if height % 8 != 0: | |
| height = height - (height % 8) | |
| if width != image.width or height != image.height: | |
| image = image.crop((0, 0, width, height)) | |
| images.append(image) | |
| # debug output | |
| if args.debug: | |
| for image, image_path in zip(images, image_paths): | |
| image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS) | |
| basename = os.path.basename(image_path) | |
| basename_wo_ext, ext = os.path.splitext(basename) | |
| dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}") | |
| image_debug.save(dest_file_name) | |
| # upscale | |
| print("Upscaling...") | |
| upscaled_latents = upscaler.upscale( | |
| vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size | |
| ) | |
| upscaled_latents /= 0.18215 | |
| # decode with batch | |
| print("Decoding...") | |
| upscaled_images = [] | |
| for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)): | |
| with torch.no_grad(): | |
| batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample | |
| batch = batch.to("cpu") | |
| upscaled_images.append(batch) | |
| upscaled_images = torch.cat(upscaled_images, dim=0) | |
| # tensor to numpy | |
| upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy() | |
| upscaled_images = (upscaled_images + 1.0) * 127.5 | |
| upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8) | |
| upscaled_images = upscaled_images[..., ::-1] | |
| # save images | |
| for i, image in enumerate(upscaled_images): | |
| basename = os.path.basename(image_paths[i]) | |
| basename_wo_ext, ext = os.path.splitext(basename) | |
| dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}") | |
| cv2.imwrite(dest_file_name, image) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--vae_path", type=str, default=None, help="VAE path") | |
| parser.add_argument("--weights", type=str, default=None, help="Weights path") | |
| parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern") | |
| parser.add_argument("--output_dir", type=str, default=".", help="Output directory") | |
| parser.add_argument("--batch_size", type=int, default=4, help="Batch size") | |
| parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size") | |
| parser.add_argument("--debug", action="store_true", help="Debug mode") | |
| args = parser.parse_args() | |
| upscale_images(args) | |