Spaces:
Runtime error
Runtime error
fist version
Browse files- app.py +130 -0
- dnnlib/__init__.py +9 -0
- dnnlib/__pycache__/__init__.cpython-39.pyc +0 -0
- dnnlib/__pycache__/util.cpython-39.pyc +0 -0
- dnnlib/util.py +491 -0
- feature_networks/__pycache__/constants.cpython-39.pyc +0 -0
- feature_networks/__pycache__/pretrained_builder.cpython-39.pyc +0 -0
- feature_networks/__pycache__/vit.cpython-39.pyc +0 -0
- feature_networks/clip/__init__.py +1 -0
- feature_networks/clip/__pycache__/__init__.cpython-39.pyc +0 -0
- feature_networks/clip/__pycache__/clip.cpython-39.pyc +0 -0
- feature_networks/clip/__pycache__/model.cpython-39.pyc +0 -0
- feature_networks/clip/__pycache__/simple_tokenizer.cpython-39.pyc +0 -0
- feature_networks/clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- feature_networks/clip/clip.py +244 -0
- feature_networks/clip/model.py +453 -0
- feature_networks/clip/simple_tokenizer.py +132 -0
- feature_networks/constants.py +129 -0
- feature_networks/pretrained_builder.py +417 -0
- feature_networks/vit.py +436 -0
- legacy.py +331 -0
- misc.py +275 -0
- pg_modules/__init__.py +0 -0
- pg_modules/__pycache__/MViT.cpython-39.pyc +0 -0
- pg_modules/__pycache__/__init__.cpython-39.pyc +0 -0
- pg_modules/__pycache__/blocks.cpython-38.pyc +0 -0
- pg_modules/__pycache__/blocks.cpython-39.pyc +0 -0
- pg_modules/__pycache__/diffaug.cpython-38.pyc +0 -0
- pg_modules/__pycache__/diffaug.cpython-39.pyc +0 -0
- pg_modules/__pycache__/discriminator.cpython-38.pyc +0 -0
- pg_modules/__pycache__/discriminator.cpython-39.pyc +0 -0
- pg_modules/__pycache__/mae.cpython-39.pyc +0 -0
- pg_modules/__pycache__/models_tnt.cpython-39.pyc +0 -0
- pg_modules/__pycache__/networks_fastgan.cpython-38.pyc +0 -0
- pg_modules/__pycache__/networks_fastgan.cpython-39.pyc +0 -0
- pg_modules/__pycache__/networks_stylegan2.cpython-39.pyc +0 -0
- pg_modules/__pycache__/projector.cpython-38.pyc +0 -0
- pg_modules/__pycache__/projector.cpython-39.pyc +0 -0
- pg_modules/__pycache__/simmim.cpython-39.pyc +0 -0
- pg_modules/__pycache__/vision_transformer.cpython-39.pyc +0 -0
- pg_modules/blocks.py +370 -0
- pg_modules/diffaug.py +76 -0
- pg_modules/discriminator.py +153 -0
- pg_modules/networks_fastgan.py +180 -0
- pg_modules/networks_stylegan2.py +537 -0
- pg_modules/projector.py +158 -0
app.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
"""Generate images using pretrained network pickle."""
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
from typing import List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import click
|
| 12 |
+
import dnnlib
|
| 13 |
+
import numpy as np
|
| 14 |
+
import PIL.Image
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
import legacy
|
| 18 |
+
|
| 19 |
+
from huggingface_hub import hf_hub_url
|
| 20 |
+
|
| 21 |
+
#----------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
def parse_range(s: Union[str, List]) -> List[int]:
|
| 24 |
+
'''Parse a comma separated list of numbers or ranges and return a list of ints.
|
| 25 |
+
Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
|
| 26 |
+
'''
|
| 27 |
+
if isinstance(s, list): return s
|
| 28 |
+
ranges = []
|
| 29 |
+
range_re = re.compile(r'^(\d+)-(\d+)$')
|
| 30 |
+
for p in s.split(','):
|
| 31 |
+
m = range_re.match(p)
|
| 32 |
+
if m:
|
| 33 |
+
ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
|
| 34 |
+
else:
|
| 35 |
+
ranges.append(int(p))
|
| 36 |
+
return ranges
|
| 37 |
+
|
| 38 |
+
#----------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
|
| 41 |
+
'''Parse a floating point 2-vector of syntax 'a,b'.
|
| 42 |
+
Example:
|
| 43 |
+
'0,1' returns (0,1)
|
| 44 |
+
'''
|
| 45 |
+
if isinstance(s, tuple): return s
|
| 46 |
+
parts = s.split(',')
|
| 47 |
+
if len(parts) == 2:
|
| 48 |
+
return (float(parts[0]), float(parts[1]))
|
| 49 |
+
raise ValueError(f'cannot parse 2-vector {s}')
|
| 50 |
+
|
| 51 |
+
#----------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
def make_transform(translate: Tuple[float,float], angle: float):
|
| 54 |
+
m = np.eye(3)
|
| 55 |
+
s = np.sin(angle/360.0*np.pi*2)
|
| 56 |
+
c = np.cos(angle/360.0*np.pi*2)
|
| 57 |
+
m[0][0] = c
|
| 58 |
+
m[0][1] = s
|
| 59 |
+
m[0][2] = translate[0]
|
| 60 |
+
m[1][0] = -s
|
| 61 |
+
m[1][1] = c
|
| 62 |
+
m[1][2] = translate[1]
|
| 63 |
+
return m
|
| 64 |
+
|
| 65 |
+
#----------------------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 68 |
+
|
| 69 |
+
# config_file_url = hf_hub_url("autonomousvision/Projected_GAN_Pokemon", filename="pokemon.pkl")
|
| 70 |
+
# config_file_url = r'E:\桌面\Preparation of Papers for IEEE Signal Processing Letters (5-page limit)\codes\pokemon.pkl'
|
| 71 |
+
# with dnnlib.util.open_url(config_file_url) as f:
|
| 72 |
+
# G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
|
| 73 |
+
|
| 74 |
+
# models = {
|
| 75 |
+
# 'pokemon':
|
| 76 |
+
# }
|
| 77 |
+
# base_path =
|
| 78 |
+
models = dict()
|
| 79 |
+
for i in ["pokemon", "art-paint", "flowers", "landscapes","obama"]:
|
| 80 |
+
with dnnlib.util.open_url("E:\桌面\Preparation of Papers for IEEE Signal Processing Letters (5-page limit)\codes\projected-gan-clc - 副本\\" +i+".pkl") as f:
|
| 81 |
+
models[i] = legacy.load_network_pkl(f)['G_ema']
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def generate_images(seeds, name):
|
| 85 |
+
"""Generate images using pretrained network pickle.
|
| 86 |
+
Examples:
|
| 87 |
+
\b
|
| 88 |
+
# Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
|
| 89 |
+
python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
|
| 90 |
+
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
|
| 91 |
+
\b
|
| 92 |
+
# Generate uncurated images with truncation using the MetFaces-U dataset
|
| 93 |
+
python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
|
| 94 |
+
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
|
| 95 |
+
"""
|
| 96 |
+
# models
|
| 97 |
+
G = models[name].to(device)
|
| 98 |
+
# Labels.
|
| 99 |
+
label = torch.zeros([1, G.c_dim], device=device)
|
| 100 |
+
|
| 101 |
+
# Generate images.
|
| 102 |
+
for seed_idx, seed in enumerate(seeds):
|
| 103 |
+
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
|
| 104 |
+
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device).float()
|
| 105 |
+
|
| 106 |
+
# Construct an inverse rotation/translation matrix and pass to the generator. The
|
| 107 |
+
# generator expects this matrix as an inverse to avoid potentially failing numerical
|
| 108 |
+
# operations in the network.
|
| 109 |
+
if hasattr(G.synthesis, 'input'):
|
| 110 |
+
m = make_transform('0,0', 0)
|
| 111 |
+
m = np.linalg.inv(m)
|
| 112 |
+
G.synthesis.input.transform.copy_(torch.from_numpy(m))
|
| 113 |
+
|
| 114 |
+
img = G(z, label, truncation_psi=1, noise_mode='const')
|
| 115 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
| 116 |
+
pilimg = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
| 117 |
+
return pilimg
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def inference(seedin, name = None):
|
| 121 |
+
print(name)
|
| 122 |
+
listseed = [int(seedin)]
|
| 123 |
+
output = generate_images(listseed, name)
|
| 124 |
+
return output
|
| 125 |
+
|
| 126 |
+
title = "Projected GAN CLC"
|
| 127 |
+
description = "Gradio demo for Projected GANs CLC, Pokemon."
|
| 128 |
+
|
| 129 |
+
gr.Interface(fn=inference,inputs=[gr.Slider(label="Seed",minimum=0, maximum=5000, step=1, value=0), gr.Radio(["pokemon", "art-paint", "flowers", "landscapes","obama"], label='Dataset', value='art-paint')],outputs=["image"],title=title,description=description
|
| 130 |
+
).launch()
|
dnnlib/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
from .util import EasyDict, make_cache_dir_path
|
dnnlib/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (285 Bytes). View file
|
|
|
dnnlib/__pycache__/util.cpython-39.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
dnnlib/util.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
"""Miscellaneous utility classes and functions."""
|
| 10 |
+
|
| 11 |
+
import ctypes
|
| 12 |
+
import fnmatch
|
| 13 |
+
import importlib
|
| 14 |
+
import inspect
|
| 15 |
+
import numpy as np
|
| 16 |
+
import os
|
| 17 |
+
import shutil
|
| 18 |
+
import sys
|
| 19 |
+
import types
|
| 20 |
+
import io
|
| 21 |
+
import pickle
|
| 22 |
+
import re
|
| 23 |
+
import requests
|
| 24 |
+
import html
|
| 25 |
+
import hashlib
|
| 26 |
+
import glob
|
| 27 |
+
import tempfile
|
| 28 |
+
import urllib
|
| 29 |
+
import urllib.request
|
| 30 |
+
import uuid
|
| 31 |
+
|
| 32 |
+
from distutils.util import strtobool
|
| 33 |
+
from typing import Any, List, Tuple, Union
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Util classes
|
| 37 |
+
# ------------------------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class EasyDict(dict):
|
| 41 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
| 42 |
+
|
| 43 |
+
def __getattr__(self, name: str) -> Any:
|
| 44 |
+
try:
|
| 45 |
+
return self[name]
|
| 46 |
+
except KeyError:
|
| 47 |
+
raise AttributeError(name)
|
| 48 |
+
|
| 49 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 50 |
+
self[name] = value
|
| 51 |
+
|
| 52 |
+
def __delattr__(self, name: str) -> None:
|
| 53 |
+
del self[name]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Logger(object):
|
| 57 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
| 58 |
+
|
| 59 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
| 60 |
+
self.file = None
|
| 61 |
+
|
| 62 |
+
if file_name is not None:
|
| 63 |
+
self.file = open(file_name, file_mode)
|
| 64 |
+
|
| 65 |
+
self.should_flush = should_flush
|
| 66 |
+
self.stdout = sys.stdout
|
| 67 |
+
self.stderr = sys.stderr
|
| 68 |
+
|
| 69 |
+
sys.stdout = self
|
| 70 |
+
sys.stderr = self
|
| 71 |
+
|
| 72 |
+
def __enter__(self) -> "Logger":
|
| 73 |
+
return self
|
| 74 |
+
|
| 75 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 76 |
+
self.close()
|
| 77 |
+
|
| 78 |
+
def write(self, text: Union[str, bytes]) -> None:
|
| 79 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
| 80 |
+
if isinstance(text, bytes):
|
| 81 |
+
text = text.decode()
|
| 82 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
if self.file is not None:
|
| 86 |
+
self.file.write(text)
|
| 87 |
+
|
| 88 |
+
self.stdout.write(text)
|
| 89 |
+
|
| 90 |
+
if self.should_flush:
|
| 91 |
+
self.flush()
|
| 92 |
+
|
| 93 |
+
def flush(self) -> None:
|
| 94 |
+
"""Flush written text to both stdout and a file, if open."""
|
| 95 |
+
if self.file is not None:
|
| 96 |
+
self.file.flush()
|
| 97 |
+
|
| 98 |
+
self.stdout.flush()
|
| 99 |
+
|
| 100 |
+
def close(self) -> None:
|
| 101 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
| 102 |
+
self.flush()
|
| 103 |
+
|
| 104 |
+
# if using multiple loggers, prevent closing in wrong order
|
| 105 |
+
if sys.stdout is self:
|
| 106 |
+
sys.stdout = self.stdout
|
| 107 |
+
if sys.stderr is self:
|
| 108 |
+
sys.stderr = self.stderr
|
| 109 |
+
|
| 110 |
+
if self.file is not None:
|
| 111 |
+
self.file.close()
|
| 112 |
+
self.file = None
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# Cache directories
|
| 116 |
+
# ------------------------------------------------------------------------------------------
|
| 117 |
+
|
| 118 |
+
_dnnlib_cache_dir = None
|
| 119 |
+
|
| 120 |
+
def set_cache_dir(path: str) -> None:
|
| 121 |
+
global _dnnlib_cache_dir
|
| 122 |
+
_dnnlib_cache_dir = path
|
| 123 |
+
|
| 124 |
+
def make_cache_dir_path(*paths: str) -> str:
|
| 125 |
+
if _dnnlib_cache_dir is not None:
|
| 126 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
| 127 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
| 128 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
| 129 |
+
if 'HOME' in os.environ:
|
| 130 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
| 131 |
+
if 'USERPROFILE' in os.environ:
|
| 132 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
| 133 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
| 134 |
+
|
| 135 |
+
# Small util functions
|
| 136 |
+
# ------------------------------------------------------------------------------------------
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def format_time(seconds: Union[int, float]) -> str:
|
| 140 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 141 |
+
s = int(np.rint(seconds))
|
| 142 |
+
|
| 143 |
+
if s < 60:
|
| 144 |
+
return "{0}s".format(s)
|
| 145 |
+
elif s < 60 * 60:
|
| 146 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 147 |
+
elif s < 24 * 60 * 60:
|
| 148 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
| 149 |
+
else:
|
| 150 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def format_time_brief(seconds: Union[int, float]) -> str:
|
| 154 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 155 |
+
s = int(np.rint(seconds))
|
| 156 |
+
|
| 157 |
+
if s < 60:
|
| 158 |
+
return "{0}s".format(s)
|
| 159 |
+
elif s < 60 * 60:
|
| 160 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 161 |
+
elif s < 24 * 60 * 60:
|
| 162 |
+
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
| 163 |
+
else:
|
| 164 |
+
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def ask_yes_no(question: str) -> bool:
|
| 168 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
| 169 |
+
while True:
|
| 170 |
+
try:
|
| 171 |
+
print("{0} [y/n]".format(question))
|
| 172 |
+
return strtobool(input().lower())
|
| 173 |
+
except ValueError:
|
| 174 |
+
pass
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def tuple_product(t: Tuple) -> Any:
|
| 178 |
+
"""Calculate the product of the tuple elements."""
|
| 179 |
+
result = 1
|
| 180 |
+
|
| 181 |
+
for v in t:
|
| 182 |
+
result *= v
|
| 183 |
+
|
| 184 |
+
return result
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
_str_to_ctype = {
|
| 188 |
+
"uint8": ctypes.c_ubyte,
|
| 189 |
+
"uint16": ctypes.c_uint16,
|
| 190 |
+
"uint32": ctypes.c_uint32,
|
| 191 |
+
"uint64": ctypes.c_uint64,
|
| 192 |
+
"int8": ctypes.c_byte,
|
| 193 |
+
"int16": ctypes.c_int16,
|
| 194 |
+
"int32": ctypes.c_int32,
|
| 195 |
+
"int64": ctypes.c_int64,
|
| 196 |
+
"float32": ctypes.c_float,
|
| 197 |
+
"float64": ctypes.c_double
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
| 202 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
| 203 |
+
type_str = None
|
| 204 |
+
|
| 205 |
+
if isinstance(type_obj, str):
|
| 206 |
+
type_str = type_obj
|
| 207 |
+
elif hasattr(type_obj, "__name__"):
|
| 208 |
+
type_str = type_obj.__name__
|
| 209 |
+
elif hasattr(type_obj, "name"):
|
| 210 |
+
type_str = type_obj.name
|
| 211 |
+
else:
|
| 212 |
+
raise RuntimeError("Cannot infer type name from input")
|
| 213 |
+
|
| 214 |
+
assert type_str in _str_to_ctype.keys()
|
| 215 |
+
|
| 216 |
+
my_dtype = np.dtype(type_str)
|
| 217 |
+
my_ctype = _str_to_ctype[type_str]
|
| 218 |
+
|
| 219 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
| 220 |
+
|
| 221 |
+
return my_dtype, my_ctype
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def is_pickleable(obj: Any) -> bool:
|
| 225 |
+
try:
|
| 226 |
+
with io.BytesIO() as stream:
|
| 227 |
+
pickle.dump(obj, stream)
|
| 228 |
+
return True
|
| 229 |
+
except:
|
| 230 |
+
return False
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# Functionality to import modules/objects by name, and call functions by name
|
| 234 |
+
# ------------------------------------------------------------------------------------------
|
| 235 |
+
|
| 236 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
| 237 |
+
"""Searches for the underlying module behind the name to some python object.
|
| 238 |
+
Returns the module and the object name (original name with module part removed)."""
|
| 239 |
+
|
| 240 |
+
# allow convenience shorthands, substitute them by full names
|
| 241 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
| 242 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
| 243 |
+
|
| 244 |
+
# list alternatives for (module_name, local_obj_name)
|
| 245 |
+
parts = obj_name.split(".")
|
| 246 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
| 247 |
+
|
| 248 |
+
# try each alternative in turn
|
| 249 |
+
for module_name, local_obj_name in name_pairs:
|
| 250 |
+
try:
|
| 251 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 252 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 253 |
+
return module, local_obj_name
|
| 254 |
+
except:
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
# maybe some of the modules themselves contain errors?
|
| 258 |
+
for module_name, _local_obj_name in name_pairs:
|
| 259 |
+
try:
|
| 260 |
+
importlib.import_module(module_name) # may raise ImportError
|
| 261 |
+
except ImportError:
|
| 262 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
| 263 |
+
raise
|
| 264 |
+
|
| 265 |
+
# maybe the requested attribute is missing?
|
| 266 |
+
for module_name, local_obj_name in name_pairs:
|
| 267 |
+
try:
|
| 268 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 269 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 270 |
+
except ImportError:
|
| 271 |
+
pass
|
| 272 |
+
|
| 273 |
+
# we are out of luck, but we have no idea why
|
| 274 |
+
raise ImportError(obj_name)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
| 278 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
| 279 |
+
if obj_name == '':
|
| 280 |
+
return module
|
| 281 |
+
obj = module
|
| 282 |
+
for part in obj_name.split("."):
|
| 283 |
+
obj = getattr(obj, part)
|
| 284 |
+
return obj
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def get_obj_by_name(name: str) -> Any:
|
| 288 |
+
"""Finds the python object with the given name."""
|
| 289 |
+
module, obj_name = get_module_from_obj_name(name)
|
| 290 |
+
return get_obj_from_module(module, obj_name)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
| 294 |
+
"""Finds the python object with the given name and calls it as a function."""
|
| 295 |
+
assert func_name is not None
|
| 296 |
+
func_obj = get_obj_by_name(func_name)
|
| 297 |
+
assert callable(func_obj)
|
| 298 |
+
return func_obj(*args, **kwargs)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
| 302 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
| 303 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
| 307 |
+
"""Get the directory path of the module containing the given object name."""
|
| 308 |
+
module, _ = get_module_from_obj_name(obj_name)
|
| 309 |
+
return os.path.dirname(inspect.getfile(module))
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def is_top_level_function(obj: Any) -> bool:
|
| 313 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
| 314 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def get_top_level_function_name(obj: Any) -> str:
|
| 318 |
+
"""Return the fully-qualified name of a top-level function."""
|
| 319 |
+
assert is_top_level_function(obj)
|
| 320 |
+
module = obj.__module__
|
| 321 |
+
if module == '__main__':
|
| 322 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
| 323 |
+
return module + "." + obj.__name__
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
# File system helpers
|
| 327 |
+
# ------------------------------------------------------------------------------------------
|
| 328 |
+
|
| 329 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
| 330 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
| 331 |
+
Returns list of tuples containing both absolute and relative paths."""
|
| 332 |
+
assert os.path.isdir(dir_path)
|
| 333 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
| 334 |
+
|
| 335 |
+
if ignores is None:
|
| 336 |
+
ignores = []
|
| 337 |
+
|
| 338 |
+
result = []
|
| 339 |
+
|
| 340 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
| 341 |
+
for ignore_ in ignores:
|
| 342 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
| 343 |
+
|
| 344 |
+
# dirs need to be edited in-place
|
| 345 |
+
for d in dirs_to_remove:
|
| 346 |
+
dirs.remove(d)
|
| 347 |
+
|
| 348 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
| 349 |
+
|
| 350 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
| 351 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
| 352 |
+
|
| 353 |
+
if add_base_to_relative:
|
| 354 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
| 355 |
+
|
| 356 |
+
assert len(absolute_paths) == len(relative_paths)
|
| 357 |
+
result += zip(absolute_paths, relative_paths)
|
| 358 |
+
|
| 359 |
+
return result
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
| 363 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
| 364 |
+
Will create all necessary directories."""
|
| 365 |
+
for file in files:
|
| 366 |
+
target_dir_name = os.path.dirname(file[1])
|
| 367 |
+
|
| 368 |
+
# will create all intermediate-level directories
|
| 369 |
+
if not os.path.exists(target_dir_name):
|
| 370 |
+
os.makedirs(target_dir_name)
|
| 371 |
+
|
| 372 |
+
shutil.copyfile(file[0], file[1])
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# URL helpers
|
| 376 |
+
# ------------------------------------------------------------------------------------------
|
| 377 |
+
|
| 378 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
| 379 |
+
"""Determine whether the given object is a valid URL string."""
|
| 380 |
+
if not isinstance(obj, str) or not "://" in obj:
|
| 381 |
+
return False
|
| 382 |
+
if allow_file_urls and obj.startswith('file://'):
|
| 383 |
+
return True
|
| 384 |
+
try:
|
| 385 |
+
res = requests.compat.urlparse(obj)
|
| 386 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 387 |
+
return False
|
| 388 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
| 389 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 390 |
+
return False
|
| 391 |
+
except:
|
| 392 |
+
return False
|
| 393 |
+
return True
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
| 397 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 398 |
+
assert num_attempts >= 1
|
| 399 |
+
assert not (return_filename and (not cache))
|
| 400 |
+
|
| 401 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
| 402 |
+
if not re.match('^[a-z]+://', url):
|
| 403 |
+
return url if return_filename else open(url, "rb")
|
| 404 |
+
|
| 405 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
| 406 |
+
# arise on Windows:
|
| 407 |
+
#
|
| 408 |
+
# file:///c:/foo.txt
|
| 409 |
+
#
|
| 410 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
| 411 |
+
# invalid. Drop the forward slash for such pathnames.
|
| 412 |
+
#
|
| 413 |
+
# If you touch this code path, you should test it on both Linux and
|
| 414 |
+
# Windows.
|
| 415 |
+
#
|
| 416 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
| 417 |
+
# but that converts forward slashes to backslashes and this causes
|
| 418 |
+
# its own set of problems.
|
| 419 |
+
if url.startswith('file://'):
|
| 420 |
+
filename = urllib.parse.urlparse(url).path
|
| 421 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
| 422 |
+
filename = filename[1:]
|
| 423 |
+
return filename if return_filename else open(filename, "rb")
|
| 424 |
+
|
| 425 |
+
assert is_url(url)
|
| 426 |
+
|
| 427 |
+
# Lookup from cache.
|
| 428 |
+
if cache_dir is None:
|
| 429 |
+
cache_dir = make_cache_dir_path('downloads')
|
| 430 |
+
|
| 431 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 432 |
+
if cache:
|
| 433 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
| 434 |
+
if len(cache_files) == 1:
|
| 435 |
+
filename = cache_files[0]
|
| 436 |
+
return filename if return_filename else open(filename, "rb")
|
| 437 |
+
|
| 438 |
+
# Download.
|
| 439 |
+
url_name = None
|
| 440 |
+
url_data = None
|
| 441 |
+
with requests.Session() as session:
|
| 442 |
+
if verbose:
|
| 443 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 444 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 445 |
+
try:
|
| 446 |
+
with session.get(url) as res:
|
| 447 |
+
res.raise_for_status()
|
| 448 |
+
if len(res.content) == 0:
|
| 449 |
+
raise IOError("No data received")
|
| 450 |
+
|
| 451 |
+
if len(res.content) < 8192:
|
| 452 |
+
content_str = res.content.decode("utf-8")
|
| 453 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 454 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
| 455 |
+
if len(links) == 1:
|
| 456 |
+
url = requests.compat.urljoin(url, links[0])
|
| 457 |
+
raise IOError("Google Drive virus checker nag")
|
| 458 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 459 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
| 460 |
+
|
| 461 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
| 462 |
+
url_name = match[1] if match else url
|
| 463 |
+
url_data = res.content
|
| 464 |
+
if verbose:
|
| 465 |
+
print(" done")
|
| 466 |
+
break
|
| 467 |
+
except KeyboardInterrupt:
|
| 468 |
+
raise
|
| 469 |
+
except:
|
| 470 |
+
if not attempts_left:
|
| 471 |
+
if verbose:
|
| 472 |
+
print(" failed")
|
| 473 |
+
raise
|
| 474 |
+
if verbose:
|
| 475 |
+
print(".", end="", flush=True)
|
| 476 |
+
|
| 477 |
+
# Save to cache.
|
| 478 |
+
if cache:
|
| 479 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
| 480 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
| 481 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
| 482 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 483 |
+
with open(temp_file, "wb") as f:
|
| 484 |
+
f.write(url_data)
|
| 485 |
+
os.replace(temp_file, cache_file) # atomic
|
| 486 |
+
if return_filename:
|
| 487 |
+
return cache_file
|
| 488 |
+
|
| 489 |
+
# Return data as file object.
|
| 490 |
+
assert not return_filename
|
| 491 |
+
return io.BytesIO(url_data)
|
feature_networks/__pycache__/constants.cpython-39.pyc
ADDED
|
Binary file (2.06 kB). View file
|
|
|
feature_networks/__pycache__/pretrained_builder.cpython-39.pyc
ADDED
|
Binary file (8.5 kB). View file
|
|
|
feature_networks/__pycache__/vit.cpython-39.pyc
ADDED
|
Binary file (8.58 kB). View file
|
|
|
feature_networks/clip/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .clip import *
|
feature_networks/clip/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (254 Bytes). View file
|
|
|
feature_networks/clip/__pycache__/clip.cpython-39.pyc
ADDED
|
Binary file (9.19 kB). View file
|
|
|
feature_networks/clip/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
feature_networks/clip/__pycache__/simple_tokenizer.cpython-39.pyc
ADDED
|
Binary file (5.84 kB). View file
|
|
|
feature_networks/clip/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|
feature_networks/clip/clip.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import os
|
| 3 |
+
import urllib
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Union, List
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from .model import build_model
|
| 14 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 15 |
+
|
| 16 |
+
__all__ = ["available_models", "load", "tokenize"]
|
| 17 |
+
_tokenizer = _Tokenizer()
|
| 18 |
+
|
| 19 |
+
_MODELS = {
|
| 20 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
| 21 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
| 22 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
| 23 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
| 28 |
+
os.makedirs(root, exist_ok=True)
|
| 29 |
+
filename = os.path.basename(url)
|
| 30 |
+
|
| 31 |
+
expected_sha256 = url.split("/")[-2]
|
| 32 |
+
download_target = os.path.join(root, filename)
|
| 33 |
+
|
| 34 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 35 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 36 |
+
|
| 37 |
+
if os.path.isfile(download_target):
|
| 38 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
| 39 |
+
return download_target
|
| 40 |
+
else:
|
| 41 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 42 |
+
|
| 43 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 44 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
| 45 |
+
while True:
|
| 46 |
+
buffer = source.read(8192)
|
| 47 |
+
if not buffer:
|
| 48 |
+
break
|
| 49 |
+
|
| 50 |
+
output.write(buffer)
|
| 51 |
+
loop.update(len(buffer))
|
| 52 |
+
|
| 53 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
| 54 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
| 55 |
+
|
| 56 |
+
return download_target
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _transform(n_px):
|
| 60 |
+
return Compose([
|
| 61 |
+
Resize(n_px, interpolation=Image.BICUBIC),
|
| 62 |
+
CenterCrop(n_px),
|
| 63 |
+
lambda image: image.convert("RGB"),
|
| 64 |
+
ToTensor(),
|
| 65 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def available_models() -> List[str]:
|
| 70 |
+
"""Returns the names of available CLIP models"""
|
| 71 |
+
return list(_MODELS.keys())
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
| 75 |
+
"""Load a CLIP model
|
| 76 |
+
|
| 77 |
+
Parameters
|
| 78 |
+
----------
|
| 79 |
+
name : str
|
| 80 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 81 |
+
|
| 82 |
+
device : Union[str, torch.device]
|
| 83 |
+
The device to put the loaded model
|
| 84 |
+
|
| 85 |
+
jit : bool
|
| 86 |
+
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
| 87 |
+
|
| 88 |
+
Returns
|
| 89 |
+
-------
|
| 90 |
+
model : torch.nn.Module
|
| 91 |
+
The CLIP model
|
| 92 |
+
|
| 93 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
| 94 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
| 95 |
+
"""
|
| 96 |
+
if name in _MODELS:
|
| 97 |
+
model_path = _download(_MODELS[name])
|
| 98 |
+
elif os.path.isfile(name):
|
| 99 |
+
model_path = name
|
| 100 |
+
else:
|
| 101 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
# loading JIT archive
|
| 105 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
| 106 |
+
state_dict = None
|
| 107 |
+
except RuntimeError:
|
| 108 |
+
# loading saved state dict
|
| 109 |
+
if jit:
|
| 110 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
| 111 |
+
jit = False
|
| 112 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 113 |
+
|
| 114 |
+
if not jit:
|
| 115 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
| 116 |
+
if str(device) == "cpu":
|
| 117 |
+
model.float()
|
| 118 |
+
return model, _transform(model.visual.input_resolution)
|
| 119 |
+
|
| 120 |
+
# patch the device names
|
| 121 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
| 122 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
| 123 |
+
|
| 124 |
+
def patch_device(module):
|
| 125 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 126 |
+
if hasattr(module, "forward1"):
|
| 127 |
+
graphs.append(module.forward1.graph)
|
| 128 |
+
|
| 129 |
+
for graph in graphs:
|
| 130 |
+
for node in graph.findAllNodes("prim::Constant"):
|
| 131 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
| 132 |
+
node.copyAttributes(device_node)
|
| 133 |
+
|
| 134 |
+
model.apply(patch_device)
|
| 135 |
+
patch_device(model.encode_image)
|
| 136 |
+
patch_device(model.encode_text)
|
| 137 |
+
|
| 138 |
+
# patch dtype to float32 on CPU
|
| 139 |
+
if str(device) == "cpu":
|
| 140 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
| 141 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
| 142 |
+
float_node = float_input.node()
|
| 143 |
+
|
| 144 |
+
def patch_float(module):
|
| 145 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 146 |
+
if hasattr(module, "forward1"):
|
| 147 |
+
graphs.append(module.forward1.graph)
|
| 148 |
+
|
| 149 |
+
for graph in graphs:
|
| 150 |
+
for node in graph.findAllNodes("aten::to"):
|
| 151 |
+
inputs = list(node.inputs())
|
| 152 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
| 153 |
+
if inputs[i].node()["value"] == 5:
|
| 154 |
+
inputs[i].node().copyAttributes(float_node)
|
| 155 |
+
|
| 156 |
+
model.apply(patch_float)
|
| 157 |
+
patch_float(model.encode_image)
|
| 158 |
+
patch_float(model.encode_text)
|
| 159 |
+
|
| 160 |
+
model.float()
|
| 161 |
+
|
| 162 |
+
return model, _transform(model.input_resolution.item())
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
| 166 |
+
"""
|
| 167 |
+
Returns the tokenized representation of given input string(s)
|
| 168 |
+
|
| 169 |
+
Parameters
|
| 170 |
+
----------
|
| 171 |
+
texts : Union[str, List[str]]
|
| 172 |
+
An input string or a list of input strings to tokenize
|
| 173 |
+
|
| 174 |
+
context_length : int
|
| 175 |
+
The context length to use; all CLIP models use 77 as the context length
|
| 176 |
+
|
| 177 |
+
Returns
|
| 178 |
+
-------
|
| 179 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
| 180 |
+
"""
|
| 181 |
+
if isinstance(texts, str):
|
| 182 |
+
texts = [texts]
|
| 183 |
+
|
| 184 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
| 185 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
| 186 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
| 187 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 188 |
+
|
| 189 |
+
for i, tokens in enumerate(all_tokens):
|
| 190 |
+
if len(tokens) > context_length:
|
| 191 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
| 192 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 193 |
+
|
| 194 |
+
return result
|
| 195 |
+
|
| 196 |
+
def pdist(sample_1, sample_2, norm=2, eps=1e-5):
|
| 197 |
+
r"""Compute the matrix of all squared pairwise distances.
|
| 198 |
+
Arguments
|
| 199 |
+
---------
|
| 200 |
+
sample_1 : torch.Tensor or Variable
|
| 201 |
+
The first sample, should be of shape ``(n_1, d)``.
|
| 202 |
+
sample_2 : torch.Tensor or Variable
|
| 203 |
+
The second sample, should be of shape ``(n_2, d)``.
|
| 204 |
+
norm : float
|
| 205 |
+
The l_p norm to be used.
|
| 206 |
+
Returns
|
| 207 |
+
-------
|
| 208 |
+
torch.Tensor or Variable
|
| 209 |
+
Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
|
| 210 |
+
``|| sample_1[i, :] - sample_2[j, :] ||_p``."""
|
| 211 |
+
n_1, n_2 = sample_1.size(0), sample_2.size(0)
|
| 212 |
+
norm = float(norm)
|
| 213 |
+
if norm == 2.:
|
| 214 |
+
norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
|
| 215 |
+
norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
|
| 216 |
+
norms = (norms_1.expand(n_1, n_2) +
|
| 217 |
+
norms_2.transpose(0, 1).expand(n_1, n_2))
|
| 218 |
+
distances_squared = norms - 2 * sample_1.mm(sample_2.t())
|
| 219 |
+
return torch.sqrt(eps + torch.abs(distances_squared))
|
| 220 |
+
else:
|
| 221 |
+
dim = sample_1.size(1)
|
| 222 |
+
expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim)
|
| 223 |
+
expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim)
|
| 224 |
+
differences = torch.abs(expanded_1 - expanded_2) ** norm
|
| 225 |
+
inner = torch.sum(differences, dim=2, keepdim=False)
|
| 226 |
+
return (eps + inner) ** (1. / norm)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ClipHead(nn.Module):
|
| 230 |
+
def __init__(self, prompt, device='cpu'):
|
| 231 |
+
super().__init__()
|
| 232 |
+
self.clip_model = load("RN50", device=device, jit=False)[0].eval()
|
| 233 |
+
self.prompt = prompt
|
| 234 |
+
|
| 235 |
+
def calc_loss(self, features):
|
| 236 |
+
dev = features['last'].get_device()
|
| 237 |
+
text_input = tokenize(self.prompt).to(dev)
|
| 238 |
+
|
| 239 |
+
text_features = self.clip_model.encode_text(text_input)
|
| 240 |
+
image_features = self.clip_model.encode_conv_features(features['last'])
|
| 241 |
+
loss = - torch.cosine_similarity(text_features, image_features, dim=1)
|
| 242 |
+
# loss -= (pdist(image_features, image_features)/image_features.max()).sum()
|
| 243 |
+
|
| 244 |
+
return loss.mean()
|
feature_networks/clip/model.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Bottleneck(nn.Module):
|
| 11 |
+
expansion = 4
|
| 12 |
+
|
| 13 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 19 |
+
|
| 20 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 21 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 22 |
+
|
| 23 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 24 |
+
|
| 25 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 26 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 27 |
+
|
| 28 |
+
self.relu = nn.ReLU(inplace=True)
|
| 29 |
+
self.downsample = None
|
| 30 |
+
self.stride = stride
|
| 31 |
+
|
| 32 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 33 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 34 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 35 |
+
("-1", nn.AvgPool2d(stride)),
|
| 36 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 37 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 38 |
+
]))
|
| 39 |
+
|
| 40 |
+
def forward(self, x: torch.Tensor):
|
| 41 |
+
identity = x
|
| 42 |
+
|
| 43 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
| 44 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
| 45 |
+
out = self.avgpool(out)
|
| 46 |
+
out = self.bn3(self.conv3(out))
|
| 47 |
+
|
| 48 |
+
if self.downsample is not None:
|
| 49 |
+
identity = self.downsample(x)
|
| 50 |
+
|
| 51 |
+
out += identity
|
| 52 |
+
out = self.relu(out)
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class AttentionPool2d(nn.Module):
|
| 57 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 60 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 61 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 62 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 63 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 64 |
+
self.num_heads = num_heads
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 68 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 69 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 70 |
+
x, _ = F.multi_head_attention_forward(
|
| 71 |
+
query=x, key=x, value=x,
|
| 72 |
+
embed_dim_to_check=x.shape[-1],
|
| 73 |
+
num_heads=self.num_heads,
|
| 74 |
+
q_proj_weight=self.q_proj.weight,
|
| 75 |
+
k_proj_weight=self.k_proj.weight,
|
| 76 |
+
v_proj_weight=self.v_proj.weight,
|
| 77 |
+
in_proj_weight=None,
|
| 78 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 79 |
+
bias_k=None,
|
| 80 |
+
bias_v=None,
|
| 81 |
+
add_zero_attn=False,
|
| 82 |
+
dropout_p=0,
|
| 83 |
+
out_proj_weight=self.c_proj.weight,
|
| 84 |
+
out_proj_bias=self.c_proj.bias,
|
| 85 |
+
use_separate_proj_weight=True,
|
| 86 |
+
training=self.training,
|
| 87 |
+
need_weights=False
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
return x[0]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ModifiedResNet(nn.Module):
|
| 94 |
+
"""
|
| 95 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 96 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 97 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 98 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.output_dim = output_dim
|
| 104 |
+
self.input_resolution = input_resolution
|
| 105 |
+
|
| 106 |
+
# the 3-layer stem
|
| 107 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 108 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 109 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 110 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 111 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 112 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 113 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 114 |
+
self.relu = nn.ReLU(inplace=True)
|
| 115 |
+
|
| 116 |
+
# residual layers
|
| 117 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 118 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 119 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 120 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 121 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 122 |
+
|
| 123 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 124 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 125 |
+
|
| 126 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 127 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 128 |
+
|
| 129 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 130 |
+
for _ in range(1, blocks):
|
| 131 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 132 |
+
|
| 133 |
+
return nn.Sequential(*layers)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
def stem(x):
|
| 137 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
| 138 |
+
x = self.relu(bn(conv(x)))
|
| 139 |
+
x = self.avgpool(x)
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
x = x.type(self.conv1.weight.dtype)
|
| 143 |
+
x = stem(x)
|
| 144 |
+
x = self.layer1(x)
|
| 145 |
+
x = self.layer2(x)
|
| 146 |
+
x = self.layer3(x)
|
| 147 |
+
x = self.layer4(x)
|
| 148 |
+
x = self.attnpool(x)
|
| 149 |
+
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class LayerNorm(nn.LayerNorm):
|
| 154 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 155 |
+
|
| 156 |
+
def forward(self, x: torch.Tensor):
|
| 157 |
+
orig_type = x.dtype
|
| 158 |
+
ret = super().forward(x.type(torch.float32))
|
| 159 |
+
return ret.type(orig_type)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class QuickGELU(nn.Module):
|
| 163 |
+
def forward(self, x: torch.Tensor):
|
| 164 |
+
return x * torch.sigmoid(1.702 * x)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class ResidualAttentionBlock(nn.Module):
|
| 168 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 169 |
+
super().__init__()
|
| 170 |
+
|
| 171 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 172 |
+
self.ln_1 = LayerNorm(d_model)
|
| 173 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 174 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 175 |
+
("gelu", QuickGELU()),
|
| 176 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 177 |
+
]))
|
| 178 |
+
self.ln_2 = LayerNorm(d_model)
|
| 179 |
+
self.attn_mask = attn_mask
|
| 180 |
+
|
| 181 |
+
def attention(self, x: torch.Tensor):
|
| 182 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 183 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 184 |
+
|
| 185 |
+
def forward(self, x: torch.Tensor):
|
| 186 |
+
x = x + self.attention(self.ln_1(x))
|
| 187 |
+
x = x + self.mlp(self.ln_2(x))
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class Transformer(nn.Module):
|
| 192 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.width = width
|
| 195 |
+
self.layers = layers
|
| 196 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
| 197 |
+
|
| 198 |
+
def forward(self, x: torch.Tensor):
|
| 199 |
+
return self.resblocks(x)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class VisualTransformer(nn.Module):
|
| 203 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
| 204 |
+
super().__init__()
|
| 205 |
+
self.input_resolution = input_resolution
|
| 206 |
+
self.output_dim = output_dim
|
| 207 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 208 |
+
|
| 209 |
+
scale = width ** -0.5
|
| 210 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 211 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
| 212 |
+
self.ln_pre = LayerNorm(width)
|
| 213 |
+
|
| 214 |
+
self.transformer = Transformer(width, layers, heads)
|
| 215 |
+
|
| 216 |
+
self.ln_post = LayerNorm(width)
|
| 217 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
| 218 |
+
|
| 219 |
+
def forward(self, x: torch.Tensor):
|
| 220 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
| 221 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 222 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 223 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 224 |
+
x = x + self.positional_embedding.to(x.dtype)
|
| 225 |
+
x = self.ln_pre(x)
|
| 226 |
+
|
| 227 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 228 |
+
x = self.transformer(x)
|
| 229 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 230 |
+
|
| 231 |
+
x = self.ln_post(x[:, 0, :])
|
| 232 |
+
|
| 233 |
+
if self.proj is not None:
|
| 234 |
+
x = x @ self.proj
|
| 235 |
+
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class CLIP(nn.Module):
|
| 240 |
+
def __init__(self,
|
| 241 |
+
embed_dim: int,
|
| 242 |
+
# vision
|
| 243 |
+
image_resolution: int,
|
| 244 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 245 |
+
vision_width: int,
|
| 246 |
+
vision_patch_size: int,
|
| 247 |
+
# text
|
| 248 |
+
context_length: int,
|
| 249 |
+
vocab_size: int,
|
| 250 |
+
transformer_width: int,
|
| 251 |
+
transformer_heads: int,
|
| 252 |
+
transformer_layers: int
|
| 253 |
+
):
|
| 254 |
+
super().__init__()
|
| 255 |
+
|
| 256 |
+
self.context_length = context_length
|
| 257 |
+
|
| 258 |
+
if isinstance(vision_layers, (tuple, list)):
|
| 259 |
+
vision_heads = vision_width * 32 // 64
|
| 260 |
+
self.visual = ModifiedResNet(
|
| 261 |
+
layers=vision_layers,
|
| 262 |
+
output_dim=embed_dim,
|
| 263 |
+
heads=vision_heads,
|
| 264 |
+
input_resolution=image_resolution,
|
| 265 |
+
width=vision_width
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
vision_heads = vision_width // 64
|
| 269 |
+
self.visual = VisualTransformer(
|
| 270 |
+
input_resolution=image_resolution,
|
| 271 |
+
patch_size=vision_patch_size,
|
| 272 |
+
width=vision_width,
|
| 273 |
+
layers=vision_layers,
|
| 274 |
+
heads=vision_heads,
|
| 275 |
+
output_dim=embed_dim
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
self.transformer = Transformer(
|
| 279 |
+
width=transformer_width,
|
| 280 |
+
layers=transformer_layers,
|
| 281 |
+
heads=transformer_heads,
|
| 282 |
+
attn_mask=self.build_attention_mask()
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
self.vocab_size = vocab_size
|
| 286 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 287 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 288 |
+
self.ln_final = LayerNorm(transformer_width)
|
| 289 |
+
|
| 290 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 291 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 292 |
+
|
| 293 |
+
self.initialize_parameters()
|
| 294 |
+
|
| 295 |
+
def initialize_parameters(self):
|
| 296 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 297 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 298 |
+
|
| 299 |
+
if isinstance(self.visual, ModifiedResNet):
|
| 300 |
+
if self.visual.attnpool is not None:
|
| 301 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 302 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 303 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 304 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 305 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 306 |
+
|
| 307 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 308 |
+
for name, param in resnet_block.named_parameters():
|
| 309 |
+
if name.endswith("bn3.weight"):
|
| 310 |
+
nn.init.zeros_(param)
|
| 311 |
+
|
| 312 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 313 |
+
attn_std = self.transformer.width ** -0.5
|
| 314 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 315 |
+
for block in self.transformer.resblocks:
|
| 316 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 317 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 318 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 319 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 320 |
+
|
| 321 |
+
if self.text_projection is not None:
|
| 322 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 323 |
+
|
| 324 |
+
def build_attention_mask(self):
|
| 325 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 326 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 327 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 328 |
+
mask.fill_(float("-inf"))
|
| 329 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 330 |
+
return mask
|
| 331 |
+
|
| 332 |
+
@property
|
| 333 |
+
def dtype(self):
|
| 334 |
+
return self.visual.conv1.weight.dtype
|
| 335 |
+
|
| 336 |
+
def encode_image(self, image):
|
| 337 |
+
return self.visual(image.type(self.dtype))
|
| 338 |
+
|
| 339 |
+
def encode_text(self, text):
|
| 340 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 341 |
+
|
| 342 |
+
x = x + self.positional_embedding.type(self.dtype)
|
| 343 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 344 |
+
x = self.transformer(x)
|
| 345 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 346 |
+
x = self.ln_final(x).type(self.dtype)
|
| 347 |
+
|
| 348 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 349 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 350 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 351 |
+
|
| 352 |
+
return x
|
| 353 |
+
|
| 354 |
+
def encode_conv_features(self, features):
|
| 355 |
+
# pool to 7, the feature map resolution for 224x224 input
|
| 356 |
+
features = nn.AdaptiveAvgPool2d(7)(features)
|
| 357 |
+
return self.visual.attnpool(features)
|
| 358 |
+
|
| 359 |
+
def forward(self, image, text):
|
| 360 |
+
image_features = self.encode_image(image)
|
| 361 |
+
text_features = self.encode_text(text)
|
| 362 |
+
|
| 363 |
+
# normalized features
|
| 364 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 365 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 366 |
+
|
| 367 |
+
# cosine similarity as logits
|
| 368 |
+
logit_scale = self.logit_scale.exp()
|
| 369 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 370 |
+
logits_per_text = logit_scale * text_features @ image_features.t()
|
| 371 |
+
|
| 372 |
+
# shape = [global_batch_size, global_batch_size]
|
| 373 |
+
return logits_per_image, logits_per_text
|
| 374 |
+
|
| 375 |
+
def forward_features(self, features, text):
|
| 376 |
+
image_features = self.encode_conv_features(features)
|
| 377 |
+
text_features = self.encode_text(text)
|
| 378 |
+
|
| 379 |
+
# normalized features
|
| 380 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 381 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 382 |
+
|
| 383 |
+
# cosine similarity as logits
|
| 384 |
+
logit_scale = self.logit_scale.exp()
|
| 385 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 386 |
+
logits_per_text = logit_scale * text_features @ image_features.t()
|
| 387 |
+
|
| 388 |
+
# shape = [global_batch_size, global_batch_size]
|
| 389 |
+
return logits_per_image, logits_per_text
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def convert_weights(model: nn.Module):
|
| 393 |
+
"""Convert applicable model parameters to fp16"""
|
| 394 |
+
|
| 395 |
+
def _convert_weights_to_fp16(l):
|
| 396 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 397 |
+
l.weight.data = l.weight.data.half()
|
| 398 |
+
if l.bias is not None:
|
| 399 |
+
l.bias.data = l.bias.data.half()
|
| 400 |
+
|
| 401 |
+
if isinstance(l, nn.MultiheadAttention):
|
| 402 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 403 |
+
tensor = getattr(l, attr)
|
| 404 |
+
if tensor is not None:
|
| 405 |
+
tensor.data = tensor.data.half()
|
| 406 |
+
|
| 407 |
+
for name in ["text_projection", "proj"]:
|
| 408 |
+
if hasattr(l, name):
|
| 409 |
+
attr = getattr(l, name)
|
| 410 |
+
if attr is not None:
|
| 411 |
+
attr.data = attr.data.half()
|
| 412 |
+
|
| 413 |
+
model.apply(_convert_weights_to_fp16)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def build_model(state_dict: dict):
|
| 417 |
+
vit = "visual.proj" in state_dict
|
| 418 |
+
|
| 419 |
+
if vit:
|
| 420 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 421 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 422 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 423 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 424 |
+
image_resolution = vision_patch_size * grid_size
|
| 425 |
+
else:
|
| 426 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 427 |
+
vision_layers = tuple(counts)
|
| 428 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 429 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 430 |
+
vision_patch_size = None
|
| 431 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 432 |
+
image_resolution = output_width * 32
|
| 433 |
+
|
| 434 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 435 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 436 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 437 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 438 |
+
transformer_heads = transformer_width // 64
|
| 439 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
| 440 |
+
|
| 441 |
+
model = CLIP(
|
| 442 |
+
embed_dim,
|
| 443 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
| 444 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 448 |
+
if key in state_dict:
|
| 449 |
+
del state_dict[key]
|
| 450 |
+
|
| 451 |
+
# convert_weights(model)
|
| 452 |
+
model.load_state_dict(state_dict)
|
| 453 |
+
return model.eval()
|
feature_networks/clip/simple_tokenizer.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gzip
|
| 2 |
+
import html
|
| 3 |
+
import os
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
import ftfy
|
| 7 |
+
import regex as re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@lru_cache()
|
| 11 |
+
def default_bpe():
|
| 12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@lru_cache()
|
| 16 |
+
def bytes_to_unicode():
|
| 17 |
+
"""
|
| 18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 19 |
+
The reversible bpe codes work on unicode strings.
|
| 20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 25 |
+
"""
|
| 26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
| 27 |
+
cs = bs[:]
|
| 28 |
+
n = 0
|
| 29 |
+
for b in range(2**8):
|
| 30 |
+
if b not in bs:
|
| 31 |
+
bs.append(b)
|
| 32 |
+
cs.append(2**8+n)
|
| 33 |
+
n += 1
|
| 34 |
+
cs = [chr(n) for n in cs]
|
| 35 |
+
return dict(zip(bs, cs))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_pairs(word):
|
| 39 |
+
"""Return set of symbol pairs in a word.
|
| 40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 41 |
+
"""
|
| 42 |
+
pairs = set()
|
| 43 |
+
prev_char = word[0]
|
| 44 |
+
for char in word[1:]:
|
| 45 |
+
pairs.add((prev_char, char))
|
| 46 |
+
prev_char = char
|
| 47 |
+
return pairs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def basic_clean(text):
|
| 51 |
+
text = ftfy.fix_text(text)
|
| 52 |
+
text = html.unescape(html.unescape(text))
|
| 53 |
+
return text.strip()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def whitespace_clean(text):
|
| 57 |
+
text = re.sub(r'\s+', ' ', text)
|
| 58 |
+
text = text.strip()
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class SimpleTokenizer(object):
|
| 63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
| 64 |
+
self.byte_encoder = bytes_to_unicode()
|
| 65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 67 |
+
merges = merges[1:49152-256-2+1]
|
| 68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 69 |
+
vocab = list(bytes_to_unicode().values())
|
| 70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
| 71 |
+
for merge in merges:
|
| 72 |
+
vocab.append(''.join(merge))
|
| 73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
| 74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
| 78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
| 79 |
+
|
| 80 |
+
def bpe(self, token):
|
| 81 |
+
if token in self.cache:
|
| 82 |
+
return self.cache[token]
|
| 83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
| 84 |
+
pairs = get_pairs(word)
|
| 85 |
+
|
| 86 |
+
if not pairs:
|
| 87 |
+
return token+'</w>'
|
| 88 |
+
|
| 89 |
+
while True:
|
| 90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 91 |
+
if bigram not in self.bpe_ranks:
|
| 92 |
+
break
|
| 93 |
+
first, second = bigram
|
| 94 |
+
new_word = []
|
| 95 |
+
i = 0
|
| 96 |
+
while i < len(word):
|
| 97 |
+
try:
|
| 98 |
+
j = word.index(first, i)
|
| 99 |
+
new_word.extend(word[i:j])
|
| 100 |
+
i = j
|
| 101 |
+
except:
|
| 102 |
+
new_word.extend(word[i:])
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 106 |
+
new_word.append(first+second)
|
| 107 |
+
i += 2
|
| 108 |
+
else:
|
| 109 |
+
new_word.append(word[i])
|
| 110 |
+
i += 1
|
| 111 |
+
new_word = tuple(new_word)
|
| 112 |
+
word = new_word
|
| 113 |
+
if len(word) == 1:
|
| 114 |
+
break
|
| 115 |
+
else:
|
| 116 |
+
pairs = get_pairs(word)
|
| 117 |
+
word = ' '.join(word)
|
| 118 |
+
self.cache[token] = word
|
| 119 |
+
return word
|
| 120 |
+
|
| 121 |
+
def encode(self, text):
|
| 122 |
+
bpe_tokens = []
|
| 123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 124 |
+
for token in re.findall(self.pat, text):
|
| 125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 127 |
+
return bpe_tokens
|
| 128 |
+
|
| 129 |
+
def decode(self, tokens):
|
| 130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
| 131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 132 |
+
return text
|
feature_networks/constants.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TORCHVISION = [
|
| 2 |
+
"vgg11_bn",
|
| 3 |
+
"vgg13_bn",
|
| 4 |
+
"vgg16",
|
| 5 |
+
"vgg16_bn",
|
| 6 |
+
"vgg19_bn",
|
| 7 |
+
"densenet121",
|
| 8 |
+
"densenet169",
|
| 9 |
+
"densenet201",
|
| 10 |
+
"inception_v3",
|
| 11 |
+
"resnet18",
|
| 12 |
+
"resnet34",
|
| 13 |
+
"resnet50",
|
| 14 |
+
"resnet101",
|
| 15 |
+
"resnet152",
|
| 16 |
+
"shufflenet_v2_x0_5",
|
| 17 |
+
"mobilenet_v2",
|
| 18 |
+
"wide_resnet50_2",
|
| 19 |
+
"mnasnet0_5",
|
| 20 |
+
"mnasnet1_0",
|
| 21 |
+
"ghostnet_100",
|
| 22 |
+
"cspresnet50",
|
| 23 |
+
"fbnetc_100",
|
| 24 |
+
"spnasnet_100",
|
| 25 |
+
"resnet50d",
|
| 26 |
+
"resnet26",
|
| 27 |
+
"resnet26d",
|
| 28 |
+
"seresnet50",
|
| 29 |
+
"resnetblur50",
|
| 30 |
+
"resnetrs50",
|
| 31 |
+
"tf_mixnet_s",
|
| 32 |
+
"tf_mixnet_m",
|
| 33 |
+
"tf_mixnet_l",
|
| 34 |
+
"ese_vovnet19b_dw",
|
| 35 |
+
"ese_vovnet39b",
|
| 36 |
+
"res2next50",
|
| 37 |
+
"gernet_s",
|
| 38 |
+
"gernet_m",
|
| 39 |
+
"repvgg_a2",
|
| 40 |
+
"repvgg_b0",
|
| 41 |
+
"repvgg_b1",
|
| 42 |
+
"repvgg_b1g4",
|
| 43 |
+
"revnet",
|
| 44 |
+
"dm_nfnet_f1",
|
| 45 |
+
"nfnet_l0",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
REGNETS = [
|
| 49 |
+
"regnetx_002",
|
| 50 |
+
"regnetx_004",
|
| 51 |
+
"regnetx_006",
|
| 52 |
+
"regnetx_008",
|
| 53 |
+
"regnetx_016",
|
| 54 |
+
"regnetx_032",
|
| 55 |
+
"regnetx_040",
|
| 56 |
+
"regnetx_064",
|
| 57 |
+
"regnety_002",
|
| 58 |
+
"regnety_004",
|
| 59 |
+
"regnety_006",
|
| 60 |
+
"regnety_008",
|
| 61 |
+
"regnety_016",
|
| 62 |
+
"regnety_032",
|
| 63 |
+
"regnety_040",
|
| 64 |
+
"regnety_064",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
EFFNETS_IMAGENET = [
|
| 68 |
+
'tf_efficientnet_b0',
|
| 69 |
+
'tf_efficientnet_b1',
|
| 70 |
+
'tf_efficientnet_b2',
|
| 71 |
+
'tf_efficientnet_b3',
|
| 72 |
+
'tf_efficientnet_b4',
|
| 73 |
+
'tf_efficientnet_b0_ns',
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
EFFNETS_INCEPTION = [
|
| 77 |
+
'tf_efficientnet_lite0',
|
| 78 |
+
'tf_efficientnet_lite1',
|
| 79 |
+
'tf_efficientnet_lite2',
|
| 80 |
+
'tf_efficientnet_lite3',
|
| 81 |
+
'tf_efficientnet_lite4',
|
| 82 |
+
'tf_efficientnetv2_b0',
|
| 83 |
+
'tf_efficientnetv2_b1',
|
| 84 |
+
'tf_efficientnetv2_b2',
|
| 85 |
+
'tf_efficientnetv2_b3',
|
| 86 |
+
'efficientnet_b1',
|
| 87 |
+
'efficientnet_b1_pruned',
|
| 88 |
+
'efficientnet_b2_pruned',
|
| 89 |
+
'efficientnet_b3_pruned',
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
EFFNETS = EFFNETS_IMAGENET + EFFNETS_INCEPTION
|
| 93 |
+
|
| 94 |
+
VITS_IMAGENET = [
|
| 95 |
+
'deit_tiny_distilled_patch16_224',
|
| 96 |
+
'deit_small_distilled_patch16_224',
|
| 97 |
+
'deit_base_distilled_patch16_224',
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
VITS_INCEPTION = [
|
| 101 |
+
'vit_base_patch16_224',
|
| 102 |
+
'vit_large_patch16_224'
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
MAES = [
|
| 106 |
+
'mae_vit_base_patch16',
|
| 107 |
+
'mae_vit_large_patch16',
|
| 108 |
+
'mae_vit_huge_patch14'
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
ST= ['swin_base_patch4_window7_224']
|
| 112 |
+
|
| 113 |
+
TNT = ['tnt_b_patch16_224']
|
| 114 |
+
|
| 115 |
+
VITS = VITS_IMAGENET + VITS_INCEPTION
|
| 116 |
+
|
| 117 |
+
CLIP = [
|
| 118 |
+
'resnet50_clip'
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
ALL_MODELS = TORCHVISION + REGNETS + EFFNETS + VITS + CLIP + MAES + TNT
|
| 122 |
+
|
| 123 |
+
# Group according to input normalization
|
| 124 |
+
|
| 125 |
+
NORMALIZED_IMAGENET = TORCHVISION + REGNETS + EFFNETS_IMAGENET + VITS_IMAGENET
|
| 126 |
+
|
| 127 |
+
NORMALIZED_INCEPTION = EFFNETS_INCEPTION + VITS_INCEPTION + MAES + TNT + ST
|
| 128 |
+
|
| 129 |
+
NORMALIZED_CLIP = CLIP
|
feature_networks/pretrained_builder.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchvision.models as zoomodels
|
| 5 |
+
from torch.autograd import Function
|
| 6 |
+
|
| 7 |
+
import timm
|
| 8 |
+
|
| 9 |
+
from feature_networks import clip
|
| 10 |
+
from feature_networks.vit import _make_vit_b16_backbone, forward_vit
|
| 11 |
+
from feature_networks.constants import ALL_MODELS, VITS, EFFNETS, REGNETS
|
| 12 |
+
from pg_modules.blocks import Interpolate
|
| 13 |
+
|
| 14 |
+
def _feature_splitter(model, idcs):
|
| 15 |
+
pretrained = nn.Module()
|
| 16 |
+
pretrained.layer0 = nn.Sequential(model.features[:idcs[0]])
|
| 17 |
+
pretrained.layer1 = nn.Sequential(model.features[idcs[0]:idcs[1]])
|
| 18 |
+
pretrained.layer2 = nn.Sequential(model.features[idcs[1]:idcs[2]])
|
| 19 |
+
pretrained.layer3 = nn.Sequential(model.features[idcs[2]:idcs[3]])
|
| 20 |
+
return pretrained
|
| 21 |
+
|
| 22 |
+
def _make_resnet(model):
|
| 23 |
+
pretrained = nn.Module()
|
| 24 |
+
pretrained.layer0 = nn.Sequential(
|
| 25 |
+
model.conv1, model.bn1, model.relu, model.maxpool, model.layer1,
|
| 26 |
+
)
|
| 27 |
+
pretrained.layer1 = model.layer2
|
| 28 |
+
pretrained.layer2 = model.layer3
|
| 29 |
+
pretrained.layer3 = model.layer4
|
| 30 |
+
return pretrained
|
| 31 |
+
|
| 32 |
+
def _make_regnet(model):
|
| 33 |
+
pretrained = nn.Module()
|
| 34 |
+
pretrained.layer0 = nn.Sequential(
|
| 35 |
+
model.stem, model.s1
|
| 36 |
+
)
|
| 37 |
+
pretrained.layer1 = model.s2
|
| 38 |
+
pretrained.layer2 = model.s3
|
| 39 |
+
pretrained.layer3 = model.s4
|
| 40 |
+
return pretrained
|
| 41 |
+
|
| 42 |
+
def _make_nfnet(model):
|
| 43 |
+
pretrained = nn.Module()
|
| 44 |
+
pretrained.layer0 = nn.Sequential(
|
| 45 |
+
model.stem, model.stages[0]
|
| 46 |
+
)
|
| 47 |
+
pretrained.layer1 = model.stages[1]
|
| 48 |
+
pretrained.layer2 = model.stages[2]
|
| 49 |
+
pretrained.layer3 = model.stages[3]
|
| 50 |
+
return pretrained
|
| 51 |
+
|
| 52 |
+
def _make_resnet_v2(model):
|
| 53 |
+
pretrained = nn.Module()
|
| 54 |
+
pretrained.layer0 = nn.Sequential(model.stem, model.stages[0])
|
| 55 |
+
pretrained.layer1 = model.stages[1]
|
| 56 |
+
pretrained.layer2 = model.stages[2]
|
| 57 |
+
pretrained.layer3 = model.stages[3]
|
| 58 |
+
return pretrained
|
| 59 |
+
|
| 60 |
+
def _make_resnet_clip(model):
|
| 61 |
+
pretrained = nn.Module()
|
| 62 |
+
|
| 63 |
+
# slightly more complicated than the standard resnet
|
| 64 |
+
pretrained.layer0 = nn.Sequential(
|
| 65 |
+
model.conv1,
|
| 66 |
+
model.bn1,
|
| 67 |
+
model.relu,
|
| 68 |
+
model.conv2,
|
| 69 |
+
model.bn2,
|
| 70 |
+
model.relu,
|
| 71 |
+
model.conv3,
|
| 72 |
+
model.bn3,
|
| 73 |
+
model.relu,
|
| 74 |
+
model.avgpool,
|
| 75 |
+
model.layer1,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
pretrained.layer1 = model.layer2
|
| 79 |
+
pretrained.layer2 = model.layer3
|
| 80 |
+
pretrained.layer3 = model.layer4
|
| 81 |
+
|
| 82 |
+
return pretrained
|
| 83 |
+
|
| 84 |
+
def _make_densenet(model):
|
| 85 |
+
pretrained = nn.Module()
|
| 86 |
+
|
| 87 |
+
pretrained.layer0 = model.features[:6]
|
| 88 |
+
|
| 89 |
+
pretrained.layer1 = model.features[6:8]
|
| 90 |
+
pretrained.layer1[-1][-1] = nn.Identity()
|
| 91 |
+
pretrained.layer1 = nn.Sequential(nn.AvgPool2d(2, 2), pretrained.layer1)
|
| 92 |
+
|
| 93 |
+
pretrained.layer2 = model.features[8:10]
|
| 94 |
+
pretrained.layer2[-1][-1] = nn.Identity()
|
| 95 |
+
pretrained.layer2 = nn.Sequential(nn.AvgPool2d(2, 2), pretrained.layer2)
|
| 96 |
+
|
| 97 |
+
pretrained.layer3 = model.features[10:12]
|
| 98 |
+
pretrained.layer3 = nn.Sequential(nn.AvgPool2d(2, 2), pretrained.layer3)
|
| 99 |
+
|
| 100 |
+
return pretrained
|
| 101 |
+
|
| 102 |
+
def _make_shufflenet(model):
|
| 103 |
+
pretrained = nn.Module()
|
| 104 |
+
pretrained.layer0 = nn.Sequential(model.conv1, model.maxpool)
|
| 105 |
+
pretrained.layer1 = model.stage2
|
| 106 |
+
pretrained.layer2 = model.stage3
|
| 107 |
+
pretrained.layer3 = model.stage4
|
| 108 |
+
return pretrained
|
| 109 |
+
|
| 110 |
+
def _make_cspresnet(model):
|
| 111 |
+
pretrained = nn.Module()
|
| 112 |
+
pretrained.layer0 = nn.Sequential(model.stem, model.stages[0])
|
| 113 |
+
pretrained.layer1 = model.stages[1]
|
| 114 |
+
pretrained.layer2 = model.stages[2]
|
| 115 |
+
pretrained.layer3 = model.stages[3]
|
| 116 |
+
return pretrained
|
| 117 |
+
|
| 118 |
+
def _make_efficientnet(model):
|
| 119 |
+
pretrained = nn.Module()
|
| 120 |
+
pretrained.layer0 = nn.Sequential(
|
| 121 |
+
model.conv_stem, model.bn1, model.act1, *model.blocks[0:2]
|
| 122 |
+
)
|
| 123 |
+
pretrained.layer1 = nn.Sequential(*model.blocks[2:3])
|
| 124 |
+
pretrained.layer2 = nn.Sequential(*model.blocks[3:5])
|
| 125 |
+
pretrained.layer3 = nn.Sequential(*model.blocks[5:9])
|
| 126 |
+
return pretrained
|
| 127 |
+
|
| 128 |
+
def _make_ghostnet(model):
|
| 129 |
+
pretrained = nn.Module()
|
| 130 |
+
pretrained.layer0 = nn.Sequential(
|
| 131 |
+
model.conv_stem, model.bn1, model.act1, *model.blocks[0:3],
|
| 132 |
+
)
|
| 133 |
+
pretrained.layer1 = nn.Sequential(*model.blocks[3:5])
|
| 134 |
+
pretrained.layer2 = nn.Sequential(*model.blocks[5:7])
|
| 135 |
+
pretrained.layer3 = nn.Sequential(*model.blocks[7:-1])
|
| 136 |
+
return pretrained
|
| 137 |
+
|
| 138 |
+
def _make_vit(model, name):
|
| 139 |
+
if 'tiny' in name:
|
| 140 |
+
features = [24, 48, 96, 192]
|
| 141 |
+
hooks = [2, 5, 8, 11]
|
| 142 |
+
vit_features = 192
|
| 143 |
+
|
| 144 |
+
elif 'small' in name:
|
| 145 |
+
features = [48, 96, 192, 384]
|
| 146 |
+
hooks = [2, 5, 8, 11]
|
| 147 |
+
vit_features = 384
|
| 148 |
+
|
| 149 |
+
elif 'base' in name:
|
| 150 |
+
features = [96, 192, 384, 768]
|
| 151 |
+
hooks = [2, 5, 8, 11]
|
| 152 |
+
vit_features = 768
|
| 153 |
+
|
| 154 |
+
elif 'large' in name:
|
| 155 |
+
features = [256, 512, 1024, 1024]
|
| 156 |
+
hooks = [5, 11, 17, 23]
|
| 157 |
+
vit_features = 1024
|
| 158 |
+
|
| 159 |
+
else:
|
| 160 |
+
raise NotImplementedError('Invalid ViT backbone not available')
|
| 161 |
+
|
| 162 |
+
return _make_vit_b16_backbone(
|
| 163 |
+
model,
|
| 164 |
+
features=features,
|
| 165 |
+
size=[224, 224],
|
| 166 |
+
hooks=hooks,
|
| 167 |
+
vit_features=vit_features,
|
| 168 |
+
start_index=2 if 'deit' in name else 1,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def calc_dims(pretrained, is_vit=False):
|
| 172 |
+
dims = []
|
| 173 |
+
inp_res = 256
|
| 174 |
+
tmp = torch.zeros(1, 3, inp_res, inp_res)
|
| 175 |
+
|
| 176 |
+
if not is_vit:
|
| 177 |
+
tmp = pretrained.layer0(tmp)
|
| 178 |
+
dims.append(tmp.shape[1:3])
|
| 179 |
+
tmp = pretrained.layer1(tmp)
|
| 180 |
+
dims.append(tmp.shape[1:3])
|
| 181 |
+
tmp = pretrained.layer2(tmp)
|
| 182 |
+
dims.append(tmp.shape[1:3])
|
| 183 |
+
tmp = pretrained.layer3(tmp)
|
| 184 |
+
dims.append(tmp.shape[1:3])
|
| 185 |
+
else:
|
| 186 |
+
tmp = forward_vit(pretrained, tmp)
|
| 187 |
+
dims = [out.shape[1:3] for out in tmp]
|
| 188 |
+
|
| 189 |
+
# split to channels and resolution multiplier
|
| 190 |
+
dims = np.array(dims)
|
| 191 |
+
channels = dims[:, 0]
|
| 192 |
+
res_mult = dims[:, 1] / inp_res
|
| 193 |
+
return channels, res_mult
|
| 194 |
+
|
| 195 |
+
def _make_pretrained(backbone, verbose=False):
|
| 196 |
+
assert backbone in ALL_MODELS
|
| 197 |
+
|
| 198 |
+
if backbone == 'vgg11_bn':
|
| 199 |
+
model = zoomodels.__dict__[backbone](True)
|
| 200 |
+
idcs = [7, 14, 21, 28]
|
| 201 |
+
pretrained = _feature_splitter(model, idcs)
|
| 202 |
+
|
| 203 |
+
elif backbone == 'vgg13_bn':
|
| 204 |
+
model = zoomodels.__dict__[backbone](True)
|
| 205 |
+
idcs = [13, 20, 27, 34]
|
| 206 |
+
pretrained = _feature_splitter(model, idcs)
|
| 207 |
+
|
| 208 |
+
elif backbone == 'vgg16_bn':
|
| 209 |
+
model = zoomodels.__dict__[backbone](True)
|
| 210 |
+
idcs = [13, 23, 33, 43]
|
| 211 |
+
pretrained = _feature_splitter(model, idcs)
|
| 212 |
+
|
| 213 |
+
elif backbone == 'vgg19_bn':
|
| 214 |
+
model = zoomodels.__dict__[backbone](True)
|
| 215 |
+
idcs = [13, 26, 39, 52]
|
| 216 |
+
pretrained = _feature_splitter(model, idcs)
|
| 217 |
+
|
| 218 |
+
elif backbone == 'densenet121':
|
| 219 |
+
model = zoomodels.__dict__[backbone](True)
|
| 220 |
+
pretrained = _make_densenet(model)
|
| 221 |
+
|
| 222 |
+
elif backbone == 'densenet169':
|
| 223 |
+
model = zoomodels.__dict__[backbone](True)
|
| 224 |
+
pretrained = _make_densenet(model)
|
| 225 |
+
|
| 226 |
+
elif backbone == 'densenet201':
|
| 227 |
+
model = zoomodels.__dict__[backbone](True)
|
| 228 |
+
pretrained = _make_densenet(model)
|
| 229 |
+
|
| 230 |
+
elif backbone == 'resnet18':
|
| 231 |
+
model = zoomodels.__dict__[backbone](True)
|
| 232 |
+
pretrained = _make_resnet(model)
|
| 233 |
+
|
| 234 |
+
elif backbone == 'resnet34':
|
| 235 |
+
model = zoomodels.__dict__[backbone](True)
|
| 236 |
+
pretrained = _make_resnet(model)
|
| 237 |
+
|
| 238 |
+
elif backbone == 'resnet50':
|
| 239 |
+
model = zoomodels.__dict__[backbone](True)
|
| 240 |
+
pretrained = _make_resnet(model)
|
| 241 |
+
|
| 242 |
+
elif backbone == 'resnet101':
|
| 243 |
+
model = zoomodels.__dict__[backbone](True)
|
| 244 |
+
pretrained = _make_resnet(model)
|
| 245 |
+
|
| 246 |
+
elif backbone == 'resnet152':
|
| 247 |
+
model = zoomodels.__dict__[backbone](True)
|
| 248 |
+
pretrained = _make_resnet(model)
|
| 249 |
+
|
| 250 |
+
elif backbone == 'wide_resnet50_2':
|
| 251 |
+
model = zoomodels.__dict__[backbone](True)
|
| 252 |
+
pretrained = _make_resnet(model)
|
| 253 |
+
|
| 254 |
+
elif backbone == 'wide_resnet101_2':
|
| 255 |
+
model = zoomodels.__dict__[backbone](True)
|
| 256 |
+
pretrained = _make_resnet(model)
|
| 257 |
+
|
| 258 |
+
elif backbone == 'shufflenet_v2_x0_5':
|
| 259 |
+
model = zoomodels.__dict__[backbone](True)
|
| 260 |
+
pretrained = _make_shufflenet(model)
|
| 261 |
+
|
| 262 |
+
elif backbone == 'mobilenet_v2':
|
| 263 |
+
model = zoomodels.__dict__[backbone](True)
|
| 264 |
+
idcs = [4, 7, 14, 18]
|
| 265 |
+
pretrained = _feature_splitter(model, idcs) # same structure as vgg
|
| 266 |
+
|
| 267 |
+
elif backbone == 'mnasnet0_5':
|
| 268 |
+
model = zoomodels.__dict__[backbone](True)
|
| 269 |
+
model.features = model.layers
|
| 270 |
+
idcs = [9, 10, 12, 14]
|
| 271 |
+
pretrained = _feature_splitter(model, idcs)
|
| 272 |
+
|
| 273 |
+
elif backbone == 'mnasnet1_0':
|
| 274 |
+
model = zoomodels.__dict__[backbone](True)
|
| 275 |
+
model.features = model.layers
|
| 276 |
+
idcs = [9, 10, 12, 14]
|
| 277 |
+
pretrained = _feature_splitter(model, idcs)
|
| 278 |
+
|
| 279 |
+
elif backbone == 'ghostnet_100':
|
| 280 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 281 |
+
pretrained = _make_ghostnet(model)
|
| 282 |
+
|
| 283 |
+
elif backbone == 'cspresnet50':
|
| 284 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 285 |
+
pretrained = _make_cspresnet(model)
|
| 286 |
+
|
| 287 |
+
elif backbone == 'fbnetc_100':
|
| 288 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 289 |
+
pretrained = _make_efficientnet(model)
|
| 290 |
+
|
| 291 |
+
elif backbone == 'spnasnet_100':
|
| 292 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 293 |
+
pretrained = _make_efficientnet(model)
|
| 294 |
+
|
| 295 |
+
elif backbone == 'resnet50d':
|
| 296 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 297 |
+
model.relu = model.act1
|
| 298 |
+
pretrained = _make_resnet(model)
|
| 299 |
+
|
| 300 |
+
elif backbone == 'resnet26':
|
| 301 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 302 |
+
model.relu = model.act1
|
| 303 |
+
pretrained = _make_resnet(model)
|
| 304 |
+
|
| 305 |
+
elif backbone == 'resnet26d':
|
| 306 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 307 |
+
model.relu = model.act1
|
| 308 |
+
pretrained = _make_resnet(model)
|
| 309 |
+
|
| 310 |
+
elif backbone == 'seresnet50':
|
| 311 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 312 |
+
model.relu = model.act1
|
| 313 |
+
pretrained = _make_resnet(model)
|
| 314 |
+
|
| 315 |
+
elif backbone == 'resnetblur50':
|
| 316 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 317 |
+
model.relu = model.act1
|
| 318 |
+
pretrained = _make_resnet(model)
|
| 319 |
+
|
| 320 |
+
elif backbone == 'resnetrs50':
|
| 321 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 322 |
+
model.relu = model.act1
|
| 323 |
+
pretrained = _make_resnet(model)
|
| 324 |
+
|
| 325 |
+
elif backbone == 'tf_mixnet_s':
|
| 326 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 327 |
+
pretrained = _make_efficientnet(model)
|
| 328 |
+
|
| 329 |
+
elif backbone == 'tf_mixnet_m':
|
| 330 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 331 |
+
pretrained = _make_efficientnet(model)
|
| 332 |
+
|
| 333 |
+
elif backbone == 'tf_mixnet_l':
|
| 334 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 335 |
+
pretrained = _make_efficientnet(model)
|
| 336 |
+
|
| 337 |
+
elif backbone == 'dm_nfnet_f0':
|
| 338 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 339 |
+
pretrained = _make_cspresnet(model)
|
| 340 |
+
|
| 341 |
+
elif backbone == 'dm_nfnet_f1':
|
| 342 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 343 |
+
pretrained = _make_cspresnet(model)
|
| 344 |
+
|
| 345 |
+
elif backbone == 'ese_vovnet19b_dw':
|
| 346 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 347 |
+
pretrained = _make_cspresnet(model)
|
| 348 |
+
|
| 349 |
+
elif backbone == 'ese_vovnet39b':
|
| 350 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 351 |
+
pretrained = _make_cspresnet(model)
|
| 352 |
+
|
| 353 |
+
elif backbone == 'res2next50':
|
| 354 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 355 |
+
model.relu = model.act1
|
| 356 |
+
pretrained = _make_resnet(model)
|
| 357 |
+
|
| 358 |
+
elif backbone == 'gernet_s':
|
| 359 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 360 |
+
pretrained = _make_cspresnet(model)
|
| 361 |
+
|
| 362 |
+
elif backbone == 'gernet_m':
|
| 363 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 364 |
+
pretrained = _make_cspresnet(model)
|
| 365 |
+
|
| 366 |
+
elif backbone == 'repvgg_a2':
|
| 367 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 368 |
+
pretrained = _make_cspresnet(model)
|
| 369 |
+
|
| 370 |
+
elif backbone == 'repvgg_b0':
|
| 371 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 372 |
+
pretrained = _make_cspresnet(model)
|
| 373 |
+
|
| 374 |
+
elif backbone == 'repvgg_b1':
|
| 375 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 376 |
+
pretrained = _make_cspresnet(model)
|
| 377 |
+
|
| 378 |
+
elif backbone == 'repvgg_b1g4':
|
| 379 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 380 |
+
pretrained = _make_cspresnet(model)
|
| 381 |
+
|
| 382 |
+
elif backbone == 'dm_nfnet_f1':
|
| 383 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 384 |
+
pretrained = _make_nfnet(model)
|
| 385 |
+
|
| 386 |
+
elif backbone == 'nfnet_l0':
|
| 387 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 388 |
+
pretrained = _make_nfnet(model)
|
| 389 |
+
|
| 390 |
+
elif backbone in REGNETS:
|
| 391 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 392 |
+
pretrained = _make_regnet(model)
|
| 393 |
+
|
| 394 |
+
elif backbone in EFFNETS:
|
| 395 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 396 |
+
pretrained = _make_efficientnet(model)
|
| 397 |
+
|
| 398 |
+
elif backbone in VITS:
|
| 399 |
+
model = timm.create_model(backbone, pretrained=True)
|
| 400 |
+
pretrained = _make_vit(model, backbone)
|
| 401 |
+
|
| 402 |
+
elif backbone == 'resnet50_clip':
|
| 403 |
+
model = clip.load('RN50', device='cpu', jit=False)[0].visual
|
| 404 |
+
pretrained = _make_resnet_clip(model)
|
| 405 |
+
|
| 406 |
+
else:
|
| 407 |
+
raise NotImplementedError('Wrong model name?')
|
| 408 |
+
|
| 409 |
+
pretrained.CHANNELS, pretrained.RES_MULT = calc_dims(pretrained, is_vit=backbone in VITS)
|
| 410 |
+
|
| 411 |
+
if verbose:
|
| 412 |
+
print(f"Succesfully loaded: {backbone}")
|
| 413 |
+
print(f"Channels: {pretrained.CHANNELS}")
|
| 414 |
+
print(f"Resolution Multiplier: {pretrained.RES_MULT}")
|
| 415 |
+
print(f"Out Res for 256 : {pretrained.RES_MULT*256}")
|
| 416 |
+
|
| 417 |
+
return pretrained
|
feature_networks/vit.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import timm
|
| 4 |
+
import types
|
| 5 |
+
import math
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Slice(nn.Module):
|
| 10 |
+
def __init__(self, start_index=1):
|
| 11 |
+
super(Slice, self).__init__()
|
| 12 |
+
self.start_index = start_index
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
return x[:, self.start_index :]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AddReadout(nn.Module):
|
| 19 |
+
def __init__(self, start_index=1):
|
| 20 |
+
super(AddReadout, self).__init__()
|
| 21 |
+
self.start_index = start_index
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
if self.start_index == 2:
|
| 25 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
| 26 |
+
else:
|
| 27 |
+
readout = x[:, 0]
|
| 28 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ProjectReadout(nn.Module):
|
| 32 |
+
def __init__(self, in_features, start_index=1):
|
| 33 |
+
super(ProjectReadout, self).__init__()
|
| 34 |
+
self.start_index = start_index
|
| 35 |
+
|
| 36 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
| 40 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
| 41 |
+
|
| 42 |
+
return self.project(features)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Transpose(nn.Module):
|
| 46 |
+
def __init__(self, dim0, dim1):
|
| 47 |
+
super(Transpose, self).__init__()
|
| 48 |
+
self.dim0 = dim0
|
| 49 |
+
self.dim1 = dim1
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x = x.transpose(self.dim0, self.dim1)
|
| 53 |
+
return x.contiguous()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def forward_vit(pretrained, x):
|
| 57 |
+
b, c, h, w = x.shape
|
| 58 |
+
|
| 59 |
+
lantent,_ = pretrained.model.forward_flex(x)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
layer_1 = pretrained.activations["1"]
|
| 63 |
+
layer_2 = pretrained.activations["2"]
|
| 64 |
+
layer_3 = pretrained.activations["3"]
|
| 65 |
+
layer_4 = pretrained.activations["4"]
|
| 66 |
+
|
| 67 |
+
layer_1 = pretrained.layer1[0:2](layer_1)
|
| 68 |
+
layer_2 = pretrained.layer2[0:2](layer_2)
|
| 69 |
+
layer_3 = pretrained.layer3[0:2](layer_3)
|
| 70 |
+
layer_4 = pretrained.layer4[0:2](layer_4)
|
| 71 |
+
|
| 72 |
+
unflatten = nn.Sequential(
|
| 73 |
+
nn.Unflatten(
|
| 74 |
+
2,
|
| 75 |
+
torch.Size(
|
| 76 |
+
[
|
| 77 |
+
h // pretrained.model.patch_size[1],
|
| 78 |
+
w // pretrained.model.patch_size[0],
|
| 79 |
+
]
|
| 80 |
+
),
|
| 81 |
+
)
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if layer_1.ndim == 3:
|
| 85 |
+
layer_1 = unflatten(layer_1)
|
| 86 |
+
if layer_2.ndim == 3:
|
| 87 |
+
layer_2 = unflatten(layer_2)
|
| 88 |
+
if layer_3.ndim == 3:
|
| 89 |
+
layer_3 = unflatten(layer_3)
|
| 90 |
+
if layer_4.ndim == 3:
|
| 91 |
+
layer_4 = unflatten(layer_4)
|
| 92 |
+
|
| 93 |
+
layer_1 = pretrained.layer1[3 : len(pretrained.layer1)](layer_1)
|
| 94 |
+
layer_2 = pretrained.layer2[3 : len(pretrained.layer2)](layer_2)
|
| 95 |
+
layer_3 = pretrained.layer3[3 : len(pretrained.layer3)](layer_3)
|
| 96 |
+
layer_4 = pretrained.layer4[3 : len(pretrained.layer4)](layer_4)
|
| 97 |
+
|
| 98 |
+
return layer_1, layer_2, layer_3, layer_4
|
| 99 |
+
|
| 100 |
+
def forward_swin(pretrained, x):
|
| 101 |
+
b, c, h, w = x.shape
|
| 102 |
+
|
| 103 |
+
lantent,_ = pretrained.model.forward_flex(x)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
layer_1 = pretrained.activations["1"]
|
| 107 |
+
layer_2 = pretrained.activations["2"]
|
| 108 |
+
layer_3 = pretrained.activations["3"]
|
| 109 |
+
layer_4 = pretrained.activations["4"]
|
| 110 |
+
|
| 111 |
+
layer_1 = pretrained.layer1[0:2](layer_1)
|
| 112 |
+
layer_2 = pretrained.layer2[0:2](layer_2)
|
| 113 |
+
layer_3 = pretrained.layer3[0:2](layer_3)
|
| 114 |
+
layer_4 = pretrained.layer4[0:2](layer_4)
|
| 115 |
+
|
| 116 |
+
unflatten = nn.Sequential(
|
| 117 |
+
nn.Unflatten(
|
| 118 |
+
2,
|
| 119 |
+
torch.Size(
|
| 120 |
+
[
|
| 121 |
+
h // pretrained.model.patch_size[1],
|
| 122 |
+
w // pretrained.model.patch_size[0],
|
| 123 |
+
]
|
| 124 |
+
),
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if layer_1.ndim == 3:
|
| 129 |
+
layer_1 = unflatten(layer_1)
|
| 130 |
+
if layer_2.ndim == 3:
|
| 131 |
+
layer_2 = unflatten(layer_2)
|
| 132 |
+
if layer_3.ndim == 3:
|
| 133 |
+
layer_3 = unflatten(layer_3)
|
| 134 |
+
if layer_4.ndim == 3:
|
| 135 |
+
layer_4 = unflatten(layer_4)
|
| 136 |
+
|
| 137 |
+
layer_1 = pretrained.layer1[3 : len(pretrained.layer1)](layer_1)
|
| 138 |
+
layer_2 = pretrained.layer2[3 : len(pretrained.layer2)](layer_2)
|
| 139 |
+
layer_3 = pretrained.layer3[3 : len(pretrained.layer3)](layer_3)
|
| 140 |
+
layer_4 = pretrained.layer4[3 : len(pretrained.layer4)](layer_4)
|
| 141 |
+
|
| 142 |
+
return layer_1, layer_2, layer_3, layer_4
|
| 143 |
+
|
| 144 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
| 145 |
+
posemb_tok, posemb_grid = (
|
| 146 |
+
posemb[:, : self.start_index],
|
| 147 |
+
posemb[0, self.start_index :],
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
| 151 |
+
|
| 152 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
| 153 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False)
|
| 154 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
| 155 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
| 156 |
+
|
| 157 |
+
return posemb
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def forward_flex(self, x):
|
| 161 |
+
b, c, h, w = x.shape
|
| 162 |
+
# print(x.shape, self.OOD2ID)
|
| 163 |
+
# x = self.OOD2ID(x)
|
| 164 |
+
|
| 165 |
+
pos_embed = self._resize_pos_embed(
|
| 166 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
B = x.shape[0]
|
| 170 |
+
|
| 171 |
+
if hasattr(self.patch_embed, "backbone"):
|
| 172 |
+
x = self.patch_embed.backbone(x)
|
| 173 |
+
if isinstance(x, (list, tuple)):
|
| 174 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
| 175 |
+
|
| 176 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
| 177 |
+
|
| 178 |
+
if hasattr(self, "dist_token") and self.dist_token is not None:
|
| 179 |
+
cls_tokens = self.cls_token.expand(
|
| 180 |
+
B, -1, -1
|
| 181 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
| 182 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
| 183 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
| 184 |
+
else:
|
| 185 |
+
cls_tokens = self.cls_token.expand(
|
| 186 |
+
B, -1, -1
|
| 187 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
| 188 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 189 |
+
|
| 190 |
+
x = x + pos_embed
|
| 191 |
+
x = self.pos_drop(x)
|
| 192 |
+
|
| 193 |
+
for blk in self.blocks:
|
| 194 |
+
x = blk(x)
|
| 195 |
+
|
| 196 |
+
x = self.norm(x)
|
| 197 |
+
|
| 198 |
+
return x, None
|
| 199 |
+
|
| 200 |
+
def forward_flex_swin(self, x):
|
| 201 |
+
x = self.patch_embed(x)
|
| 202 |
+
if self.absolute_pos_embed is not None:
|
| 203 |
+
x = x + self.absolute_pos_embed
|
| 204 |
+
x = self.pos_drop(x)
|
| 205 |
+
x = self.layers(x)
|
| 206 |
+
x = self.norm(x) # B L C
|
| 207 |
+
|
| 208 |
+
return x
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
activations = {}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def get_activation(name):
|
| 216 |
+
def hook(model, input, output):
|
| 217 |
+
activations[name] = output
|
| 218 |
+
|
| 219 |
+
return hook
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
| 223 |
+
if use_readout == "ignore":
|
| 224 |
+
readout_oper = [Slice(start_index)] * len(features)
|
| 225 |
+
elif use_readout == "add":
|
| 226 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
| 227 |
+
elif use_readout == "project":
|
| 228 |
+
readout_oper = [
|
| 229 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
| 230 |
+
]
|
| 231 |
+
else:
|
| 232 |
+
assert (
|
| 233 |
+
False
|
| 234 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
| 235 |
+
|
| 236 |
+
return readout_oper
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _make_vit_b16_backbone(
|
| 240 |
+
model,
|
| 241 |
+
features=[96, 192, 384, 768],
|
| 242 |
+
size=[384, 384],
|
| 243 |
+
hooks=[2, 5, 8, 11],
|
| 244 |
+
vit_features=768,
|
| 245 |
+
use_readout="ignore",
|
| 246 |
+
start_index=1,
|
| 247 |
+
):
|
| 248 |
+
pretrained = nn.Module()
|
| 249 |
+
|
| 250 |
+
pretrained.model = model
|
| 251 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
| 252 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
| 253 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
| 254 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
| 255 |
+
|
| 256 |
+
pretrained.activations = activations
|
| 257 |
+
|
| 258 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
| 259 |
+
|
| 260 |
+
# 32, 48, 136, 384
|
| 261 |
+
pretrained.layer1 = nn.Sequential(
|
| 262 |
+
readout_oper[0],
|
| 263 |
+
Transpose(1, 2),
|
| 264 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 265 |
+
nn.Conv2d(
|
| 266 |
+
in_channels=vit_features,
|
| 267 |
+
out_channels=features[0],
|
| 268 |
+
kernel_size=1,
|
| 269 |
+
stride=1,
|
| 270 |
+
padding=0,
|
| 271 |
+
),
|
| 272 |
+
nn.ConvTranspose2d(
|
| 273 |
+
in_channels=features[0],
|
| 274 |
+
out_channels=features[0],
|
| 275 |
+
kernel_size=4,
|
| 276 |
+
stride=4,
|
| 277 |
+
padding=0,
|
| 278 |
+
bias=True,
|
| 279 |
+
dilation=1,
|
| 280 |
+
groups=1,
|
| 281 |
+
),
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
pretrained.layer2 = nn.Sequential(
|
| 285 |
+
readout_oper[1],
|
| 286 |
+
Transpose(1, 2),
|
| 287 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 288 |
+
nn.Conv2d(
|
| 289 |
+
in_channels=vit_features,
|
| 290 |
+
out_channels=features[1],
|
| 291 |
+
kernel_size=1,
|
| 292 |
+
stride=1,
|
| 293 |
+
padding=0,
|
| 294 |
+
),
|
| 295 |
+
nn.ConvTranspose2d(
|
| 296 |
+
in_channels=features[1],
|
| 297 |
+
out_channels=features[1],
|
| 298 |
+
kernel_size=2,
|
| 299 |
+
stride=2,
|
| 300 |
+
padding=0,
|
| 301 |
+
bias=True,
|
| 302 |
+
dilation=1,
|
| 303 |
+
groups=1,
|
| 304 |
+
),
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
pretrained.layer3 = nn.Sequential(
|
| 308 |
+
readout_oper[2],
|
| 309 |
+
Transpose(1, 2),
|
| 310 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 311 |
+
nn.Conv2d(
|
| 312 |
+
in_channels=vit_features,
|
| 313 |
+
out_channels=features[2],
|
| 314 |
+
kernel_size=1,
|
| 315 |
+
stride=1,
|
| 316 |
+
padding=0,
|
| 317 |
+
),
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
pretrained.layer4 = nn.Sequential(
|
| 321 |
+
readout_oper[3],
|
| 322 |
+
Transpose(1, 2),
|
| 323 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 324 |
+
nn.Conv2d(
|
| 325 |
+
in_channels=vit_features,
|
| 326 |
+
out_channels=features[3],
|
| 327 |
+
kernel_size=1,
|
| 328 |
+
stride=1,
|
| 329 |
+
padding=0,
|
| 330 |
+
),
|
| 331 |
+
nn.Conv2d(
|
| 332 |
+
in_channels=features[3],
|
| 333 |
+
out_channels=features[3],
|
| 334 |
+
kernel_size=3,
|
| 335 |
+
stride=2,
|
| 336 |
+
padding=1,
|
| 337 |
+
),
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
pretrained.model.start_index = start_index
|
| 341 |
+
pretrained.model.patch_size = [16, 16]
|
| 342 |
+
|
| 343 |
+
# We inject this function into the VisionTransformer instances so that
|
| 344 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
| 345 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
| 346 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
| 347 |
+
_resize_pos_embed, pretrained.model
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
return pretrained
|
| 351 |
+
|
| 352 |
+
def _make_swin_b16_backbone(
|
| 353 |
+
model,
|
| 354 |
+
features=[96, 192, 384, 768],
|
| 355 |
+
size=[384, 384],
|
| 356 |
+
hooks=[2, 5, 8, 11],
|
| 357 |
+
vit_features=768,
|
| 358 |
+
use_readout="ignore",
|
| 359 |
+
start_index=1,
|
| 360 |
+
):
|
| 361 |
+
pretrained = nn.Module()
|
| 362 |
+
|
| 363 |
+
pretrained.model = model
|
| 364 |
+
pretrained.model.blocks[hooks[0]].blocks[-1].register_forward_hook(get_activation("1"))
|
| 365 |
+
pretrained.model.blocks[hooks[1]].blocks[-1].register_forward_hook(get_activation("2"))
|
| 366 |
+
pretrained.model.blocks[hooks[2]].blocks[-1].register_forward_hook(get_activation("3"))
|
| 367 |
+
pretrained.model.blocks[hooks[3]].blocks[-1].register_forward_hook(get_activation("4"))
|
| 368 |
+
|
| 369 |
+
pretrained.activations = activations
|
| 370 |
+
|
| 371 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
| 372 |
+
|
| 373 |
+
# 32, 48, 136, 384
|
| 374 |
+
pretrained.layer1 = nn.Sequential(
|
| 375 |
+
readout_oper[0],
|
| 376 |
+
Transpose(1, 2),
|
| 377 |
+
nn.Unflatten(2, torch.Size([size[0] // 4, size[1] // 4])),
|
| 378 |
+
nn.Conv2d(
|
| 379 |
+
in_channels=vit_features,
|
| 380 |
+
out_channels=features[0],
|
| 381 |
+
kernel_size=1,
|
| 382 |
+
stride=1,
|
| 383 |
+
padding=0,
|
| 384 |
+
),
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
pretrained.layer2 = nn.Sequential(
|
| 388 |
+
readout_oper[1],
|
| 389 |
+
Transpose(1, 2),
|
| 390 |
+
nn.Unflatten(2, torch.Size([size[0] // 8, size[1] // 8])),
|
| 391 |
+
nn.Conv2d(
|
| 392 |
+
in_channels=vit_features*2,
|
| 393 |
+
out_channels=features[1],
|
| 394 |
+
kernel_size=1,
|
| 395 |
+
stride=1,
|
| 396 |
+
padding=0,
|
| 397 |
+
),
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
pretrained.layer3 = nn.Sequential(
|
| 401 |
+
readout_oper[2],
|
| 402 |
+
Transpose(1, 2),
|
| 403 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 404 |
+
nn.Conv2d(
|
| 405 |
+
in_channels=vit_features*4,
|
| 406 |
+
out_channels=features[2],
|
| 407 |
+
kernel_size=1,
|
| 408 |
+
stride=1,
|
| 409 |
+
padding=0,
|
| 410 |
+
),
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
pretrained.layer4 = nn.Sequential(
|
| 414 |
+
readout_oper[3],
|
| 415 |
+
Transpose(1, 2),
|
| 416 |
+
nn.Unflatten(2, torch.Size([size[0] // 32, size[1] // 32])),
|
| 417 |
+
nn.Conv2d(
|
| 418 |
+
in_channels=vit_features*8,
|
| 419 |
+
out_channels=features[3],
|
| 420 |
+
kernel_size=1,
|
| 421 |
+
stride=1,
|
| 422 |
+
padding=0,
|
| 423 |
+
),
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
pretrained.model.start_index = start_index
|
| 427 |
+
pretrained.model.patch_size = [16, 16]
|
| 428 |
+
|
| 429 |
+
# We inject this function into the VisionTransformer instances so that
|
| 430 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
| 431 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex_swin, pretrained.model)
|
| 432 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
| 433 |
+
_resize_pos_embed, pretrained.model
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
return pretrained
|
legacy.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
"""Converting legacy network pickle into the new format."""
|
| 10 |
+
|
| 11 |
+
import click
|
| 12 |
+
import pickle
|
| 13 |
+
import re
|
| 14 |
+
import copy
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import io
|
| 18 |
+
import dnnlib
|
| 19 |
+
import misc
|
| 20 |
+
|
| 21 |
+
#----------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
def load_network_pkl(f, force_fp16=False):
|
| 24 |
+
data = _LegacyUnpickler(f).load()
|
| 25 |
+
|
| 26 |
+
# Legacy TensorFlow pickle => convert.
|
| 27 |
+
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
|
| 28 |
+
tf_G, tf_D, tf_Gs = data
|
| 29 |
+
G = convert_tf_generator(tf_G)
|
| 30 |
+
D = convert_tf_discriminator(tf_D)
|
| 31 |
+
G_ema = convert_tf_generator(tf_Gs)
|
| 32 |
+
data = dict(G=G, D=D, G_ema=G_ema)
|
| 33 |
+
|
| 34 |
+
# Add missing fields.
|
| 35 |
+
if 'training_set_kwargs' not in data:
|
| 36 |
+
data['training_set_kwargs'] = None
|
| 37 |
+
if 'augment_pipe' not in data:
|
| 38 |
+
data['augment_pipe'] = None
|
| 39 |
+
|
| 40 |
+
# Validate contents.
|
| 41 |
+
assert isinstance(data['G'], torch.nn.Module)
|
| 42 |
+
assert isinstance(data['D'], torch.nn.Module)
|
| 43 |
+
assert isinstance(data['G_ema'], torch.nn.Module)
|
| 44 |
+
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
|
| 45 |
+
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
|
| 46 |
+
|
| 47 |
+
# Force FP16.
|
| 48 |
+
if force_fp16:
|
| 49 |
+
for key in ['G', 'D', 'G_ema']:
|
| 50 |
+
old = data[key]
|
| 51 |
+
kwargs = copy.deepcopy(old.init_kwargs)
|
| 52 |
+
fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
|
| 53 |
+
fp16_kwargs.num_fp16_res = 4
|
| 54 |
+
fp16_kwargs.conv_clamp = 256
|
| 55 |
+
if kwargs != old.init_kwargs:
|
| 56 |
+
new = type(old)(**kwargs).eval().requires_grad_(False)
|
| 57 |
+
misc.copy_params_and_buffers(old, new, require_all=True)
|
| 58 |
+
data[key] = new
|
| 59 |
+
return data
|
| 60 |
+
|
| 61 |
+
#----------------------------------------------------------------------------
|
| 62 |
+
|
| 63 |
+
class _TFNetworkStub(dnnlib.EasyDict):
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
class _LegacyUnpickler(pickle.Unpickler):
|
| 67 |
+
def find_class(self, module, name):
|
| 68 |
+
# print(module,name)
|
| 69 |
+
if module == '__builtin__':
|
| 70 |
+
return
|
| 71 |
+
if module == 'dnnlib.tflib.network' and name == 'Network':
|
| 72 |
+
return _TFNetworkStub
|
| 73 |
+
if module == 'torch.storage' and name == '_load_from_bytes':
|
| 74 |
+
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
|
| 75 |
+
return super().find_class(module, name)
|
| 76 |
+
|
| 77 |
+
#----------------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
def _collect_tf_params(tf_net):
|
| 80 |
+
# pylint: disable=protected-access
|
| 81 |
+
tf_params = dict()
|
| 82 |
+
def recurse(prefix, tf_net):
|
| 83 |
+
for name, value in tf_net.variables:
|
| 84 |
+
tf_params[prefix + name] = value
|
| 85 |
+
for name, comp in tf_net.components.items():
|
| 86 |
+
recurse(prefix + name + '/', comp)
|
| 87 |
+
recurse('', tf_net)
|
| 88 |
+
return tf_params
|
| 89 |
+
|
| 90 |
+
#----------------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
def _populate_module_params(module, *patterns):
|
| 93 |
+
for name, tensor in misc.named_params_and_buffers(module):
|
| 94 |
+
found = False
|
| 95 |
+
value = None
|
| 96 |
+
for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
|
| 97 |
+
match = re.fullmatch(pattern, name)
|
| 98 |
+
if match:
|
| 99 |
+
found = True
|
| 100 |
+
if value_fn is not None:
|
| 101 |
+
value = value_fn(*match.groups())
|
| 102 |
+
break
|
| 103 |
+
try:
|
| 104 |
+
assert found
|
| 105 |
+
if value is not None:
|
| 106 |
+
tensor.copy_(torch.from_numpy(np.array(value)))
|
| 107 |
+
except:
|
| 108 |
+
print(name, list(tensor.shape))
|
| 109 |
+
raise
|
| 110 |
+
|
| 111 |
+
#----------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
def convert_tf_generator(tf_G):
|
| 114 |
+
if tf_G.version < 4:
|
| 115 |
+
raise ValueError('TensorFlow pickle version too low')
|
| 116 |
+
|
| 117 |
+
# Collect kwargs.
|
| 118 |
+
tf_kwargs = tf_G.static_kwargs
|
| 119 |
+
known_kwargs = set()
|
| 120 |
+
def kwarg(tf_name, default=None, none=None):
|
| 121 |
+
known_kwargs.add(tf_name)
|
| 122 |
+
val = tf_kwargs.get(tf_name, default)
|
| 123 |
+
return val if val is not None else none
|
| 124 |
+
|
| 125 |
+
# Convert kwargs.
|
| 126 |
+
from pg_modules import networks_stylegan2
|
| 127 |
+
network_class = networks_stylegan2.Generator
|
| 128 |
+
kwargs = dnnlib.EasyDict(
|
| 129 |
+
z_dim = kwarg('latent_size', 512),
|
| 130 |
+
c_dim = kwarg('label_size', 0),
|
| 131 |
+
w_dim = kwarg('dlatent_size', 512),
|
| 132 |
+
img_resolution = kwarg('resolution', 1024),
|
| 133 |
+
img_channels = kwarg('num_channels', 3),
|
| 134 |
+
channel_base = kwarg('fmap_base', 16384) * 2,
|
| 135 |
+
channel_max = kwarg('fmap_max', 512),
|
| 136 |
+
num_fp16_res = kwarg('num_fp16_res', 0),
|
| 137 |
+
conv_clamp = kwarg('conv_clamp', None),
|
| 138 |
+
architecture = kwarg('architecture', 'skip'),
|
| 139 |
+
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
| 140 |
+
use_noise = kwarg('use_noise', True),
|
| 141 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
| 142 |
+
mapping_kwargs = dnnlib.EasyDict(
|
| 143 |
+
num_layers = kwarg('mapping_layers', 8),
|
| 144 |
+
embed_features = kwarg('label_fmaps', None),
|
| 145 |
+
layer_features = kwarg('mapping_fmaps', None),
|
| 146 |
+
activation = kwarg('mapping_nonlinearity', 'lrelu'),
|
| 147 |
+
lr_multiplier = kwarg('mapping_lrmul', 0.01),
|
| 148 |
+
w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
|
| 149 |
+
),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Check for unknown kwargs.
|
| 153 |
+
kwarg('truncation_psi')
|
| 154 |
+
kwarg('truncation_cutoff')
|
| 155 |
+
kwarg('style_mixing_prob')
|
| 156 |
+
kwarg('structure')
|
| 157 |
+
kwarg('conditioning')
|
| 158 |
+
kwarg('fused_modconv')
|
| 159 |
+
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
| 160 |
+
if len(unknown_kwargs) > 0:
|
| 161 |
+
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
| 162 |
+
|
| 163 |
+
# Collect params.
|
| 164 |
+
tf_params = _collect_tf_params(tf_G)
|
| 165 |
+
for name, value in list(tf_params.items()):
|
| 166 |
+
match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
|
| 167 |
+
if match:
|
| 168 |
+
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
| 169 |
+
tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
|
| 170 |
+
kwargs.synthesis.kwargs.architecture = 'orig'
|
| 171 |
+
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
| 172 |
+
|
| 173 |
+
# Convert params.
|
| 174 |
+
G = network_class(**kwargs).eval().requires_grad_(False)
|
| 175 |
+
# pylint: disable=unnecessary-lambda
|
| 176 |
+
# pylint: disable=f-string-without-interpolation
|
| 177 |
+
_populate_module_params(G,
|
| 178 |
+
r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
|
| 179 |
+
r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
|
| 180 |
+
r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
|
| 181 |
+
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
|
| 182 |
+
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
|
| 183 |
+
r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
|
| 184 |
+
r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
| 185 |
+
r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
|
| 186 |
+
r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
|
| 187 |
+
r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
|
| 188 |
+
r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
|
| 189 |
+
r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
|
| 190 |
+
r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
| 191 |
+
r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
|
| 192 |
+
r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
|
| 193 |
+
r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
|
| 194 |
+
r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
|
| 195 |
+
r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
|
| 196 |
+
r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
|
| 197 |
+
r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
|
| 198 |
+
r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
|
| 199 |
+
r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
|
| 200 |
+
r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
|
| 201 |
+
r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
|
| 202 |
+
r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
|
| 203 |
+
r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
|
| 204 |
+
r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
|
| 205 |
+
r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
|
| 206 |
+
r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
| 207 |
+
r'.*\.resample_filter', None,
|
| 208 |
+
r'.*\.act_filter', None,
|
| 209 |
+
)
|
| 210 |
+
return G
|
| 211 |
+
|
| 212 |
+
#----------------------------------------------------------------------------
|
| 213 |
+
|
| 214 |
+
def convert_tf_discriminator(tf_D):
|
| 215 |
+
if tf_D.version < 4:
|
| 216 |
+
raise ValueError('TensorFlow pickle version too low')
|
| 217 |
+
|
| 218 |
+
# Collect kwargs.
|
| 219 |
+
tf_kwargs = tf_D.static_kwargs
|
| 220 |
+
known_kwargs = set()
|
| 221 |
+
def kwarg(tf_name, default=None):
|
| 222 |
+
known_kwargs.add(tf_name)
|
| 223 |
+
return tf_kwargs.get(tf_name, default)
|
| 224 |
+
|
| 225 |
+
# Convert kwargs.
|
| 226 |
+
kwargs = dnnlib.EasyDict(
|
| 227 |
+
c_dim = kwarg('label_size', 0),
|
| 228 |
+
img_resolution = kwarg('resolution', 1024),
|
| 229 |
+
img_channels = kwarg('num_channels', 3),
|
| 230 |
+
architecture = kwarg('architecture', 'resnet'),
|
| 231 |
+
channel_base = kwarg('fmap_base', 16384) * 2,
|
| 232 |
+
channel_max = kwarg('fmap_max', 512),
|
| 233 |
+
num_fp16_res = kwarg('num_fp16_res', 0),
|
| 234 |
+
conv_clamp = kwarg('conv_clamp', None),
|
| 235 |
+
cmap_dim = kwarg('mapping_fmaps', None),
|
| 236 |
+
block_kwargs = dnnlib.EasyDict(
|
| 237 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
| 238 |
+
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
| 239 |
+
freeze_layers = kwarg('freeze_layers', 0),
|
| 240 |
+
),
|
| 241 |
+
mapping_kwargs = dnnlib.EasyDict(
|
| 242 |
+
num_layers = kwarg('mapping_layers', 0),
|
| 243 |
+
embed_features = kwarg('mapping_fmaps', None),
|
| 244 |
+
layer_features = kwarg('mapping_fmaps', None),
|
| 245 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
| 246 |
+
lr_multiplier = kwarg('mapping_lrmul', 0.1),
|
| 247 |
+
),
|
| 248 |
+
epilogue_kwargs = dnnlib.EasyDict(
|
| 249 |
+
mbstd_group_size = kwarg('mbstd_group_size', None),
|
| 250 |
+
mbstd_num_channels = kwarg('mbstd_num_features', 1),
|
| 251 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
| 252 |
+
),
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Check for unknown kwargs.
|
| 256 |
+
kwarg('structure')
|
| 257 |
+
kwarg('conditioning')
|
| 258 |
+
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
| 259 |
+
if len(unknown_kwargs) > 0:
|
| 260 |
+
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
| 261 |
+
|
| 262 |
+
# Collect params.
|
| 263 |
+
tf_params = _collect_tf_params(tf_D)
|
| 264 |
+
for name, value in list(tf_params.items()):
|
| 265 |
+
match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
|
| 266 |
+
if match:
|
| 267 |
+
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
| 268 |
+
tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
|
| 269 |
+
kwargs.architecture = 'orig'
|
| 270 |
+
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
| 271 |
+
|
| 272 |
+
# Convert params.
|
| 273 |
+
#from pg_modules import networks_stylegan2
|
| 274 |
+
from pg_modules.discriminator import ProjectedDiscriminator
|
| 275 |
+
|
| 276 |
+
D = ProjectedDiscriminator(**kwargs).eval().requires_grad_(False)
|
| 277 |
+
# pylint: disable=unnecessary-lambda
|
| 278 |
+
# pylint: disable=f-string-without-interpolation
|
| 279 |
+
_populate_module_params(D,
|
| 280 |
+
r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
|
| 281 |
+
r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
|
| 282 |
+
r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
|
| 283 |
+
r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
|
| 284 |
+
r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
|
| 285 |
+
r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
|
| 286 |
+
r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
|
| 287 |
+
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
|
| 288 |
+
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
|
| 289 |
+
r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
| 290 |
+
r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
|
| 291 |
+
r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
|
| 292 |
+
r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
|
| 293 |
+
r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
|
| 294 |
+
r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
|
| 295 |
+
r'.*\.resample_filter', None,
|
| 296 |
+
)
|
| 297 |
+
return D
|
| 298 |
+
|
| 299 |
+
#----------------------------------------------------------------------------
|
| 300 |
+
|
| 301 |
+
@click.command()
|
| 302 |
+
@click.option('--source', help='Input pickle', required=True, metavar='PATH')
|
| 303 |
+
@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
|
| 304 |
+
@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
|
| 305 |
+
def convert_network_pickle(source, dest, force_fp16):
|
| 306 |
+
"""Convert legacy network pickle into the native PyTorch format.
|
| 307 |
+
|
| 308 |
+
The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
|
| 309 |
+
It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
|
| 310 |
+
|
| 311 |
+
Example:
|
| 312 |
+
|
| 313 |
+
\b
|
| 314 |
+
python legacy.py \\
|
| 315 |
+
--source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
|
| 316 |
+
--dest=stylegan2-cat-config-f.pkl
|
| 317 |
+
"""
|
| 318 |
+
print(f'Loading "{source}"...')
|
| 319 |
+
with dnnlib.util.open_url(source) as f:
|
| 320 |
+
data = load_network_pkl(f, force_fp16=force_fp16)
|
| 321 |
+
print(f'Saving "{dest}"...')
|
| 322 |
+
with open(dest, 'wb') as f:
|
| 323 |
+
pickle.dump(data, f)
|
| 324 |
+
print('Done.')
|
| 325 |
+
|
| 326 |
+
#----------------------------------------------------------------------------
|
| 327 |
+
|
| 328 |
+
if __name__ == "__main__":
|
| 329 |
+
convert_network_pickle() # pylint: disable=no-value-for-parameter
|
| 330 |
+
|
| 331 |
+
#----------------------------------------------------------------------------
|
misc.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import re
|
| 10 |
+
import contextlib
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import warnings
|
| 14 |
+
import dnnlib
|
| 15 |
+
|
| 16 |
+
#----------------------------------------------------------------------------
|
| 17 |
+
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
| 18 |
+
# same constant is used multiple times.
|
| 19 |
+
|
| 20 |
+
_constant_cache = dict()
|
| 21 |
+
|
| 22 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
| 23 |
+
value = np.asarray(value)
|
| 24 |
+
if shape is not None:
|
| 25 |
+
shape = tuple(shape)
|
| 26 |
+
if dtype is None:
|
| 27 |
+
dtype = torch.get_default_dtype()
|
| 28 |
+
if device is None:
|
| 29 |
+
device = torch.device('cpu')
|
| 30 |
+
if memory_format is None:
|
| 31 |
+
memory_format = torch.contiguous_format
|
| 32 |
+
|
| 33 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
| 34 |
+
tensor = _constant_cache.get(key, None)
|
| 35 |
+
if tensor is None:
|
| 36 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
| 37 |
+
if shape is not None:
|
| 38 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
| 39 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
| 40 |
+
_constant_cache[key] = tensor
|
| 41 |
+
return tensor
|
| 42 |
+
|
| 43 |
+
#----------------------------------------------------------------------------
|
| 44 |
+
# Replace NaN/Inf with specified numerical values.
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
nan_to_num = torch.nan_to_num # 1.8.0a0
|
| 48 |
+
except AttributeError:
|
| 49 |
+
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
| 50 |
+
assert isinstance(input, torch.Tensor)
|
| 51 |
+
if posinf is None:
|
| 52 |
+
posinf = torch.finfo(input.dtype).max
|
| 53 |
+
if neginf is None:
|
| 54 |
+
neginf = torch.finfo(input.dtype).min
|
| 55 |
+
assert nan == 0
|
| 56 |
+
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
| 57 |
+
|
| 58 |
+
#----------------------------------------------------------------------------
|
| 59 |
+
# Symbolic assert.
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
| 63 |
+
except AttributeError:
|
| 64 |
+
symbolic_assert = torch.Assert # 1.7.0
|
| 65 |
+
|
| 66 |
+
#----------------------------------------------------------------------------
|
| 67 |
+
# Context manager to temporarily suppress known warnings in torch.jit.trace().
|
| 68 |
+
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
|
| 69 |
+
|
| 70 |
+
@contextlib.contextmanager
|
| 71 |
+
def suppress_tracer_warnings():
|
| 72 |
+
flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
|
| 73 |
+
warnings.filters.insert(0, flt)
|
| 74 |
+
yield
|
| 75 |
+
warnings.filters.remove(flt)
|
| 76 |
+
|
| 77 |
+
#----------------------------------------------------------------------------
|
| 78 |
+
# Assert that the shape of a tensor matches the given list of integers.
|
| 79 |
+
# None indicates that the size of a dimension is allowed to vary.
|
| 80 |
+
# Performs symbolic assertion when used in torch.jit.trace().
|
| 81 |
+
|
| 82 |
+
def assert_shape(tensor, ref_shape):
|
| 83 |
+
if tensor.ndim != len(ref_shape):
|
| 84 |
+
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
| 85 |
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
| 86 |
+
if ref_size is None:
|
| 87 |
+
pass
|
| 88 |
+
elif isinstance(ref_size, torch.Tensor):
|
| 89 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
| 90 |
+
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
| 91 |
+
elif isinstance(size, torch.Tensor):
|
| 92 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
| 93 |
+
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
| 94 |
+
elif size != ref_size:
|
| 95 |
+
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
| 96 |
+
|
| 97 |
+
#----------------------------------------------------------------------------
|
| 98 |
+
# Function decorator that calls torch.autograd.profiler.record_function().
|
| 99 |
+
|
| 100 |
+
def profiled_function(fn):
|
| 101 |
+
def decorator(*args, **kwargs):
|
| 102 |
+
with torch.autograd.profiler.record_function(fn.__name__):
|
| 103 |
+
return fn(*args, **kwargs)
|
| 104 |
+
decorator.__name__ = fn.__name__
|
| 105 |
+
return decorator
|
| 106 |
+
|
| 107 |
+
#----------------------------------------------------------------------------
|
| 108 |
+
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
| 109 |
+
# indefinitely, shuffling items as it goes.
|
| 110 |
+
|
| 111 |
+
class InfiniteSampler(torch.utils.data.Sampler):
|
| 112 |
+
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
| 113 |
+
assert len(dataset) > 0
|
| 114 |
+
assert num_replicas > 0
|
| 115 |
+
assert 0 <= rank < num_replicas
|
| 116 |
+
assert 0 <= window_size <= 1
|
| 117 |
+
super().__init__(dataset)
|
| 118 |
+
self.dataset = dataset
|
| 119 |
+
self.rank = rank
|
| 120 |
+
self.num_replicas = num_replicas
|
| 121 |
+
self.shuffle = shuffle
|
| 122 |
+
self.seed = seed
|
| 123 |
+
self.window_size = window_size
|
| 124 |
+
|
| 125 |
+
def __iter__(self):
|
| 126 |
+
order = np.arange(len(self.dataset))
|
| 127 |
+
rnd = None
|
| 128 |
+
window = 0
|
| 129 |
+
if self.shuffle:
|
| 130 |
+
rnd = np.random.RandomState(self.seed)
|
| 131 |
+
rnd.shuffle(order)
|
| 132 |
+
window = int(np.rint(order.size * self.window_size))
|
| 133 |
+
|
| 134 |
+
idx = 0
|
| 135 |
+
while True:
|
| 136 |
+
i = idx % order.size
|
| 137 |
+
if idx % self.num_replicas == self.rank:
|
| 138 |
+
yield order[i]
|
| 139 |
+
if window >= 2:
|
| 140 |
+
j = (i - rnd.randint(window)) % order.size
|
| 141 |
+
order[i], order[j] = order[j], order[i]
|
| 142 |
+
idx += 1
|
| 143 |
+
|
| 144 |
+
#----------------------------------------------------------------------------
|
| 145 |
+
# Utilities for operating with torch.nn.Module parameters and buffers.
|
| 146 |
+
|
| 147 |
+
def params_and_buffers(module):
|
| 148 |
+
assert isinstance(module, torch.nn.Module)
|
| 149 |
+
return list(module.parameters()) + list(module.buffers())
|
| 150 |
+
|
| 151 |
+
def named_params_and_buffers(module):
|
| 152 |
+
assert isinstance(module, torch.nn.Module)
|
| 153 |
+
return list(module.named_parameters()) + list(module.named_buffers())
|
| 154 |
+
|
| 155 |
+
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
| 156 |
+
assert isinstance(src_module, torch.nn.Module)
|
| 157 |
+
assert isinstance(dst_module, torch.nn.Module)
|
| 158 |
+
src_tensors = dict(named_params_and_buffers(src_module))
|
| 159 |
+
for name, tensor in named_params_and_buffers(dst_module):
|
| 160 |
+
assert (name in src_tensors) or (not require_all)
|
| 161 |
+
if name in src_tensors:
|
| 162 |
+
try:
|
| 163 |
+
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
| 164 |
+
except:
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
#----------------------------------------------------------------------------
|
| 168 |
+
# Context manager for easily enabling/disabling DistributedDataParallel
|
| 169 |
+
# synchronization.
|
| 170 |
+
|
| 171 |
+
@contextlib.contextmanager
|
| 172 |
+
def ddp_sync(module, sync):
|
| 173 |
+
assert isinstance(module, torch.nn.Module)
|
| 174 |
+
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
| 175 |
+
yield
|
| 176 |
+
else:
|
| 177 |
+
with module.no_sync():
|
| 178 |
+
yield
|
| 179 |
+
|
| 180 |
+
#----------------------------------------------------------------------------
|
| 181 |
+
# Check DistributedDataParallel consistency across processes.
|
| 182 |
+
|
| 183 |
+
def check_ddp_consistency(module, ignore_regex=None):
|
| 184 |
+
assert isinstance(module, torch.nn.Module)
|
| 185 |
+
for name, tensor in named_params_and_buffers(module):
|
| 186 |
+
fullname = type(module).__name__ + '.' + name
|
| 187 |
+
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
| 188 |
+
continue
|
| 189 |
+
tensor = tensor.detach()
|
| 190 |
+
if tensor.is_floating_point():
|
| 191 |
+
tensor = nan_to_num(tensor)
|
| 192 |
+
other = tensor.clone()
|
| 193 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
| 194 |
+
assert (tensor == other).all(), fullname
|
| 195 |
+
|
| 196 |
+
#----------------------------------------------------------------------------
|
| 197 |
+
# Print summary table of module hierarchy.
|
| 198 |
+
|
| 199 |
+
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
| 200 |
+
assert isinstance(module, torch.nn.Module)
|
| 201 |
+
assert not isinstance(module, torch.jit.ScriptModule)
|
| 202 |
+
assert isinstance(inputs, (tuple, list))
|
| 203 |
+
|
| 204 |
+
# Register hooks.
|
| 205 |
+
entries = []
|
| 206 |
+
nesting = [0]
|
| 207 |
+
def pre_hook(_mod, _inputs):
|
| 208 |
+
nesting[0] += 1
|
| 209 |
+
def post_hook(mod, _inputs, outputs):
|
| 210 |
+
nesting[0] -= 1
|
| 211 |
+
if nesting[0] <= max_nesting:
|
| 212 |
+
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
| 213 |
+
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
| 214 |
+
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
| 215 |
+
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
| 216 |
+
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
| 217 |
+
|
| 218 |
+
# Run module.
|
| 219 |
+
outputs = module(*inputs)
|
| 220 |
+
for hook in hooks:
|
| 221 |
+
hook.remove()
|
| 222 |
+
|
| 223 |
+
# Identify unique outputs, parameters, and buffers.
|
| 224 |
+
tensors_seen = set()
|
| 225 |
+
for e in entries:
|
| 226 |
+
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
| 227 |
+
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
| 228 |
+
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
| 229 |
+
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
| 230 |
+
|
| 231 |
+
# Filter out redundant entries.
|
| 232 |
+
if skip_redundant:
|
| 233 |
+
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
| 234 |
+
|
| 235 |
+
# Construct table.
|
| 236 |
+
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
| 237 |
+
rows += [['---'] * len(rows[0])]
|
| 238 |
+
param_total = 0
|
| 239 |
+
buffer_total = 0
|
| 240 |
+
submodule_names = {mod: name for name, mod in module.named_modules()}
|
| 241 |
+
for e in entries:
|
| 242 |
+
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
| 243 |
+
param_size = sum(t.numel() for t in e.unique_params)
|
| 244 |
+
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
| 245 |
+
output_shapes = [str(list(t.shape)) for t in e.outputs]
|
| 246 |
+
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
| 247 |
+
rows += [[
|
| 248 |
+
name + (':0' if len(e.outputs) >= 2 else ''),
|
| 249 |
+
str(param_size) if param_size else '-',
|
| 250 |
+
str(buffer_size) if buffer_size else '-',
|
| 251 |
+
(output_shapes + ['-'])[0],
|
| 252 |
+
(output_dtypes + ['-'])[0],
|
| 253 |
+
]]
|
| 254 |
+
for idx in range(1, len(e.outputs)):
|
| 255 |
+
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
| 256 |
+
param_total += param_size
|
| 257 |
+
buffer_total += buffer_size
|
| 258 |
+
rows += [['---'] * len(rows[0])]
|
| 259 |
+
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
| 260 |
+
|
| 261 |
+
# Print table.
|
| 262 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
| 263 |
+
print()
|
| 264 |
+
for row in rows:
|
| 265 |
+
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
| 266 |
+
print()
|
| 267 |
+
return outputs
|
| 268 |
+
|
| 269 |
+
#----------------------------------------------------------------------------
|
| 270 |
+
|
| 271 |
+
# Added by Katja
|
| 272 |
+
import os
|
| 273 |
+
|
| 274 |
+
def get_ckpt_path(run_dir):
|
| 275 |
+
return os.path.join(run_dir, f'network-snapshot.pkl')
|
pg_modules/__init__.py
ADDED
|
File without changes
|
pg_modules/__pycache__/MViT.cpython-39.pyc
ADDED
|
Binary file (17.9 kB). View file
|
|
|
pg_modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (222 Bytes). View file
|
|
|
pg_modules/__pycache__/blocks.cpython-38.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
pg_modules/__pycache__/blocks.cpython-39.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
pg_modules/__pycache__/diffaug.cpython-38.pyc
ADDED
|
Binary file (2.69 kB). View file
|
|
|
pg_modules/__pycache__/diffaug.cpython-39.pyc
ADDED
|
Binary file (2.81 kB). View file
|
|
|
pg_modules/__pycache__/discriminator.cpython-38.pyc
ADDED
|
Binary file (5.65 kB). View file
|
|
|
pg_modules/__pycache__/discriminator.cpython-39.pyc
ADDED
|
Binary file (4.51 kB). View file
|
|
|
pg_modules/__pycache__/mae.cpython-39.pyc
ADDED
|
Binary file (8.85 kB). View file
|
|
|
pg_modules/__pycache__/models_tnt.cpython-39.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
pg_modules/__pycache__/networks_fastgan.cpython-38.pyc
ADDED
|
Binary file (5.2 kB). View file
|
|
|
pg_modules/__pycache__/networks_fastgan.cpython-39.pyc
ADDED
|
Binary file (5.34 kB). View file
|
|
|
pg_modules/__pycache__/networks_stylegan2.cpython-39.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
pg_modules/__pycache__/projector.cpython-38.pyc
ADDED
|
Binary file (3.85 kB). View file
|
|
|
pg_modules/__pycache__/projector.cpython-39.pyc
ADDED
|
Binary file (4.21 kB). View file
|
|
|
pg_modules/__pycache__/simmim.cpython-39.pyc
ADDED
|
Binary file (4.22 kB). View file
|
|
|
pg_modules/__pycache__/vision_transformer.cpython-39.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
pg_modules/blocks.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.nn.utils import spectral_norm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
### single layers
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def conv2d(*args, **kwargs):
|
| 12 |
+
return spectral_norm(nn.Conv2d(*args, **kwargs))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def convTranspose2d(*args, **kwargs):
|
| 16 |
+
return spectral_norm(nn.ConvTranspose2d(*args, **kwargs))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def embedding(*args, **kwargs):
|
| 20 |
+
return spectral_norm(nn.Embedding(*args, **kwargs))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def linear(*args, **kwargs):
|
| 24 |
+
return spectral_norm(nn.Linear(*args, **kwargs))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def NormLayer(c, mode='batch'):
|
| 28 |
+
if mode == 'group':
|
| 29 |
+
return nn.GroupNorm(c//2, c)
|
| 30 |
+
elif mode == 'batch':
|
| 31 |
+
return nn.BatchNorm2d(c)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
### Activations
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class GLU(nn.Module):
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
nc = x.size(1)
|
| 40 |
+
assert nc % 2 == 0, 'channels dont divide 2!'
|
| 41 |
+
nc = int(nc/2)
|
| 42 |
+
return x[:, :nc] * torch.sigmoid(x[:, nc:])
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Swish(nn.Module):
|
| 46 |
+
def forward(self, feat):
|
| 47 |
+
return feat * torch.sigmoid(feat)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
### Upblocks
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class InitLayer(nn.Module):
|
| 54 |
+
def __init__(self, nz, channel, sz=4):
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
self.init = nn.Sequential(
|
| 58 |
+
convTranspose2d(nz, channel*2, sz, 1, 0, bias=False),
|
| 59 |
+
NormLayer(channel*2),
|
| 60 |
+
GLU(),
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def forward(self, noise):
|
| 64 |
+
noise = noise.view(noise.shape[0], -1, 1, 1)
|
| 65 |
+
return self.init(noise)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def UpBlockSmall(in_planes, out_planes):
|
| 69 |
+
block = nn.Sequential(
|
| 70 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
| 71 |
+
conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
|
| 72 |
+
NormLayer(out_planes*2), GLU())
|
| 73 |
+
return block
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class UpBlockSmallCond(nn.Module):
|
| 77 |
+
def __init__(self, in_planes, out_planes, z_dim):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.in_planes = in_planes
|
| 80 |
+
self.out_planes = out_planes
|
| 81 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
| 82 |
+
self.conv = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
|
| 83 |
+
|
| 84 |
+
which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
|
| 85 |
+
self.bn = which_bn(2*out_planes)
|
| 86 |
+
self.act = GLU()
|
| 87 |
+
|
| 88 |
+
def forward(self, x, c):
|
| 89 |
+
x = self.up(x)
|
| 90 |
+
x = self.conv(x)
|
| 91 |
+
x = self.bn(x, c)
|
| 92 |
+
x = self.act(x)
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def UpBlockBig(in_planes, out_planes):
|
| 97 |
+
block = nn.Sequential(
|
| 98 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
| 99 |
+
conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
|
| 100 |
+
NoiseInjection(),
|
| 101 |
+
NormLayer(out_planes*2), GLU(),
|
| 102 |
+
conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False),
|
| 103 |
+
NoiseInjection(),
|
| 104 |
+
NormLayer(out_planes*2), GLU()
|
| 105 |
+
)
|
| 106 |
+
return block
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class UpBlockBigCond(nn.Module):
|
| 110 |
+
def __init__(self, in_planes, out_planes, z_dim):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.in_planes = in_planes
|
| 113 |
+
self.out_planes = out_planes
|
| 114 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
| 115 |
+
self.conv1 = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
|
| 116 |
+
self.conv2 = conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False)
|
| 117 |
+
|
| 118 |
+
which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
|
| 119 |
+
self.bn1 = which_bn(2*out_planes)
|
| 120 |
+
self.bn2 = which_bn(2*out_planes)
|
| 121 |
+
self.act = GLU()
|
| 122 |
+
self.noise = NoiseInjection()
|
| 123 |
+
|
| 124 |
+
def forward(self, x, c):
|
| 125 |
+
# block 1
|
| 126 |
+
x = self.up(x)
|
| 127 |
+
x = self.conv1(x)
|
| 128 |
+
x = self.noise(x)
|
| 129 |
+
x = self.bn1(x, c)
|
| 130 |
+
x = self.act(x)
|
| 131 |
+
|
| 132 |
+
# block 2
|
| 133 |
+
x = self.conv2(x)
|
| 134 |
+
x = self.noise(x)
|
| 135 |
+
x = self.bn2(x, c)
|
| 136 |
+
x = self.act(x)
|
| 137 |
+
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class SEBlock(nn.Module):
|
| 142 |
+
def __init__(self, ch_in, ch_out):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.main = nn.Sequential(
|
| 145 |
+
nn.AdaptiveAvgPool2d(4),
|
| 146 |
+
conv2d(ch_in, ch_out, 4, 1, 0, bias=False),
|
| 147 |
+
Swish(),
|
| 148 |
+
conv2d(ch_out, ch_out, 1, 1, 0, bias=False),
|
| 149 |
+
nn.Sigmoid(),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def forward(self, feat_small, feat_big):
|
| 153 |
+
return feat_big * self.main(feat_small)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
### Downblocks
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class SeparableConv2d(nn.Module):
|
| 160 |
+
def __init__(self, in_channels, out_channels, kernel_size, bias=False):
|
| 161 |
+
super(SeparableConv2d, self).__init__()
|
| 162 |
+
self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size,
|
| 163 |
+
groups=in_channels, bias=bias, padding=1)
|
| 164 |
+
self.pointwise = conv2d(in_channels, out_channels,
|
| 165 |
+
kernel_size=1, bias=bias)
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
out = self.depthwise(x)
|
| 169 |
+
out = self.pointwise(out)
|
| 170 |
+
return out
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class DownBlock(nn.Module):
|
| 174 |
+
def __init__(self, in_planes, out_planes, separable=False):
|
| 175 |
+
super().__init__()
|
| 176 |
+
if not separable:
|
| 177 |
+
self.main = nn.Sequential(
|
| 178 |
+
conv2d(in_planes, out_planes, 4, 2, 1),
|
| 179 |
+
NormLayer(out_planes),
|
| 180 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
self.main = nn.Sequential(
|
| 184 |
+
SeparableConv2d(in_planes, out_planes, 3),
|
| 185 |
+
NormLayer(out_planes),
|
| 186 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 187 |
+
nn.AvgPool2d(2, 2),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def forward(self, feat):
|
| 191 |
+
return self.main(feat)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class DownBlockPatch(nn.Module):
|
| 195 |
+
def __init__(self, in_planes, out_planes, separable=False):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.main = nn.Sequential(
|
| 198 |
+
DownBlock(in_planes, out_planes, separable),
|
| 199 |
+
conv2d(out_planes, out_planes, 1, 1, 0, bias=False),
|
| 200 |
+
NormLayer(out_planes),
|
| 201 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
def forward(self, feat):
|
| 205 |
+
return self.main(feat)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
### CSM
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class ResidualConvUnit(nn.Module):
|
| 212 |
+
def __init__(self, cin, activation, bn):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.conv = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=True)
|
| 215 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 216 |
+
|
| 217 |
+
def forward(self, x):
|
| 218 |
+
return self.skip_add.add(self.conv(x), x)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class FeatureFusionBlock(nn.Module):
|
| 222 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False):
|
| 223 |
+
super().__init__()
|
| 224 |
+
|
| 225 |
+
self.deconv = deconv
|
| 226 |
+
self.align_corners = align_corners
|
| 227 |
+
|
| 228 |
+
self.expand = expand
|
| 229 |
+
out_features = features
|
| 230 |
+
if self.expand==True:
|
| 231 |
+
out_features = features//2
|
| 232 |
+
|
| 233 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
| 234 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 235 |
+
|
| 236 |
+
def forward(self, *xs):
|
| 237 |
+
output = xs[0]
|
| 238 |
+
|
| 239 |
+
if len(xs) == 2:
|
| 240 |
+
output = self.skip_add.add(output, xs[1])
|
| 241 |
+
|
| 242 |
+
output = nn.functional.interpolate(
|
| 243 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
output = self.out_conv(output)
|
| 247 |
+
|
| 248 |
+
return output
|
| 249 |
+
|
| 250 |
+
class FeatureFusionBlock_V2(nn.Module):
|
| 251 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False):
|
| 252 |
+
super().__init__()
|
| 253 |
+
|
| 254 |
+
self.deconv = deconv
|
| 255 |
+
self.align_corners = align_corners
|
| 256 |
+
|
| 257 |
+
self.expand = expand
|
| 258 |
+
out_features = features
|
| 259 |
+
if self.expand==True:
|
| 260 |
+
out_features = features
|
| 261 |
+
|
| 262 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
| 263 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 264 |
+
|
| 265 |
+
def forward(self, *xs):
|
| 266 |
+
output = xs[0]
|
| 267 |
+
|
| 268 |
+
if len(xs) == 2:
|
| 269 |
+
output = self.skip_add.add(output, xs[1])
|
| 270 |
+
|
| 271 |
+
# output = nn.functional.interpolate(
|
| 272 |
+
# output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
| 273 |
+
# )
|
| 274 |
+
|
| 275 |
+
output = self.out_conv(output)
|
| 276 |
+
|
| 277 |
+
return output
|
| 278 |
+
from timm.models.vision_transformer import PatchEmbed, Block
|
| 279 |
+
|
| 280 |
+
class FeatureFusionBlockTrans(nn.Module):
|
| 281 |
+
def __init__(self, features):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.out_conv = Block(features,num_heads=12)
|
| 284 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 285 |
+
|
| 286 |
+
def forward(self, *xs):
|
| 287 |
+
output = xs[0]
|
| 288 |
+
|
| 289 |
+
if len(xs) == 2:
|
| 290 |
+
output = self.skip_add.add(output, xs[1])
|
| 291 |
+
output = self.out_conv(output)
|
| 292 |
+
|
| 293 |
+
return output
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
### Misc
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class NoiseInjection(nn.Module):
|
| 300 |
+
def __init__(self):
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)
|
| 303 |
+
|
| 304 |
+
def forward(self, feat, noise=None):
|
| 305 |
+
if noise is None:
|
| 306 |
+
batch, _, height, width = feat.shape
|
| 307 |
+
noise = torch.randn(batch, 1, height, width).to(feat.device)
|
| 308 |
+
|
| 309 |
+
return feat + self.weight * noise
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class CCBN(nn.Module):
|
| 313 |
+
''' conditional batchnorm '''
|
| 314 |
+
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1):
|
| 315 |
+
super().__init__()
|
| 316 |
+
self.output_size, self.input_size = output_size, input_size
|
| 317 |
+
|
| 318 |
+
# Prepare gain and bias layers
|
| 319 |
+
self.gain = which_linear(input_size, output_size)
|
| 320 |
+
self.bias = which_linear(input_size, output_size)
|
| 321 |
+
|
| 322 |
+
# epsilon to avoid dividing by 0
|
| 323 |
+
self.eps = eps
|
| 324 |
+
# Momentum
|
| 325 |
+
self.momentum = momentum
|
| 326 |
+
|
| 327 |
+
self.register_buffer('stored_mean', torch.zeros(output_size))
|
| 328 |
+
self.register_buffer('stored_var', torch.ones(output_size))
|
| 329 |
+
|
| 330 |
+
def forward(self, x, y):
|
| 331 |
+
# Calculate class-conditional gains and biases
|
| 332 |
+
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
|
| 333 |
+
bias = self.bias(y).view(y.size(0), -1, 1, 1)
|
| 334 |
+
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
|
| 335 |
+
self.training, 0.1, self.eps)
|
| 336 |
+
return out * gain + bias
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class Interpolate(nn.Module):
|
| 340 |
+
"""Interpolation module."""
|
| 341 |
+
|
| 342 |
+
def __init__(self, size, mode='bilinear', align_corners=False):
|
| 343 |
+
"""Init.
|
| 344 |
+
Args:
|
| 345 |
+
scale_factor (float): scaling
|
| 346 |
+
mode (str): interpolation mode
|
| 347 |
+
"""
|
| 348 |
+
super(Interpolate, self).__init__()
|
| 349 |
+
|
| 350 |
+
self.interp = nn.functional.interpolate
|
| 351 |
+
self.size = size
|
| 352 |
+
self.mode = mode
|
| 353 |
+
self.align_corners = align_corners
|
| 354 |
+
|
| 355 |
+
def forward(self, x):
|
| 356 |
+
"""Forward pass.
|
| 357 |
+
Args:
|
| 358 |
+
x (tensor): input
|
| 359 |
+
Returns:
|
| 360 |
+
tensor: interpolated data
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
x = self.interp(
|
| 364 |
+
x,
|
| 365 |
+
size=self.size,
|
| 366 |
+
mode=self.mode,
|
| 367 |
+
align_corners=self.align_corners,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
return x
|
pg_modules/diffaug.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Differentiable Augmentation for Data-Efficient GAN Training
|
| 2 |
+
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
|
| 3 |
+
# https://arxiv.org/pdf/2006.10738
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def DiffAugment(x, policy='', channels_first=True):
|
| 10 |
+
if policy:
|
| 11 |
+
if not channels_first:
|
| 12 |
+
x = x.permute(0, 3, 1, 2)
|
| 13 |
+
for p in policy.split(','):
|
| 14 |
+
for f in AUGMENT_FNS[p]:
|
| 15 |
+
x = f(x)
|
| 16 |
+
if not channels_first:
|
| 17 |
+
x = x.permute(0, 2, 3, 1)
|
| 18 |
+
x = x.contiguous()
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def rand_brightness(x):
|
| 23 |
+
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def rand_saturation(x):
|
| 28 |
+
x_mean = x.mean(dim=1, keepdim=True)
|
| 29 |
+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def rand_contrast(x):
|
| 34 |
+
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
| 35 |
+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def rand_translation(x, ratio=0.125):
|
| 40 |
+
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
| 41 |
+
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
| 42 |
+
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
| 43 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
| 44 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
| 45 |
+
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
| 46 |
+
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
| 47 |
+
)
|
| 48 |
+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
| 49 |
+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
| 50 |
+
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
| 51 |
+
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def rand_cutout(x, ratio=0.2):
|
| 56 |
+
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
| 57 |
+
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
|
| 58 |
+
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
|
| 59 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
| 60 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
| 61 |
+
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
| 62 |
+
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
| 63 |
+
)
|
| 64 |
+
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
| 65 |
+
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
| 66 |
+
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
| 67 |
+
mask[grid_batch, grid_x, grid_y] = 0
|
| 68 |
+
x = x * mask.unsqueeze(1)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
AUGMENT_FNS = {
|
| 73 |
+
'color': [rand_brightness, rand_saturation, rand_contrast],
|
| 74 |
+
'translation': [rand_translation],
|
| 75 |
+
'cutout': [rand_cutout],
|
| 76 |
+
}
|
pg_modules/discriminator.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torchvision.transforms import Normalize
|
| 6 |
+
import pickle
|
| 7 |
+
|
| 8 |
+
from pg_modules.diffaug import DiffAugment
|
| 9 |
+
from pg_modules.blocks import conv2d, DownBlock, DownBlockPatch
|
| 10 |
+
from pg_modules.projector import F_RandomProj
|
| 11 |
+
from feature_networks.constants import VITS
|
| 12 |
+
|
| 13 |
+
class SingleDisc(nn.Module):
|
| 14 |
+
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, patch=False):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
# midas channels
|
| 18 |
+
nfc_midas = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
|
| 19 |
+
256: 32, 512: 16, 1024: 8}
|
| 20 |
+
|
| 21 |
+
# interpolate for start sz that are not powers of two
|
| 22 |
+
if start_sz not in nfc_midas.keys():
|
| 23 |
+
sizes = np.array(list(nfc_midas.keys()))
|
| 24 |
+
start_sz = sizes[np.argmin(abs(sizes - start_sz))]
|
| 25 |
+
self.start_sz = start_sz
|
| 26 |
+
|
| 27 |
+
# if given ndf, allocate all layers with the same ndf
|
| 28 |
+
if ndf is None:
|
| 29 |
+
nfc = nfc_midas
|
| 30 |
+
else:
|
| 31 |
+
nfc = {k: ndf for k, v in nfc_midas.items()}
|
| 32 |
+
|
| 33 |
+
# for feature map discriminators with nfc not in nfc_midas
|
| 34 |
+
# this is the case for the pretrained backbone (midas.pretrained)
|
| 35 |
+
if nc is not None and head is None:
|
| 36 |
+
nfc[start_sz] = nc
|
| 37 |
+
|
| 38 |
+
layers = []
|
| 39 |
+
|
| 40 |
+
# Head if the initial input is the full modality
|
| 41 |
+
if head:
|
| 42 |
+
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
|
| 43 |
+
nn.LeakyReLU(0.2, inplace=True)]
|
| 44 |
+
|
| 45 |
+
# Down Blocks
|
| 46 |
+
DB = DownBlockPatch if patch else DownBlock
|
| 47 |
+
while start_sz > end_sz:
|
| 48 |
+
layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
|
| 49 |
+
start_sz = start_sz // 2
|
| 50 |
+
|
| 51 |
+
layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False))
|
| 52 |
+
self.main = nn.Sequential(*layers)
|
| 53 |
+
|
| 54 |
+
def forward(self, x, c):
|
| 55 |
+
return self.main(x)
|
| 56 |
+
|
| 57 |
+
class MultiScaleD(nn.Module):
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
channels,
|
| 61 |
+
resolutions,
|
| 62 |
+
num_discs=4,
|
| 63 |
+
proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
|
| 64 |
+
cond=0,
|
| 65 |
+
patch=False,
|
| 66 |
+
**kwargs,
|
| 67 |
+
):
|
| 68 |
+
super().__init__()
|
| 69 |
+
|
| 70 |
+
assert num_discs in [1, 2, 3, 4, 5]
|
| 71 |
+
|
| 72 |
+
# the first disc is on the lowest level of the backbone
|
| 73 |
+
self.disc_in_channels = channels[:num_discs]
|
| 74 |
+
self.disc_in_res = resolutions[:num_discs]
|
| 75 |
+
Disc = SingleDisc
|
| 76 |
+
|
| 77 |
+
mini_discs = []
|
| 78 |
+
for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)):
|
| 79 |
+
start_sz = res if not patch else 16
|
| 80 |
+
mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, patch=patch)],
|
| 81 |
+
|
| 82 |
+
self.mini_discs = nn.ModuleDict(mini_discs)
|
| 83 |
+
|
| 84 |
+
def forward(self, features, c, rec=False):
|
| 85 |
+
all_logits = []
|
| 86 |
+
for k, disc in self.mini_discs.items():
|
| 87 |
+
all_logits.append(disc(features[k], c).view(features[k].size(0), -1))
|
| 88 |
+
|
| 89 |
+
all_logits = torch.cat(all_logits, dim=1)
|
| 90 |
+
return all_logits
|
| 91 |
+
|
| 92 |
+
class ProjectedDiscriminator(torch.nn.Module):
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
backbones,
|
| 96 |
+
diffaug=True,
|
| 97 |
+
interp224=True,
|
| 98 |
+
backbone_kwargs={},
|
| 99 |
+
**kwargs
|
| 100 |
+
):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.backbones = backbones
|
| 103 |
+
self.diffaug = diffaug
|
| 104 |
+
self.interp224 = interp224
|
| 105 |
+
|
| 106 |
+
# get backbones and multi-scale discs
|
| 107 |
+
feature_networks, discriminators = [], []
|
| 108 |
+
|
| 109 |
+
for i, bb_name in enumerate(backbones):
|
| 110 |
+
|
| 111 |
+
feat = F_RandomProj(bb_name, **backbone_kwargs)
|
| 112 |
+
disc = MultiScaleD(
|
| 113 |
+
channels=feat.CHANNELS,
|
| 114 |
+
resolutions=feat.RESOLUTIONS,
|
| 115 |
+
**backbone_kwargs,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
feature_networks.append([bb_name, feat])
|
| 119 |
+
discriminators.append([bb_name, disc])
|
| 120 |
+
|
| 121 |
+
self.feature_networks = nn.ModuleDict(feature_networks)
|
| 122 |
+
self.discriminators = nn.ModuleDict(discriminators)
|
| 123 |
+
|
| 124 |
+
def train(self, mode=True):
|
| 125 |
+
self.feature_networks = self.feature_networks.train(False)
|
| 126 |
+
self.discriminators = self.discriminators.train(mode)
|
| 127 |
+
return self
|
| 128 |
+
|
| 129 |
+
def eval(self):
|
| 130 |
+
return self.train(False)
|
| 131 |
+
|
| 132 |
+
def forward(self, x, c):
|
| 133 |
+
logits = []
|
| 134 |
+
for bb_name, feat in self.feature_networks.items():
|
| 135 |
+
|
| 136 |
+
# apply augmentation (x in [-1, 1])
|
| 137 |
+
x_aug = DiffAugment(x, policy='color,translation,cutout') if self.diffaug else x
|
| 138 |
+
|
| 139 |
+
# transform to [0,1]
|
| 140 |
+
x_aug = x_aug.add(1).div(2)
|
| 141 |
+
|
| 142 |
+
# apply F-specific normalization
|
| 143 |
+
x_n = Normalize(feat.normstats['mean'], feat.normstats['std'])(x_aug)
|
| 144 |
+
|
| 145 |
+
# upsample if smaller, downsample if larger + VIT
|
| 146 |
+
if self.interp224 or bb_name in VITS:
|
| 147 |
+
x_n = F.interpolate(x_n, 224, mode='bilinear', align_corners=False)
|
| 148 |
+
|
| 149 |
+
# forward pass
|
| 150 |
+
features = feat(x_n)
|
| 151 |
+
logits += self.discriminators[bb_name](features, c)
|
| 152 |
+
|
| 153 |
+
return logits
|
pg_modules/networks_fastgan.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# original implementation: https://github.com/odegeasslbc/FastGAN-pytorch/blob/main/models.py
|
| 2 |
+
#
|
| 3 |
+
# modified by Axel Sauer for "Projected GANs Converge Faster"
|
| 4 |
+
#
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from pg_modules.blocks import (InitLayer, UpBlockBig, UpBlockBigCond, UpBlockSmall, UpBlockSmallCond, SEBlock, conv2d)
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
def normalize_second_moment(x, dim=1, eps=1e-8):
|
| 10 |
+
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DummyMapping(nn.Module):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
def forward(self, z, c, **kwargs):
|
| 18 |
+
return z.unsqueeze(1) # to fit the StyleGAN API
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FastganSynthesis(nn.Module):
|
| 22 |
+
def __init__(self, ngf=128, z_dim=256, nc=3, img_resolution=256, lite=False):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.img_resolution = img_resolution
|
| 25 |
+
self.z_dim = z_dim
|
| 26 |
+
|
| 27 |
+
# channel multiplier
|
| 28 |
+
nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 224:0.5, 256:0.5,
|
| 29 |
+
512:0.25, 1024:0.125}
|
| 30 |
+
nfc = {}
|
| 31 |
+
for k, v in nfc_multi.items():
|
| 32 |
+
nfc[k] = int(v*ngf)
|
| 33 |
+
|
| 34 |
+
# layers
|
| 35 |
+
self.init = InitLayer(z_dim, channel=nfc[2], sz=4)
|
| 36 |
+
|
| 37 |
+
UpBlock = UpBlockSmall if lite else UpBlockBig
|
| 38 |
+
|
| 39 |
+
self.feat_8 = UpBlock(nfc[4], nfc[8])
|
| 40 |
+
self.feat_16 = UpBlock(nfc[8], nfc[16])
|
| 41 |
+
self.feat_32 = UpBlock(nfc[16], nfc[32])
|
| 42 |
+
self.feat_64 = UpBlock(nfc[32], nfc[64])
|
| 43 |
+
self.feat_128 = UpBlock(nfc[64], nfc[128])
|
| 44 |
+
self.feat_256 = UpBlock(nfc[128], nfc[256])
|
| 45 |
+
|
| 46 |
+
self.se_64 = SEBlock(nfc[4], nfc[64])
|
| 47 |
+
self.se_128 = SEBlock(nfc[8], nfc[128])
|
| 48 |
+
self.se_256 = SEBlock(nfc[16], nfc[256])
|
| 49 |
+
|
| 50 |
+
self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True)
|
| 51 |
+
|
| 52 |
+
if img_resolution > 256:
|
| 53 |
+
self.feat_512 = UpBlock(nfc[256], nfc[512])
|
| 54 |
+
self.se_512 = SEBlock(nfc[32], nfc[512])
|
| 55 |
+
if img_resolution > 512:
|
| 56 |
+
self.feat_1024 = UpBlock(nfc[512], nfc[1024])
|
| 57 |
+
|
| 58 |
+
def forward(self, input, c, **kwargs):
|
| 59 |
+
# map noise to hypersphere as in "Progressive Growing of GANS"
|
| 60 |
+
input = normalize_second_moment(input[:, 0])
|
| 61 |
+
feat_4 = self.init(input)
|
| 62 |
+
feat_8 = self.feat_8(feat_4)
|
| 63 |
+
feat_16 = self.feat_16(feat_8)
|
| 64 |
+
feat_32 = self.feat_32(feat_16)
|
| 65 |
+
feat_64 = self.se_64(feat_4, self.feat_64(feat_32))
|
| 66 |
+
feat_128 = self.se_128(feat_8, self.feat_128(feat_64))\
|
| 67 |
+
|
| 68 |
+
if self.img_resolution >= 64:
|
| 69 |
+
feat_last = feat_64
|
| 70 |
+
|
| 71 |
+
if self.img_resolution >= 128:
|
| 72 |
+
feat_last = feat_128
|
| 73 |
+
|
| 74 |
+
if self.img_resolution >= 224:
|
| 75 |
+
feat_last = self.se_256(feat_16, self.feat_256(feat_last))
|
| 76 |
+
|
| 77 |
+
if self.img_resolution >= 512:
|
| 78 |
+
feat_last = self.se_512(feat_32, self.feat_512(feat_last))
|
| 79 |
+
|
| 80 |
+
if self.img_resolution >= 1024:
|
| 81 |
+
feat_last = self.feat_1024(feat_last)
|
| 82 |
+
|
| 83 |
+
return torch.tanh(self.to_big(feat_last))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class FastganSynthesisCond(nn.Module):
|
| 87 |
+
def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False):
|
| 88 |
+
super().__init__()
|
| 89 |
+
|
| 90 |
+
self.z_dim = z_dim
|
| 91 |
+
nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5,
|
| 92 |
+
512:0.25, 1024:0.125, 2048:0.125}
|
| 93 |
+
nfc = {}
|
| 94 |
+
for k, v in nfc_multi.items():
|
| 95 |
+
nfc[k] = int(v*ngf)
|
| 96 |
+
|
| 97 |
+
self.img_resolution = img_resolution
|
| 98 |
+
|
| 99 |
+
self.init = InitLayer(z_dim, channel=nfc[2], sz=4)
|
| 100 |
+
|
| 101 |
+
UpBlock = UpBlockSmallCond if lite else UpBlockBigCond
|
| 102 |
+
|
| 103 |
+
self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim)
|
| 104 |
+
self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim)
|
| 105 |
+
self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim)
|
| 106 |
+
self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim)
|
| 107 |
+
self.feat_128 = UpBlock(nfc[64], nfc[128], z_dim)
|
| 108 |
+
self.feat_256 = UpBlock(nfc[128], nfc[256], z_dim)
|
| 109 |
+
|
| 110 |
+
self.se_64 = SEBlock(nfc[4], nfc[64])
|
| 111 |
+
self.se_128 = SEBlock(nfc[8], nfc[128])
|
| 112 |
+
self.se_256 = SEBlock(nfc[16], nfc[256])
|
| 113 |
+
|
| 114 |
+
self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True)
|
| 115 |
+
|
| 116 |
+
if img_resolution > 256:
|
| 117 |
+
self.feat_512 = UpBlock(nfc[256], nfc[512])
|
| 118 |
+
self.se_512 = SEBlock(nfc[32], nfc[512])
|
| 119 |
+
if img_resolution > 512:
|
| 120 |
+
self.feat_1024 = UpBlock(nfc[512], nfc[1024])
|
| 121 |
+
|
| 122 |
+
self.embed = nn.Embedding(num_classes, z_dim)
|
| 123 |
+
|
| 124 |
+
def forward(self, input, c, update_emas=False):
|
| 125 |
+
|
| 126 |
+
c = self.embed(c.argmax(1))
|
| 127 |
+
|
| 128 |
+
# map noise to hypersphere as in "Progressive Growing of GANS"
|
| 129 |
+
input = normalize_second_moment(input[:, 0])
|
| 130 |
+
|
| 131 |
+
feat_4 = self.init(input)
|
| 132 |
+
feat_8 = self.feat_8(feat_4, c)
|
| 133 |
+
feat_16 = self.feat_16(feat_8, c)
|
| 134 |
+
feat_32 = self.feat_32(feat_16, c)
|
| 135 |
+
feat_64 = self.se_64(feat_4, self.feat_64(feat_32, c))
|
| 136 |
+
feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c))
|
| 137 |
+
|
| 138 |
+
if self.img_resolution >= 128:
|
| 139 |
+
feat_last = feat_128
|
| 140 |
+
|
| 141 |
+
if self.img_resolution >= 256:
|
| 142 |
+
feat_last = self.se_256(feat_16, self.feat_256(feat_last, c))
|
| 143 |
+
|
| 144 |
+
if self.img_resolution >= 512:
|
| 145 |
+
feat_last = self.se_512(feat_32, self.feat_512(feat_last, c))
|
| 146 |
+
|
| 147 |
+
if self.img_resolution >= 1024:
|
| 148 |
+
feat_last = self.feat_1024(feat_last, c)
|
| 149 |
+
return self.to_big(feat_last)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Generator(nn.Module):
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
z_dim=256,
|
| 156 |
+
c_dim=0,
|
| 157 |
+
w_dim=0,
|
| 158 |
+
img_resolution=256,
|
| 159 |
+
img_channels=3,
|
| 160 |
+
ngf=128,
|
| 161 |
+
cond=0,
|
| 162 |
+
mapping_kwargs={},
|
| 163 |
+
synthesis_kwargs={}
|
| 164 |
+
):
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.z_dim = z_dim
|
| 167 |
+
self.c_dim = c_dim
|
| 168 |
+
self.w_dim = w_dim
|
| 169 |
+
self.img_resolution = img_resolution
|
| 170 |
+
self.img_channels = img_channels
|
| 171 |
+
|
| 172 |
+
# Mapping and Synthesis Networks
|
| 173 |
+
self.mapping = DummyMapping() # to fit the StyleGAN API
|
| 174 |
+
Synthesis = FastganSynthesisCond if cond else FastganSynthesis
|
| 175 |
+
self.synthesis = Synthesis(ngf=ngf, z_dim=z_dim, nc=img_channels, img_resolution=img_resolution, **synthesis_kwargs)
|
| 176 |
+
|
| 177 |
+
def forward(self, z, c, **kwargs):
|
| 178 |
+
w = self.mapping(z, c)
|
| 179 |
+
img = self.synthesis(w, c)
|
| 180 |
+
return img
|
pg_modules/networks_stylegan2.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
#
|
| 9 |
+
# modified by Axel Sauer for "Projected GANs Converge Faster"
|
| 10 |
+
#
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from torch_utils import misc
|
| 14 |
+
from torch_utils import persistence
|
| 15 |
+
from torch_utils.ops import conv2d_resample
|
| 16 |
+
from torch_utils.ops import upfirdn2d
|
| 17 |
+
from torch_utils.ops import bias_act
|
| 18 |
+
from torch_utils.ops import fma
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@misc.profiled_function
|
| 22 |
+
def normalize_2nd_moment(x, dim=1, eps=1e-8):
|
| 23 |
+
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@misc.profiled_function
|
| 27 |
+
def modulated_conv2d(
|
| 28 |
+
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
|
| 29 |
+
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
|
| 30 |
+
styles, # Modulation coefficients of shape [batch_size, in_channels].
|
| 31 |
+
noise = None, # Optional noise tensor to add to the output activations.
|
| 32 |
+
up = 1, # Integer upsampling factor.
|
| 33 |
+
down = 1, # Integer downsampling factor.
|
| 34 |
+
padding = 0, # Padding with respect to the upsampled image.
|
| 35 |
+
resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
|
| 36 |
+
demodulate = True, # Apply weight demodulation?
|
| 37 |
+
flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
|
| 38 |
+
fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
|
| 39 |
+
):
|
| 40 |
+
batch_size = x.shape[0]
|
| 41 |
+
out_channels, in_channels, kh, kw = weight.shape
|
| 42 |
+
misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
|
| 43 |
+
misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
|
| 44 |
+
misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
|
| 45 |
+
|
| 46 |
+
# Pre-normalize inputs to avoid FP16 overflow.
|
| 47 |
+
if x.dtype == torch.float16 and demodulate:
|
| 48 |
+
weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
|
| 49 |
+
styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
|
| 50 |
+
|
| 51 |
+
# Calculate per-sample weights and demodulation coefficients.
|
| 52 |
+
w = None
|
| 53 |
+
dcoefs = None
|
| 54 |
+
if demodulate or fused_modconv:
|
| 55 |
+
w = weight.unsqueeze(0) # [NOIkk]
|
| 56 |
+
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
|
| 57 |
+
if demodulate:
|
| 58 |
+
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
|
| 59 |
+
if demodulate and fused_modconv:
|
| 60 |
+
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
|
| 61 |
+
|
| 62 |
+
# Execute by scaling the activations before and after the convolution.
|
| 63 |
+
if not fused_modconv:
|
| 64 |
+
x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
| 65 |
+
x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
|
| 66 |
+
if demodulate and noise is not None:
|
| 67 |
+
x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
|
| 68 |
+
elif demodulate:
|
| 69 |
+
x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
| 70 |
+
elif noise is not None:
|
| 71 |
+
x = x.add_(noise.to(x.dtype))
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
# Execute as one fused op using grouped convolution.
|
| 75 |
+
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
| 76 |
+
batch_size = int(batch_size)
|
| 77 |
+
misc.assert_shape(x, [batch_size, in_channels, None, None])
|
| 78 |
+
x = x.reshape(1, -1, *x.shape[2:])
|
| 79 |
+
w = w.reshape(-1, in_channels, kh, kw)
|
| 80 |
+
x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
|
| 81 |
+
x = x.reshape(batch_size, -1, *x.shape[2:])
|
| 82 |
+
if noise is not None:
|
| 83 |
+
x = x.add_(noise)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@persistence.persistent_class
|
| 88 |
+
class FullyConnectedLayer(torch.nn.Module):
|
| 89 |
+
def __init__(self,
|
| 90 |
+
in_features, # Number of input features.
|
| 91 |
+
out_features, # Number of output features.
|
| 92 |
+
bias = True, # Apply additive bias before the activation function?
|
| 93 |
+
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
| 94 |
+
lr_multiplier = 1, # Learning rate multiplier.
|
| 95 |
+
bias_init = 0, # Initial value for the additive bias.
|
| 96 |
+
):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.in_features = in_features
|
| 99 |
+
self.out_features = out_features
|
| 100 |
+
self.activation = activation
|
| 101 |
+
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
|
| 102 |
+
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
|
| 103 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
| 104 |
+
self.bias_gain = lr_multiplier
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
w = self.weight.to(x.dtype) * self.weight_gain
|
| 108 |
+
b = self.bias
|
| 109 |
+
if b is not None:
|
| 110 |
+
b = b.to(x.dtype)
|
| 111 |
+
if self.bias_gain != 1:
|
| 112 |
+
b = b * self.bias_gain
|
| 113 |
+
|
| 114 |
+
if self.activation == 'linear' and b is not None:
|
| 115 |
+
x = torch.addmm(b.unsqueeze(0), x, w.t())
|
| 116 |
+
else:
|
| 117 |
+
x = x.matmul(w.t())
|
| 118 |
+
x = bias_act.bias_act(x, b, act=self.activation)
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
def extra_repr(self):
|
| 122 |
+
return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@persistence.persistent_class
|
| 126 |
+
class Conv2dLayer(torch.nn.Module):
|
| 127 |
+
def __init__(self,
|
| 128 |
+
in_channels, # Number of input channels.
|
| 129 |
+
out_channels, # Number of output channels.
|
| 130 |
+
kernel_size, # Width and height of the convolution kernel.
|
| 131 |
+
bias = True, # Apply additive bias before the activation function?
|
| 132 |
+
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
| 133 |
+
up = 1, # Integer upsampling factor.
|
| 134 |
+
down = 1, # Integer downsampling factor.
|
| 135 |
+
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
| 136 |
+
conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
|
| 137 |
+
channels_last = False, # Expect the input to have memory_format=channels_last?
|
| 138 |
+
trainable = True, # Update the weights of this layer during training?
|
| 139 |
+
):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.in_channels = in_channels
|
| 142 |
+
self.out_channels = out_channels
|
| 143 |
+
self.activation = activation
|
| 144 |
+
self.up = up
|
| 145 |
+
self.down = down
|
| 146 |
+
self.conv_clamp = conv_clamp
|
| 147 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
| 148 |
+
self.padding = kernel_size // 2
|
| 149 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
| 150 |
+
self.act_gain = bias_act.activation_funcs[activation].def_gain
|
| 151 |
+
|
| 152 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
| 153 |
+
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
|
| 154 |
+
bias = torch.zeros([out_channels]) if bias else None
|
| 155 |
+
if trainable:
|
| 156 |
+
self.weight = torch.nn.Parameter(weight)
|
| 157 |
+
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
| 158 |
+
else:
|
| 159 |
+
self.register_buffer('weight', weight)
|
| 160 |
+
if bias is not None:
|
| 161 |
+
self.register_buffer('bias', bias)
|
| 162 |
+
else:
|
| 163 |
+
self.bias = None
|
| 164 |
+
|
| 165 |
+
def forward(self, x, gain=1):
|
| 166 |
+
w = self.weight * self.weight_gain
|
| 167 |
+
b = self.bias.to(x.dtype) if self.bias is not None else None
|
| 168 |
+
flip_weight = (self.up == 1) # slightly faster
|
| 169 |
+
x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
|
| 170 |
+
|
| 171 |
+
act_gain = self.act_gain * gain
|
| 172 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
| 173 |
+
x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
def extra_repr(self):
|
| 177 |
+
return ' '.join([
|
| 178 |
+
f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},',
|
| 179 |
+
f'up={self.up}, down={self.down}'])
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@persistence.persistent_class
|
| 183 |
+
class MappingNetwork(torch.nn.Module):
|
| 184 |
+
def __init__(self,
|
| 185 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
| 186 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
| 187 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 188 |
+
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
| 189 |
+
num_layers = 8, # Number of mapping layers.
|
| 190 |
+
embed_features = None, # Label embedding dimensionality, None = same as w_dim.
|
| 191 |
+
layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
| 192 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 193 |
+
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
|
| 194 |
+
w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track.
|
| 195 |
+
):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.z_dim = z_dim
|
| 198 |
+
self.c_dim = c_dim
|
| 199 |
+
self.w_dim = w_dim
|
| 200 |
+
self.num_ws = num_ws
|
| 201 |
+
self.num_layers = num_layers
|
| 202 |
+
self.w_avg_beta = w_avg_beta
|
| 203 |
+
|
| 204 |
+
if embed_features is None:
|
| 205 |
+
embed_features = w_dim
|
| 206 |
+
if c_dim == 0:
|
| 207 |
+
embed_features = 0
|
| 208 |
+
if layer_features is None:
|
| 209 |
+
layer_features = w_dim
|
| 210 |
+
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
| 211 |
+
|
| 212 |
+
if c_dim > 0:
|
| 213 |
+
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
| 214 |
+
for idx in range(num_layers):
|
| 215 |
+
in_features = features_list[idx]
|
| 216 |
+
out_features = features_list[idx + 1]
|
| 217 |
+
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
|
| 218 |
+
setattr(self, f'fc{idx}', layer)
|
| 219 |
+
|
| 220 |
+
if num_ws is not None and w_avg_beta is not None:
|
| 221 |
+
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
| 222 |
+
|
| 223 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
|
| 224 |
+
# Embed, normalize, and concat inputs.
|
| 225 |
+
x = None
|
| 226 |
+
with torch.autograd.profiler.record_function('input'):
|
| 227 |
+
if self.z_dim > 0:
|
| 228 |
+
misc.assert_shape(z, [None, self.z_dim])
|
| 229 |
+
x = normalize_2nd_moment(z.to(torch.float32))
|
| 230 |
+
if self.c_dim > 0:
|
| 231 |
+
misc.assert_shape(c, [None, self.c_dim])
|
| 232 |
+
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
| 233 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
| 234 |
+
|
| 235 |
+
# Main layers.
|
| 236 |
+
for idx in range(self.num_layers):
|
| 237 |
+
layer = getattr(self, f'fc{idx}')
|
| 238 |
+
x = layer(x)
|
| 239 |
+
|
| 240 |
+
# Update moving average of W.
|
| 241 |
+
if update_emas and self.w_avg_beta is not None:
|
| 242 |
+
with torch.autograd.profiler.record_function('update_w_avg'):
|
| 243 |
+
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
| 244 |
+
|
| 245 |
+
# Broadcast.
|
| 246 |
+
if self.num_ws is not None:
|
| 247 |
+
with torch.autograd.profiler.record_function('broadcast'):
|
| 248 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
| 249 |
+
|
| 250 |
+
# Apply truncation.
|
| 251 |
+
if truncation_psi != 1:
|
| 252 |
+
with torch.autograd.profiler.record_function('truncate'):
|
| 253 |
+
assert self.w_avg_beta is not None
|
| 254 |
+
if self.num_ws is None or truncation_cutoff is None:
|
| 255 |
+
x = self.w_avg.lerp(x, truncation_psi)
|
| 256 |
+
else:
|
| 257 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
| 258 |
+
return x
|
| 259 |
+
|
| 260 |
+
def extra_repr(self):
|
| 261 |
+
return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@persistence.persistent_class
|
| 265 |
+
class SynthesisLayer(torch.nn.Module):
|
| 266 |
+
def __init__(self,
|
| 267 |
+
in_channels, # Number of input channels.
|
| 268 |
+
out_channels, # Number of output channels.
|
| 269 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 270 |
+
resolution, # Resolution of this layer.
|
| 271 |
+
kernel_size = 3, # Convolution kernel size.
|
| 272 |
+
up = 1, # Integer upsampling factor.
|
| 273 |
+
use_noise = True, # Enable noise input?
|
| 274 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 275 |
+
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
| 276 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 277 |
+
channels_last = False, # Use channels_last format for the weights?
|
| 278 |
+
):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.in_channels = in_channels
|
| 281 |
+
self.out_channels = out_channels
|
| 282 |
+
self.w_dim = w_dim
|
| 283 |
+
self.resolution = resolution
|
| 284 |
+
self.up = up
|
| 285 |
+
self.use_noise = use_noise
|
| 286 |
+
self.activation = activation
|
| 287 |
+
self.conv_clamp = conv_clamp
|
| 288 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
| 289 |
+
self.padding = kernel_size // 2
|
| 290 |
+
self.act_gain = bias_act.activation_funcs[activation].def_gain
|
| 291 |
+
|
| 292 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
| 293 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
| 294 |
+
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
| 295 |
+
if use_noise:
|
| 296 |
+
self.register_buffer('noise_const', torch.randn([resolution, resolution]))
|
| 297 |
+
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
|
| 298 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
| 299 |
+
|
| 300 |
+
def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
|
| 301 |
+
assert noise_mode in ['random', 'const', 'none']
|
| 302 |
+
in_resolution = self.resolution // self.up
|
| 303 |
+
misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution])
|
| 304 |
+
styles = self.affine(w)
|
| 305 |
+
|
| 306 |
+
noise = None
|
| 307 |
+
if self.use_noise and noise_mode == 'random':
|
| 308 |
+
noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
|
| 309 |
+
if self.use_noise and noise_mode == 'const':
|
| 310 |
+
noise = self.noise_const * self.noise_strength
|
| 311 |
+
|
| 312 |
+
flip_weight = (self.up == 1) # slightly faster
|
| 313 |
+
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
|
| 314 |
+
padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
|
| 315 |
+
|
| 316 |
+
act_gain = self.act_gain * gain
|
| 317 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
| 318 |
+
x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
|
| 319 |
+
return x
|
| 320 |
+
|
| 321 |
+
def extra_repr(self):
|
| 322 |
+
return ' '.join([
|
| 323 |
+
f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},',
|
| 324 |
+
f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}'])
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
@persistence.persistent_class
|
| 328 |
+
class ToRGBLayer(torch.nn.Module):
|
| 329 |
+
def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
|
| 330 |
+
super().__init__()
|
| 331 |
+
self.in_channels = in_channels
|
| 332 |
+
self.out_channels = out_channels
|
| 333 |
+
self.w_dim = w_dim
|
| 334 |
+
self.conv_clamp = conv_clamp
|
| 335 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
| 336 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
| 337 |
+
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
| 338 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
| 339 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
| 340 |
+
|
| 341 |
+
def forward(self, x, w, fused_modconv=True):
|
| 342 |
+
styles = self.affine(w) * self.weight_gain
|
| 343 |
+
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
|
| 344 |
+
x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
|
| 345 |
+
return x
|
| 346 |
+
|
| 347 |
+
def extra_repr(self):
|
| 348 |
+
return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}'
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
@persistence.persistent_class
|
| 352 |
+
class SynthesisBlock(torch.nn.Module):
|
| 353 |
+
def __init__(self,
|
| 354 |
+
in_channels, # Number of input channels, 0 = first block.
|
| 355 |
+
out_channels, # Number of output channels.
|
| 356 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 357 |
+
resolution, # Resolution of this block.
|
| 358 |
+
img_channels, # Number of output color channels.
|
| 359 |
+
is_last, # Is this the last block?
|
| 360 |
+
architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
|
| 361 |
+
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
| 362 |
+
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 363 |
+
use_fp16 = False, # Use FP16 for this block?
|
| 364 |
+
fp16_channels_last = False, # Use channels-last memory format with FP16?
|
| 365 |
+
fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training.
|
| 366 |
+
**layer_kwargs, # Arguments for SynthesisLayer.
|
| 367 |
+
):
|
| 368 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
| 369 |
+
super().__init__()
|
| 370 |
+
self.in_channels = in_channels
|
| 371 |
+
self.w_dim = w_dim
|
| 372 |
+
self.resolution = resolution
|
| 373 |
+
self.img_channels = img_channels
|
| 374 |
+
self.is_last = is_last
|
| 375 |
+
self.architecture = architecture
|
| 376 |
+
self.use_fp16 = use_fp16
|
| 377 |
+
self.channels_last = (use_fp16 and fp16_channels_last)
|
| 378 |
+
self.fused_modconv_default = fused_modconv_default
|
| 379 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
| 380 |
+
self.num_conv = 0
|
| 381 |
+
self.num_torgb = 0
|
| 382 |
+
|
| 383 |
+
if in_channels == 0:
|
| 384 |
+
self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
|
| 385 |
+
|
| 386 |
+
if in_channels != 0:
|
| 387 |
+
self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
|
| 388 |
+
resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
|
| 389 |
+
self.num_conv += 1
|
| 390 |
+
|
| 391 |
+
self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
|
| 392 |
+
conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
|
| 393 |
+
self.num_conv += 1
|
| 394 |
+
|
| 395 |
+
if is_last or architecture == 'skip':
|
| 396 |
+
self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
|
| 397 |
+
conv_clamp=conv_clamp, channels_last=self.channels_last)
|
| 398 |
+
self.num_torgb += 1
|
| 399 |
+
|
| 400 |
+
if in_channels != 0 and architecture == 'resnet':
|
| 401 |
+
self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
|
| 402 |
+
resample_filter=resample_filter, channels_last=self.channels_last)
|
| 403 |
+
|
| 404 |
+
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):
|
| 405 |
+
_ = update_emas # unused
|
| 406 |
+
misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
|
| 407 |
+
w_iter = iter(ws.unbind(dim=1))
|
| 408 |
+
if ws.device.type != 'cuda':
|
| 409 |
+
force_fp32 = True
|
| 410 |
+
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
| 411 |
+
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
| 412 |
+
if fused_modconv is None:
|
| 413 |
+
fused_modconv = self.fused_modconv_default
|
| 414 |
+
if fused_modconv == 'inference_only':
|
| 415 |
+
fused_modconv = (not self.training)
|
| 416 |
+
|
| 417 |
+
# Input.
|
| 418 |
+
if self.in_channels == 0:
|
| 419 |
+
x = self.const.to(dtype=dtype, memory_format=memory_format)
|
| 420 |
+
x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
|
| 421 |
+
else:
|
| 422 |
+
misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
|
| 423 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
| 424 |
+
|
| 425 |
+
# Main layers.
|
| 426 |
+
if self.in_channels == 0:
|
| 427 |
+
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
| 428 |
+
elif self.architecture == 'resnet':
|
| 429 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
| 430 |
+
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
| 431 |
+
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
|
| 432 |
+
x = y.add_(x)
|
| 433 |
+
else:
|
| 434 |
+
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
| 435 |
+
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
| 436 |
+
|
| 437 |
+
# ToRGB.
|
| 438 |
+
if img is not None:
|
| 439 |
+
misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
|
| 440 |
+
img = upfirdn2d.upsample2d(img, self.resample_filter)
|
| 441 |
+
if self.is_last or self.architecture == 'skip':
|
| 442 |
+
y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
|
| 443 |
+
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
| 444 |
+
img = img.add_(y) if img is not None else y
|
| 445 |
+
|
| 446 |
+
assert x.dtype == dtype
|
| 447 |
+
assert img is None or img.dtype == torch.float32
|
| 448 |
+
return x, img
|
| 449 |
+
|
| 450 |
+
def extra_repr(self):
|
| 451 |
+
return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
@persistence.persistent_class
|
| 455 |
+
class SynthesisNetwork(torch.nn.Module):
|
| 456 |
+
def __init__(self,
|
| 457 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 458 |
+
img_resolution, # Output image resolution.
|
| 459 |
+
img_channels, # Number of color channels.
|
| 460 |
+
channel_base = 32768, # Overall multiplier for the number of channels.
|
| 461 |
+
channel_max = 512, # Maximum number of channels in any layer.
|
| 462 |
+
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
|
| 463 |
+
**block_kwargs, # Arguments for SynthesisBlock.
|
| 464 |
+
):
|
| 465 |
+
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
|
| 466 |
+
super().__init__()
|
| 467 |
+
self.w_dim = w_dim
|
| 468 |
+
self.img_resolution = img_resolution
|
| 469 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
| 470 |
+
self.img_channels = img_channels
|
| 471 |
+
self.num_fp16_res = num_fp16_res
|
| 472 |
+
self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
|
| 473 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
|
| 474 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
| 475 |
+
|
| 476 |
+
self.num_ws = 0
|
| 477 |
+
for res in self.block_resolutions:
|
| 478 |
+
in_channels = channels_dict[res // 2] if res > 4 else 0
|
| 479 |
+
out_channels = channels_dict[res]
|
| 480 |
+
use_fp16 = (res >= fp16_resolution)
|
| 481 |
+
is_last = (res == self.img_resolution)
|
| 482 |
+
block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
|
| 483 |
+
img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
|
| 484 |
+
self.num_ws += block.num_conv
|
| 485 |
+
if is_last:
|
| 486 |
+
self.num_ws += block.num_torgb
|
| 487 |
+
setattr(self, f'b{res}', block)
|
| 488 |
+
|
| 489 |
+
def forward(self, ws, c=None, **block_kwargs):
|
| 490 |
+
block_ws = []
|
| 491 |
+
with torch.autograd.profiler.record_function('split_ws'):
|
| 492 |
+
misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
|
| 493 |
+
ws = ws.to(torch.float32)
|
| 494 |
+
w_idx = 0
|
| 495 |
+
for res in self.block_resolutions:
|
| 496 |
+
block = getattr(self, f'b{res}')
|
| 497 |
+
block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
|
| 498 |
+
w_idx += block.num_conv
|
| 499 |
+
|
| 500 |
+
x = img = None
|
| 501 |
+
for res, cur_ws in zip(self.block_resolutions, block_ws):
|
| 502 |
+
block = getattr(self, f'b{res}')
|
| 503 |
+
x, img = block(x, img, cur_ws, **block_kwargs)
|
| 504 |
+
return img
|
| 505 |
+
|
| 506 |
+
def extra_repr(self):
|
| 507 |
+
return ' '.join([
|
| 508 |
+
f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
|
| 509 |
+
f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
|
| 510 |
+
f'num_fp16_res={self.num_fp16_res:d}'])
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
@persistence.persistent_class
|
| 514 |
+
class Generator(torch.nn.Module):
|
| 515 |
+
def __init__(self,
|
| 516 |
+
z_dim, # Input latent (Z) dimensionality.
|
| 517 |
+
c_dim, # Conditioning label (C) dimensionality.
|
| 518 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 519 |
+
img_resolution, # Output resolution.
|
| 520 |
+
img_channels, # Number of output color channels.
|
| 521 |
+
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
| 522 |
+
**synthesis_kwargs, # Arguments for SynthesisNetwork.
|
| 523 |
+
):
|
| 524 |
+
super().__init__()
|
| 525 |
+
self.z_dim = z_dim
|
| 526 |
+
self.c_dim = c_dim
|
| 527 |
+
self.w_dim = w_dim
|
| 528 |
+
self.img_resolution = img_resolution
|
| 529 |
+
self.img_channels = img_channels
|
| 530 |
+
self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
|
| 531 |
+
self.num_ws = self.synthesis.num_ws
|
| 532 |
+
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
|
| 533 |
+
|
| 534 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
|
| 535 |
+
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
|
| 536 |
+
img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
|
| 537 |
+
return img
|
pg_modules/projector.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from feature_networks.vit import forward_vit
|
| 5 |
+
from feature_networks.pretrained_builder import _make_pretrained
|
| 6 |
+
from feature_networks.constants import NORMALIZED_INCEPTION, NORMALIZED_IMAGENET, NORMALIZED_CLIP, VITS
|
| 7 |
+
from pg_modules.blocks import FeatureFusionBlock
|
| 8 |
+
|
| 9 |
+
def get_backbone_normstats(backbone):
|
| 10 |
+
if backbone in NORMALIZED_INCEPTION:
|
| 11 |
+
return {
|
| 12 |
+
'mean': [0.5, 0.5, 0.5],
|
| 13 |
+
'std': [0.5, 0.5, 0.5],
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
elif backbone in NORMALIZED_IMAGENET:
|
| 17 |
+
return {
|
| 18 |
+
'mean': [0.485, 0.456, 0.406],
|
| 19 |
+
'std': [0.229, 0.224, 0.225],
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
elif backbone in NORMALIZED_CLIP:
|
| 23 |
+
return {
|
| 24 |
+
'mean': [0.48145466, 0.4578275, 0.40821073],
|
| 25 |
+
'std': [0.26862954, 0.26130258, 0.27577711],
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
else:
|
| 29 |
+
raise NotImplementedError
|
| 30 |
+
|
| 31 |
+
def _make_scratch_ccm(scratch, in_channels, cout, expand=False):
|
| 32 |
+
# shapes
|
| 33 |
+
out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4
|
| 34 |
+
|
| 35 |
+
scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True)
|
| 36 |
+
scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True)
|
| 37 |
+
scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True)
|
| 38 |
+
scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True)
|
| 39 |
+
|
| 40 |
+
scratch.CHANNELS = out_channels
|
| 41 |
+
|
| 42 |
+
return scratch
|
| 43 |
+
|
| 44 |
+
def _make_scratch_csm(scratch, in_channels, cout, expand):
|
| 45 |
+
scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True)
|
| 46 |
+
scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand)
|
| 47 |
+
scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand)
|
| 48 |
+
scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False))
|
| 49 |
+
|
| 50 |
+
# last refinenet does not expand to save channels in higher dimensions
|
| 51 |
+
scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4
|
| 52 |
+
|
| 53 |
+
return scratch
|
| 54 |
+
|
| 55 |
+
def _make_projector(im_res, backbone, cout, proj_type, expand=False):
|
| 56 |
+
assert proj_type in [0, 1, 2], "Invalid projection type"
|
| 57 |
+
|
| 58 |
+
### Build pretrained feature network
|
| 59 |
+
pretrained = _make_pretrained(backbone)
|
| 60 |
+
|
| 61 |
+
# Following Projected GAN
|
| 62 |
+
im_res = 256
|
| 63 |
+
pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32]
|
| 64 |
+
|
| 65 |
+
if proj_type == 0: return pretrained, None
|
| 66 |
+
|
| 67 |
+
### Build CCM
|
| 68 |
+
scratch = nn.Module()
|
| 69 |
+
scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand)
|
| 70 |
+
|
| 71 |
+
pretrained.CHANNELS = scratch.CHANNELS
|
| 72 |
+
|
| 73 |
+
if proj_type == 1: return pretrained, scratch
|
| 74 |
+
|
| 75 |
+
### build CSM
|
| 76 |
+
scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand)
|
| 77 |
+
|
| 78 |
+
# CSM upsamples x2 so the feature map resolution doubles
|
| 79 |
+
pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS]
|
| 80 |
+
pretrained.CHANNELS = scratch.CHANNELS
|
| 81 |
+
|
| 82 |
+
return pretrained, scratch
|
| 83 |
+
|
| 84 |
+
class F_Identity(nn.Module):
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
class F_RandomProj(nn.Module):
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
backbone="tf_efficientnet_lite3",
|
| 92 |
+
im_res=256,
|
| 93 |
+
cout=64,
|
| 94 |
+
expand=True,
|
| 95 |
+
proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
|
| 96 |
+
**kwargs,
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.proj_type = proj_type
|
| 100 |
+
self.backbone = backbone
|
| 101 |
+
self.cout = cout
|
| 102 |
+
self.expand = expand
|
| 103 |
+
self.normstats = get_backbone_normstats(backbone)
|
| 104 |
+
|
| 105 |
+
# build pretrained feature network and random decoder (scratch)
|
| 106 |
+
self.pretrained, self.scratch = _make_projector(im_res=im_res, backbone=self.backbone, cout=self.cout,
|
| 107 |
+
proj_type=self.proj_type, expand=self.expand)
|
| 108 |
+
self.CHANNELS = self.pretrained.CHANNELS
|
| 109 |
+
self.RESOLUTIONS = self.pretrained.RESOLUTIONS
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
# predict feature maps
|
| 113 |
+
if self.backbone in VITS:
|
| 114 |
+
out0, out1, out2, out3 = forward_vit(self.pretrained, x)
|
| 115 |
+
else:
|
| 116 |
+
out0 = self.pretrained.layer0(x)
|
| 117 |
+
out1 = self.pretrained.layer1(out0)
|
| 118 |
+
out2 = self.pretrained.layer2(out1)
|
| 119 |
+
out3 = self.pretrained.layer3(out2)
|
| 120 |
+
|
| 121 |
+
# start enumerating at the lowest layer (this is where we put the first discriminator)
|
| 122 |
+
out = {
|
| 123 |
+
'0': out0,
|
| 124 |
+
'1': out1,
|
| 125 |
+
'2': out2,
|
| 126 |
+
'3': out3,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
if self.proj_type == 0: return out
|
| 130 |
+
|
| 131 |
+
out0_channel_mixed = self.scratch.layer0_ccm(out['0'])
|
| 132 |
+
out1_channel_mixed = self.scratch.layer1_ccm(out['1'])
|
| 133 |
+
out2_channel_mixed = self.scratch.layer2_ccm(out['2'])
|
| 134 |
+
out3_channel_mixed = self.scratch.layer3_ccm(out['3'])
|
| 135 |
+
|
| 136 |
+
out = {
|
| 137 |
+
'0': out0_channel_mixed,
|
| 138 |
+
'1': out1_channel_mixed,
|
| 139 |
+
'2': out2_channel_mixed,
|
| 140 |
+
'3': out3_channel_mixed,
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
if self.proj_type == 1: return out
|
| 144 |
+
|
| 145 |
+
# from bottom to top
|
| 146 |
+
out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed)
|
| 147 |
+
out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed)
|
| 148 |
+
out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed)
|
| 149 |
+
out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed)
|
| 150 |
+
|
| 151 |
+
out = {
|
| 152 |
+
'0': out0_scale_mixed,
|
| 153 |
+
'1': out1_scale_mixed,
|
| 154 |
+
'2': out2_scale_mixed,
|
| 155 |
+
'3': out3_scale_mixed,
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
return out
|