AkinoKaze commited on
Commit
c2a3a8b
·
verified ·
1 Parent(s): 44c9a5b

Upload sd_models.py

Browse files
Files changed (1) hide show
  1. sd_models.py +1046 -0
sd_models.py ADDED
@@ -0,0 +1,1046 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import importlib
3
+ import os
4
+ import sys
5
+ import threading
6
+ import enum
7
+
8
+ import gc
9
+ import torch
10
+
11
+ def clear_cuda():
12
+ gc.collect()
13
+ if torch.cuda.is_available():
14
+ torch.cuda.empty_cache()
15
+ torch.cuda.ipc_collect()
16
+ torch.cuda.reset_peak_memory_stats()
17
+
18
+ import torch
19
+ import re
20
+ import safetensors.torch
21
+ from omegaconf import OmegaConf, ListConfig
22
+ from urllib import request
23
+ import ldm.modules.midas as midas
24
+
25
+ from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
26
+ from modules.timer import Timer
27
+ from modules.shared import opts
28
+ import tomesd
29
+ import numpy as np
30
+
31
+ model_dir = "Stable-diffusion"
32
+ model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
33
+
34
+ checkpoints_list = {}
35
+ checkpoint_aliases = {}
36
+ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
37
+ checkpoints_loaded = collections.OrderedDict()
38
+
39
+
40
+ class ModelType(enum.Enum):
41
+ SD1 = 1
42
+ SD2 = 2
43
+ SDXL = 3
44
+ SSD = 4
45
+ SD3 = 5
46
+
47
+
48
+ def replace_key(d, key, new_key, value):
49
+ keys = list(d.keys())
50
+
51
+ d[new_key] = value
52
+
53
+ if key not in keys:
54
+ return d
55
+
56
+ index = keys.index(key)
57
+ keys[index] = new_key
58
+
59
+ new_d = {k: d[k] for k in keys}
60
+
61
+ d.clear()
62
+ d.update(new_d)
63
+ return d
64
+
65
+
66
+ class CheckpointInfo:
67
+ def __init__(self, filename):
68
+ self.filename = filename
69
+ abspath = os.path.abspath(filename)
70
+ abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
71
+
72
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
73
+
74
+ if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
75
+ name = abspath.replace(abs_ckpt_dir, '')
76
+ elif abspath.startswith(model_path):
77
+ name = abspath.replace(model_path, '')
78
+ else:
79
+ name = os.path.basename(filename)
80
+
81
+ if name.startswith("\\") or name.startswith("/"):
82
+ name = name[1:]
83
+
84
+ def read_metadata():
85
+ metadata = read_metadata_from_safetensors(filename)
86
+ self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
87
+
88
+ return metadata
89
+
90
+ self.metadata = {}
91
+ if self.is_safetensors:
92
+ try:
93
+ self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
94
+ except Exception as e:
95
+ errors.display(e, f"reading metadata for {filename}")
96
+
97
+ self.name = name
98
+ self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
99
+ self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
100
+ self.hash = model_hash(filename)
101
+
102
+ self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
103
+ self.shorthash = self.sha256[0:10] if self.sha256 else None
104
+
105
+ self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
106
+ self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
107
+
108
+ self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
109
+ if self.shorthash:
110
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
111
+
112
+ def register(self):
113
+ checkpoints_list[self.title] = self
114
+ for id in self.ids:
115
+ checkpoint_aliases[id] = self
116
+
117
+ def calculate_shorthash(self):
118
+ self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
119
+ if self.sha256 is None:
120
+ return
121
+
122
+ shorthash = self.sha256[0:10]
123
+ if self.shorthash == self.sha256[0:10]:
124
+ return self.shorthash
125
+
126
+ self.shorthash = shorthash
127
+
128
+ if self.shorthash not in self.ids:
129
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
130
+
131
+ old_title = self.title
132
+ self.title = f'{self.name} [{self.shorthash}]'
133
+ self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
134
+
135
+ replace_key(checkpoints_list, old_title, self.title, self)
136
+ self.register()
137
+
138
+ return self.shorthash
139
+
140
+
141
+ try:
142
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
143
+ from transformers import logging, CLIPModel # noqa: F401
144
+
145
+ logging.set_verbosity_error()
146
+ except Exception:
147
+ pass
148
+
149
+
150
+ def setup_model():
151
+ """called once at startup to do various one-time tasks related to SD models"""
152
+
153
+ os.makedirs(model_path, exist_ok=True)
154
+
155
+ enable_midas_autodownload()
156
+ patch_given_betas()
157
+
158
+
159
+ def checkpoint_tiles(use_short=False):
160
+ return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
161
+
162
+
163
+ def list_models():
164
+ checkpoints_list.clear()
165
+ checkpoint_aliases.clear()
166
+
167
+ cmd_ckpt = shared.cmd_opts.ckpt
168
+ if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
169
+ model_url = None
170
+ expected_sha256 = None
171
+ else:
172
+ model_url = f"{shared.hf_endpoint}/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
173
+ expected_sha256 = '6ce0161689b3853acaa03779ec93eafe75a02f4ced659bee03f50797806fa2fa'
174
+
175
+ model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"], hash_prefix=expected_sha256)
176
+
177
+ if os.path.exists(cmd_ckpt):
178
+ checkpoint_info = CheckpointInfo(cmd_ckpt)
179
+ checkpoint_info.register()
180
+
181
+ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
182
+ elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
183
+ print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
184
+
185
+ for filename in model_list:
186
+ checkpoint_info = CheckpointInfo(filename)
187
+ checkpoint_info.register()
188
+
189
+
190
+ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
191
+
192
+
193
+ def get_closet_checkpoint_match(search_string):
194
+ if not search_string:
195
+ return None
196
+
197
+ checkpoint_info = checkpoint_aliases.get(search_string, None)
198
+ if checkpoint_info is not None:
199
+ return checkpoint_info
200
+
201
+ found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
202
+ if found:
203
+ return found[0]
204
+
205
+ search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
206
+ found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
207
+ if found:
208
+ return found[0]
209
+
210
+ return None
211
+
212
+
213
+ def model_hash(filename):
214
+ """old hash that only looks at a small part of the file and is prone to collisions"""
215
+
216
+ try:
217
+ with open(filename, "rb") as file:
218
+ import hashlib
219
+ m = hashlib.sha256()
220
+
221
+ file.seek(0x100000)
222
+ m.update(file.read(0x10000))
223
+ return m.hexdigest()[0:8]
224
+ except FileNotFoundError:
225
+ return 'NOFILE'
226
+
227
+
228
+ def select_checkpoint():
229
+ """Raises `FileNotFoundError` if no checkpoints are found."""
230
+ model_checkpoint = shared.opts.sd_model_checkpoint
231
+
232
+ checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
233
+ if checkpoint_info is not None:
234
+ return checkpoint_info
235
+
236
+ if len(checkpoints_list) == 0:
237
+ error_message = "No checkpoints found. When searching for checkpoints, looked at:"
238
+ if shared.cmd_opts.ckpt is not None:
239
+ error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
240
+ error_message += f"\n - directory {model_path}"
241
+ if shared.cmd_opts.ckpt_dir is not None:
242
+ error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
243
+ error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
244
+ raise FileNotFoundError(error_message)
245
+
246
+ checkpoint_info = next(iter(checkpoints_list.values()))
247
+ if model_checkpoint is not None:
248
+ print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
249
+
250
+ return checkpoint_info
251
+
252
+
253
+ checkpoint_dict_replacements_sd1 = {
254
+ 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
255
+ 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
256
+ 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
257
+ }
258
+
259
+ checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
260
+ 'conditioner.embedders.0.': 'cond_stage_model.',
261
+ }
262
+
263
+
264
+ def transform_checkpoint_dict_key(k, replacements):
265
+ for text, replacement in replacements.items():
266
+ if k.startswith(text):
267
+ k = replacement + k[len(text):]
268
+
269
+ return k
270
+
271
+
272
+ def get_state_dict_from_checkpoint(pl_sd):
273
+ pl_sd = pl_sd.pop("state_dict", pl_sd)
274
+ pl_sd.pop("state_dict", None)
275
+
276
+ is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
277
+
278
+ sd = {}
279
+ for k, v in pl_sd.items():
280
+ if is_sd2_turbo:
281
+ new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
282
+ else:
283
+ new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
284
+
285
+ if new_key is not None:
286
+ sd[new_key] = v
287
+
288
+ pl_sd.clear()
289
+ pl_sd.update(sd)
290
+
291
+ return pl_sd
292
+
293
+
294
+ def read_metadata_from_safetensors(filename):
295
+ import json
296
+
297
+ with open(filename, mode="rb") as file:
298
+ metadata_len = file.read(8)
299
+ metadata_len = int.from_bytes(metadata_len, "little")
300
+ json_start = file.read(2)
301
+
302
+ assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
303
+
304
+ res = {}
305
+
306
+ try:
307
+ json_data = json_start + file.read(metadata_len-2)
308
+ json_obj = json.loads(json_data)
309
+ for k, v in json_obj.get("__metadata__", {}).items():
310
+ res[k] = v
311
+ if isinstance(v, str) and v[0:1] == '{':
312
+ try:
313
+ res[k] = json.loads(v)
314
+ except Exception:
315
+ pass
316
+ except Exception:
317
+ errors.report(f"Error reading metadata from file: {filename}", exc_info=True)
318
+
319
+ return res
320
+
321
+
322
+ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
323
+ _, extension = os.path.splitext(checkpoint_file)
324
+ if extension.lower() == ".safetensors":
325
+ device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
326
+
327
+ if not shared.opts.disable_mmap_load_safetensors:
328
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
329
+ else:
330
+ pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
331
+ pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
332
+ else:
333
+ pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
334
+
335
+ if print_global_state and "global_step" in pl_sd:
336
+ print(f"Global Step: {pl_sd['global_step']}")
337
+
338
+ sd = get_state_dict_from_checkpoint(pl_sd)
339
+ return sd
340
+
341
+
342
+ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
343
+ sd_model_hash = checkpoint_info.calculate_shorthash()
344
+ timer.record("calculate hash")
345
+
346
+ if checkpoint_info in checkpoints_loaded:
347
+ # use checkpoint cache
348
+ print(f"Loading weights [{sd_model_hash}] from cache")
349
+ # move to end as latest
350
+ checkpoints_loaded.move_to_end(checkpoint_info)
351
+ return checkpoints_loaded[checkpoint_info]
352
+
353
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
354
+ res = read_state_dict(checkpoint_info.filename)
355
+ timer.record("load weights from disk")
356
+
357
+ return res
358
+
359
+
360
+ class SkipWritingToConfig:
361
+ """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
362
+
363
+ skip = False
364
+ previous = None
365
+
366
+ def __enter__(self):
367
+ self.previous = SkipWritingToConfig.skip
368
+ SkipWritingToConfig.skip = True
369
+ return self
370
+
371
+ def __exit__(self, exc_type, exc_value, exc_traceback):
372
+ SkipWritingToConfig.skip = self.previous
373
+
374
+
375
+ def check_fp8(model):
376
+ if model is None:
377
+ return None
378
+ if devices.get_optimal_device_name() == "mps":
379
+ enable_fp8 = False
380
+ elif shared.opts.fp8_storage == "Enable":
381
+ enable_fp8 = True
382
+ elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
383
+ enable_fp8 = True
384
+ else:
385
+ enable_fp8 = False
386
+ return enable_fp8
387
+
388
+
389
+ def set_model_type(model, state_dict):
390
+ model.is_sd1 = False
391
+ model.is_sd2 = False
392
+ model.is_sdxl = False
393
+ model.is_ssd = False
394
+ model.is_sd3 = False
395
+
396
+ if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
397
+ model.is_sd3 = True
398
+ model.model_type = ModelType.SD3
399
+ elif hasattr(model, 'conditioner'):
400
+ model.is_sdxl = True
401
+
402
+ if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
403
+ model.is_ssd = True
404
+ model.model_type = ModelType.SSD
405
+ else:
406
+ model.model_type = ModelType.SDXL
407
+ elif hasattr(model.cond_stage_model, 'model'):
408
+ model.is_sd2 = True
409
+ model.model_type = ModelType.SD2
410
+ else:
411
+ model.is_sd1 = True
412
+ model.model_type = ModelType.SD1
413
+
414
+
415
+ def set_model_fields(model):
416
+ if not hasattr(model, 'latent_channels'):
417
+ model.latent_channels = 4
418
+
419
+
420
+ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
421
+ sd_model_hash = checkpoint_info.calculate_shorthash()
422
+ timer.record("calculate hash")
423
+
424
+ if devices.fp8:
425
+ # prevent model to load state dict in fp8
426
+ model.half()
427
+
428
+ if not SkipWritingToConfig.skip:
429
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
430
+
431
+ if state_dict is None:
432
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
433
+
434
+ set_model_type(model, state_dict)
435
+ set_model_fields(model)
436
+
437
+ if model.is_sdxl:
438
+ sd_models_xl.extend_sdxl(model)
439
+
440
+ if model.is_ssd:
441
+ sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
442
+
443
+ if shared.opts.sd_checkpoint_cache > 0:
444
+ # cache newly loaded model
445
+ checkpoints_loaded[checkpoint_info] = state_dict.copy()
446
+
447
+ if hasattr(model, "before_load_weights"):
448
+ model.before_load_weights(state_dict)
449
+
450
+ model.load_state_dict(state_dict, strict=False)
451
+ timer.record("apply weights to model")
452
+
453
+ if hasattr(model, "after_load_weights"):
454
+ model.after_load_weights(state_dict)
455
+
456
+ del state_dict
457
+
458
+ # Set is_sdxl_inpaint flag.
459
+ # Checks Unet structure to detect inpaint model. The inpaint model's
460
+ # checkpoint state_dict does not contain the key
461
+ # 'diffusion_model.input_blocks.0.0.weight'.
462
+ diffusion_model_input = model.model.state_dict().get(
463
+ 'diffusion_model.input_blocks.0.0.weight'
464
+ )
465
+ model.is_sdxl_inpaint = (
466
+ model.is_sdxl and
467
+ diffusion_model_input is not None and
468
+ diffusion_model_input.shape[1] == 9
469
+ )
470
+
471
+ if shared.cmd_opts.opt_channelslast:
472
+ model.to(memory_format=torch.channels_last)
473
+ timer.record("apply channels_last")
474
+
475
+ if shared.cmd_opts.no_half:
476
+ model.float()
477
+ model.alphas_cumprod_original = model.alphas_cumprod
478
+ devices.dtype_unet = torch.float32
479
+ assert shared.cmd_opts.precision != "half", "Cannot use --precision half with --no-half"
480
+ timer.record("apply float()")
481
+ else:
482
+ vae = model.first_stage_model
483
+ depth_model = getattr(model, 'depth_model', None)
484
+
485
+ # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
486
+ if shared.cmd_opts.no_half_vae:
487
+ model.first_stage_model = None
488
+ # with --upcast-sampling, don't convert the depth model weights to float16
489
+ if shared.cmd_opts.upcast_sampling and depth_model:
490
+ model.depth_model = None
491
+
492
+ alphas_cumprod = model.alphas_cumprod
493
+ model.alphas_cumprod = None
494
+ model.half()
495
+ model.alphas_cumprod = alphas_cumprod
496
+ model.alphas_cumprod_original = alphas_cumprod
497
+ model.first_stage_model = vae
498
+ if depth_model:
499
+ model.depth_model = depth_model
500
+
501
+ devices.dtype_unet = torch.float16
502
+ timer.record("apply half()")
503
+
504
+ apply_alpha_schedule_override(model)
505
+
506
+ for module in model.modules():
507
+ if hasattr(module, 'fp16_weight'):
508
+ del module.fp16_weight
509
+ if hasattr(module, 'fp16_bias'):
510
+ del module.fp16_bias
511
+
512
+ if check_fp8(model):
513
+ devices.fp8 = True
514
+ first_stage = model.first_stage_model
515
+ model.first_stage_model = None
516
+ for module in model.modules():
517
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
518
+ if shared.opts.cache_fp16_weight:
519
+ module.fp16_weight = module.weight.data.clone().cpu().half()
520
+ if module.bias is not None:
521
+ module.fp16_bias = module.bias.data.clone().cpu().half()
522
+ module.to(torch.float8_e4m3fn)
523
+ model.first_stage_model = first_stage
524
+ timer.record("apply fp8")
525
+ else:
526
+ devices.fp8 = False
527
+
528
+ devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
529
+
530
+ model.first_stage_model.to(devices.dtype_vae)
531
+ timer.record("apply dtype to VAE")
532
+
533
+ # clean up cache if limit is reached
534
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
535
+ checkpoints_loaded.popitem(last=False)
536
+
537
+ model.sd_model_hash = sd_model_hash
538
+ model.sd_model_checkpoint = checkpoint_info.filename
539
+ model.sd_checkpoint_info = checkpoint_info
540
+ shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
541
+
542
+ if hasattr(model, 'logvar'):
543
+ model.logvar = model.logvar.to(devices.device) # fix for training
544
+
545
+ sd_vae.delete_base_vae()
546
+ sd_vae.clear_loaded_vae()
547
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
548
+ sd_vae.load_vae(model, vae_file, vae_source)
549
+ timer.record("load VAE")
550
+
551
+
552
+ def enable_midas_autodownload():
553
+ """
554
+ Gives the ldm.modules.midas.api.load_model function automatic downloading.
555
+
556
+ When the 512-depth-ema model, and other future models like it, is loaded,
557
+ it calls midas.api.load_model to load the associated midas depth model.
558
+ This function applies a wrapper to download the model to the correct
559
+ location automatically.
560
+ """
561
+
562
+ midas_path = os.path.join(paths.models_path, 'midas')
563
+
564
+ # stable-diffusion-stability-ai hard-codes the midas model path to
565
+ # a location that differs from where other scripts using this model look.
566
+ # HACK: Overriding the path here.
567
+ for k, v in midas.api.ISL_PATHS.items():
568
+ file_name = os.path.basename(v)
569
+ midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
570
+
571
+ midas_urls = {
572
+ "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
573
+ "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
574
+ "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
575
+ "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
576
+ }
577
+
578
+ midas.api.load_model_inner = midas.api.load_model
579
+
580
+ def load_model_wrapper(model_type):
581
+ path = midas.api.ISL_PATHS[model_type]
582
+ if not os.path.exists(path):
583
+ if not os.path.exists(midas_path):
584
+ os.mkdir(midas_path)
585
+
586
+ print(f"Downloading midas model weights for {model_type} to {path}")
587
+ request.urlretrieve(midas_urls[model_type], path)
588
+ print(f"{model_type} downloaded")
589
+
590
+ return midas.api.load_model_inner(model_type)
591
+
592
+ midas.api.load_model = load_model_wrapper
593
+
594
+
595
+ def patch_given_betas():
596
+ import ldm.models.diffusion.ddpm
597
+
598
+ def patched_register_schedule(*args, **kwargs):
599
+ """a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
600
+
601
+ if isinstance(args[1], ListConfig):
602
+ args = (args[0], np.array(args[1]), *args[2:])
603
+
604
+ original_register_schedule(*args, **kwargs)
605
+
606
+ original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
607
+
608
+
609
+ def repair_config(sd_config, state_dict=None):
610
+ if not hasattr(sd_config.model.params, "use_ema"):
611
+ sd_config.model.params.use_ema = False
612
+
613
+ if hasattr(sd_config.model.params, 'unet_config'):
614
+ if shared.cmd_opts.no_half:
615
+ sd_config.model.params.unet_config.params.use_fp16 = False
616
+ elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
617
+ sd_config.model.params.unet_config.params.use_fp16 = True
618
+
619
+ if hasattr(sd_config.model.params, 'first_stage_config'):
620
+ if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
621
+ sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
622
+
623
+ # For UnCLIP-L, override the hardcoded karlo directory
624
+ if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
625
+ karlo_path = os.path.join(paths.models_path, 'karlo')
626
+ sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
627
+
628
+ # Do not use checkpoint for inference.
629
+ # This helps prevent extra performance overhead on checking parameters.
630
+ # The perf overhead is about 100ms/it on 4090 for SDXL.
631
+ if hasattr(sd_config.model.params, "network_config"):
632
+ sd_config.model.params.network_config.params.use_checkpoint = False
633
+ if hasattr(sd_config.model.params, "unet_config"):
634
+ sd_config.model.params.unet_config.params.use_checkpoint = False
635
+
636
+
637
+
638
+ def rescale_zero_terminal_snr_abar(alphas_cumprod):
639
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
640
+
641
+ # Store old values.
642
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
643
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
644
+
645
+ # Shift so the last timestep is zero.
646
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
647
+
648
+ # Scale so the first timestep is back to the old value.
649
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
650
+
651
+ # Convert alphas_bar_sqrt to betas
652
+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
653
+ alphas_bar[-1] = 4.8973451890853435e-08
654
+ return alphas_bar
655
+
656
+
657
+ def apply_alpha_schedule_override(sd_model, p=None):
658
+ """
659
+ Applies an override to the alpha schedule of the model according to settings.
660
+ - downcasts the alpha schedule to half precision
661
+ - rescales the alpha schedule to have zero terminal SNR
662
+ """
663
+
664
+ if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
665
+ return
666
+
667
+ sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
668
+
669
+ if opts.use_downcasted_alpha_bar:
670
+ if p is not None:
671
+ p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
672
+ sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
673
+
674
+ if opts.sd_noise_schedule == "Zero Terminal SNR":
675
+ if p is not None:
676
+ p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
677
+ sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
678
+
679
+
680
+ sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
681
+ sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
682
+ sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
683
+ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
684
+
685
+
686
+ class SdModelData:
687
+ def __init__(self):
688
+ self.sd_model = None
689
+ self.loaded_sd_models = []
690
+ self.was_loaded_at_least_once = False
691
+ self.lock = threading.Lock()
692
+
693
+ def get_sd_model(self):
694
+ if self.was_loaded_at_least_once:
695
+ return self.sd_model
696
+
697
+ if self.sd_model is None:
698
+ with self.lock:
699
+ if self.sd_model is not None or self.was_loaded_at_least_once:
700
+ return self.sd_model
701
+
702
+ try:
703
+ load_model()
704
+
705
+ except Exception as e:
706
+ errors.display(e, "loading stable diffusion model", full_traceback=True)
707
+ print("", file=sys.stderr)
708
+ print("Stable diffusion model failed to load", file=sys.stderr)
709
+ self.sd_model = None
710
+
711
+ return self.sd_model
712
+
713
+ def set_sd_model(self, v, already_loaded=False):
714
+ self.sd_model = v
715
+ if already_loaded:
716
+ sd_vae.base_vae = getattr(v, "base_vae", None)
717
+ sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
718
+ sd_vae.checkpoint_info = v.sd_checkpoint_info
719
+
720
+ try:
721
+ self.loaded_sd_models.remove(v)
722
+ except ValueError:
723
+ pass
724
+
725
+ if v is not None:
726
+ self.loaded_sd_models.insert(0, v)
727
+
728
+
729
+ model_data = SdModelData()
730
+
731
+
732
+ def get_empty_cond(sd_model):
733
+
734
+ p = processing.StableDiffusionProcessingTxt2Img()
735
+ extra_networks.activate(p, {})
736
+
737
+ if hasattr(sd_model, 'get_learned_conditioning'):
738
+ d = sd_model.get_learned_conditioning([""])
739
+ else:
740
+ d = sd_model.cond_stage_model([""])
741
+
742
+ if isinstance(d, dict):
743
+ d = d['crossattn']
744
+
745
+ return d
746
+
747
+
748
+ def send_model_to_cpu(m):
749
+ if m is not None:
750
+ if m.lowvram:
751
+ lowvram.send_everything_to_cpu()
752
+ else:
753
+ m.to(devices.cpu)
754
+
755
+ devices.torch_gc()
756
+
757
+
758
+ def model_target_device(m):
759
+ if lowvram.is_needed(m):
760
+ return devices.cpu
761
+ else:
762
+ return devices.device
763
+
764
+
765
+ def send_model_to_device(m):
766
+ lowvram.apply(m)
767
+
768
+ if not m.lowvram:
769
+ m.to(shared.device)
770
+
771
+
772
+ def send_model_to_trash(m):
773
+ m.to(device="meta")
774
+ devices.torch_gc()
775
+
776
+
777
+ def instantiate_from_config(config, state_dict=None):
778
+ constructor = get_obj_from_str(config["target"])
779
+
780
+ params = {**config.get("params", {})}
781
+
782
+ if state_dict and "state_dict" in params and params["state_dict"] is None:
783
+ params["state_dict"] = state_dict
784
+
785
+ return constructor(**params)
786
+
787
+
788
+ def get_obj_from_str(string, reload=False):
789
+ module, cls = string.rsplit(".", 1)
790
+ if reload:
791
+ module_imp = importlib.import_module(module)
792
+ importlib.reload(module_imp)
793
+ return getattr(importlib.import_module(module, package=None), cls)
794
+
795
+
796
+ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
797
+ from modules import sd_hijack
798
+ checkpoint_info = checkpoint_info or select_checkpoint()
799
+
800
+ timer = Timer()
801
+
802
+ clear_cuda() # Очищаем память до загрузки модели
803
+
804
+ if model_data.sd_model:
805
+ send_model_to_trash(model_data.sd_model)
806
+ model_data.sd_model = None
807
+ devices.torch_gc()
808
+
809
+ timer.record("unload existing model")
810
+
811
+ if already_loaded_state_dict is not None:
812
+ state_dict = already_loaded_state_dict
813
+ else:
814
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
815
+
816
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
817
+ clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
818
+
819
+ timer.record("find config")
820
+
821
+ sd_config = OmegaConf.load(checkpoint_config)
822
+ repair_config(sd_config, state_dict)
823
+
824
+ timer.record("load config")
825
+
826
+ print(f"Creating model from config: {checkpoint_config}")
827
+
828
+ sd_model = None
829
+ try:
830
+ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
831
+ with sd_disable_initialization.InitializeOnMeta():
832
+ sd_model = instantiate_from_config(sd_config.model, state_dict)
833
+
834
+ except Exception as e:
835
+ errors.display(e, "creating model quickly", full_traceback=True)
836
+
837
+ if sd_model is None:
838
+ print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
839
+
840
+ with sd_disable_initialization.InitializeOnMeta():
841
+ sd_model = instantiate_from_config(sd_config.model, state_dict)
842
+
843
+ sd_model.used_config = checkpoint_config
844
+
845
+ timer.record("create model")
846
+
847
+ if shared.cmd_opts.no_half:
848
+ weight_dtype_conversion = None
849
+ else:
850
+ weight_dtype_conversion = {
851
+ 'first_stage_model': None,
852
+ 'alphas_cumprod': None,
853
+ '': torch.float16,
854
+ }
855
+
856
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
857
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
858
+
859
+ timer.record("load weights from state dict")
860
+
861
+ send_model_to_device(sd_model)
862
+ timer.record("move model to device")
863
+
864
+ sd_hijack.model_hijack.hijack(sd_model)
865
+
866
+ timer.record("hijack")
867
+
868
+ sd_model.eval()
869
+ model_data.set_sd_model(sd_model)
870
+ model_data.was_loaded_at_least_once = True
871
+
872
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
873
+
874
+ timer.record("load textual inversion embeddings")
875
+
876
+ script_callbacks.model_loaded_callback(sd_model)
877
+
878
+ timer.record("scripts callbacks")
879
+
880
+ with devices.autocast(), torch.no_grad():
881
+ sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
882
+
883
+ timer.record("calculate empty prompt")
884
+
885
+ print(f"Model loaded in {timer.summary()}.")
886
+
887
+ return sd_model
888
+
889
+
890
+ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
891
+ """
892
+ Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
893
+ If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
894
+ If not, returns the model that can be used to load weights from checkpoint_info's file.
895
+ If no such model exists, returns None.
896
+ Additionally deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
897
+ """
898
+
899
+ if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
900
+ return sd_model
901
+
902
+ if shared.opts.sd_checkpoints_keep_in_cpu:
903
+ send_model_to_cpu(sd_model)
904
+ timer.record("send model to cpu")
905
+
906
+ already_loaded = None
907
+ for i in reversed(range(len(model_data.loaded_sd_models))):
908
+ loaded_model = model_data.loaded_sd_models[i]
909
+ if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
910
+ already_loaded = loaded_model
911
+ continue
912
+
913
+ if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
914
+ print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
915
+ del model_data.loaded_sd_models[i]
916
+ send_model_to_trash(loaded_model)
917
+ timer.record("send model to trash")
918
+
919
+ if already_loaded is not None:
920
+ send_model_to_device(already_loaded)
921
+ timer.record("send model to device")
922
+
923
+ model_data.set_sd_model(already_loaded, already_loaded=True)
924
+
925
+ if not SkipWritingToConfig.skip:
926
+ shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
927
+ shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
928
+
929
+ print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
930
+ sd_vae.reload_vae_weights(already_loaded)
931
+ return model_data.sd_model
932
+ elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
933
+ print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
934
+
935
+ model_data.sd_model = None
936
+ load_model(checkpoint_info)
937
+ return model_data.sd_model
938
+ elif len(model_data.loaded_sd_models) > 0:
939
+ sd_model = model_data.loaded_sd_models.pop()
940
+ model_data.sd_model = sd_model
941
+
942
+ sd_vae.base_vae = getattr(sd_model, "base_vae", None)
943
+ sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
944
+ sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
945
+
946
+ print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
947
+ return sd_model
948
+ else:
949
+ return None
950
+
951
+
952
+ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
953
+ checkpoint_info = info or select_checkpoint()
954
+
955
+ timer = Timer()
956
+
957
+ if not sd_model:
958
+ sd_model = model_data.sd_model
959
+
960
+ if sd_model is None: # previous model load failed
961
+ current_checkpoint_info = None
962
+ else:
963
+ current_checkpoint_info = sd_model.sd_checkpoint_info
964
+ if check_fp8(sd_model) != devices.fp8:
965
+ # load from state dict again to prevent extra numerical errors
966
+ forced_reload = True
967
+ elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
968
+ return sd_model
969
+
970
+ sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
971
+ if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
972
+ return sd_model
973
+
974
+ if sd_model is not None:
975
+ sd_unet.apply_unet("None")
976
+ send_model_to_cpu(sd_model)
977
+ sd_hijack.model_hijack.undo_hijack(sd_model)
978
+
979
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
980
+
981
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
982
+
983
+ timer.record("find config")
984
+
985
+ if sd_model is None or checkpoint_config != sd_model.used_config:
986
+ if sd_model is not None:
987
+ send_model_to_trash(sd_model)
988
+
989
+ load_model(checkpoint_info, already_loaded_state_dict=state_dict)
990
+ return model_data.sd_model
991
+
992
+ try:
993
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
994
+ except Exception:
995
+ print("Failed to load checkpoint, restoring previous")
996
+ load_model_weights(sd_model, current_checkpoint_info, None, timer)
997
+ raise
998
+ finally:
999
+ sd_hijack.model_hijack.hijack(sd_model)
1000
+ timer.record("hijack")
1001
+
1002
+ if not sd_model.lowvram:
1003
+ sd_model.to(devices.device)
1004
+ timer.record("move model to device")
1005
+
1006
+ script_callbacks.model_loaded_callback(sd_model)
1007
+ timer.record("script callbacks")
1008
+
1009
+ print(f"Weights loaded in {timer.summary()}.")
1010
+
1011
+ model_data.set_sd_model(sd_model)
1012
+ sd_unet.apply_unet()
1013
+
1014
+ return sd_model
1015
+
1016
+
1017
+ def unload_model_weights(sd_model=None, info=None):
1018
+ send_model_to_cpu(sd_model or shared.sd_model)
1019
+
1020
+ return sd_model
1021
+
1022
+
1023
+ def apply_token_merging(sd_model, token_merging_ratio):
1024
+ """
1025
+ Applies speed and memory optimizations from tomesd.
1026
+ """
1027
+
1028
+ current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
1029
+
1030
+ if current_token_merging_ratio == token_merging_ratio:
1031
+ return
1032
+
1033
+ if current_token_merging_ratio > 0:
1034
+ tomesd.remove_patch(sd_model)
1035
+
1036
+ if token_merging_ratio > 0:
1037
+ tomesd.apply_patch(
1038
+ sd_model,
1039
+ ratio=token_merging_ratio,
1040
+ use_rand=False, # can cause issues with some samplers
1041
+ merge_attn=True,
1042
+ merge_crossattn=False,
1043
+ merge_mlp=False
1044
+ )
1045
+
1046
+ sd_model.applied_token_merged_ratio = token_merging_ratio