Spaces:
Running
Running
| import torch | |
| import gc | |
| import os | |
| import torch.nn as nn | |
| import urllib.request | |
| import cv2 | |
| from tqdm import tqdm | |
| HTTP_PREFIXES = [ | |
| 'http', | |
| 'data:image/jpeg', | |
| ] | |
| RELEASED_WEIGHTS = { | |
| "hayao:v2": ( | |
| # Dataset trained on Google Landmark micro as training real photo | |
| "v2", | |
| "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.1/GeneratorV2_gldv2_Hayao.pt" | |
| ), | |
| "hayao:v1": ( | |
| "v1", | |
| "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth" | |
| ), | |
| "hayao": ( | |
| "v1", | |
| "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth" | |
| ), | |
| "shinkai:v1": ( | |
| "v1", | |
| "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth" | |
| ), | |
| "shinkai": ( | |
| "v1", | |
| "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth" | |
| ), | |
| } | |
| def is_image_file(path): | |
| _, ext = os.path.splitext(path) | |
| return ext.lower() in (".png", ".jpg", ".jpeg") | |
| def read_image(path): | |
| """ | |
| Read image from given path | |
| """ | |
| if any(path.startswith(p) for p in HTTP_PREFIXES): | |
| urllib.request.urlretrieve(path, "temp.jpg") | |
| path = "temp.jpg" | |
| return cv2.imread(path)[: ,: ,::-1] | |
| def save_checkpoint(model, path, optimizer=None, epoch=None): | |
| checkpoint = { | |
| 'model_state_dict': model.state_dict(), | |
| 'epoch': epoch, | |
| } | |
| if optimizer is not None: | |
| checkpoint['optimizer_state_dict'] = optimizer.state_dict() | |
| torch.save(checkpoint, path) | |
| def maybe_remove_module(state_dict): | |
| # Remove added module ins state_dict in ddp training | |
| # https://discuss.pytorch.org/t/why-are-state-dict-keys-getting-prepended-with-the-string-module/104627/3 | |
| new_state_dict = {} | |
| module_str = 'module.' | |
| for k, v in state_dict.items(): | |
| if k.startswith(module_str): | |
| k = k[len(module_str):] | |
| new_state_dict[k] = v | |
| return new_state_dict | |
| def load_checkpoint(model, path, optimizer=None, strip_optimizer=False, map_location=None) -> int: | |
| state_dict = load_state_dict(path, map_location) | |
| model_state_dict = maybe_remove_module(state_dict['model_state_dict']) | |
| model.load_state_dict( | |
| model_state_dict, | |
| strict=True | |
| ) | |
| if 'optimizer_state_dict' in state_dict: | |
| if optimizer is not None: | |
| optimizer.load_state_dict(state_dict['optimizer_state_dict']) | |
| if strip_optimizer: | |
| del state_dict["optimizer_state_dict"] | |
| torch.save(state_dict, path) | |
| print(f"Optimizer stripped and saved to {path}") | |
| epoch = state_dict.get('epoch', 0) | |
| return epoch | |
| def load_state_dict(weight, map_location) -> dict: | |
| if weight.lower() in RELEASED_WEIGHTS: | |
| weight = _download_weight(weight.lower()) | |
| if map_location is None: | |
| # auto select | |
| map_location = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| state_dict = torch.load(weight, map_location=map_location) | |
| return state_dict | |
| def initialize_weights(net): | |
| for m in net.modules(): | |
| try: | |
| if isinstance(m, nn.Conv2d): | |
| # m.weight.data.normal_(0, 0.02) | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.ConvTranspose2d): | |
| # m.weight.data.normal_(0, 0.02) | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.Linear): | |
| # m.weight.data.normal_(0, 0.02) | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| except Exception as e: | |
| # print(f'SKip layer {m}, {e}') | |
| pass | |
| def set_lr(optimizer, lr): | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = lr | |
| class DownloadProgressBar(tqdm): | |
| ''' | |
| https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads | |
| ''' | |
| def update_to(self, b=1, bsize=1, tsize=None): | |
| if tsize is not None: | |
| self.total = tsize | |
| self.update(b * bsize - self.n) | |
| def _download_weight(weight): | |
| ''' | |
| Download weight and save to local file | |
| ''' | |
| os.makedirs('.cache', exist_ok=True) | |
| url = RELEASED_WEIGHTS[weight][1] | |
| filename = os.path.basename(url) | |
| save_path = f'.cache/{filename}' | |
| if os.path.isfile(save_path): | |
| return save_path | |
| desc = f'Downloading {url} to {save_path}' | |
| with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=desc) as t: | |
| urllib.request.urlretrieve(url, save_path, reporthook=t.update_to) | |
| return save_path | |