|
|
import os |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
def make_linear_decode(model_version, device='cuda:0'): |
|
|
v1_4_rgb_latent_factors = [ |
|
|
|
|
|
[ 0.298, 0.207, 0.208], |
|
|
[ 0.187, 0.286, 0.173], |
|
|
[-0.158, 0.189, 0.264], |
|
|
[-0.184, -0.271, -0.473], |
|
|
] |
|
|
|
|
|
if model_version[:5] == "sd-v1": |
|
|
rgb_latent_factors = torch.Tensor(v1_4_rgb_latent_factors).to(device) |
|
|
else: |
|
|
raise Exception(f"Model name {model_version} not recognized.") |
|
|
|
|
|
def linear_decode(latent): |
|
|
latent_image = latent.permute(0, 2, 3, 1) @ rgb_latent_factors |
|
|
latent_image = latent_image.permute(0, 3, 1, 2) |
|
|
return latent_image |
|
|
|
|
|
return linear_decode |
|
|
|
|
|
def load_model(root, load_on_run_all=True, check_sha256=True): |
|
|
|
|
|
import requests |
|
|
import torch |
|
|
from ldm.util import instantiate_from_config |
|
|
from omegaconf import OmegaConf |
|
|
from transformers import logging |
|
|
logging.set_verbosity_error() |
|
|
|
|
|
try: |
|
|
ipy = get_ipython() |
|
|
except: |
|
|
ipy = 'could not get_ipython' |
|
|
|
|
|
if 'google.colab' in str(ipy): |
|
|
path_extend = "deforum-stable-diffusion" |
|
|
else: |
|
|
path_extend = "" |
|
|
|
|
|
model_map = { |
|
|
"512-base-ema.ckpt": { |
|
|
'sha256': 'd635794c1fedfdfa261e065370bea59c651fc9bfa65dc6d67ad29e11869a1824', |
|
|
'url': 'https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt', |
|
|
'requires_login': True, |
|
|
}, |
|
|
"v1-5-pruned.ckpt": { |
|
|
'sha256': 'e1441589a6f3c5a53f5f54d0975a18a7feb7cdf0b0dee276dfc3331ae376a053', |
|
|
'url': 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt', |
|
|
'requires_login': True, |
|
|
}, |
|
|
"v1-5-pruned-emaonly.ckpt": { |
|
|
'sha256': 'cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516', |
|
|
'url': 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt', |
|
|
'requires_login': True, |
|
|
}, |
|
|
"sd-v1-4-full-ema.ckpt": { |
|
|
'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a', |
|
|
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-4-full-ema.ckpt', |
|
|
'requires_login': True, |
|
|
}, |
|
|
"sd-v1-4.ckpt": { |
|
|
'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556', |
|
|
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', |
|
|
'requires_login': True, |
|
|
}, |
|
|
"sd-v1-3-full-ema.ckpt": { |
|
|
'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca', |
|
|
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/blob/main/sd-v1-3-full-ema.ckpt', |
|
|
'requires_login': True, |
|
|
}, |
|
|
"sd-v1-3.ckpt": { |
|
|
'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f', |
|
|
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/resolve/main/sd-v1-3.ckpt', |
|
|
'requires_login': True, |
|
|
}, |
|
|
"sd-v1-2-full-ema.ckpt": { |
|
|
'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a', |
|
|
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-2-full-ema.ckpt', |
|
|
'requires_login': True, |
|
|
}, |
|
|
"sd-v1-2.ckpt": { |
|
|
'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d', |
|
|
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/resolve/main/sd-v1-2.ckpt', |
|
|
'requires_login': True, |
|
|
}, |
|
|
"sd-v1-1-full-ema.ckpt": { |
|
|
'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829', |
|
|
'url':'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1-full-ema.ckpt', |
|
|
'requires_login': True, |
|
|
}, |
|
|
"sd-v1-1.ckpt": { |
|
|
'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea', |
|
|
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1.ckpt', |
|
|
'requires_login': True, |
|
|
}, |
|
|
"robo-diffusion-v1.ckpt": { |
|
|
'sha256': '244dbe0dcb55c761bde9c2ac0e9b46cc9705ebfe5f1f3a7cc46251573ea14e16', |
|
|
'url': 'https://huggingface.co/nousr/robo-diffusion/resolve/main/models/robo-diffusion-v1.ckpt', |
|
|
'requires_login': False, |
|
|
}, |
|
|
"wd-v1-3-float16.ckpt": { |
|
|
'sha256': '4afab9126057859b34d13d6207d90221d0b017b7580469ea70cee37757a29edd', |
|
|
'url': 'https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt', |
|
|
'requires_login': False, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
ckpt_config_path = root.custom_config_path if root.model_config == "custom" else os.path.join(root.configs_path, root.model_config) |
|
|
|
|
|
if os.path.exists(ckpt_config_path): |
|
|
print(f"{ckpt_config_path} exists") |
|
|
else: |
|
|
print(f"Warning: {ckpt_config_path} does not exist.") |
|
|
ckpt_config_path = os.path.join(path_extend,"configs",root.model_config) |
|
|
print(f"Using {ckpt_config_path} instead.") |
|
|
|
|
|
ckpt_config_path = os.path.abspath(ckpt_config_path) |
|
|
|
|
|
|
|
|
ckpt_path = root.custom_checkpoint_path if root.model_checkpoint == "custom" else os.path.join(root.models_path, root.model_checkpoint) |
|
|
ckpt_valid = True |
|
|
|
|
|
if os.path.exists(ckpt_path): |
|
|
pass |
|
|
elif 'url' in model_map[root.model_checkpoint]: |
|
|
url = model_map[root.model_checkpoint]['url'] |
|
|
|
|
|
|
|
|
if model_map[root.model_checkpoint]['requires_login']: |
|
|
print("This model requires an authentication token") |
|
|
print("Please ensure you have accepted the terms of service before continuing.") |
|
|
|
|
|
username = input("[What is your huggingface username?]: ") |
|
|
token = input("[What is your huggingface token?]: ") |
|
|
|
|
|
_, path = url.split("https://") |
|
|
|
|
|
url = f"https://{username}:{token}@{path}" |
|
|
|
|
|
|
|
|
print(f"..attempting to download {root.model_checkpoint}...this may take a while") |
|
|
ckpt_request = requests.get(url) |
|
|
request_status = ckpt_request.status_code |
|
|
|
|
|
|
|
|
if request_status == 403: |
|
|
raise ConnectionRefusedError("You have not accepted the license for this model.") |
|
|
elif request_status == 404: |
|
|
raise ConnectionError("Could not make contact with server") |
|
|
elif request_status != 200: |
|
|
raise ConnectionError(f"Some other error has ocurred - response code: {request_status}") |
|
|
|
|
|
|
|
|
with open(os.path.join(root.models_path, root.model_checkpoint), 'wb') as model_file: |
|
|
model_file.write(ckpt_request.content) |
|
|
else: |
|
|
print(f"Please download model checkpoint and place in {os.path.join(root.models_path, root.model_checkpoint)}") |
|
|
ckpt_valid = False |
|
|
|
|
|
print(f"config_path: {ckpt_config_path}") |
|
|
print(f"ckpt_path: {ckpt_path}") |
|
|
|
|
|
if check_sha256 and root.model_checkpoint != "custom" and ckpt_valid: |
|
|
try: |
|
|
import hashlib |
|
|
print("..checking sha256") |
|
|
with open(ckpt_path, "rb") as f: |
|
|
bytes = f.read() |
|
|
hash = hashlib.sha256(bytes).hexdigest() |
|
|
del bytes |
|
|
if model_map[root.model_checkpoint]["sha256"] == hash: |
|
|
print("..hash is correct") |
|
|
else: |
|
|
print("..hash in not correct") |
|
|
ckpt_valid = False |
|
|
except: |
|
|
print("..could not verify model integrity") |
|
|
|
|
|
def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True,print_flag=False): |
|
|
map_location = "cuda" |
|
|
print(f"..loading model") |
|
|
pl_sd = torch.load(ckpt, map_location=map_location) |
|
|
if "global_step" in pl_sd: |
|
|
if print_flag: |
|
|
print(f"Global Step: {pl_sd['global_step']}") |
|
|
sd = pl_sd["state_dict"] |
|
|
model = instantiate_from_config(config.model) |
|
|
m, u = model.load_state_dict(sd, strict=False) |
|
|
if print_flag: |
|
|
if len(m) > 0 and verbose: |
|
|
print("missing keys:") |
|
|
print(m) |
|
|
if len(u) > 0 and verbose: |
|
|
print("unexpected keys:") |
|
|
print(u) |
|
|
|
|
|
if half_precision: |
|
|
model = model.half().to(device) |
|
|
else: |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
if load_on_run_all and ckpt_valid: |
|
|
local_config = OmegaConf.load(f"{ckpt_config_path}") |
|
|
model = load_model_from_config(local_config, f"{ckpt_path}", half_precision=root.half_precision) |
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
model = model.to(device) |
|
|
|
|
|
autoencoder_version = "sd-v1" |
|
|
model.linear_decode = make_linear_decode(autoencoder_version, device) |
|
|
|
|
|
return model, device |
|
|
|
|
|
|
|
|
def get_model_output_paths(root): |
|
|
|
|
|
models_path = root.models_path |
|
|
output_path = root.output_path |
|
|
|
|
|
|
|
|
|
|
|
force_remount = False |
|
|
|
|
|
try: |
|
|
ipy = get_ipython() |
|
|
except: |
|
|
ipy = 'could not get_ipython' |
|
|
|
|
|
if 'google.colab' in str(ipy): |
|
|
if root.mount_google_drive: |
|
|
from google.colab import drive |
|
|
try: |
|
|
drive_path = "/content/drive" |
|
|
drive.mount(drive_path,force_remount=force_remount) |
|
|
models_path = root.models_path_gdrive |
|
|
output_path = root.output_path_gdrive |
|
|
except: |
|
|
print("..error mounting drive or with drive path variables") |
|
|
print("..reverting to default path variables") |
|
|
|
|
|
models_path = os.path.abspath(models_path) |
|
|
output_path = os.path.abspath(output_path) |
|
|
os.makedirs(models_path, exist_ok=True) |
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
|
|
|
print(f"models_path: {models_path}") |
|
|
print(f"output_path: {output_path}") |
|
|
|
|
|
return models_path, output_path |
|
|
|