Coercer commited on
Commit
b155c46
·
verified ·
1 Parent(s): 9a81093

Delete Python_Infer_Utils/snake.py

Browse files
Files changed (1) hide show
  1. Python_Infer_Utils/snake.py +0 -1209
Python_Infer_Utils/snake.py DELETED
@@ -1,1209 +0,0 @@
1
- # Cherry-picked some good parts from ComfyUI with some bad parts fixed
2
-
3
- import sys
4
- import time
5
- import psutil
6
- import torch
7
- import platform
8
-
9
- from enum import Enum
10
- from backend import stream, utils
11
- from backend.args import args
12
-
13
-
14
- cpu = torch.device('cpu')
15
-
16
-
17
- class VRAMState(Enum):
18
- DISABLED = 0 # No vram present: no need to move models to vram
19
- NO_VRAM = 1 # Very low vram: enable all the options to save vram
20
- LOW_VRAM = 2
21
- NORMAL_VRAM = 3
22
- HIGH_VRAM = 4
23
- SHARED = 5 # No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
24
-
25
-
26
- class CPUState(Enum):
27
- GPU = 0
28
- CPU = 1
29
- MPS = 2
30
-
31
-
32
- # Determine VRAM State
33
- vram_state = VRAMState.NORMAL_VRAM
34
- set_vram_to = VRAMState.NORMAL_VRAM
35
- cpu_state = CPUState.GPU
36
-
37
- total_vram = 0
38
-
39
- lowvram_available = True
40
- xpu_available = False
41
-
42
- if args.pytorch_deterministic:
43
- print("Using deterministic algorithms for pytorch")
44
- torch.use_deterministic_algorithms(True, warn_only=True)
45
-
46
- directml_enabled = False
47
- if args.directml is not None:
48
- import torch_directml
49
-
50
- directml_enabled = True
51
- device_index = args.directml
52
- if device_index < 0:
53
- directml_device = torch_directml.device()
54
- else:
55
- directml_device = torch_directml.device(device_index)
56
- print("Using directml with device: {}".format(torch_directml.device_name(device_index)))
57
-
58
- try:
59
- import intel_extension_for_pytorch as ipex
60
-
61
- if torch.xpu.is_available():
62
- xpu_available = True
63
- except:
64
- pass
65
-
66
- try:
67
- if torch.backends.mps.is_available():
68
- cpu_state = CPUState.MPS
69
- import torch.mps
70
- except:
71
- pass
72
-
73
- if args.always_cpu:
74
- cpu_state = CPUState.CPU
75
-
76
-
77
- def is_intel_xpu():
78
- global cpu_state
79
- global xpu_available
80
- if cpu_state == CPUState.GPU:
81
- if xpu_available:
82
- return True
83
- return False
84
-
85
-
86
- def get_torch_device():
87
- global directml_enabled
88
- global cpu_state
89
- if directml_enabled:
90
- global directml_device
91
- return directml_device
92
- if cpu_state == CPUState.MPS:
93
- return torch.device("mps")
94
- if cpu_state == CPUState.CPU:
95
- return torch.device("cpu")
96
- else:
97
- if is_intel_xpu():
98
- return torch.device("xpu", torch.xpu.current_device())
99
- else:
100
- return torch.device(torch.cuda.current_device())
101
-
102
-
103
- def get_total_memory(dev=None, torch_total_too=False):
104
- global directml_enabled
105
- if dev is None:
106
- dev = get_torch_device()
107
-
108
- if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
109
- mem_total = psutil.virtual_memory().total
110
- mem_total_torch = mem_total
111
- else:
112
- if directml_enabled:
113
- mem_total = 1024 * 1024 * 1024 # TODO
114
- mem_total_torch = mem_total
115
- elif is_intel_xpu():
116
- stats = torch.xpu.memory_stats(dev)
117
- mem_reserved = stats['reserved_bytes.all.current']
118
- mem_total_torch = mem_reserved
119
- mem_total = torch.xpu.get_device_properties(dev).total_memory
120
- else:
121
- stats = torch.cuda.memory_stats(dev)
122
- mem_reserved = stats['reserved_bytes.all.current']
123
- _, mem_total_cuda = torch.cuda.mem_get_info(dev)
124
- mem_total_torch = mem_reserved
125
- mem_total = mem_total_cuda
126
-
127
- if torch_total_too:
128
- return (mem_total, mem_total_torch)
129
- else:
130
- return mem_total
131
-
132
-
133
- total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
134
- total_ram = psutil.virtual_memory().total / (1024 * 1024)
135
- print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
136
-
137
- try:
138
- print("pytorch version: {}".format(torch.version.__version__))
139
- except:
140
- pass
141
-
142
- try:
143
- OOM_EXCEPTION = torch.cuda.OutOfMemoryError
144
- except:
145
- OOM_EXCEPTION = Exception
146
-
147
- if directml_enabled:
148
- OOM_EXCEPTION = Exception
149
-
150
- XFORMERS_VERSION = ""
151
- XFORMERS_ENABLED_VAE = True
152
- if args.disable_xformers:
153
- XFORMERS_IS_AVAILABLE = False
154
- else:
155
- try:
156
- import xformers
157
- import xformers.ops
158
-
159
- XFORMERS_IS_AVAILABLE = True
160
- try:
161
- XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
162
- except:
163
- pass
164
- try:
165
- XFORMERS_VERSION = xformers.version.__version__
166
- print("xformers version: {}".format(XFORMERS_VERSION))
167
- if XFORMERS_VERSION.startswith("0.0.18"):
168
- print("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
169
- print("Please downgrade or upgrade xformers to a different version.\n")
170
- XFORMERS_ENABLED_VAE = False
171
- except:
172
- pass
173
- except:
174
- XFORMERS_IS_AVAILABLE = False
175
-
176
-
177
- def is_nvidia():
178
- global cpu_state
179
- if cpu_state == CPUState.GPU:
180
- if torch.version.cuda:
181
- return True
182
- return False
183
-
184
-
185
- ENABLE_PYTORCH_ATTENTION = False
186
- if args.attention_pytorch:
187
- ENABLE_PYTORCH_ATTENTION = True
188
- XFORMERS_IS_AVAILABLE = False
189
-
190
- VAE_DTYPES = [torch.float32]
191
-
192
- try:
193
- if is_nvidia():
194
- torch_version = torch.version.__version__
195
- if int(torch_version[0]) >= 2:
196
- if ENABLE_PYTORCH_ATTENTION == False and args.attention_split == False and args.attention_quad == False:
197
- ENABLE_PYTORCH_ATTENTION = True
198
- if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
199
- VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
200
- if is_intel_xpu():
201
- if args.attention_split == False and args.attention_quad == False:
202
- ENABLE_PYTORCH_ATTENTION = True
203
- except:
204
- pass
205
-
206
- if is_intel_xpu():
207
- VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
208
-
209
- if args.vae_in_cpu:
210
- VAE_DTYPES = [torch.float32]
211
-
212
- VAE_ALWAYS_TILED = False
213
-
214
- if ENABLE_PYTORCH_ATTENTION:
215
- torch.backends.cuda.enable_math_sdp(True)
216
- torch.backends.cuda.enable_flash_sdp(True)
217
- torch.backends.cuda.enable_mem_efficient_sdp(True)
218
-
219
- if args.always_low_vram:
220
- set_vram_to = VRAMState.LOW_VRAM
221
- lowvram_available = True
222
- elif args.always_no_vram:
223
- set_vram_to = VRAMState.NO_VRAM
224
- elif args.always_high_vram or args.always_gpu:
225
- vram_state = VRAMState.HIGH_VRAM
226
-
227
- FORCE_FP32 = False
228
- FORCE_FP16 = False
229
- if args.all_in_fp32:
230
- print("Forcing FP32, if this improves things please report it.")
231
- FORCE_FP32 = True
232
-
233
- if args.all_in_fp16:
234
- print("Forcing FP16.")
235
- FORCE_FP16 = True
236
-
237
- if lowvram_available:
238
- if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
239
- vram_state = set_vram_to
240
-
241
- if cpu_state != CPUState.GPU:
242
- vram_state = VRAMState.DISABLED
243
-
244
- if cpu_state == CPUState.MPS:
245
- vram_state = VRAMState.SHARED
246
-
247
- print(f"Set vram state to: {vram_state.name}")
248
-
249
- ALWAYS_VRAM_OFFLOAD = args.always_offload_from_vram
250
-
251
- if ALWAYS_VRAM_OFFLOAD:
252
- print("Always offload VRAM")
253
-
254
- PIN_SHARED_MEMORY = args.pin_shared_memory
255
-
256
- if PIN_SHARED_MEMORY:
257
- print("Always pin shared GPU memory")
258
-
259
-
260
- def get_torch_device_name(device):
261
- if hasattr(device, 'type'):
262
- if device.type == "cuda":
263
- try:
264
- allocator_backend = torch.cuda.get_allocator_backend()
265
- except:
266
- allocator_backend = ""
267
- return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
268
- else:
269
- return "{}".format(device.type)
270
- elif is_intel_xpu():
271
- return "{} {}".format(device, torch.xpu.get_device_name(device))
272
- else:
273
- return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
274
-
275
-
276
- try:
277
- torch_device_name = get_torch_device_name(get_torch_device())
278
- print("Device: {}".format(torch_device_name))
279
- except:
280
- torch_device_name = ''
281
- print("Could not pick default device.")
282
-
283
- if 'rtx' in torch_device_name.lower():
284
- if not args.cuda_malloc:
285
- print('Hint: your device supports --cuda-malloc for potential speed improvements.')
286
-
287
-
288
- current_loaded_models = []
289
-
290
-
291
- def state_dict_size(sd, exclude_device=None):
292
- module_mem = 0
293
- for k in sd:
294
- t = sd[k]
295
-
296
- if exclude_device is not None:
297
- if t.device == exclude_device:
298
- continue
299
-
300
- module_mem += t.nelement() * t.element_size()
301
- return module_mem
302
-
303
-
304
- def state_dict_parameters(sd):
305
- module_mem = 0
306
- for k, v in sd.items():
307
- module_mem += v.nelement()
308
- return module_mem
309
-
310
-
311
- def state_dict_dtype(state_dict):
312
- for k, v in state_dict.items():
313
- if hasattr(v, 'gguf_cls'):
314
- return 'gguf'
315
- if 'bitsandbytes__nf4' in k:
316
- return 'nf4'
317
- if 'bitsandbytes__fp4' in k:
318
- return 'fp4'
319
-
320
- dtype_counts = {}
321
-
322
- for tensor in state_dict.values():
323
- dtype = tensor.dtype
324
- if dtype in dtype_counts:
325
- dtype_counts[dtype] += 1
326
- else:
327
- dtype_counts[dtype] = 1
328
-
329
- major_dtype = None
330
- max_count = 0
331
-
332
- for dtype, count in dtype_counts.items():
333
- if count > max_count:
334
- max_count = count
335
- major_dtype = dtype
336
-
337
- return major_dtype
338
-
339
-
340
- def bake_gguf_model(model):
341
- if getattr(model, 'gguf_baked', False):
342
- return
343
-
344
- for p in model.parameters():
345
- gguf_cls = getattr(p, 'gguf_cls', None)
346
- if gguf_cls is not None:
347
- gguf_cls.bake(p)
348
-
349
- global signal_empty_cache
350
- signal_empty_cache = True
351
-
352
- model.gguf_baked = True
353
- return model
354
-
355
-
356
- def module_size(module, exclude_device=None, include_device=None, return_split=False):
357
- module_mem = 0
358
- weight_mem = 0
359
- weight_patterns = ['weight']
360
-
361
- for k, p in module.named_parameters():
362
- t = p.data
363
-
364
- if exclude_device is not None:
365
- if t.device == exclude_device:
366
- continue
367
-
368
- if include_device is not None:
369
- if t.device != include_device:
370
- continue
371
-
372
- element_size = t.element_size()
373
-
374
- if getattr(p, 'quant_type', None) in ['fp4', 'nf4']:
375
- if element_size > 1:
376
- # not quanted yet
377
- element_size = 0.55 # a bit more than 0.5 because of quant state parameters
378
- else:
379
- # quanted
380
- element_size = 1.1 # a bit more than 0.5 because of quant state parameters
381
-
382
- module_mem += t.nelement() * element_size
383
-
384
- if k in weight_patterns:
385
- weight_mem += t.nelement() * element_size
386
-
387
- if return_split:
388
- return module_mem, weight_mem, module_mem - weight_mem
389
-
390
- return module_mem
391
-
392
-
393
- def module_move(module, device, recursive=True, excluded_pattens=[]):
394
- if recursive:
395
- return module.to(device=device)
396
-
397
- for k, p in module.named_parameters(recurse=False, remove_duplicate=True):
398
- if k in excluded_pattens:
399
- continue
400
- setattr(module, k, utils.tensor2parameter(p.to(device=device)))
401
-
402
- return module
403
-
404
-
405
- def build_module_profile(model, model_gpu_memory_when_using_cpu_swap):
406
- all_modules = []
407
- legacy_modules = []
408
-
409
- for m in model.modules():
410
- if hasattr(m, "parameters_manual_cast"):
411
- m.total_mem, m.weight_mem, m.extra_mem = module_size(m, return_split=True)
412
- all_modules.append(m)
413
- elif hasattr(m, "weight"):
414
- m.total_mem, m.weight_mem, m.extra_mem = module_size(m, return_split=True)
415
- legacy_modules.append(m)
416
-
417
- gpu_modules = []
418
- gpu_modules_only_extras = []
419
- mem_counter = 0
420
-
421
- for m in legacy_modules.copy():
422
- gpu_modules.append(m)
423
- legacy_modules.remove(m)
424
- mem_counter += m.total_mem
425
-
426
- for m in sorted(all_modules, key=lambda x: x.extra_mem).copy():
427
- if mem_counter + m.extra_mem < model_gpu_memory_when_using_cpu_swap:
428
- gpu_modules_only_extras.append(m)
429
- all_modules.remove(m)
430
- mem_counter += m.extra_mem
431
-
432
- cpu_modules = all_modules
433
-
434
- for m in sorted(gpu_modules_only_extras, key=lambda x: x.weight_mem).copy():
435
- if mem_counter + m.weight_mem < model_gpu_memory_when_using_cpu_swap:
436
- gpu_modules.append(m)
437
- gpu_modules_only_extras.remove(m)
438
- mem_counter += m.weight_mem
439
-
440
- return gpu_modules, gpu_modules_only_extras, cpu_modules
441
-
442
-
443
- class LoadedModel:
444
- def __init__(self, model):
445
- self.model = model
446
- self.model_accelerated = False
447
- self.device = model.load_device
448
- self.inclusive_memory = 0
449
- self.exclusive_memory = 0
450
-
451
- def compute_inclusive_exclusive_memory(self):
452
- self.inclusive_memory = module_size(self.model.model, include_device=self.device)
453
- self.exclusive_memory = module_size(self.model.model, exclude_device=self.device)
454
- return
455
-
456
- def model_load(self, model_gpu_memory_when_using_cpu_swap=-1):
457
- patch_model_to = None
458
- do_not_need_cpu_swap = model_gpu_memory_when_using_cpu_swap < 0
459
-
460
- if do_not_need_cpu_swap:
461
- patch_model_to = self.device
462
-
463
- self.model.model_patches_to(self.device)
464
- self.model.model_patches_to(self.model.model_dtype())
465
-
466
- try:
467
- self.real_model = self.model.forge_patch_model(patch_model_to)
468
- self.model.current_device = self.model.load_device
469
- except Exception as e:
470
- self.model.forge_unpatch_model(self.model.offload_device)
471
- self.model_unload()
472
- raise e
473
-
474
- if do_not_need_cpu_swap:
475
- print('All loaded to GPU.')
476
- else:
477
- gpu_modules, gpu_modules_only_extras, cpu_modules = build_module_profile(self.real_model, model_gpu_memory_when_using_cpu_swap)
478
- pin_memory = PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device)
479
-
480
- mem_counter = 0
481
- swap_counter = 0
482
-
483
- for m in gpu_modules:
484
- m.to(self.device)
485
- mem_counter += m.total_mem
486
-
487
- for m in cpu_modules:
488
- m.prev_parameters_manual_cast = m.parameters_manual_cast
489
- m.parameters_manual_cast = True
490
- m.to(self.model.offload_device)
491
- if pin_memory:
492
- m._apply(lambda x: x.pin_memory())
493
- swap_counter += m.total_mem
494
-
495
- for m in gpu_modules_only_extras:
496
- m.prev_parameters_manual_cast = m.parameters_manual_cast
497
- m.parameters_manual_cast = True
498
- module_move(m, device=self.device, recursive=False, excluded_pattens=['weight'])
499
- if hasattr(m, 'weight') and m.weight is not None:
500
- if pin_memory:
501
- m.weight = utils.tensor2parameter(m.weight.to(self.model.offload_device).pin_memory())
502
- else:
503
- m.weight = utils.tensor2parameter(m.weight.to(self.model.offload_device))
504
- mem_counter += m.extra_mem
505
- swap_counter += m.weight_mem
506
-
507
- swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
508
- method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'
509
- print(f"{swap_flag} Swap Loaded ({method_flag} method): {swap_counter / (1024 * 1024):.2f} MB, GPU Loaded: {mem_counter / (1024 * 1024):.2f} MB")
510
-
511
- self.model_accelerated = True
512
-
513
- global signal_empty_cache
514
- signal_empty_cache = True
515
-
516
- bake_gguf_model(self.real_model)
517
-
518
- self.model.refresh_loras()
519
-
520
- if is_intel_xpu() and not args.disable_ipex_hijack:
521
- self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
522
-
523
- return self.real_model
524
-
525
- def model_unload(self, avoid_model_moving=False):
526
- if self.model_accelerated:
527
- for m in self.real_model.modules():
528
- if hasattr(m, "prev_parameters_manual_cast"):
529
- m.parameters_manual_cast = m.prev_parameters_manual_cast
530
- del m.prev_parameters_manual_cast
531
-
532
- self.model_accelerated = False
533
-
534
- if avoid_model_moving:
535
- self.model.forge_unpatch_model()
536
- else:
537
- self.model.forge_unpatch_model(self.model.offload_device)
538
- self.model.model_patches_to(self.model.offload_device)
539
-
540
- def __eq__(self, other):
541
- return self.model is other.model # and self.memory_required == other.memory_required
542
-
543
-
544
- current_inference_memory = 1024 * 1024 * 1024
545
-
546
-
547
- def minimum_inference_memory():
548
- global current_inference_memory
549
- return current_inference_memory
550
-
551
-
552
- def unload_model_clones(model):
553
- to_unload = []
554
- for i in range(len(current_loaded_models)):
555
- if model.is_clone(current_loaded_models[i].model):
556
- to_unload = [i] + to_unload
557
-
558
- for i in to_unload:
559
- current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
560
-
561
-
562
- def free_memory(memory_required, device, keep_loaded=[], free_all=False):
563
- # this check fully unloads any 'abandoned' models
564
- for i in range(len(current_loaded_models) - 1, -1, -1):
565
- if sys.getrefcount(current_loaded_models[i].model) <= 2:
566
- current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
567
-
568
- if free_all:
569
- memory_required = 1e30
570
- print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
571
- else:
572
- print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
573
-
574
- offload_everything = ALWAYS_VRAM_OFFLOAD or vram_state == VRAMState.NO_VRAM
575
- unloaded_model = False
576
- for i in range(len(current_loaded_models) - 1, -1, -1):
577
- if not offload_everything:
578
- free_memory = get_free_memory(device)
579
- print(f"Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ", end="")
580
- if free_memory > memory_required:
581
- break
582
- shift_model = current_loaded_models[i]
583
- if shift_model.device == device:
584
- if shift_model not in keep_loaded:
585
- m = current_loaded_models.pop(i)
586
- print(f"Unload model {m.model.model.__class__.__name__} ", end="")
587
- m.model_unload()
588
- del m
589
- unloaded_model = True
590
-
591
- if unloaded_model:
592
- soft_empty_cache()
593
- else:
594
- if vram_state != VRAMState.HIGH_VRAM:
595
- mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
596
- if mem_free_torch > mem_free_total * 0.25:
597
- soft_empty_cache()
598
-
599
- print('Done.')
600
- return
601
-
602
-
603
- def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
604
- maximum_memory_available = current_free_mem - inference_memory
605
-
606
- suggestion = max(
607
- maximum_memory_available / 1.3,
608
- maximum_memory_available - 1024 * 1024 * 1024 * 1.25
609
- )
610
-
611
- return int(max(0, suggestion))
612
-
613
-
614
- def load_models_gpu(models, memory_required=0, hard_memory_preservation=0):
615
- global vram_state
616
-
617
- execution_start_time = time.perf_counter()
618
- memory_to_free = max(minimum_inference_memory(), memory_required) + hard_memory_preservation
619
- memory_for_inference = minimum_inference_memory() + hard_memory_preservation
620
-
621
- models_to_load = []
622
- models_already_loaded = []
623
- for x in models:
624
- loaded_model = LoadedModel(x)
625
-
626
- if loaded_model in current_loaded_models:
627
- index = current_loaded_models.index(loaded_model)
628
- current_loaded_models.insert(0, current_loaded_models.pop(index))
629
- models_already_loaded.append(loaded_model)
630
- else:
631
- models_to_load.append(loaded_model)
632
-
633
- if len(models_to_load) == 0:
634
- devs = set(map(lambda a: a.device, models_already_loaded))
635
- for d in devs:
636
- if d != torch.device("cpu"):
637
- free_memory(memory_to_free, d, models_already_loaded)
638
-
639
- moving_time = time.perf_counter() - execution_start_time
640
- if moving_time > 0.1:
641
- print(f'Memory cleanup has taken {moving_time:.2f} seconds')
642
-
643
- return
644
-
645
- for loaded_model in models_to_load:
646
- unload_model_clones(loaded_model.model)
647
-
648
- total_memory_required = {}
649
- for loaded_model in models_to_load:
650
- loaded_model.compute_inclusive_exclusive_memory()
651
- total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.exclusive_memory + loaded_model.inclusive_memory * 0.25
652
-
653
- for device in total_memory_required:
654
- if device != torch.device("cpu"):
655
- free_memory(total_memory_required[device] * 1.3 + memory_to_free, device, models_already_loaded)
656
-
657
- for loaded_model in models_to_load:
658
- model = loaded_model.model
659
- torch_dev = model.load_device
660
- if is_device_cpu(torch_dev):
661
- vram_set_state = VRAMState.DISABLED
662
- else:
663
- vram_set_state = vram_state
664
-
665
- model_gpu_memory_when_using_cpu_swap = -1
666
-
667
- if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
668
- model_require = loaded_model.exclusive_memory
669
- previously_loaded = loaded_model.inclusive_memory
670
- current_free_mem = get_free_memory(torch_dev)
671
- estimated_remaining_memory = current_free_mem - model_require - memory_for_inference
672
-
673
- print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_require / (1024 * 1024):.2f} MB, Previously Loaded: {previously_loaded / (1024 * 1024):.2f} MB, Inference Require: {memory_for_inference / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
674
-
675
- if estimated_remaining_memory < 0:
676
- vram_set_state = VRAMState.LOW_VRAM
677
- model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, memory_for_inference)
678
- if previously_loaded > 0:
679
- model_gpu_memory_when_using_cpu_swap = previously_loaded
680
-
681
- if vram_set_state == VRAMState.NO_VRAM:
682
- model_gpu_memory_when_using_cpu_swap = 0
683
-
684
- loaded_model.model_load(model_gpu_memory_when_using_cpu_swap)
685
- current_loaded_models.insert(0, loaded_model)
686
-
687
- moving_time = time.perf_counter() - execution_start_time
688
- print(f'Moving model(s) has taken {moving_time:.2f} seconds')
689
-
690
- return
691
-
692
-
693
- def load_model_gpu(model):
694
- return load_models_gpu([model])
695
-
696
-
697
- def cleanup_models():
698
- to_delete = []
699
- for i in range(len(current_loaded_models)):
700
- if sys.getrefcount(current_loaded_models[i].model) <= 2:
701
- to_delete = [i] + to_delete
702
-
703
- for i in to_delete:
704
- x = current_loaded_models.pop(i)
705
- x.model_unload()
706
- del x
707
-
708
-
709
- def dtype_size(dtype):
710
- dtype_size = 4
711
- if dtype == torch.float16 or dtype == torch.bfloat16:
712
- dtype_size = 2
713
- elif dtype == torch.float32:
714
- dtype_size = 4
715
- else:
716
- try:
717
- dtype_size = dtype.itemsize
718
- except: # Old pytorch doesn't have .itemsize
719
- pass
720
- return dtype_size
721
-
722
-
723
- def unet_offload_device():
724
- if vram_state == VRAMState.HIGH_VRAM:
725
- return get_torch_device()
726
- else:
727
- return torch.device("cpu")
728
-
729
-
730
- def unet_inital_load_device(parameters, dtype):
731
- torch_dev = get_torch_device()
732
- if vram_state == VRAMState.HIGH_VRAM:
733
- return torch_dev
734
-
735
- cpu_dev = torch.device("cpu")
736
- if ALWAYS_VRAM_OFFLOAD:
737
- return cpu_dev
738
-
739
- model_size = dtype_size(dtype) * parameters
740
-
741
- mem_dev = get_free_memory(torch_dev)
742
- mem_cpu = get_free_memory(cpu_dev)
743
- if mem_dev > mem_cpu and model_size < mem_dev:
744
- return torch_dev
745
- else:
746
- return cpu_dev
747
-
748
-
749
- def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
750
- if args.unet_in_bf16:
751
- return torch.bfloat16
752
-
753
- if args.unet_in_fp16:
754
- return torch.float16
755
-
756
- if args.unet_in_fp8_e4m3fn:
757
- return torch.float8_e4m3fn
758
-
759
- if args.unet_in_fp8_e5m2:
760
- return torch.float8_e5m2
761
-
762
- for candidate in supported_dtypes:
763
- if candidate == torch.float16:
764
- if should_use_fp16(device, model_params=model_params, prioritize_performance=True, manual_cast=True):
765
- return candidate
766
- if candidate == torch.bfloat16:
767
- if should_use_bf16(device, model_params=model_params, prioritize_performance=True, manual_cast=True):
768
- return candidate
769
-
770
- return torch.float32
771
-
772
-
773
- def get_computation_dtype(inference_device, parameters=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
774
- for candidate in supported_dtypes:
775
- if candidate == torch.float16:
776
- if should_use_fp16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False):
777
- return candidate
778
- if candidate == torch.bfloat16:
779
- if should_use_bf16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False):
780
- return candidate
781
-
782
- return torch.float32
783
-
784
-
785
- def text_encoder_offload_device():
786
- if args.always_gpu:
787
- return get_torch_device()
788
- else:
789
- return torch.device("cpu")
790
-
791
-
792
- def text_encoder_device():
793
- if args.always_gpu:
794
- return get_torch_device()
795
- elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
796
- if should_use_fp16(prioritize_performance=False):
797
- return get_torch_device()
798
- else:
799
- return torch.device("cpu")
800
- else:
801
- return torch.device("cpu")
802
-
803
-
804
- def text_encoder_dtype(device=None):
805
- if args.clip_in_fp8_e4m3fn:
806
- return torch.float8_e4m3fn
807
- elif args.clip_in_fp8_e5m2:
808
- return torch.float8_e5m2
809
- elif args.clip_in_fp16:
810
- return torch.float16
811
- elif args.clip_in_fp32:
812
- return torch.float32
813
-
814
- if is_device_cpu(device):
815
- return torch.float16
816
-
817
- return torch.float16
818
-
819
-
820
- def intermediate_device():
821
- if args.always_gpu:
822
- return get_torch_device()
823
- else:
824
- return torch.device("cpu")
825
-
826
-
827
- def vae_device():
828
- if args.vae_in_cpu:
829
- return torch.device("cpu")
830
- return get_torch_device()
831
-
832
-
833
- def vae_offload_device():
834
- if args.always_gpu:
835
- return get_torch_device()
836
- else:
837
- return torch.device("cpu")
838
-
839
-
840
- def vae_dtype(device=None, allowed_dtypes=[]):
841
- global VAE_DTYPES
842
- if args.vae_in_fp16:
843
- return torch.float16
844
- elif args.vae_in_bf16:
845
- return torch.bfloat16
846
- elif args.vae_in_fp32:
847
- return torch.float32
848
-
849
- for d in allowed_dtypes:
850
- if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
851
- return d
852
- if d in VAE_DTYPES:
853
- return d
854
-
855
- return VAE_DTYPES[0]
856
-
857
-
858
- print(f"VAE dtype preferences: {VAE_DTYPES} -> {vae_dtype()}")
859
-
860
-
861
- def get_autocast_device(dev):
862
- if hasattr(dev, 'type'):
863
- return dev.type
864
- return "cuda"
865
-
866
-
867
- def supports_dtype(device, dtype): # TODO
868
- if dtype == torch.float32:
869
- return True
870
- if is_device_cpu(device):
871
- return False
872
- if dtype == torch.float16:
873
- return True
874
- if dtype == torch.bfloat16:
875
- return True
876
- return False
877
-
878
-
879
- def supports_cast(device, dtype): # TODO
880
- if dtype == torch.float32:
881
- return True
882
- if dtype == torch.float16:
883
- return True
884
- if directml_enabled: # TODO: test this
885
- return False
886
- if dtype == torch.bfloat16:
887
- return True
888
- if is_device_mps(device):
889
- return False
890
- if dtype == torch.float8_e4m3fn:
891
- return True
892
- if dtype == torch.float8_e5m2:
893
- return True
894
- return False
895
-
896
-
897
- def pick_weight_dtype(dtype, fallback_dtype, device=None):
898
- if dtype is None:
899
- dtype = fallback_dtype
900
- elif dtype_size(dtype) > dtype_size(fallback_dtype):
901
- dtype = fallback_dtype
902
-
903
- if not supports_cast(device, dtype):
904
- dtype = fallback_dtype
905
-
906
- return dtype
907
-
908
-
909
- def device_supports_non_blocking(device):
910
- if is_device_mps(device):
911
- return False # pytorch bug? mps doesn't support non blocking
912
- if is_intel_xpu():
913
- return False
914
- if args.pytorch_deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
915
- return False
916
- if directml_enabled:
917
- return False
918
- return True
919
-
920
-
921
- def device_should_use_non_blocking(device):
922
- if not device_supports_non_blocking(device):
923
- return False
924
- return False
925
- # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
926
-
927
-
928
- def force_channels_last():
929
- if args.force_channels_last:
930
- return True
931
-
932
- # TODO
933
- return False
934
-
935
-
936
- def cast_to_device(tensor, device, dtype, copy=False):
937
- device_supports_cast = False
938
- if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
939
- device_supports_cast = True
940
- elif tensor.dtype == torch.bfloat16:
941
- if hasattr(device, 'type') and device.type.startswith("cuda"):
942
- device_supports_cast = True
943
- elif is_intel_xpu():
944
- device_supports_cast = True
945
-
946
- non_blocking = device_should_use_non_blocking(device)
947
-
948
- if device_supports_cast:
949
- if copy:
950
- if tensor.device == device:
951
- return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
952
- return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
953
- else:
954
- return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
955
- else:
956
- return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
957
-
958
-
959
- def xformers_enabled():
960
- global directml_enabled
961
- global cpu_state
962
- if cpu_state != CPUState.GPU:
963
- return False
964
- if is_intel_xpu():
965
- return False
966
- if directml_enabled:
967
- return False
968
- return XFORMERS_IS_AVAILABLE
969
-
970
-
971
- def xformers_enabled_vae():
972
- enabled = xformers_enabled()
973
- if not enabled:
974
- return False
975
-
976
- return XFORMERS_ENABLED_VAE
977
-
978
-
979
- def pytorch_attention_enabled():
980
- global ENABLE_PYTORCH_ATTENTION
981
- return ENABLE_PYTORCH_ATTENTION
982
-
983
-
984
- def pytorch_attention_flash_attention():
985
- global ENABLE_PYTORCH_ATTENTION
986
- if ENABLE_PYTORCH_ATTENTION:
987
- # TODO: more reliable way of checking for flash attention?
988
- if is_nvidia(): # pytorch flash attention only works on Nvidia
989
- return True
990
- if is_intel_xpu():
991
- return True
992
- return False
993
-
994
-
995
- def force_upcast_attention_dtype():
996
- upcast = args.force_upcast_attention
997
- try:
998
- if platform.mac_ver()[0] in ['14.5']: # black image bug on OSX Sonoma 14.5
999
- upcast = True
1000
- except:
1001
- pass
1002
- if upcast:
1003
- return torch.float32
1004
- else:
1005
- return None
1006
-
1007
-
1008
- def get_free_memory(dev=None, torch_free_too=False):
1009
- global directml_enabled
1010
- if dev is None:
1011
- dev = get_torch_device()
1012
-
1013
- if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
1014
- mem_free_total = psutil.virtual_memory().available
1015
- mem_free_torch = mem_free_total
1016
- else:
1017
- if directml_enabled:
1018
- mem_free_total = 1024 * 1024 * 1024
1019
- mem_free_torch = mem_free_total
1020
- elif is_intel_xpu():
1021
- stats = torch.xpu.memory_stats(dev)
1022
- mem_active = stats['active_bytes.all.current']
1023
- mem_reserved = stats['reserved_bytes.all.current']
1024
- mem_free_torch = mem_reserved - mem_active
1025
- mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
1026
- mem_free_total = mem_free_xpu + mem_free_torch
1027
- else:
1028
- stats = torch.cuda.memory_stats(dev)
1029
- mem_active = stats['active_bytes.all.current']
1030
- mem_reserved = stats['reserved_bytes.all.current']
1031
- mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
1032
- mem_free_torch = mem_reserved - mem_active
1033
- mem_free_total = mem_free_cuda + mem_free_torch
1034
-
1035
- if torch_free_too:
1036
- return (mem_free_total, mem_free_torch)
1037
- else:
1038
- return mem_free_total
1039
-
1040
-
1041
- def cpu_mode():
1042
- global cpu_state
1043
- return cpu_state == CPUState.CPU
1044
-
1045
-
1046
- def mps_mode():
1047
- global cpu_state
1048
- return cpu_state == CPUState.MPS
1049
-
1050
-
1051
- def is_device_type(device, type):
1052
- if hasattr(device, 'type'):
1053
- if (device.type == type):
1054
- return True
1055
- return False
1056
-
1057
-
1058
- def is_device_cpu(device):
1059
- return is_device_type(device, 'cpu')
1060
-
1061
-
1062
- def is_device_mps(device):
1063
- return is_device_type(device, 'mps')
1064
-
1065
-
1066
- def is_device_cuda(device):
1067
- return is_device_type(device, 'cuda')
1068
-
1069
-
1070
- def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
1071
- global directml_enabled
1072
-
1073
- if device is not None:
1074
- if is_device_cpu(device):
1075
- return False
1076
-
1077
- if FORCE_FP16:
1078
- return True
1079
-
1080
- if device is not None:
1081
- if is_device_mps(device):
1082
- return True
1083
-
1084
- if FORCE_FP32:
1085
- return False
1086
-
1087
- if directml_enabled:
1088
- return False
1089
-
1090
- if mps_mode():
1091
- return True
1092
-
1093
- if cpu_mode():
1094
- return False
1095
-
1096
- if is_intel_xpu():
1097
- return True
1098
-
1099
- if torch.version.hip:
1100
- return True
1101
-
1102
- props = torch.cuda.get_device_properties("cuda")
1103
- if props.major >= 8:
1104
- return True
1105
-
1106
- if props.major < 6:
1107
- return False
1108
-
1109
- nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
1110
- for x in nvidia_10_series:
1111
- if x in props.name.lower():
1112
- if manual_cast:
1113
- # For storage dtype
1114
- free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
1115
- if (not prioritize_performance) or model_params * 4 > free_model_memory:
1116
- return True
1117
- else:
1118
- # For computation dtype
1119
- return False # Flux on 1080 can store model in fp16 to reduce swap, but computation must be fp32, otherwise super slow.
1120
-
1121
- if props.major < 7:
1122
- return False
1123
-
1124
- # FP16 is just broken on these cards
1125
- nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
1126
- for x in nvidia_16_series:
1127
- if x in props.name:
1128
- return False
1129
-
1130
- return True
1131
-
1132
-
1133
- def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
1134
- if device is not None:
1135
- if is_device_cpu(device): # TODO ? bf16 works on CPU but is extremely slow
1136
- return False
1137
-
1138
- if device is not None:
1139
- if is_device_mps(device):
1140
- return True
1141
-
1142
- if FORCE_FP32:
1143
- return False
1144
-
1145
- if directml_enabled:
1146
- return False
1147
-
1148
- if mps_mode():
1149
- return True
1150
-
1151
- if cpu_mode():
1152
- return False
1153
-
1154
- if is_intel_xpu():
1155
- return True
1156
-
1157
- if device is None:
1158
- device = torch.device("cuda")
1159
-
1160
- props = torch.cuda.get_device_properties(device)
1161
- if props.major >= 8:
1162
- return True
1163
-
1164
- if torch.cuda.is_bf16_supported():
1165
- # This device is an old enough device but bf16 somewhat reports supported.
1166
- # So in this case bf16 should only be used as storge dtype
1167
- if manual_cast:
1168
- # For storage dtype
1169
- free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
1170
- if (not prioritize_performance) or model_params * 4 > free_model_memory:
1171
- return True
1172
-
1173
- return False
1174
-
1175
-
1176
- def can_install_bnb():
1177
- try:
1178
- if not torch.cuda.is_available():
1179
- return False
1180
-
1181
- cuda_version = tuple(int(x) for x in torch.version.cuda.split('.'))
1182
-
1183
- if cuda_version >= (11, 7):
1184
- return True
1185
-
1186
- return False
1187
- except:
1188
- return False
1189
-
1190
-
1191
- signal_empty_cache = False
1192
-
1193
-
1194
- def soft_empty_cache(force=False):
1195
- global cpu_state, signal_empty_cache
1196
- if cpu_state == CPUState.MPS:
1197
- torch.mps.empty_cache()
1198
- elif is_intel_xpu():
1199
- torch.xpu.empty_cache()
1200
- elif torch.cuda.is_available():
1201
- if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda
1202
- torch.cuda.empty_cache()
1203
- torch.cuda.ipc_collect()
1204
- signal_empty_cache = False
1205
- return
1206
-
1207
-
1208
- def unload_all_models():
1209
- free_memory(1e30, get_torch_device(), free_all=True)