smzerbe commited on
Commit
d034647
·
verified ·
1 Parent(s): 3e15aef

Update modules/sd_models.py

Browse files
Files changed (1) hide show
  1. modules/sd_models.py +1037 -1034
modules/sd_models.py CHANGED
@@ -1,1034 +1,1037 @@
1
- import collections
2
- import importlib
3
- import os
4
- import sys
5
- import threading
6
- import enum
7
-
8
- import torch
9
- import re
10
- import safetensors.torch
11
- from omegaconf import OmegaConf, ListConfig
12
- from urllib import request
13
- import ldm.modules.midas as midas
14
-
15
- from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
16
- from modules.timer import Timer
17
- from modules.shared import opts
18
- import tomesd
19
- import numpy as np
20
-
21
- model_dir = "Stable-diffusion"
22
- model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
23
-
24
- checkpoints_list = {}
25
- checkpoint_aliases = {}
26
- checkpoint_alisases = checkpoint_aliases # for compatibility with old name
27
- checkpoints_loaded = collections.OrderedDict()
28
-
29
-
30
- class ModelType(enum.Enum):
31
- SD1 = 1
32
- SD2 = 2
33
- SDXL = 3
34
- SSD = 4
35
- SD3 = 5
36
-
37
-
38
- def replace_key(d, key, new_key, value):
39
- keys = list(d.keys())
40
-
41
- d[new_key] = value
42
-
43
- if key not in keys:
44
- return d
45
-
46
- index = keys.index(key)
47
- keys[index] = new_key
48
-
49
- new_d = {k: d[k] for k in keys}
50
-
51
- d.clear()
52
- d.update(new_d)
53
- return d
54
-
55
-
56
- class CheckpointInfo:
57
- def __init__(self, filename):
58
- self.filename = filename
59
- abspath = os.path.abspath(filename)
60
- abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
61
-
62
- self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
63
-
64
- if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
65
- name = abspath.replace(abs_ckpt_dir, '')
66
- elif abspath.startswith(model_path):
67
- name = abspath.replace(model_path, '')
68
- else:
69
- name = os.path.basename(filename)
70
-
71
- if name.startswith("\\") or name.startswith("/"):
72
- name = name[1:]
73
-
74
- def read_metadata():
75
- metadata = read_metadata_from_safetensors(filename)
76
- self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
77
-
78
- return metadata
79
-
80
- self.metadata = {}
81
- if self.is_safetensors:
82
- try:
83
- self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
84
- except Exception as e:
85
- errors.display(e, f"reading metadata for {filename}")
86
-
87
- self.name = name
88
- self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
89
- self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
90
- self.hash = model_hash(filename)
91
-
92
- self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
93
- self.shorthash = self.sha256[0:10] if self.sha256 else None
94
-
95
- self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
96
- self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
97
-
98
- self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
99
- if self.shorthash:
100
- self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
101
-
102
- def register(self):
103
- checkpoints_list[self.title] = self
104
- for id in self.ids:
105
- checkpoint_aliases[id] = self
106
-
107
- def calculate_shorthash(self):
108
- self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
109
- if self.sha256 is None:
110
- return
111
-
112
- shorthash = self.sha256[0:10]
113
- if self.shorthash == self.sha256[0:10]:
114
- return self.shorthash
115
-
116
- self.shorthash = shorthash
117
-
118
- if self.shorthash not in self.ids:
119
- self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
120
-
121
- old_title = self.title
122
- self.title = f'{self.name} [{self.shorthash}]'
123
- self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
124
-
125
- replace_key(checkpoints_list, old_title, self.title, self)
126
- self.register()
127
-
128
- return self.shorthash
129
-
130
-
131
- try:
132
- # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
133
- from transformers import logging, CLIPModel # noqa: F401
134
-
135
- logging.set_verbosity_error()
136
- except Exception:
137
- pass
138
-
139
-
140
- def setup_model():
141
- """called once at startup to do various one-time tasks related to SD models"""
142
-
143
- os.makedirs(model_path, exist_ok=True)
144
-
145
- enable_midas_autodownload()
146
- patch_given_betas()
147
-
148
-
149
- def checkpoint_tiles(use_short=False):
150
- return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
151
-
152
-
153
- def list_models():
154
- checkpoints_list.clear()
155
- checkpoint_aliases.clear()
156
-
157
- cmd_ckpt = shared.cmd_opts.ckpt
158
- if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
159
- model_url = None
160
- expected_sha256 = None
161
- else:
162
- model_url = f"{shared.hf_endpoint}/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
163
- expected_sha256 = '6ce0161689b3853acaa03779ec93eafe75a02f4ced659bee03f50797806fa2fa'
164
-
165
- model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"], hash_prefix=expected_sha256)
166
-
167
- if os.path.exists(cmd_ckpt):
168
- checkpoint_info = CheckpointInfo(cmd_ckpt)
169
- checkpoint_info.register()
170
-
171
- shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
172
- elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
173
- print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
174
-
175
- for filename in model_list:
176
- checkpoint_info = CheckpointInfo(filename)
177
- checkpoint_info.register()
178
-
179
-
180
- re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
181
-
182
-
183
- def get_closet_checkpoint_match(search_string):
184
- if not search_string:
185
- return None
186
-
187
- checkpoint_info = checkpoint_aliases.get(search_string, None)
188
- if checkpoint_info is not None:
189
- return checkpoint_info
190
-
191
- found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
192
- if found:
193
- return found[0]
194
-
195
- search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
196
- found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
197
- if found:
198
- return found[0]
199
-
200
- return None
201
-
202
-
203
- def model_hash(filename):
204
- """old hash that only looks at a small part of the file and is prone to collisions"""
205
-
206
- try:
207
- with open(filename, "rb") as file:
208
- import hashlib
209
- m = hashlib.sha256()
210
-
211
- file.seek(0x100000)
212
- m.update(file.read(0x10000))
213
- return m.hexdigest()[0:8]
214
- except FileNotFoundError:
215
- return 'NOFILE'
216
-
217
-
218
- def select_checkpoint():
219
- """Raises `FileNotFoundError` if no checkpoints are found."""
220
- model_checkpoint = shared.opts.sd_model_checkpoint
221
-
222
- checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
223
- if checkpoint_info is not None:
224
- return checkpoint_info
225
-
226
- if len(checkpoints_list) == 0:
227
- error_message = "No checkpoints found. When searching for checkpoints, looked at:"
228
- if shared.cmd_opts.ckpt is not None:
229
- error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
230
- error_message += f"\n - directory {model_path}"
231
- if shared.cmd_opts.ckpt_dir is not None:
232
- error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
233
- error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
234
- raise FileNotFoundError(error_message)
235
-
236
- checkpoint_info = next(iter(checkpoints_list.values()))
237
- if model_checkpoint is not None:
238
- print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
239
-
240
- return checkpoint_info
241
-
242
-
243
- checkpoint_dict_replacements_sd1 = {
244
- 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
245
- 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
246
- 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
247
- }
248
-
249
- checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
250
- 'conditioner.embedders.0.': 'cond_stage_model.',
251
- }
252
-
253
-
254
- def transform_checkpoint_dict_key(k, replacements):
255
- for text, replacement in replacements.items():
256
- if k.startswith(text):
257
- k = replacement + k[len(text):]
258
-
259
- return k
260
-
261
-
262
- def get_state_dict_from_checkpoint(pl_sd):
263
- pl_sd = pl_sd.pop("state_dict", pl_sd)
264
- pl_sd.pop("state_dict", None)
265
-
266
- is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
267
-
268
- sd = {}
269
- for k, v in pl_sd.items():
270
- if is_sd2_turbo:
271
- new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
272
- else:
273
- new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
274
-
275
- if new_key is not None:
276
- sd[new_key] = v
277
-
278
- pl_sd.clear()
279
- pl_sd.update(sd)
280
-
281
- return pl_sd
282
-
283
-
284
- def read_metadata_from_safetensors(filename):
285
- import json
286
-
287
- with open(filename, mode="rb") as file:
288
- metadata_len = file.read(8)
289
- metadata_len = int.from_bytes(metadata_len, "little")
290
- json_start = file.read(2)
291
-
292
- assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
293
-
294
- res = {}
295
-
296
- try:
297
- json_data = json_start + file.read(metadata_len-2)
298
- json_obj = json.loads(json_data)
299
- for k, v in json_obj.get("__metadata__", {}).items():
300
- res[k] = v
301
- if isinstance(v, str) and v[0:1] == '{':
302
- try:
303
- res[k] = json.loads(v)
304
- except Exception:
305
- pass
306
- except Exception:
307
- errors.report(f"Error reading metadata from file: {filename}", exc_info=True)
308
-
309
- return res
310
-
311
-
312
- def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
313
- _, extension = os.path.splitext(checkpoint_file)
314
- if extension.lower() == ".safetensors":
315
- device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
316
-
317
- if not shared.opts.disable_mmap_load_safetensors:
318
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
319
- else:
320
- pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
321
- pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
322
- else:
323
- pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
324
-
325
- if print_global_state and "global_step" in pl_sd:
326
- print(f"Global Step: {pl_sd['global_step']}")
327
-
328
- sd = get_state_dict_from_checkpoint(pl_sd)
329
- return sd
330
-
331
-
332
- def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
333
- sd_model_hash = checkpoint_info.calculate_shorthash()
334
- timer.record("calculate hash")
335
-
336
- if checkpoint_info in checkpoints_loaded:
337
- # use checkpoint cache
338
- print(f"Loading weights [{sd_model_hash}] from cache")
339
- # move to end as latest
340
- checkpoints_loaded.move_to_end(checkpoint_info)
341
- return checkpoints_loaded[checkpoint_info]
342
-
343
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
344
- res = read_state_dict(checkpoint_info.filename)
345
- timer.record("load weights from disk")
346
-
347
- return res
348
-
349
-
350
- class SkipWritingToConfig:
351
- """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
352
-
353
- skip = False
354
- previous = None
355
-
356
- def __enter__(self):
357
- self.previous = SkipWritingToConfig.skip
358
- SkipWritingToConfig.skip = True
359
- return self
360
-
361
- def __exit__(self, exc_type, exc_value, exc_traceback):
362
- SkipWritingToConfig.skip = self.previous
363
-
364
-
365
- def check_fp8(model):
366
- if model is None:
367
- return None
368
- if devices.get_optimal_device_name() == "mps":
369
- enable_fp8 = False
370
- elif shared.opts.fp8_storage == "Enable":
371
- enable_fp8 = True
372
- elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
373
- enable_fp8 = True
374
- else:
375
- enable_fp8 = False
376
- return enable_fp8
377
-
378
-
379
- def set_model_type(model, state_dict):
380
- model.is_sd1 = False
381
- model.is_sd2 = False
382
- model.is_sdxl = False
383
- model.is_ssd = False
384
- model.is_sd3 = False
385
-
386
- if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
387
- model.is_sd3 = True
388
- model.model_type = ModelType.SD3
389
- elif hasattr(model, 'conditioner'):
390
- model.is_sdxl = True
391
-
392
- if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
393
- model.is_ssd = True
394
- model.model_type = ModelType.SSD
395
- else:
396
- model.model_type = ModelType.SDXL
397
- elif hasattr(model.cond_stage_model, 'model'):
398
- model.is_sd2 = True
399
- model.model_type = ModelType.SD2
400
- else:
401
- model.is_sd1 = True
402
- model.model_type = ModelType.SD1
403
-
404
-
405
- def set_model_fields(model):
406
- if not hasattr(model, 'latent_channels'):
407
- model.latent_channels = 4
408
-
409
-
410
- def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
411
- sd_model_hash = checkpoint_info.calculate_shorthash()
412
- timer.record("calculate hash")
413
-
414
- if devices.fp8:
415
- # prevent model to load state dict in fp8
416
- model.half()
417
-
418
- if not SkipWritingToConfig.skip:
419
- shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
420
-
421
- if state_dict is None:
422
- state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
423
-
424
- set_model_type(model, state_dict)
425
- set_model_fields(model)
426
-
427
- if model.is_sdxl:
428
- sd_models_xl.extend_sdxl(model)
429
-
430
- if model.is_ssd:
431
- sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
432
-
433
- if shared.opts.sd_checkpoint_cache > 0:
434
- # cache newly loaded model
435
- checkpoints_loaded[checkpoint_info] = state_dict.copy()
436
-
437
- if hasattr(model, "before_load_weights"):
438
- model.before_load_weights(state_dict)
439
-
440
- model.load_state_dict(state_dict, strict=False)
441
- timer.record("apply weights to model")
442
-
443
- if hasattr(model, "after_load_weights"):
444
- model.after_load_weights(state_dict)
445
-
446
- del state_dict
447
-
448
- # Set is_sdxl_inpaint flag.
449
- # Checks Unet structure to detect inpaint model. The inpaint model's
450
- # checkpoint state_dict does not contain the key
451
- # 'diffusion_model.input_blocks.0.0.weight'.
452
- diffusion_model_input = model.model.state_dict().get(
453
- 'diffusion_model.input_blocks.0.0.weight'
454
- )
455
- model.is_sdxl_inpaint = (
456
- model.is_sdxl and
457
- diffusion_model_input is not None and
458
- diffusion_model_input.shape[1] == 9
459
- )
460
-
461
- if shared.cmd_opts.opt_channelslast:
462
- model.to(memory_format=torch.channels_last)
463
- timer.record("apply channels_last")
464
-
465
- if shared.cmd_opts.no_half:
466
- model.float()
467
- model.alphas_cumprod_original = model.alphas_cumprod
468
- devices.dtype_unet = torch.float32
469
- assert shared.cmd_opts.precision != "half", "Cannot use --precision half with --no-half"
470
- timer.record("apply float()")
471
- else:
472
- vae = model.first_stage_model
473
- depth_model = getattr(model, 'depth_model', None)
474
-
475
- # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
476
- if shared.cmd_opts.no_half_vae:
477
- model.first_stage_model = None
478
- # with --upcast-sampling, don't convert the depth model weights to float16
479
- if shared.cmd_opts.upcast_sampling and depth_model:
480
- model.depth_model = None
481
-
482
- alphas_cumprod = model.alphas_cumprod
483
- model.alphas_cumprod = None
484
- model.half()
485
- model.alphas_cumprod = alphas_cumprod
486
- model.alphas_cumprod_original = alphas_cumprod
487
- model.first_stage_model = vae
488
- if depth_model:
489
- model.depth_model = depth_model
490
-
491
- devices.dtype_unet = torch.float16
492
- timer.record("apply half()")
493
-
494
- apply_alpha_schedule_override(model)
495
-
496
- for module in model.modules():
497
- if hasattr(module, 'fp16_weight'):
498
- del module.fp16_weight
499
- if hasattr(module, 'fp16_bias'):
500
- del module.fp16_bias
501
-
502
- if check_fp8(model):
503
- devices.fp8 = True
504
- first_stage = model.first_stage_model
505
- model.first_stage_model = None
506
- for module in model.modules():
507
- if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
508
- if shared.opts.cache_fp16_weight:
509
- module.fp16_weight = module.weight.data.clone().cpu().half()
510
- if module.bias is not None:
511
- module.fp16_bias = module.bias.data.clone().cpu().half()
512
- module.to(torch.float8_e4m3fn)
513
- model.first_stage_model = first_stage
514
- timer.record("apply fp8")
515
- else:
516
- devices.fp8 = False
517
-
518
- devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
519
-
520
- model.first_stage_model.to(devices.dtype_vae)
521
- timer.record("apply dtype to VAE")
522
-
523
- # clean up cache if limit is reached
524
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
525
- checkpoints_loaded.popitem(last=False)
526
-
527
- model.sd_model_hash = sd_model_hash
528
- model.sd_model_checkpoint = checkpoint_info.filename
529
- model.sd_checkpoint_info = checkpoint_info
530
- shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
531
-
532
- if hasattr(model, 'logvar'):
533
- model.logvar = model.logvar.to(devices.device) # fix for training
534
-
535
- sd_vae.delete_base_vae()
536
- sd_vae.clear_loaded_vae()
537
- vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
538
- sd_vae.load_vae(model, vae_file, vae_source)
539
- timer.record("load VAE")
540
-
541
-
542
- def enable_midas_autodownload():
543
- """
544
- Gives the ldm.modules.midas.api.load_model function automatic downloading.
545
-
546
- When the 512-depth-ema model, and other future models like it, is loaded,
547
- it calls midas.api.load_model to load the associated midas depth model.
548
- This function applies a wrapper to download the model to the correct
549
- location automatically.
550
- """
551
-
552
- midas_path = os.path.join(paths.models_path, 'midas')
553
-
554
- # stable-diffusion-stability-ai hard-codes the midas model path to
555
- # a location that differs from where other scripts using this model look.
556
- # HACK: Overriding the path here.
557
- for k, v in midas.api.ISL_PATHS.items():
558
- file_name = os.path.basename(v)
559
- midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
560
-
561
- midas_urls = {
562
- "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
563
- "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
564
- "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
565
- "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
566
- }
567
-
568
- midas.api.load_model_inner = midas.api.load_model
569
-
570
- def load_model_wrapper(model_type):
571
- path = midas.api.ISL_PATHS[model_type]
572
- if not os.path.exists(path):
573
- if not os.path.exists(midas_path):
574
- os.mkdir(midas_path)
575
-
576
- print(f"Downloading midas model weights for {model_type} to {path}")
577
- request.urlretrieve(midas_urls[model_type], path)
578
- print(f"{model_type} downloaded")
579
-
580
- return midas.api.load_model_inner(model_type)
581
-
582
- midas.api.load_model = load_model_wrapper
583
-
584
-
585
- def patch_given_betas():
586
- import ldm.models.diffusion.ddpm
587
-
588
- def patched_register_schedule(*args, **kwargs):
589
- """a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
590
-
591
- if isinstance(args[1], ListConfig):
592
- args = (args[0], np.array(args[1]), *args[2:])
593
-
594
- original_register_schedule(*args, **kwargs)
595
-
596
- original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
597
-
598
-
599
- def repair_config(sd_config, state_dict=None):
600
- if not hasattr(sd_config.model.params, "use_ema"):
601
- sd_config.model.params.use_ema = False
602
-
603
- if hasattr(sd_config.model.params, 'unet_config'):
604
- if shared.cmd_opts.no_half:
605
- sd_config.model.params.unet_config.params.use_fp16 = False
606
- elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
607
- sd_config.model.params.unet_config.params.use_fp16 = True
608
-
609
- if hasattr(sd_config.model.params, 'first_stage_config'):
610
- if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
611
- sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
612
-
613
- # For UnCLIP-L, override the hardcoded karlo directory
614
- if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
615
- karlo_path = os.path.join(paths.models_path, 'karlo')
616
- sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
617
-
618
- # Do not use checkpoint for inference.
619
- # This helps prevent extra performance overhead on checking parameters.
620
- # The perf overhead is about 100ms/it on 4090 for SDXL.
621
- if hasattr(sd_config.model.params, "network_config"):
622
- sd_config.model.params.network_config.params.use_checkpoint = False
623
- if hasattr(sd_config.model.params, "unet_config"):
624
- sd_config.model.params.unet_config.params.use_checkpoint = False
625
-
626
-
627
-
628
- def rescale_zero_terminal_snr_abar(alphas_cumprod):
629
- alphas_bar_sqrt = alphas_cumprod.sqrt()
630
-
631
- # Store old values.
632
- alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
633
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
634
-
635
- # Shift so the last timestep is zero.
636
- alphas_bar_sqrt -= (alphas_bar_sqrt_T)
637
-
638
- # Scale so the first timestep is back to the old value.
639
- alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
640
-
641
- # Convert alphas_bar_sqrt to betas
642
- alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
643
- alphas_bar[-1] = 4.8973451890853435e-08
644
- return alphas_bar
645
-
646
-
647
- def apply_alpha_schedule_override(sd_model, p=None):
648
- """
649
- Applies an override to the alpha schedule of the model according to settings.
650
- - downcasts the alpha schedule to half precision
651
- - rescales the alpha schedule to have zero terminal SNR
652
- """
653
-
654
- if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
655
- return
656
-
657
- sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
658
-
659
- if opts.use_downcasted_alpha_bar:
660
- if p is not None:
661
- p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
662
- sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
663
-
664
- if opts.sd_noise_schedule == "Zero Terminal SNR":
665
- if p is not None:
666
- p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
667
- sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
668
-
669
-
670
- sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
671
- sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
672
- sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
673
- sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
674
-
675
-
676
- class SdModelData:
677
- def __init__(self):
678
- self.sd_model = None
679
- self.loaded_sd_models = []
680
- self.was_loaded_at_least_once = False
681
- self.lock = threading.Lock()
682
-
683
- def get_sd_model(self):
684
- if self.was_loaded_at_least_once:
685
- return self.sd_model
686
-
687
- if self.sd_model is None:
688
- with self.lock:
689
- if self.sd_model is not None or self.was_loaded_at_least_once:
690
- return self.sd_model
691
-
692
- try:
693
- load_model()
694
-
695
- except Exception as e:
696
- errors.display(e, "loading stable diffusion model", full_traceback=True)
697
- print("", file=sys.stderr)
698
- print("Stable diffusion model failed to load", file=sys.stderr)
699
- self.sd_model = None
700
-
701
- return self.sd_model
702
-
703
- def set_sd_model(self, v, already_loaded=False):
704
- self.sd_model = v
705
- if already_loaded:
706
- sd_vae.base_vae = getattr(v, "base_vae", None)
707
- sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
708
- sd_vae.checkpoint_info = v.sd_checkpoint_info
709
-
710
- try:
711
- self.loaded_sd_models.remove(v)
712
- except ValueError:
713
- pass
714
-
715
- if v is not None:
716
- self.loaded_sd_models.insert(0, v)
717
-
718
-
719
- model_data = SdModelData()
720
-
721
-
722
- def get_empty_cond(sd_model):
723
-
724
- p = processing.StableDiffusionProcessingTxt2Img()
725
- extra_networks.activate(p, {})
726
-
727
- if hasattr(sd_model, 'get_learned_conditioning'):
728
- d = sd_model.get_learned_conditioning([""])
729
- else:
730
- d = sd_model.cond_stage_model([""])
731
-
732
- if isinstance(d, dict):
733
- d = d['crossattn']
734
-
735
- return d
736
-
737
-
738
- def send_model_to_cpu(m):
739
- if m is not None:
740
- if m.lowvram:
741
- lowvram.send_everything_to_cpu()
742
- else:
743
- m.to(devices.cpu)
744
-
745
- devices.torch_gc()
746
-
747
-
748
- def model_target_device(m):
749
- if lowvram.is_needed(m):
750
- return devices.cpu
751
- else:
752
- return devices.device
753
-
754
-
755
- def send_model_to_device(m):
756
- lowvram.apply(m)
757
-
758
- if not m.lowvram:
759
- m.to(shared.device)
760
-
761
-
762
- def send_model_to_trash(m):
763
- m.to(device="meta")
764
- devices.torch_gc()
765
-
766
-
767
- def instantiate_from_config(config, state_dict=None):
768
- constructor = get_obj_from_str(config["target"])
769
-
770
- params = {**config.get("params", {})}
771
-
772
- if state_dict and "state_dict" in params and params["state_dict"] is None:
773
- params["state_dict"] = state_dict
774
-
775
- return constructor(**params)
776
-
777
-
778
- def get_obj_from_str(string, reload=False):
779
- module, cls = string.rsplit(".", 1)
780
- if reload:
781
- module_imp = importlib.import_module(module)
782
- importlib.reload(module_imp)
783
- return getattr(importlib.import_module(module, package=None), cls)
784
-
785
-
786
- def load_model(checkpoint_info=None, already_loaded_state_dict=None):
787
- from modules import sd_hijack
788
- checkpoint_info = checkpoint_info or select_checkpoint()
789
-
790
- timer = Timer()
791
-
792
- if model_data.sd_model:
793
- send_model_to_trash(model_data.sd_model)
794
- model_data.sd_model = None
795
- devices.torch_gc()
796
-
797
- timer.record("unload existing model")
798
-
799
- if already_loaded_state_dict is not None:
800
- state_dict = already_loaded_state_dict
801
- else:
802
- state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
803
-
804
- checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
805
- clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
806
-
807
- timer.record("find config")
808
-
809
- sd_config = OmegaConf.load(checkpoint_config)
810
- repair_config(sd_config, state_dict)
811
-
812
- timer.record("load config")
813
-
814
- print(f"Creating model from config: {checkpoint_config}")
815
-
816
- sd_model = None
817
- try:
818
- with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
819
- with sd_disable_initialization.InitializeOnMeta():
820
- sd_model = instantiate_from_config(sd_config.model, state_dict)
821
-
822
- except Exception as e:
823
- errors.display(e, "creating model quickly", full_traceback=True)
824
-
825
- if sd_model is None:
826
- print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
827
-
828
- with sd_disable_initialization.InitializeOnMeta():
829
- sd_model = instantiate_from_config(sd_config.model, state_dict)
830
-
831
- sd_model.used_config = checkpoint_config
832
-
833
- timer.record("create model")
834
-
835
- if shared.cmd_opts.no_half:
836
- weight_dtype_conversion = None
837
- else:
838
- weight_dtype_conversion = {
839
- 'first_stage_model': None,
840
- 'alphas_cumprod': None,
841
- '': torch.float16,
842
- }
843
-
844
- with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
845
- load_model_weights(sd_model, checkpoint_info, state_dict, timer)
846
-
847
- timer.record("load weights from state dict")
848
-
849
- send_model_to_device(sd_model)
850
- timer.record("move model to device")
851
-
852
- sd_hijack.model_hijack.hijack(sd_model)
853
-
854
- timer.record("hijack")
855
-
856
- sd_model.eval()
857
- model_data.set_sd_model(sd_model)
858
- model_data.was_loaded_at_least_once = True
859
-
860
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
861
-
862
- timer.record("load textual inversion embeddings")
863
-
864
- script_callbacks.model_loaded_callback(sd_model)
865
-
866
- timer.record("scripts callbacks")
867
-
868
- with devices.autocast(), torch.no_grad():
869
- sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
870
-
871
- timer.record("calculate empty prompt")
872
-
873
- print(f"Model loaded in {timer.summary()}.")
874
-
875
- return sd_model
876
-
877
-
878
- def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
879
- """
880
- Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
881
- If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
882
- If not, returns the model that can be used to load weights from checkpoint_info's file.
883
- If no such model exists, returns None.
884
- Additionally deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
885
- """
886
-
887
- if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
888
- return sd_model
889
-
890
- if shared.opts.sd_checkpoints_keep_in_cpu:
891
- send_model_to_cpu(sd_model)
892
- timer.record("send model to cpu")
893
-
894
- already_loaded = None
895
- for i in reversed(range(len(model_data.loaded_sd_models))):
896
- loaded_model = model_data.loaded_sd_models[i]
897
- if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
898
- already_loaded = loaded_model
899
- continue
900
-
901
- if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
902
- print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
903
- del model_data.loaded_sd_models[i]
904
- send_model_to_trash(loaded_model)
905
- timer.record("send model to trash")
906
-
907
- if already_loaded is not None:
908
- send_model_to_device(already_loaded)
909
- timer.record("send model to device")
910
-
911
- model_data.set_sd_model(already_loaded, already_loaded=True)
912
-
913
- if not SkipWritingToConfig.skip:
914
- shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
915
- shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
916
-
917
- print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
918
- sd_vae.reload_vae_weights(already_loaded)
919
- return model_data.sd_model
920
- elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
921
- print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
922
-
923
- model_data.sd_model = None
924
- load_model(checkpoint_info)
925
- return model_data.sd_model
926
- elif len(model_data.loaded_sd_models) > 0:
927
- sd_model = model_data.loaded_sd_models.pop()
928
- model_data.sd_model = sd_model
929
-
930
- sd_vae.base_vae = getattr(sd_model, "base_vae", None)
931
- sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
932
- sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
933
-
934
- print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
935
- return sd_model
936
- else:
937
- return None
938
-
939
-
940
- def reload_model_weights(sd_model=None, info=None, forced_reload=False):
941
- checkpoint_info = info or select_checkpoint()
942
-
943
- timer = Timer()
944
-
945
- if not sd_model:
946
- sd_model = model_data.sd_model
947
-
948
- if sd_model is None: # previous model load failed
949
- current_checkpoint_info = None
950
- else:
951
- current_checkpoint_info = sd_model.sd_checkpoint_info
952
- if check_fp8(sd_model) != devices.fp8:
953
- # load from state dict again to prevent extra numerical errors
954
- forced_reload = True
955
- elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
956
- return sd_model
957
-
958
- sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
959
- if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
960
- return sd_model
961
-
962
- if sd_model is not None:
963
- sd_unet.apply_unet("None")
964
- send_model_to_cpu(sd_model)
965
- sd_hijack.model_hijack.undo_hijack(sd_model)
966
-
967
- state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
968
-
969
- checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
970
-
971
- timer.record("find config")
972
-
973
- if sd_model is None or checkpoint_config != sd_model.used_config:
974
- if sd_model is not None:
975
- send_model_to_trash(sd_model)
976
-
977
- load_model(checkpoint_info, already_loaded_state_dict=state_dict)
978
- return model_data.sd_model
979
-
980
- try:
981
- load_model_weights(sd_model, checkpoint_info, state_dict, timer)
982
- except Exception:
983
- print("Failed to load checkpoint, restoring previous")
984
- load_model_weights(sd_model, current_checkpoint_info, None, timer)
985
- raise
986
- finally:
987
- sd_hijack.model_hijack.hijack(sd_model)
988
- timer.record("hijack")
989
-
990
- if not sd_model.lowvram:
991
- sd_model.to(devices.device)
992
- timer.record("move model to device")
993
-
994
- script_callbacks.model_loaded_callback(sd_model)
995
- timer.record("script callbacks")
996
-
997
- print(f"Weights loaded in {timer.summary()}.")
998
-
999
- model_data.set_sd_model(sd_model)
1000
- sd_unet.apply_unet()
1001
-
1002
- return sd_model
1003
-
1004
-
1005
- def unload_model_weights(sd_model=None, info=None):
1006
- send_model_to_cpu(sd_model or shared.sd_model)
1007
-
1008
- return sd_model
1009
-
1010
-
1011
- def apply_token_merging(sd_model, token_merging_ratio):
1012
- """
1013
- Applies speed and memory optimizations from tomesd.
1014
- """
1015
-
1016
- current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
1017
-
1018
- if current_token_merging_ratio == token_merging_ratio:
1019
- return
1020
-
1021
- if current_token_merging_ratio > 0:
1022
- tomesd.remove_patch(sd_model)
1023
-
1024
- if token_merging_ratio > 0:
1025
- tomesd.apply_patch(
1026
- sd_model,
1027
- ratio=token_merging_ratio,
1028
- use_rand=False, # can cause issues with some samplers
1029
- merge_attn=True,
1030
- merge_crossattn=False,
1031
- merge_mlp=False
1032
- )
1033
-
1034
- sd_model.applied_token_merged_ratio = token_merging_ratio
 
 
 
 
1
+ import collections
2
+ import importlib
3
+ import os
4
+ import sys
5
+ import threading
6
+ import enum
7
+
8
+ import torch
9
+ import re
10
+ import safetensors.torch
11
+ from omegaconf import OmegaConf, ListConfig
12
+ from urllib import request
13
+ import ldm.modules.midas as midas
14
+
15
+ from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
16
+ from modules.timer import Timer
17
+ from modules.shared import opts
18
+ import tomesd
19
+ import numpy as np
20
+
21
+ model_dir = "Stable-diffusion"
22
+ model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
23
+
24
+ checkpoints_list = {}
25
+ checkpoint_aliases = {}
26
+ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
27
+ checkpoints_loaded = collections.OrderedDict()
28
+
29
+
30
+ class ModelType(enum.Enum):
31
+ SD1 = 1
32
+ SD2 = 2
33
+ SDXL = 3
34
+ SSD = 4
35
+ SD3 = 5
36
+
37
+
38
+ def replace_key(d, key, new_key, value):
39
+ keys = list(d.keys())
40
+
41
+ d[new_key] = value
42
+
43
+ if key not in keys:
44
+ return d
45
+
46
+ index = keys.index(key)
47
+ keys[index] = new_key
48
+
49
+ new_d = {k: d[k] for k in keys}
50
+
51
+ d.clear()
52
+ d.update(new_d)
53
+ return d
54
+
55
+
56
+ class CheckpointInfo:
57
+ def __init__(self, filename):
58
+ self.filename = filename
59
+ abspath = os.path.abspath(filename)
60
+ abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
61
+
62
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
63
+
64
+ if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
65
+ name = abspath.replace(abs_ckpt_dir, '')
66
+ elif abspath.startswith(model_path):
67
+ name = abspath.replace(model_path, '')
68
+ else:
69
+ name = os.path.basename(filename)
70
+
71
+ if name.startswith("\\") or name.startswith("/"):
72
+ name = name[1:]
73
+
74
+ def read_metadata():
75
+ metadata = read_metadata_from_safetensors(filename)
76
+ self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
77
+
78
+ return metadata
79
+
80
+ self.metadata = {}
81
+ if self.is_safetensors:
82
+ try:
83
+ self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
84
+ except Exception as e:
85
+ errors.display(e, f"reading metadata for {filename}")
86
+
87
+ self.name = name
88
+ self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
89
+ self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
90
+ self.hash = model_hash(filename)
91
+
92
+ self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
93
+ self.shorthash = self.sha256[0:10] if self.sha256 else None
94
+
95
+ self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
96
+ self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
97
+
98
+ self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
99
+ if self.shorthash:
100
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
101
+
102
+ def register(self):
103
+ checkpoints_list[self.title] = self
104
+ for id in self.ids:
105
+ checkpoint_aliases[id] = self
106
+
107
+ def calculate_shorthash(self):
108
+ self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
109
+ if self.sha256 is None:
110
+ return
111
+
112
+ shorthash = self.sha256[0:10]
113
+ if self.shorthash == self.sha256[0:10]:
114
+ return self.shorthash
115
+
116
+ self.shorthash = shorthash
117
+
118
+ if self.shorthash not in self.ids:
119
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
120
+
121
+ old_title = self.title
122
+ self.title = f'{self.name} [{self.shorthash}]'
123
+ self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
124
+
125
+ replace_key(checkpoints_list, old_title, self.title, self)
126
+ self.register()
127
+
128
+ return self.shorthash
129
+
130
+
131
+ try:
132
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
133
+ from transformers import logging, CLIPModel # noqa: F401
134
+
135
+ logging.set_verbosity_error()
136
+ except Exception:
137
+ pass
138
+
139
+
140
+ def setup_model():
141
+ """called once at startup to do various one-time tasks related to SD models"""
142
+
143
+ os.makedirs(model_path, exist_ok=True)
144
+
145
+ enable_midas_autodownload()
146
+ patch_given_betas()
147
+
148
+
149
+ def checkpoint_tiles(use_short=False):
150
+ return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
151
+
152
+
153
+ def list_models():
154
+ checkpoints_list.clear()
155
+ checkpoint_aliases.clear()
156
+
157
+ cmd_ckpt = shared.cmd_opts.ckpt
158
+ if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
159
+ model_url = None
160
+ expected_sha256 = None
161
+ else:
162
+ model_url = f"{shared.hf_endpoint}/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
163
+ expected_sha256 = '6ce0161689b3853acaa03779ec93eafe75a02f4ced659bee03f50797806fa2fa'
164
+
165
+ model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"], hash_prefix=expected_sha256)
166
+
167
+ if os.path.exists(cmd_ckpt):
168
+ checkpoint_info = CheckpointInfo(cmd_ckpt)
169
+ checkpoint_info.register()
170
+
171
+ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
172
+ elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
173
+ print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
174
+
175
+ for filename in model_list:
176
+ checkpoint_info = CheckpointInfo(filename)
177
+ checkpoint_info.register()
178
+
179
+
180
+ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
181
+
182
+
183
+ def get_closet_checkpoint_match(search_string):
184
+ if not search_string:
185
+ return None
186
+
187
+ checkpoint_info = checkpoint_aliases.get(search_string, None)
188
+ if checkpoint_info is not None:
189
+ return checkpoint_info
190
+
191
+ found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
192
+ if found:
193
+ return found[0]
194
+
195
+ search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
196
+ found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
197
+ if found:
198
+ return found[0]
199
+
200
+ return None
201
+
202
+
203
+ def model_hash(filename):
204
+ """old hash that only looks at a small part of the file and is prone to collisions"""
205
+
206
+ try:
207
+ with open(filename, "rb") as file:
208
+ import hashlib
209
+ m = hashlib.sha256()
210
+
211
+ file.seek(0x100000)
212
+ m.update(file.read(0x10000))
213
+ return m.hexdigest()[0:8]
214
+ except FileNotFoundError:
215
+ return 'NOFILE'
216
+
217
+
218
+ def select_checkpoint():
219
+ """Raises `FileNotFoundError` if no checkpoints are found."""
220
+ model_checkpoint = shared.opts.sd_model_checkpoint
221
+
222
+ checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
223
+ if checkpoint_info is not None:
224
+ return checkpoint_info
225
+
226
+ if len(checkpoints_list) == 0:
227
+ error_message = "No checkpoints found. When searching for checkpoints, looked at:"
228
+ if shared.cmd_opts.ckpt is not None:
229
+ error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
230
+ error_message += f"\n - directory {model_path}"
231
+ if shared.cmd_opts.ckpt_dir is not None:
232
+ error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
233
+ error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
234
+ raise FileNotFoundError(error_message)
235
+
236
+ checkpoint_info = next(iter(checkpoints_list.values()))
237
+ if model_checkpoint is not None:
238
+ print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
239
+
240
+ return checkpoint_info
241
+
242
+
243
+ checkpoint_dict_replacements_sd1 = {
244
+ 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
245
+ 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
246
+ 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
247
+ }
248
+
249
+ checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
250
+ 'conditioner.embedders.0.': 'cond_stage_model.',
251
+ }
252
+
253
+
254
+ def transform_checkpoint_dict_key(k, replacements):
255
+ for text, replacement in replacements.items():
256
+ if k.startswith(text):
257
+ k = replacement + k[len(text):]
258
+
259
+ return k
260
+
261
+
262
+ def get_state_dict_from_checkpoint(pl_sd):
263
+ pl_sd = pl_sd.pop("state_dict", pl_sd)
264
+ pl_sd.pop("state_dict", None)
265
+
266
+ is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
267
+
268
+ sd = {}
269
+ for k, v in pl_sd.items():
270
+ if is_sd2_turbo:
271
+ new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
272
+ else:
273
+ new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
274
+
275
+ if new_key is not None:
276
+ sd[new_key] = v
277
+
278
+ pl_sd.clear()
279
+ pl_sd.update(sd)
280
+
281
+ return pl_sd
282
+
283
+
284
+ def read_metadata_from_safetensors(filename):
285
+ import json
286
+
287
+ with open(filename, mode="rb") as file:
288
+ metadata_len = file.read(8)
289
+ metadata_len = int.from_bytes(metadata_len, "little")
290
+ json_start = file.read(2)
291
+
292
+ assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
293
+
294
+ res = {}
295
+
296
+ try:
297
+ json_data = json_start + file.read(metadata_len-2)
298
+ json_obj = json.loads(json_data)
299
+ for k, v in json_obj.get("__metadata__", {}).items():
300
+ res[k] = v
301
+ if isinstance(v, str) and v[0:1] == '{':
302
+ try:
303
+ res[k] = json.loads(v)
304
+ except Exception:
305
+ pass
306
+ except Exception:
307
+ errors.report(f"Error reading metadata from file: {filename}", exc_info=True)
308
+
309
+ return res
310
+
311
+
312
+ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
313
+ _, extension = os.path.splitext(checkpoint_file)
314
+ if extension.lower() == ".safetensors":
315
+ device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
316
+
317
+ if not shared.opts.disable_mmap_load_safetensors:
318
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
319
+ else:
320
+ pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
321
+ pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
322
+ else:
323
+ pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
324
+
325
+ if print_global_state and "global_step" in pl_sd:
326
+ print(f"Global Step: {pl_sd['global_step']}")
327
+
328
+ sd = get_state_dict_from_checkpoint(pl_sd)
329
+ return sd
330
+
331
+
332
+ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
333
+ sd_model_hash = checkpoint_info.calculate_shorthash()
334
+ timer.record("calculate hash")
335
+
336
+ if checkpoint_info in checkpoints_loaded:
337
+ # use checkpoint cache
338
+ print(f"Loading weights [{sd_model_hash}] from cache")
339
+ # move to end as latest
340
+ checkpoints_loaded.move_to_end(checkpoint_info)
341
+ return checkpoints_loaded[checkpoint_info]
342
+
343
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
344
+ res = read_state_dict(checkpoint_info.filename)
345
+ timer.record("load weights from disk")
346
+
347
+ return res
348
+
349
+
350
+ class SkipWritingToConfig:
351
+ """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
352
+
353
+ skip = False
354
+ previous = None
355
+
356
+ def __enter__(self):
357
+ self.previous = SkipWritingToConfig.skip
358
+ SkipWritingToConfig.skip = True
359
+ return self
360
+
361
+ def __exit__(self, exc_type, exc_value, exc_traceback):
362
+ SkipWritingToConfig.skip = self.previous
363
+
364
+
365
+ def check_fp8(model):
366
+ if model is None:
367
+ return None
368
+ if devices.get_optimal_device_name() == "mps":
369
+ enable_fp8 = False
370
+ elif shared.opts.fp8_storage == "Enable":
371
+ enable_fp8 = True
372
+ elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
373
+ enable_fp8 = True
374
+ else:
375
+ enable_fp8 = False
376
+ return enable_fp8
377
+
378
+
379
+ def set_model_type(model, state_dict):
380
+ model.is_sd1 = False
381
+ model.is_sd2 = False
382
+ model.is_sdxl = False
383
+ model.is_ssd = False
384
+ model.is_sd3 = False
385
+
386
+ if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
387
+ model.is_sd3 = True
388
+ model.model_type = ModelType.SD3
389
+ elif hasattr(model, 'conditioner'):
390
+ model.is_sdxl = True
391
+
392
+ if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
393
+ model.is_ssd = True
394
+ model.model_type = ModelType.SSD
395
+ else:
396
+ model.model_type = ModelType.SDXL
397
+ elif hasattr(model.cond_stage_model, 'model'):
398
+ model.is_sd2 = True
399
+ model.model_type = ModelType.SD2
400
+ else:
401
+ model.is_sd1 = True
402
+ model.model_type = ModelType.SD1
403
+
404
+
405
+ def set_model_fields(model):
406
+ if not hasattr(model, 'latent_channels'):
407
+ model.latent_channels = 4
408
+
409
+
410
+ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
411
+ sd_model_hash = checkpoint_info.calculate_shorthash()
412
+ timer.record("calculate hash")
413
+
414
+ if devices.fp8:
415
+ # prevent model to load state dict in fp8
416
+ if torch.cuda.is_available():
417
+ model.half()
418
+
419
+
420
+ if not SkipWritingToConfig.skip:
421
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
422
+
423
+ if state_dict is None:
424
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
425
+
426
+ set_model_type(model, state_dict)
427
+ set_model_fields(model)
428
+
429
+ if model.is_sdxl:
430
+ sd_models_xl.extend_sdxl(model)
431
+
432
+ if model.is_ssd:
433
+ sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
434
+
435
+ if shared.opts.sd_checkpoint_cache > 0:
436
+ # cache newly loaded model
437
+ checkpoints_loaded[checkpoint_info] = state_dict.copy()
438
+
439
+ if hasattr(model, "before_load_weights"):
440
+ model.before_load_weights(state_dict)
441
+
442
+ model.load_state_dict(state_dict, strict=False)
443
+ timer.record("apply weights to model")
444
+
445
+ if hasattr(model, "after_load_weights"):
446
+ model.after_load_weights(state_dict)
447
+
448
+ del state_dict
449
+
450
+ # Set is_sdxl_inpaint flag.
451
+ # Checks Unet structure to detect inpaint model. The inpaint model's
452
+ # checkpoint state_dict does not contain the key
453
+ # 'diffusion_model.input_blocks.0.0.weight'.
454
+ diffusion_model_input = model.model.state_dict().get(
455
+ 'diffusion_model.input_blocks.0.0.weight'
456
+ )
457
+ model.is_sdxl_inpaint = (
458
+ model.is_sdxl and
459
+ diffusion_model_input is not None and
460
+ diffusion_model_input.shape[1] == 9
461
+ )
462
+
463
+ if shared.cmd_opts.opt_channelslast:
464
+ model.to(memory_format=torch.channels_last)
465
+ timer.record("apply channels_last")
466
+
467
+ if shared.cmd_opts.no_half:
468
+ model.float()
469
+ model.alphas_cumprod_original = model.alphas_cumprod
470
+ devices.dtype_unet = torch.float32
471
+ assert shared.cmd_opts.precision != "half", "Cannot use --precision half with --no-half"
472
+ timer.record("apply float()")
473
+ else:
474
+ vae = model.first_stage_model
475
+ depth_model = getattr(model, 'depth_model', None)
476
+
477
+ # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
478
+ if shared.cmd_opts.no_half_vae:
479
+ model.first_stage_model = None
480
+ # with --upcast-sampling, don't convert the depth model weights to float16
481
+ if shared.cmd_opts.upcast_sampling and depth_model:
482
+ model.depth_model = None
483
+
484
+ alphas_cumprod = model.alphas_cumprod
485
+ model.alphas_cumprod = None
486
+ if torch.cuda.is_available():
487
+ model.half()
488
+ model.alphas_cumprod = alphas_cumprod
489
+ model.alphas_cumprod_original = alphas_cumprod
490
+ model.first_stage_model = vae
491
+ if depth_model:
492
+ model.depth_model = depth_model
493
+
494
+ devices.dtype_unet = torch.float16
495
+ timer.record("apply half()")
496
+
497
+ apply_alpha_schedule_override(model)
498
+
499
+ for module in model.modules():
500
+ if hasattr(module, 'fp16_weight'):
501
+ del module.fp16_weight
502
+ if hasattr(module, 'fp16_bias'):
503
+ del module.fp16_bias
504
+
505
+ if check_fp8(model):
506
+ devices.fp8 = True
507
+ first_stage = model.first_stage_model
508
+ model.first_stage_model = None
509
+ for module in model.modules():
510
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
511
+ if shared.opts.cache_fp16_weight:
512
+ module.fp16_weight = module.weight.data.clone().cpu().half()
513
+ if module.bias is not None:
514
+ module.fp16_bias = module.bias.data.clone().cpu().half()
515
+ module.to(torch.float8_e4m3fn)
516
+ model.first_stage_model = first_stage
517
+ timer.record("apply fp8")
518
+ else:
519
+ devices.fp8 = False
520
+
521
+ devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
522
+
523
+ model.first_stage_model.to(devices.dtype_vae)
524
+ timer.record("apply dtype to VAE")
525
+
526
+ # clean up cache if limit is reached
527
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
528
+ checkpoints_loaded.popitem(last=False)
529
+
530
+ model.sd_model_hash = sd_model_hash
531
+ model.sd_model_checkpoint = checkpoint_info.filename
532
+ model.sd_checkpoint_info = checkpoint_info
533
+ shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
534
+
535
+ if hasattr(model, 'logvar'):
536
+ model.logvar = model.logvar.to(devices.device) # fix for training
537
+
538
+ sd_vae.delete_base_vae()
539
+ sd_vae.clear_loaded_vae()
540
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
541
+ sd_vae.load_vae(model, vae_file, vae_source)
542
+ timer.record("load VAE")
543
+
544
+
545
+ def enable_midas_autodownload():
546
+ """
547
+ Gives the ldm.modules.midas.api.load_model function automatic downloading.
548
+
549
+ When the 512-depth-ema model, and other future models like it, is loaded,
550
+ it calls midas.api.load_model to load the associated midas depth model.
551
+ This function applies a wrapper to download the model to the correct
552
+ location automatically.
553
+ """
554
+
555
+ midas_path = os.path.join(paths.models_path, 'midas')
556
+
557
+ # stable-diffusion-stability-ai hard-codes the midas model path to
558
+ # a location that differs from where other scripts using this model look.
559
+ # HACK: Overriding the path here.
560
+ for k, v in midas.api.ISL_PATHS.items():
561
+ file_name = os.path.basename(v)
562
+ midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
563
+
564
+ midas_urls = {
565
+ "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
566
+ "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
567
+ "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
568
+ "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
569
+ }
570
+
571
+ midas.api.load_model_inner = midas.api.load_model
572
+
573
+ def load_model_wrapper(model_type):
574
+ path = midas.api.ISL_PATHS[model_type]
575
+ if not os.path.exists(path):
576
+ if not os.path.exists(midas_path):
577
+ os.mkdir(midas_path)
578
+
579
+ print(f"Downloading midas model weights for {model_type} to {path}")
580
+ request.urlretrieve(midas_urls[model_type], path)
581
+ print(f"{model_type} downloaded")
582
+
583
+ return midas.api.load_model_inner(model_type)
584
+
585
+ midas.api.load_model = load_model_wrapper
586
+
587
+
588
+ def patch_given_betas():
589
+ import ldm.models.diffusion.ddpm
590
+
591
+ def patched_register_schedule(*args, **kwargs):
592
+ """a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
593
+
594
+ if isinstance(args[1], ListConfig):
595
+ args = (args[0], np.array(args[1]), *args[2:])
596
+
597
+ original_register_schedule(*args, **kwargs)
598
+
599
+ original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
600
+
601
+
602
+ def repair_config(sd_config, state_dict=None):
603
+ if not hasattr(sd_config.model.params, "use_ema"):
604
+ sd_config.model.params.use_ema = False
605
+
606
+ if hasattr(sd_config.model.params, 'unet_config'):
607
+ if shared.cmd_opts.no_half:
608
+ sd_config.model.params.unet_config.params.use_fp16 = False
609
+ elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
610
+ sd_config.model.params.unet_config.params.use_fp16 = True
611
+
612
+ if hasattr(sd_config.model.params, 'first_stage_config'):
613
+ if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
614
+ sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
615
+
616
+ # For UnCLIP-L, override the hardcoded karlo directory
617
+ if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
618
+ karlo_path = os.path.join(paths.models_path, 'karlo')
619
+ sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
620
+
621
+ # Do not use checkpoint for inference.
622
+ # This helps prevent extra performance overhead on checking parameters.
623
+ # The perf overhead is about 100ms/it on 4090 for SDXL.
624
+ if hasattr(sd_config.model.params, "network_config"):
625
+ sd_config.model.params.network_config.params.use_checkpoint = False
626
+ if hasattr(sd_config.model.params, "unet_config"):
627
+ sd_config.model.params.unet_config.params.use_checkpoint = False
628
+
629
+
630
+
631
+ def rescale_zero_terminal_snr_abar(alphas_cumprod):
632
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
633
+
634
+ # Store old values.
635
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
636
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
637
+
638
+ # Shift so the last timestep is zero.
639
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
640
+
641
+ # Scale so the first timestep is back to the old value.
642
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
643
+
644
+ # Convert alphas_bar_sqrt to betas
645
+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
646
+ alphas_bar[-1] = 4.8973451890853435e-08
647
+ return alphas_bar
648
+
649
+
650
+ def apply_alpha_schedule_override(sd_model, p=None):
651
+ """
652
+ Applies an override to the alpha schedule of the model according to settings.
653
+ - downcasts the alpha schedule to half precision
654
+ - rescales the alpha schedule to have zero terminal SNR
655
+ """
656
+
657
+ if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
658
+ return
659
+
660
+ sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
661
+
662
+ if opts.use_downcasted_alpha_bar:
663
+ if p is not None:
664
+ p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
665
+ sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
666
+
667
+ if opts.sd_noise_schedule == "Zero Terminal SNR":
668
+ if p is not None:
669
+ p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
670
+ sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
671
+
672
+
673
+ sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
674
+ sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
675
+ sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
676
+ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
677
+
678
+
679
+ class SdModelData:
680
+ def __init__(self):
681
+ self.sd_model = None
682
+ self.loaded_sd_models = []
683
+ self.was_loaded_at_least_once = False
684
+ self.lock = threading.Lock()
685
+
686
+ def get_sd_model(self):
687
+ if self.was_loaded_at_least_once:
688
+ return self.sd_model
689
+
690
+ if self.sd_model is None:
691
+ with self.lock:
692
+ if self.sd_model is not None or self.was_loaded_at_least_once:
693
+ return self.sd_model
694
+
695
+ try:
696
+ load_model()
697
+
698
+ except Exception as e:
699
+ errors.display(e, "loading stable diffusion model", full_traceback=True)
700
+ print("", file=sys.stderr)
701
+ print("Stable diffusion model failed to load", file=sys.stderr)
702
+ self.sd_model = None
703
+
704
+ return self.sd_model
705
+
706
+ def set_sd_model(self, v, already_loaded=False):
707
+ self.sd_model = v
708
+ if already_loaded:
709
+ sd_vae.base_vae = getattr(v, "base_vae", None)
710
+ sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
711
+ sd_vae.checkpoint_info = v.sd_checkpoint_info
712
+
713
+ try:
714
+ self.loaded_sd_models.remove(v)
715
+ except ValueError:
716
+ pass
717
+
718
+ if v is not None:
719
+ self.loaded_sd_models.insert(0, v)
720
+
721
+
722
+ model_data = SdModelData()
723
+
724
+
725
+ def get_empty_cond(sd_model):
726
+
727
+ p = processing.StableDiffusionProcessingTxt2Img()
728
+ extra_networks.activate(p, {})
729
+
730
+ if hasattr(sd_model, 'get_learned_conditioning'):
731
+ d = sd_model.get_learned_conditioning([""])
732
+ else:
733
+ d = sd_model.cond_stage_model([""])
734
+
735
+ if isinstance(d, dict):
736
+ d = d['crossattn']
737
+
738
+ return d
739
+
740
+
741
+ def send_model_to_cpu(m):
742
+ if m is not None:
743
+ if m.lowvram:
744
+ lowvram.send_everything_to_cpu()
745
+ else:
746
+ m.to(devices.cpu)
747
+
748
+ devices.torch_gc()
749
+
750
+
751
+ def model_target_device(m):
752
+ if lowvram.is_needed(m):
753
+ return devices.cpu
754
+ else:
755
+ return devices.device
756
+
757
+
758
+ def send_model_to_device(m):
759
+ lowvram.apply(m)
760
+
761
+ if not m.lowvram:
762
+ m.to(shared.device)
763
+
764
+
765
+ def send_model_to_trash(m):
766
+ m.to(device="meta")
767
+ devices.torch_gc()
768
+
769
+
770
+ def instantiate_from_config(config, state_dict=None):
771
+ constructor = get_obj_from_str(config["target"])
772
+
773
+ params = {**config.get("params", {})}
774
+
775
+ if state_dict and "state_dict" in params and params["state_dict"] is None:
776
+ params["state_dict"] = state_dict
777
+
778
+ return constructor(**params)
779
+
780
+
781
+ def get_obj_from_str(string, reload=False):
782
+ module, cls = string.rsplit(".", 1)
783
+ if reload:
784
+ module_imp = importlib.import_module(module)
785
+ importlib.reload(module_imp)
786
+ return getattr(importlib.import_module(module, package=None), cls)
787
+
788
+
789
+ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
790
+ from modules import sd_hijack
791
+ checkpoint_info = checkpoint_info or select_checkpoint()
792
+
793
+ timer = Timer()
794
+
795
+ if model_data.sd_model:
796
+ send_model_to_trash(model_data.sd_model)
797
+ model_data.sd_model = None
798
+ devices.torch_gc()
799
+
800
+ timer.record("unload existing model")
801
+
802
+ if already_loaded_state_dict is not None:
803
+ state_dict = already_loaded_state_dict
804
+ else:
805
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
806
+
807
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
808
+ clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
809
+
810
+ timer.record("find config")
811
+
812
+ sd_config = OmegaConf.load(checkpoint_config)
813
+ repair_config(sd_config, state_dict)
814
+
815
+ timer.record("load config")
816
+
817
+ print(f"Creating model from config: {checkpoint_config}")
818
+
819
+ sd_model = None
820
+ try:
821
+ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
822
+ with sd_disable_initialization.InitializeOnMeta():
823
+ sd_model = instantiate_from_config(sd_config.model, state_dict)
824
+
825
+ except Exception as e:
826
+ errors.display(e, "creating model quickly", full_traceback=True)
827
+
828
+ if sd_model is None:
829
+ print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
830
+
831
+ with sd_disable_initialization.InitializeOnMeta():
832
+ sd_model = instantiate_from_config(sd_config.model, state_dict)
833
+
834
+ sd_model.used_config = checkpoint_config
835
+
836
+ timer.record("create model")
837
+
838
+ if shared.cmd_opts.no_half:
839
+ weight_dtype_conversion = None
840
+ else:
841
+ weight_dtype_conversion = {
842
+ 'first_stage_model': None,
843
+ 'alphas_cumprod': None,
844
+ '': torch.float16,
845
+ }
846
+
847
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
848
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
849
+
850
+ timer.record("load weights from state dict")
851
+
852
+ send_model_to_device(sd_model)
853
+ timer.record("move model to device")
854
+
855
+ sd_hijack.model_hijack.hijack(sd_model)
856
+
857
+ timer.record("hijack")
858
+
859
+ sd_model.eval()
860
+ model_data.set_sd_model(sd_model)
861
+ model_data.was_loaded_at_least_once = True
862
+
863
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
864
+
865
+ timer.record("load textual inversion embeddings")
866
+
867
+ script_callbacks.model_loaded_callback(sd_model)
868
+
869
+ timer.record("scripts callbacks")
870
+
871
+ with devices.autocast(), torch.no_grad():
872
+ sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
873
+
874
+ timer.record("calculate empty prompt")
875
+
876
+ print(f"Model loaded in {timer.summary()}.")
877
+
878
+ return sd_model
879
+
880
+
881
+ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
882
+ """
883
+ Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
884
+ If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
885
+ If not, returns the model that can be used to load weights from checkpoint_info's file.
886
+ If no such model exists, returns None.
887
+ Additionally deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
888
+ """
889
+
890
+ if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
891
+ return sd_model
892
+
893
+ if shared.opts.sd_checkpoints_keep_in_cpu:
894
+ send_model_to_cpu(sd_model)
895
+ timer.record("send model to cpu")
896
+
897
+ already_loaded = None
898
+ for i in reversed(range(len(model_data.loaded_sd_models))):
899
+ loaded_model = model_data.loaded_sd_models[i]
900
+ if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
901
+ already_loaded = loaded_model
902
+ continue
903
+
904
+ if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
905
+ print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
906
+ del model_data.loaded_sd_models[i]
907
+ send_model_to_trash(loaded_model)
908
+ timer.record("send model to trash")
909
+
910
+ if already_loaded is not None:
911
+ send_model_to_device(already_loaded)
912
+ timer.record("send model to device")
913
+
914
+ model_data.set_sd_model(already_loaded, already_loaded=True)
915
+
916
+ if not SkipWritingToConfig.skip:
917
+ shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
918
+ shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
919
+
920
+ print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
921
+ sd_vae.reload_vae_weights(already_loaded)
922
+ return model_data.sd_model
923
+ elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
924
+ print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
925
+
926
+ model_data.sd_model = None
927
+ load_model(checkpoint_info)
928
+ return model_data.sd_model
929
+ elif len(model_data.loaded_sd_models) > 0:
930
+ sd_model = model_data.loaded_sd_models.pop()
931
+ model_data.sd_model = sd_model
932
+
933
+ sd_vae.base_vae = getattr(sd_model, "base_vae", None)
934
+ sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
935
+ sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
936
+
937
+ print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
938
+ return sd_model
939
+ else:
940
+ return None
941
+
942
+
943
+ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
944
+ checkpoint_info = info or select_checkpoint()
945
+
946
+ timer = Timer()
947
+
948
+ if not sd_model:
949
+ sd_model = model_data.sd_model
950
+
951
+ if sd_model is None: # previous model load failed
952
+ current_checkpoint_info = None
953
+ else:
954
+ current_checkpoint_info = sd_model.sd_checkpoint_info
955
+ if check_fp8(sd_model) != devices.fp8:
956
+ # load from state dict again to prevent extra numerical errors
957
+ forced_reload = True
958
+ elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
959
+ return sd_model
960
+
961
+ sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
962
+ if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
963
+ return sd_model
964
+
965
+ if sd_model is not None:
966
+ sd_unet.apply_unet("None")
967
+ send_model_to_cpu(sd_model)
968
+ sd_hijack.model_hijack.undo_hijack(sd_model)
969
+
970
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
971
+
972
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
973
+
974
+ timer.record("find config")
975
+
976
+ if sd_model is None or checkpoint_config != sd_model.used_config:
977
+ if sd_model is not None:
978
+ send_model_to_trash(sd_model)
979
+
980
+ load_model(checkpoint_info, already_loaded_state_dict=state_dict)
981
+ return model_data.sd_model
982
+
983
+ try:
984
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
985
+ except Exception:
986
+ print("Failed to load checkpoint, restoring previous")
987
+ load_model_weights(sd_model, current_checkpoint_info, None, timer)
988
+ raise
989
+ finally:
990
+ sd_hijack.model_hijack.hijack(sd_model)
991
+ timer.record("hijack")
992
+
993
+ if not sd_model.lowvram:
994
+ sd_model.to(devices.device)
995
+ timer.record("move model to device")
996
+
997
+ script_callbacks.model_loaded_callback(sd_model)
998
+ timer.record("script callbacks")
999
+
1000
+ print(f"Weights loaded in {timer.summary()}.")
1001
+
1002
+ model_data.set_sd_model(sd_model)
1003
+ sd_unet.apply_unet()
1004
+
1005
+ return sd_model
1006
+
1007
+
1008
+ def unload_model_weights(sd_model=None, info=None):
1009
+ send_model_to_cpu(sd_model or shared.sd_model)
1010
+
1011
+ return sd_model
1012
+
1013
+
1014
+ def apply_token_merging(sd_model, token_merging_ratio):
1015
+ """
1016
+ Applies speed and memory optimizations from tomesd.
1017
+ """
1018
+
1019
+ current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
1020
+
1021
+ if current_token_merging_ratio == token_merging_ratio:
1022
+ return
1023
+
1024
+ if current_token_merging_ratio > 0:
1025
+ tomesd.remove_patch(sd_model)
1026
+
1027
+ if token_merging_ratio > 0:
1028
+ tomesd.apply_patch(
1029
+ sd_model,
1030
+ ratio=token_merging_ratio,
1031
+ use_rand=False, # can cause issues with some samplers
1032
+ merge_attn=True,
1033
+ merge_crossattn=False,
1034
+ merge_mlp=False
1035
+ )
1036
+
1037
+ sd_model.applied_token_merged_ratio = token_merging_ratio