|
|
import importlib
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from tqdm import tqdm
|
|
|
from inspect import isfunction
|
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
import hashlib
|
|
|
import requests
|
|
|
import os
|
|
|
|
|
|
URL_MAP = {
|
|
|
'vggishish_lpaps': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt',
|
|
|
'vggishish_mean_std_melspec_10s_22050hz': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt',
|
|
|
'melception': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt',
|
|
|
}
|
|
|
|
|
|
CKPT_MAP = {
|
|
|
'vggishish_lpaps': 'vggishish16.pt',
|
|
|
'vggishish_mean_std_melspec_10s_22050hz': 'train_means_stds_melspec_10s_22050hz.txt',
|
|
|
'melception': 'melception-21-05-10T09-28-40.pt',
|
|
|
}
|
|
|
|
|
|
MD5_MAP = {
|
|
|
'vggishish_lpaps': '197040c524a07ccacf7715d7080a80bd',
|
|
|
'vggishish_mean_std_melspec_10s_22050hz': 'f449c6fd0e248936c16f6d22492bb625',
|
|
|
'melception': 'a71a41041e945b457c7d3d814bbcf72d',
|
|
|
}
|
|
|
|
|
|
|
|
|
def download(url, local_path, chunk_size=1024):
|
|
|
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
|
|
with requests.get(url, stream=True) as r:
|
|
|
total_size = int(r.headers.get("content-length", 0))
|
|
|
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
|
|
with open(local_path, "wb") as f:
|
|
|
for data in r.iter_content(chunk_size=chunk_size):
|
|
|
if data:
|
|
|
f.write(data)
|
|
|
pbar.update(chunk_size)
|
|
|
|
|
|
|
|
|
def md5_hash(path):
|
|
|
with open(path, "rb") as f:
|
|
|
content = f.read()
|
|
|
return hashlib.md5(content).hexdigest()
|
|
|
|
|
|
|
|
|
|
|
|
def log_txt_as_img(wh, xc, size=10):
|
|
|
|
|
|
b = len(xc)
|
|
|
txts = list()
|
|
|
for bi in range(b):
|
|
|
txt = Image.new("RGB", wh, color="white")
|
|
|
draw = ImageDraw.Draw(txt)
|
|
|
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
|
|
nc = int(40 * (wh[0] / 256))
|
|
|
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
|
|
|
|
|
try:
|
|
|
draw.text((0, 0), lines, fill="black", font=font)
|
|
|
except UnicodeEncodeError:
|
|
|
print("Cant encode string for logging. Skipping.")
|
|
|
|
|
|
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
|
|
txts.append(txt)
|
|
|
txts = np.stack(txts)
|
|
|
txts = torch.tensor(txts)
|
|
|
return txts
|
|
|
|
|
|
|
|
|
def ismap(x):
|
|
|
if not isinstance(x, torch.Tensor):
|
|
|
return False
|
|
|
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
|
|
|
|
|
|
|
|
def isimage(x):
|
|
|
if not isinstance(x,torch.Tensor):
|
|
|
return False
|
|
|
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
|
|
|
|
|
|
|
|
def exists(x):
|
|
|
return x is not None
|
|
|
|
|
|
|
|
|
def default(val, d):
|
|
|
if exists(val):
|
|
|
return val
|
|
|
return d() if isfunction(d) else d
|
|
|
|
|
|
|
|
|
def mean_flat(tensor):
|
|
|
"""
|
|
|
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
|
|
Take the mean over all non-batch dimensions.
|
|
|
"""
|
|
|
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
|
|
|
|
|
|
|
|
def count_params(model, verbose=False):
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
|
if verbose:
|
|
|
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
|
|
return total_params
|
|
|
|
|
|
|
|
|
def instantiate_from_config(config,reload=False):
|
|
|
if not "target" in config:
|
|
|
if config == '__is_first_stage__':
|
|
|
return None
|
|
|
elif config == "__is_unconditional__":
|
|
|
return None
|
|
|
raise KeyError("Expected key `target` to instantiate.")
|
|
|
return get_obj_from_str(config["target"],reload=reload)(**config.get("params", dict()))
|
|
|
|
|
|
|
|
|
def get_obj_from_str(string, reload=False):
|
|
|
module, cls = string.rsplit(".", 1)
|
|
|
if reload:
|
|
|
module_imp = importlib.import_module(module)
|
|
|
importlib.reload(module_imp)
|
|
|
return getattr(importlib.import_module(module, package=None), cls)
|
|
|
|
|
|
def get_ckpt_path(name, root, check=False):
|
|
|
assert name in URL_MAP
|
|
|
path = os.path.join(root, CKPT_MAP[name])
|
|
|
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
|
|
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
|
|
download(URL_MAP[name], path)
|
|
|
md5 = md5_hash(path)
|
|
|
assert md5 == MD5_MAP[name], md5
|
|
|
return path
|
|
|
|
|
|
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True):
|
|
|
if os.path.isfile(ckpt_base_dir):
|
|
|
base_dir = os.path.dirname(ckpt_base_dir)
|
|
|
ckpt_path = ckpt_base_dir
|
|
|
checkpoint = torch.load(ckpt_base_dir, map_location='cpu')
|
|
|
else:
|
|
|
base_dir = ckpt_base_dir
|
|
|
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir)
|
|
|
if checkpoint is not None:
|
|
|
state_dict = checkpoint["state_dict"]
|
|
|
if len([k for k in state_dict.keys() if '.' in k]) > 0:
|
|
|
state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items()
|
|
|
if k.startswith(f'{model_name}.')}
|
|
|
else:
|
|
|
if '.' not in model_name:
|
|
|
state_dict = state_dict[model_name]
|
|
|
else:
|
|
|
base_model_name = model_name.split('.')[0]
|
|
|
rest_model_name = model_name[len(base_model_name) + 1:]
|
|
|
state_dict = {
|
|
|
k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items()
|
|
|
if k.startswith(f'{rest_model_name}.')}
|
|
|
if not strict:
|
|
|
cur_model_state_dict = cur_model.state_dict()
|
|
|
unmatched_keys = []
|
|
|
for key, param in state_dict.items():
|
|
|
if key in cur_model_state_dict:
|
|
|
new_param = cur_model_state_dict[key]
|
|
|
if new_param.shape != param.shape:
|
|
|
unmatched_keys.append(key)
|
|
|
print("| Unmatched keys: ", key, new_param.shape, param.shape)
|
|
|
for key in unmatched_keys:
|
|
|
del state_dict[key]
|
|
|
cur_model.load_state_dict(state_dict, strict=strict)
|
|
|
print(f"| load '{model_name}' from '{ckpt_path}'.")
|
|
|
else:
|
|
|
e_msg = f"| ckpt not found in {base_dir}."
|
|
|
if force:
|
|
|
assert False, e_msg
|
|
|
else:
|
|
|
print(e_msg)
|
|
|
|
|
|
def randn_tensor(
|
|
|
shape: Union[Tuple, List],
|
|
|
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
|
|
|
device: Optional["torch.device"] = None,
|
|
|
dtype: Optional["torch.dtype"] = None,
|
|
|
layout: Optional["torch.layout"] = None,
|
|
|
):
|
|
|
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
|
|
|
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
|
|
|
is always created on the CPU.
|
|
|
"""
|
|
|
|
|
|
rand_device = device
|
|
|
batch_size = shape[0]
|
|
|
|
|
|
layout = layout or torch.strided
|
|
|
device = device or torch.device("cpu")
|
|
|
|
|
|
if generator is not None:
|
|
|
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
|
|
|
if gen_device_type != device.type and gen_device_type == "cpu":
|
|
|
rand_device = "cpu"
|
|
|
if device != "mps":
|
|
|
logger.info(
|
|
|
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
|
|
|
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
|
|
|
f" slighly speed up this function by passing a generator that was created on the {device} device."
|
|
|
)
|
|
|
elif gen_device_type != device.type and gen_device_type == "cuda":
|
|
|
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
|
|
|
|
|
|
|
|
|
if isinstance(generator, list) and len(generator) == 1:
|
|
|
generator = generator[0]
|
|
|
|
|
|
if isinstance(generator, list):
|
|
|
shape = (1,) + shape[1:]
|
|
|
latents = [
|
|
|
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
|
|
|
for i in range(batch_size)
|
|
|
]
|
|
|
latents = torch.cat(latents, dim=0).to(device)
|
|
|
else:
|
|
|
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
|
|
|
|
|
|
return latents |