feizhengcong's picture
Upload 198 files
074c857
import os
import torch
# Decodes the image without passing through the upscaler. The resulting image will be the same size as the latent
# Thanks to Kevin Turner (https://github.com/keturn) we have a shortcut to look at the decoded image!
def make_linear_decode(model_version, device='cuda:0'):
v1_4_rgb_latent_factors = [
# R G B
[ 0.298, 0.207, 0.208], # L1
[ 0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
]
if model_version[:5] == "sd-v1":
rgb_latent_factors = torch.Tensor(v1_4_rgb_latent_factors).to(device)
else:
raise Exception(f"Model name {model_version} not recognized.")
def linear_decode(latent):
latent_image = latent.permute(0, 2, 3, 1) @ rgb_latent_factors
latent_image = latent_image.permute(0, 3, 1, 2)
return latent_image
return linear_decode
def load_model(root, load_on_run_all=True, check_sha256=True):
import requests
import torch
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from transformers import logging
logging.set_verbosity_error()
try:
ipy = get_ipython()
except:
ipy = 'could not get_ipython'
if 'google.colab' in str(ipy):
path_extend = "deforum-stable-diffusion"
else:
path_extend = ""
model_map = {
"512-base-ema.ckpt": {
'sha256': 'd635794c1fedfdfa261e065370bea59c651fc9bfa65dc6d67ad29e11869a1824',
'url': 'https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt',
'requires_login': True,
},
"v1-5-pruned.ckpt": {
'sha256': 'e1441589a6f3c5a53f5f54d0975a18a7feb7cdf0b0dee276dfc3331ae376a053',
'url': 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt',
'requires_login': True,
},
"v1-5-pruned-emaonly.ckpt": {
'sha256': 'cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516',
'url': 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt',
'requires_login': True,
},
"sd-v1-4-full-ema.ckpt": {
'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-4-full-ema.ckpt',
'requires_login': True,
},
"sd-v1-4.ckpt": {
'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt',
'requires_login': True,
},
"sd-v1-3-full-ema.ckpt": {
'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/blob/main/sd-v1-3-full-ema.ckpt',
'requires_login': True,
},
"sd-v1-3.ckpt": {
'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/resolve/main/sd-v1-3.ckpt',
'requires_login': True,
},
"sd-v1-2-full-ema.ckpt": {
'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-2-full-ema.ckpt',
'requires_login': True,
},
"sd-v1-2.ckpt": {
'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/resolve/main/sd-v1-2.ckpt',
'requires_login': True,
},
"sd-v1-1-full-ema.ckpt": {
'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829',
'url':'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1-full-ema.ckpt',
'requires_login': True,
},
"sd-v1-1.ckpt": {
'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1.ckpt',
'requires_login': True,
},
"robo-diffusion-v1.ckpt": {
'sha256': '244dbe0dcb55c761bde9c2ac0e9b46cc9705ebfe5f1f3a7cc46251573ea14e16',
'url': 'https://huggingface.co/nousr/robo-diffusion/resolve/main/models/robo-diffusion-v1.ckpt',
'requires_login': False,
},
"wd-v1-3-float16.ckpt": {
'sha256': '4afab9126057859b34d13d6207d90221d0b017b7580469ea70cee37757a29edd',
'url': 'https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt',
'requires_login': False,
},
}
# config path
ckpt_config_path = root.custom_config_path if root.model_config == "custom" else os.path.join(root.configs_path, root.model_config)
if os.path.exists(ckpt_config_path):
print(f"{ckpt_config_path} exists")
else:
print(f"Warning: {ckpt_config_path} does not exist.")
ckpt_config_path = os.path.join(path_extend,"configs",root.model_config)
print(f"Using {ckpt_config_path} instead.")
ckpt_config_path = os.path.abspath(ckpt_config_path)
# checkpoint path or download
ckpt_path = root.custom_checkpoint_path if root.model_checkpoint == "custom" else os.path.join(root.models_path, root.model_checkpoint)
ckpt_valid = True
if os.path.exists(ckpt_path):
pass
elif 'url' in model_map[root.model_checkpoint]:
url = model_map[root.model_checkpoint]['url']
# CLI dialogue to authenticate download
if model_map[root.model_checkpoint]['requires_login']:
print("This model requires an authentication token")
print("Please ensure you have accepted the terms of service before continuing.")
username = input("[What is your huggingface username?]: ")
token = input("[What is your huggingface token?]: ")
_, path = url.split("https://")
url = f"https://{username}:{token}@{path}"
# contact server for model
print(f"..attempting to download {root.model_checkpoint}...this may take a while")
ckpt_request = requests.get(url)
request_status = ckpt_request.status_code
# inform user of errors
if request_status == 403:
raise ConnectionRefusedError("You have not accepted the license for this model.")
elif request_status == 404:
raise ConnectionError("Could not make contact with server")
elif request_status != 200:
raise ConnectionError(f"Some other error has ocurred - response code: {request_status}")
# write to model path
with open(os.path.join(root.models_path, root.model_checkpoint), 'wb') as model_file:
model_file.write(ckpt_request.content)
else:
print(f"Please download model checkpoint and place in {os.path.join(root.models_path, root.model_checkpoint)}")
ckpt_valid = False
print(f"config_path: {ckpt_config_path}")
print(f"ckpt_path: {ckpt_path}")
if check_sha256 and root.model_checkpoint != "custom" and ckpt_valid:
try:
import hashlib
print("..checking sha256")
with open(ckpt_path, "rb") as f:
bytes = f.read()
hash = hashlib.sha256(bytes).hexdigest()
del bytes
if model_map[root.model_checkpoint]["sha256"] == hash:
print("..hash is correct")
else:
print("..hash in not correct")
ckpt_valid = False
except:
print("..could not verify model integrity")
def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True,print_flag=False):
map_location = "cuda" # ["cpu", "cuda"]
print(f"..loading model")
pl_sd = torch.load(ckpt, map_location=map_location)
if "global_step" in pl_sd:
if print_flag:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if print_flag:
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if half_precision:
model = model.half().to(device)
else:
model = model.to(device)
model.eval()
return model
if load_on_run_all and ckpt_valid:
local_config = OmegaConf.load(f"{ckpt_config_path}")
model = load_model_from_config(local_config, f"{ckpt_path}", half_precision=root.half_precision)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
autoencoder_version = "sd-v1" #TODO this will be different for different models
model.linear_decode = make_linear_decode(autoencoder_version, device)
return model, device
def get_model_output_paths(root):
models_path = root.models_path
output_path = root.output_path
#@markdown **Google Drive Path Variables (Optional)**
force_remount = False
try:
ipy = get_ipython()
except:
ipy = 'could not get_ipython'
if 'google.colab' in str(ipy):
if root.mount_google_drive:
from google.colab import drive # type: ignore
try:
drive_path = "/content/drive"
drive.mount(drive_path,force_remount=force_remount)
models_path = root.models_path_gdrive
output_path = root.output_path_gdrive
except:
print("..error mounting drive or with drive path variables")
print("..reverting to default path variables")
models_path = os.path.abspath(models_path)
output_path = os.path.abspath(output_path)
os.makedirs(models_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)
print(f"models_path: {models_path}")
print(f"output_path: {output_path}")
return models_path, output_path