Coercer commited on
Commit
096141f
·
verified ·
1 Parent(s): aa1ff95

Upload 5 files

Browse files
Python_Infer_Utils/Swan.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+ from collections import namedtuple
5
+ import cat, pigeon
6
+ from pig import worm
7
+ import snake
8
+
9
+
10
+ ChickenFix = namedtuple('ChickenFix', ['offset', 'embedding'])
11
+ last_extra_generation_params = {}
12
+
13
+
14
+ class Chicken:
15
+ def __init__(self):
16
+ self.tokens = []
17
+ self.multipliers = []
18
+ self.fixes = []
19
+
20
+
21
+ class Dog(torch.nn.Module):
22
+ def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
23
+ super().__init__()
24
+ self.wrapped = wrapped
25
+ self.embeddings = embeddings
26
+ self.textual_inversion_key = textual_inversion_key
27
+ self.weight = self.wrapped.weight
28
+
29
+ def forward(self, input_ids):
30
+ batch_fixes = self.embeddings.fixes
31
+ self.embeddings.fixes = None
32
+
33
+ inputs_embeds = self.wrapped(input_ids)
34
+
35
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
36
+ return inputs_embeds
37
+
38
+ vecs = []
39
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
40
+ for offset, embedding in fixes:
41
+ emb = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
42
+ emb = emb.to(inputs_embeds)
43
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
44
+ tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
45
+
46
+ vecs.append(tensor)
47
+
48
+ return torch.stack(vecs)
49
+
50
+
51
+ class Eagle:
52
+ def __init__(
53
+ self, text_encoder, tokenizer, chunk_length=75,
54
+ embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, pigeon_name="Original",
55
+ text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True
56
+ ):
57
+ super().__init__()
58
+
59
+ self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
60
+
61
+ if isinstance(embedding_dir, str):
62
+ self.embeddings.add_embedding_dir(embedding_dir)
63
+ self.embeddings.load_textual_inversion_embeddings()
64
+
65
+ self.embedding_key = embedding_key
66
+
67
+ self.text_encoder = text_encoder
68
+ self.tokenizer = tokenizer
69
+
70
+ self.pigeon = pigeon.get_current_option()()
71
+ self.text_projection = text_projection
72
+ self.minimal_clip_skip = minimal_clip_skip
73
+ self.clip_skip = clip_skip
74
+ self.return_pooled = return_pooled
75
+ self.final_layer_norm = final_layer_norm
76
+
77
+ self.chunk_length = chunk_length
78
+
79
+ self.id_start = self.tokenizer.bos_token_id
80
+ self.id_end = self.tokenizer.eos_token_id
81
+ self.id_pad = self.tokenizer.pad_token_id
82
+
83
+ model_embeddings = text_encoder.transformer.text_model.embeddings
84
+ model_embeddings.token_embedding = Dog(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key)
85
+
86
+ vocab = self.tokenizer.get_vocab()
87
+
88
+ self.comma_token = vocab.get(',</w>', None)
89
+
90
+ self.token_mults = {}
91
+
92
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
93
+ for text, ident in tokens_with_parens:
94
+ mult = 1.0
95
+ for c in text:
96
+ if c == '[':
97
+ mult /= 1.1
98
+ if c == ']':
99
+ mult *= 1.1
100
+ if c == '(':
101
+ mult *= 1.1
102
+ if c == ')':
103
+ mult /= 1.1
104
+
105
+ if mult != 1.0:
106
+ self.token_mults[ident] = mult
107
+
108
+ def empty_chunk(self):
109
+ chunk = Chicken()
110
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
111
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
112
+ return chunk
113
+
114
+ def get_target_prompt_token_count(self, token_count):
115
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
116
+
117
+ def tokenize(self, texts):
118
+ tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
119
+
120
+ return tokenized
121
+
122
+ def encode_with_transformers(self, tokens):
123
+ target_device = snake.text_encoder_device()
124
+
125
+ self.text_encoder.transformer.text_model.embeddings.position_ids = self.text_encoder.transformer.text_model.embeddings.position_ids.to(device=target_device)
126
+ self.text_encoder.transformer.text_model.embeddings.position_embedding = self.text_encoder.transformer.text_model.embeddings.position_embedding.to(dtype=torch.float32)
127
+ self.text_encoder.transformer.text_model.embeddings.token_embedding = self.text_encoder.transformer.text_model.embeddings.token_embedding.to(dtype=torch.float32)
128
+
129
+ tokens = tokens.to(target_device)
130
+
131
+ outputs = self.text_encoder.transformer(tokens, output_hidden_states=True)
132
+
133
+ layer_id = - max(self.clip_skip, self.minimal_clip_skip)
134
+ z = outputs.hidden_states[layer_id]
135
+
136
+ if self.final_layer_norm:
137
+ z = self.text_encoder.transformer.text_model.final_layer_norm(z)
138
+
139
+ if self.return_pooled:
140
+ pooled_output = outputs.pooler_output
141
+
142
+ if self.text_projection and self.embedding_key != 'clip_l':
143
+ pooled_output = self.text_encoder.transformer.text_projection(pooled_output)
144
+
145
+ z.pooled = pooled_output
146
+ return z
147
+
148
+ def tokenize_line(self, line):
149
+ parsed = cat.parse_prompt_attention(line, self.pigeon.name)
150
+
151
+ tokenized = self.tokenize([text for text, _ in parsed])
152
+
153
+ chunks = []
154
+ chunk = Chicken()
155
+ token_count = 0
156
+ last_comma = -1
157
+
158
+ def next_chunk(is_last=False):
159
+ nonlocal token_count
160
+ nonlocal last_comma
161
+ nonlocal chunk
162
+
163
+ if is_last:
164
+ token_count += len(chunk.tokens)
165
+ else:
166
+ token_count += self.chunk_length
167
+
168
+ to_add = self.chunk_length - len(chunk.tokens)
169
+ if to_add > 0:
170
+ chunk.tokens += [self.id_end] * to_add
171
+ chunk.multipliers += [1.0] * to_add
172
+
173
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
174
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
175
+
176
+ last_comma = -1
177
+ chunks.append(chunk)
178
+ chunk = Chicken()
179
+
180
+ for tokens, (text, weight) in zip(tokenized, parsed):
181
+ if text == 'BREAK' and weight == -1:
182
+ next_chunk()
183
+ continue
184
+
185
+ position = 0
186
+ while position < len(tokens):
187
+ token = tokens[position]
188
+
189
+ comma_padding_backtrack = 20
190
+
191
+ if token == self.comma_token:
192
+ last_comma = len(chunk.tokens)
193
+
194
+ elif comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= comma_padding_backtrack:
195
+ break_location = last_comma + 1
196
+
197
+ reloc_tokens = chunk.tokens[break_location:]
198
+ reloc_mults = chunk.multipliers[break_location:]
199
+
200
+ chunk.tokens = chunk.tokens[:break_location]
201
+ chunk.multipliers = chunk.multipliers[:break_location]
202
+
203
+ next_chunk()
204
+ chunk.tokens = reloc_tokens
205
+ chunk.multipliers = reloc_mults
206
+
207
+ if len(chunk.tokens) == self.chunk_length:
208
+ next_chunk()
209
+
210
+ embedding, embedding_length_in_tokens = self.embeddings.find_embedding_at_position(tokens, position)
211
+ if embedding is None:
212
+ chunk.tokens.append(token)
213
+ chunk.multipliers.append(weight)
214
+ position += 1
215
+ continue
216
+
217
+ emb_len = int(embedding.vectors)
218
+ if len(chunk.tokens) + emb_len > self.chunk_length:
219
+ next_chunk()
220
+
221
+ chunk.fixes.append(ChickenFix(len(chunk.tokens), embedding))
222
+
223
+ chunk.tokens += [0] * emb_len
224
+ chunk.multipliers += [weight] * emb_len
225
+ position += embedding_length_in_tokens
226
+
227
+ if chunk.tokens or not chunks:
228
+ next_chunk(is_last=True)
229
+
230
+ return chunks, token_count
231
+
232
+ def process_texts(self, texts):
233
+ token_count = 0
234
+
235
+ cache = {}
236
+ batch_chunks = []
237
+ for line in texts:
238
+ if line in cache:
239
+ chunks = cache[line]
240
+ else:
241
+ chunks, current_token_count = self.tokenize_line(line)
242
+ token_count = max(current_token_count, token_count)
243
+
244
+ cache[line] = chunks
245
+
246
+ batch_chunks.append(chunks)
247
+
248
+ return batch_chunks, token_count
249
+
250
+ def __call__(self, texts):
251
+ self.pigeon = pigeon.get_current_option()()
252
+
253
+ batch_chunks, token_count = self.process_texts(texts)
254
+
255
+ used_embeddings = {}
256
+ chunk_count = max([len(x) for x in batch_chunks])
257
+
258
+ zs = []
259
+ for i in range(chunk_count):
260
+ batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
261
+
262
+ tokens = [x.tokens for x in batch_chunk]
263
+ multipliers = [x.multipliers for x in batch_chunk]
264
+ self.embeddings.fixes = [x.fixes for x in batch_chunk]
265
+
266
+ for fixes in self.embeddings.fixes:
267
+ for _position, embedding in fixes:
268
+ used_embeddings[embedding.name] = embedding
269
+
270
+ z = self.process_tokens(tokens, multipliers)
271
+ zs.append(z)
272
+
273
+ global last_extra_generation_params
274
+
275
+ if used_embeddings:
276
+ names = []
277
+
278
+ for name, embedding in used_embeddings.items():
279
+ print(f'[Textual Inversion] Used Embedding [{name}] in CLIP of [{self.embedding_key}]')
280
+ names.append(name.replace(":", "").replace(",", ""))
281
+
282
+ if "TI" in last_extra_generation_params:
283
+ last_extra_generation_params["TI"] += ", " + ", ".join(names)
284
+ else:
285
+ last_extra_generation_params["TI"] = ", ".join(names)
286
+
287
+ if any(x for x in texts if "(" in x or "[" in x) and self.pigeon.name != "Original":
288
+ last_extra_generation_params["Emphasis"] = self.pigeon.name
289
+
290
+ if self.return_pooled:
291
+ return torch.hstack(zs), zs[0].pooled
292
+ else:
293
+ return torch.hstack(zs)
294
+
295
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
296
+ tokens = torch.asarray(remade_batch_tokens)
297
+
298
+ if self.id_end != self.id_pad:
299
+ for batch_pos in range(len(remade_batch_tokens)):
300
+ index = remade_batch_tokens[batch_pos].index(self.id_end)
301
+ tokens[batch_pos, index + 1:tokens.shape[1]] = self.id_pad
302
+
303
+ z = self.encode_with_transformers(tokens)
304
+
305
+ pooled = getattr(z, 'pooled', None)
306
+
307
+ self.pigeon.tokens = remade_batch_tokens
308
+ self.pigeon.multipliers = torch.asarray(batch_multipliers).to(z)
309
+ self.pigeon.z = z
310
+ self.pigeon.after_transformers()
311
+ z = self.pigeon.z
312
+
313
+ if pooled is not None:
314
+ z.pooled = pooled
315
+
316
+ return z
Python_Infer_Utils/cat.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ re_attention = re.compile(r"""
5
+ \\\(|
6
+ \\\)|
7
+ \\\[|
8
+ \\]|
9
+ \\\\|
10
+ \\|
11
+ \(|
12
+ \[|
13
+ :\s*([+-]?[.\d]+)\s*\)|
14
+ \)|
15
+ ]|
16
+ [^\\()\[\]:]+|
17
+ :
18
+ """, re.X)
19
+
20
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
21
+
22
+
23
+ def parse_prompt_attention(text, pigeon):
24
+ res = []
25
+ round_brackets = []
26
+ square_brackets = []
27
+
28
+ round_bracket_multiplier = 1.1
29
+ square_bracket_multiplier = 1 / 1.1
30
+
31
+ def multiply_range(start_position, multiplier):
32
+ for p in range(start_position, len(res)):
33
+ res[p][1] *= multiplier
34
+
35
+ if pigeon == "None":
36
+ # interpret literally
37
+ res = [[text, 1.0]]
38
+ else:
39
+ for m in re_attention.finditer(text):
40
+ text = m.group(0)
41
+ weight = m.group(1)
42
+
43
+ if text.startswith('\\'):
44
+ res.append([text[1:], 1.0])
45
+ elif text == '(':
46
+ round_brackets.append(len(res))
47
+ elif text == '[':
48
+ square_brackets.append(len(res))
49
+ elif weight is not None and round_brackets:
50
+ multiply_range(round_brackets.pop(), float(weight))
51
+ elif text == ')' and round_brackets:
52
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
53
+ elif text == ']' and square_brackets:
54
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
55
+ else:
56
+ parts = re.split(re_break, text)
57
+ for i, part in enumerate(parts):
58
+ if i > 0:
59
+ res.append(["BREAK", -1])
60
+ res.append([part, 1.0])
61
+
62
+ for pos in round_brackets:
63
+ multiply_range(pos, round_bracket_multiplier)
64
+
65
+ for pos in square_brackets:
66
+ multiply_range(pos, square_bracket_multiplier)
67
+
68
+ if len(res) == 0:
69
+ res = [["", 1.0]]
70
+
71
+ i = 0
72
+ while i + 1 < len(res):
73
+ if res[i][1] == res[i + 1][1]:
74
+ res[i][0] += res[i + 1][0]
75
+ res.pop(i + 1)
76
+ else:
77
+ i += 1
78
+
79
+ return res
Python_Infer_Utils/pig.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import base64
4
+ import json
5
+ import zlib
6
+ import numpy as np
7
+ import safetensors.torch
8
+
9
+ from PIL import Image
10
+
11
+
12
+ class EmbeddingEncoder(json.JSONEncoder):
13
+ def default(self, obj):
14
+ if isinstance(obj, torch.Tensor):
15
+ return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
16
+ return json.JSONEncoder.default(self, obj)
17
+
18
+
19
+ class EmbeddingDecoder(json.JSONDecoder):
20
+ def __init__(self, *args, **kwargs):
21
+ json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
22
+
23
+ def object_hook(self, d):
24
+ if 'TORCHTENSOR' in d:
25
+ return torch.from_numpy(np.array(d['TORCHTENSOR']))
26
+ return d
27
+
28
+
29
+ def embedding_to_b64(data):
30
+ d = json.dumps(data, cls=EmbeddingEncoder)
31
+ return base64.b64encode(d.encode())
32
+
33
+
34
+ def embedding_from_b64(data):
35
+ d = base64.b64decode(data)
36
+ return json.loads(d, cls=EmbeddingDecoder)
37
+
38
+
39
+ def lcg(m=2 ** 32, a=1664525, c=1013904223, seed=0):
40
+ while True:
41
+ seed = (a * seed + c) % m
42
+ yield seed % 255
43
+
44
+
45
+ def xor_block(block):
46
+ g = lcg()
47
+ randblock = np.array([next(g) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape)
48
+ return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
49
+
50
+
51
+ def crop_black(img, tol=0):
52
+ mask = (img > tol).all(2)
53
+ mask0, mask1 = mask.any(0), mask.any(1)
54
+ col_start, col_end = mask0.argmax(), mask.shape[1] - mask0[::-1].argmax()
55
+ row_start, row_end = mask1.argmax(), mask.shape[0] - mask1[::-1].argmax()
56
+ return img[row_start:row_end, col_start:col_end]
57
+
58
+
59
+ def extract_image_data_embed(image):
60
+ d = 3
61
+ outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
62
+ black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
63
+ if black_cols[0].shape[0] < 2:
64
+ print(f'{os.path.basename(getattr(image, "filename", "unknown image file"))}: no embedded information found.')
65
+ return None
66
+
67
+ data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
68
+ data_block_upper = outarr[:, black_cols[0].max() + 1:, :].astype(np.uint8)
69
+
70
+ data_block_lower = xor_block(data_block_lower)
71
+ data_block_upper = xor_block(data_block_upper)
72
+
73
+ data_block = (data_block_upper << 4) | (data_block_lower)
74
+ data_block = data_block.flatten().tobytes()
75
+
76
+ data = zlib.decompress(data_block)
77
+ return json.loads(data, cls=EmbeddingDecoder)
78
+
79
+
80
+ class Embedding:
81
+ def __init__(self, vec, name, step=None):
82
+ self.vec = vec
83
+ self.name = name
84
+ self.step = step
85
+ self.shape = None
86
+ self.vectors = 0
87
+ self.sd_checkpoint = None
88
+ self.sd_checkpoint_name = None
89
+
90
+
91
+ class DirWithTextualInversionEmbeddings:
92
+ def __init__(self, path):
93
+ self.path = path
94
+ self.mtime = None
95
+
96
+ def has_changed(self):
97
+ if not os.path.isdir(self.path):
98
+ return False
99
+
100
+ mt = os.path.getmtime(self.path)
101
+ if self.mtime is None or mt > self.mtime:
102
+ return True
103
+
104
+ def update(self):
105
+ if not os.path.isdir(self.path):
106
+ return
107
+
108
+ self.mtime = os.path.getmtime(self.path)
109
+
110
+
111
+ class worm:
112
+ def __init__(self, tokenizer, expected_shape=-1):
113
+ self.ids_lookup = {}
114
+ self.word_embeddings = {}
115
+ self.embedding_dirs = {}
116
+ self.skipped_embeddings = {}
117
+ self.expected_shape = expected_shape
118
+ self.tokenizer = tokenizer
119
+ self.fixes = []
120
+
121
+ def add_embedding_dir(self, path):
122
+ self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
123
+
124
+ def clear_embedding_dirs(self):
125
+ self.embedding_dirs.clear()
126
+
127
+ def register_embedding(self, embedding):
128
+ return self.register_embedding_by_name(embedding, embedding.name)
129
+
130
+ def register_embedding_by_name(self, embedding, name):
131
+ ids = self.tokenizer([name], truncation=False, add_special_tokens=False)["input_ids"][0]
132
+ first_id = ids[0]
133
+ if first_id not in self.ids_lookup:
134
+ self.ids_lookup[first_id] = []
135
+ if name in self.word_embeddings:
136
+ lookup = [x for x in self.ids_lookup[first_id] if x[1].name != name]
137
+ else:
138
+ lookup = self.ids_lookup[first_id]
139
+ if embedding is not None:
140
+ lookup += [(ids, embedding)]
141
+ self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
142
+ if embedding is None:
143
+ if name in self.word_embeddings:
144
+ del self.word_embeddings[name]
145
+ if len(self.ids_lookup[first_id]) == 0:
146
+ del self.ids_lookup[first_id]
147
+ return None
148
+ self.word_embeddings[name] = embedding
149
+ return embedding
150
+
151
+ def load_from_file(self, path, filename):
152
+ name, ext = os.path.splitext(filename)
153
+ ext = ext.upper()
154
+
155
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
156
+ _, second_ext = os.path.splitext(name)
157
+ if second_ext.upper() == '.PREVIEW':
158
+ return
159
+
160
+ embed_image = Image.open(path)
161
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
162
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
163
+ name = data.get('name', name)
164
+ else:
165
+ data = extract_image_data_embed(embed_image)
166
+ if data:
167
+ name = data.get('name', name)
168
+ else:
169
+ return
170
+ elif ext in ['.BIN', '.PT']:
171
+ data = torch.load(path, map_location="cpu")
172
+ elif ext in ['.SAFETENSORS']:
173
+ data = safetensors.torch.load_file(path, device="cpu")
174
+ else:
175
+ return
176
+
177
+ if data is not None:
178
+ embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
179
+
180
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
181
+ self.register_embedding(embedding)
182
+ else:
183
+ self.skipped_embeddings[name] = embedding
184
+ else:
185
+ print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.")
186
+
187
+ def load_from_dir(self, embdir):
188
+ if not os.path.isdir(embdir.path):
189
+ return
190
+
191
+ for root, _, fns in os.walk(embdir.path, followlinks=True):
192
+ for fn in fns:
193
+ try:
194
+ fullfn = os.path.join(root, fn)
195
+
196
+ if os.stat(fullfn).st_size == 0:
197
+ continue
198
+
199
+ self.load_from_file(fullfn, fn)
200
+ except Exception:
201
+ print(f"Error loading embedding {fn}")
202
+ continue
203
+
204
+ def load_textual_inversion_embeddings(self):
205
+ self.ids_lookup.clear()
206
+ self.word_embeddings.clear()
207
+ self.skipped_embeddings.clear()
208
+
209
+ for embdir in self.embedding_dirs.values():
210
+ self.load_from_dir(embdir)
211
+ embdir.update()
212
+
213
+ return
214
+
215
+ def find_embedding_at_position(self, tokens, offset):
216
+ token = tokens[offset]
217
+ possible_matches = self.ids_lookup.get(token, None)
218
+
219
+ if possible_matches is None:
220
+ return None, None
221
+
222
+ for ids, embedding in possible_matches:
223
+ if tokens[offset:offset + len(ids)] == ids:
224
+ return embedding, len(ids)
225
+
226
+ return None, None
227
+
228
+
229
+ def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
230
+ if 'string_to_param' in data: # textual inversion embeddings
231
+ param_dict = data['string_to_param']
232
+ param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
233
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
234
+ emb = next(iter(param_dict.items()))[1]
235
+ vec = emb.detach().to(dtype=torch.float32)
236
+ shape = vec.shape[-1]
237
+ vectors = vec.shape[0]
238
+ elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
239
+ vec = {k: v.detach().to(dtype=torch.float32) for k, v in data.items()}
240
+ shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
241
+ vectors = data['clip_g'].shape[0]
242
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
243
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
244
+
245
+ emb = next(iter(data.values()))
246
+ if len(emb.shape) == 1:
247
+ emb = emb.unsqueeze(0)
248
+ vec = emb.detach().to(dtype=torch.float32)
249
+ shape = vec.shape[-1]
250
+ vectors = vec.shape[0]
251
+ else:
252
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
253
+
254
+ embedding = Embedding(vec, name)
255
+ embedding.step = data.get('step', None)
256
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
257
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
258
+ embedding.vectors = vectors
259
+ embedding.shape = shape
260
+
261
+ if filepath:
262
+ embedding.filename = filepath
263
+
264
+ return embedding
Python_Infer_Utils/pigeon.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class Emphasis:
5
+ name: str = "Base"
6
+ description: str = ""
7
+ tokens: list[list[int]]
8
+ multipliers: torch.Tensor
9
+ z: torch.Tensor
10
+
11
+ def after_transformers(self):
12
+ pass
13
+
14
+
15
+ class EmphasisNone(Emphasis):
16
+ name = "None"
17
+ description = "disable the mechanism entirely and treat (:.1.1) as literal characters"
18
+
19
+
20
+ class EmphasisIgnore(Emphasis):
21
+ name = "Ignore"
22
+ description = "treat all empasised words as if they have no pigeon"
23
+
24
+
25
+ class EmphasisOriginal(Emphasis):
26
+ name = "Original"
27
+ description = "the original pigeon implementation"
28
+
29
+ def after_transformers(self):
30
+ original_mean = self.z.mean()
31
+ self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
32
+ new_mean = self.z.mean()
33
+ self.z = self.z * (original_mean / new_mean)
34
+
35
+
36
+ class EmphasisOriginalNoNorm(EmphasisOriginal):
37
+ name = "No norm"
38
+ description = "same as original, but without normalization (seems to work better for SDXL)"
39
+
40
+ def after_transformers(self):
41
+ self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
42
+
43
+
44
+ def get_current_option():
45
+ return (EmphasisOriginal)
46
+
47
+
48
+ def get_options_descriptions():
49
+ return ", ".join(f"{x.name}: {x.description}" for x in options)
50
+
51
+
52
+ options = [
53
+ EmphasisNone,
54
+ EmphasisIgnore,
55
+ EmphasisOriginal,
56
+ EmphasisOriginalNoNorm,
57
+ ]
Python_Infer_Utils/snake.py ADDED
@@ -0,0 +1,1209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cherry-picked some good parts from ComfyUI with some bad parts fixed
2
+
3
+ import sys
4
+ import time
5
+ import psutil
6
+ import torch
7
+ import platform
8
+
9
+ from enum import Enum
10
+ from backend import stream, utils
11
+ from backend.args import args
12
+
13
+
14
+ cpu = torch.device('cpu')
15
+
16
+
17
+ class VRAMState(Enum):
18
+ DISABLED = 0 # No vram present: no need to move models to vram
19
+ NO_VRAM = 1 # Very low vram: enable all the options to save vram
20
+ LOW_VRAM = 2
21
+ NORMAL_VRAM = 3
22
+ HIGH_VRAM = 4
23
+ SHARED = 5 # No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
24
+
25
+
26
+ class CPUState(Enum):
27
+ GPU = 0
28
+ CPU = 1
29
+ MPS = 2
30
+
31
+
32
+ # Determine VRAM State
33
+ vram_state = VRAMState.NORMAL_VRAM
34
+ set_vram_to = VRAMState.NORMAL_VRAM
35
+ cpu_state = CPUState.GPU
36
+
37
+ total_vram = 0
38
+
39
+ lowvram_available = True
40
+ xpu_available = False
41
+
42
+ if args.pytorch_deterministic:
43
+ print("Using deterministic algorithms for pytorch")
44
+ torch.use_deterministic_algorithms(True, warn_only=True)
45
+
46
+ directml_enabled = False
47
+ if args.directml is not None:
48
+ import torch_directml
49
+
50
+ directml_enabled = True
51
+ device_index = args.directml
52
+ if device_index < 0:
53
+ directml_device = torch_directml.device()
54
+ else:
55
+ directml_device = torch_directml.device(device_index)
56
+ print("Using directml with device: {}".format(torch_directml.device_name(device_index)))
57
+
58
+ try:
59
+ import intel_extension_for_pytorch as ipex
60
+
61
+ if torch.xpu.is_available():
62
+ xpu_available = True
63
+ except:
64
+ pass
65
+
66
+ try:
67
+ if torch.backends.mps.is_available():
68
+ cpu_state = CPUState.MPS
69
+ import torch.mps
70
+ except:
71
+ pass
72
+
73
+ if args.always_cpu:
74
+ cpu_state = CPUState.CPU
75
+
76
+
77
+ def is_intel_xpu():
78
+ global cpu_state
79
+ global xpu_available
80
+ if cpu_state == CPUState.GPU:
81
+ if xpu_available:
82
+ return True
83
+ return False
84
+
85
+
86
+ def get_torch_device():
87
+ global directml_enabled
88
+ global cpu_state
89
+ if directml_enabled:
90
+ global directml_device
91
+ return directml_device
92
+ if cpu_state == CPUState.MPS:
93
+ return torch.device("mps")
94
+ if cpu_state == CPUState.CPU:
95
+ return torch.device("cpu")
96
+ else:
97
+ if is_intel_xpu():
98
+ return torch.device("xpu", torch.xpu.current_device())
99
+ else:
100
+ return torch.device(torch.cuda.current_device())
101
+
102
+
103
+ def get_total_memory(dev=None, torch_total_too=False):
104
+ global directml_enabled
105
+ if dev is None:
106
+ dev = get_torch_device()
107
+
108
+ if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
109
+ mem_total = psutil.virtual_memory().total
110
+ mem_total_torch = mem_total
111
+ else:
112
+ if directml_enabled:
113
+ mem_total = 1024 * 1024 * 1024 # TODO
114
+ mem_total_torch = mem_total
115
+ elif is_intel_xpu():
116
+ stats = torch.xpu.memory_stats(dev)
117
+ mem_reserved = stats['reserved_bytes.all.current']
118
+ mem_total_torch = mem_reserved
119
+ mem_total = torch.xpu.get_device_properties(dev).total_memory
120
+ else:
121
+ stats = torch.cuda.memory_stats(dev)
122
+ mem_reserved = stats['reserved_bytes.all.current']
123
+ _, mem_total_cuda = torch.cuda.mem_get_info(dev)
124
+ mem_total_torch = mem_reserved
125
+ mem_total = mem_total_cuda
126
+
127
+ if torch_total_too:
128
+ return (mem_total, mem_total_torch)
129
+ else:
130
+ return mem_total
131
+
132
+
133
+ total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
134
+ total_ram = psutil.virtual_memory().total / (1024 * 1024)
135
+ print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
136
+
137
+ try:
138
+ print("pytorch version: {}".format(torch.version.__version__))
139
+ except:
140
+ pass
141
+
142
+ try:
143
+ OOM_EXCEPTION = torch.cuda.OutOfMemoryError
144
+ except:
145
+ OOM_EXCEPTION = Exception
146
+
147
+ if directml_enabled:
148
+ OOM_EXCEPTION = Exception
149
+
150
+ XFORMERS_VERSION = ""
151
+ XFORMERS_ENABLED_VAE = True
152
+ if args.disable_xformers:
153
+ XFORMERS_IS_AVAILABLE = False
154
+ else:
155
+ try:
156
+ import xformers
157
+ import xformers.ops
158
+
159
+ XFORMERS_IS_AVAILABLE = True
160
+ try:
161
+ XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
162
+ except:
163
+ pass
164
+ try:
165
+ XFORMERS_VERSION = xformers.version.__version__
166
+ print("xformers version: {}".format(XFORMERS_VERSION))
167
+ if XFORMERS_VERSION.startswith("0.0.18"):
168
+ print("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
169
+ print("Please downgrade or upgrade xformers to a different version.\n")
170
+ XFORMERS_ENABLED_VAE = False
171
+ except:
172
+ pass
173
+ except:
174
+ XFORMERS_IS_AVAILABLE = False
175
+
176
+
177
+ def is_nvidia():
178
+ global cpu_state
179
+ if cpu_state == CPUState.GPU:
180
+ if torch.version.cuda:
181
+ return True
182
+ return False
183
+
184
+
185
+ ENABLE_PYTORCH_ATTENTION = False
186
+ if args.attention_pytorch:
187
+ ENABLE_PYTORCH_ATTENTION = True
188
+ XFORMERS_IS_AVAILABLE = False
189
+
190
+ VAE_DTYPES = [torch.float32]
191
+
192
+ try:
193
+ if is_nvidia():
194
+ torch_version = torch.version.__version__
195
+ if int(torch_version[0]) >= 2:
196
+ if ENABLE_PYTORCH_ATTENTION == False and args.attention_split == False and args.attention_quad == False:
197
+ ENABLE_PYTORCH_ATTENTION = True
198
+ if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
199
+ VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
200
+ if is_intel_xpu():
201
+ if args.attention_split == False and args.attention_quad == False:
202
+ ENABLE_PYTORCH_ATTENTION = True
203
+ except:
204
+ pass
205
+
206
+ if is_intel_xpu():
207
+ VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
208
+
209
+ if args.vae_in_cpu:
210
+ VAE_DTYPES = [torch.float32]
211
+
212
+ VAE_ALWAYS_TILED = False
213
+
214
+ if ENABLE_PYTORCH_ATTENTION:
215
+ torch.backends.cuda.enable_math_sdp(True)
216
+ torch.backends.cuda.enable_flash_sdp(True)
217
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
218
+
219
+ if args.always_low_vram:
220
+ set_vram_to = VRAMState.LOW_VRAM
221
+ lowvram_available = True
222
+ elif args.always_no_vram:
223
+ set_vram_to = VRAMState.NO_VRAM
224
+ elif args.always_high_vram or args.always_gpu:
225
+ vram_state = VRAMState.HIGH_VRAM
226
+
227
+ FORCE_FP32 = False
228
+ FORCE_FP16 = False
229
+ if args.all_in_fp32:
230
+ print("Forcing FP32, if this improves things please report it.")
231
+ FORCE_FP32 = True
232
+
233
+ if args.all_in_fp16:
234
+ print("Forcing FP16.")
235
+ FORCE_FP16 = True
236
+
237
+ if lowvram_available:
238
+ if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
239
+ vram_state = set_vram_to
240
+
241
+ if cpu_state != CPUState.GPU:
242
+ vram_state = VRAMState.DISABLED
243
+
244
+ if cpu_state == CPUState.MPS:
245
+ vram_state = VRAMState.SHARED
246
+
247
+ print(f"Set vram state to: {vram_state.name}")
248
+
249
+ ALWAYS_VRAM_OFFLOAD = args.always_offload_from_vram
250
+
251
+ if ALWAYS_VRAM_OFFLOAD:
252
+ print("Always offload VRAM")
253
+
254
+ PIN_SHARED_MEMORY = args.pin_shared_memory
255
+
256
+ if PIN_SHARED_MEMORY:
257
+ print("Always pin shared GPU memory")
258
+
259
+
260
+ def get_torch_device_name(device):
261
+ if hasattr(device, 'type'):
262
+ if device.type == "cuda":
263
+ try:
264
+ allocator_backend = torch.cuda.get_allocator_backend()
265
+ except:
266
+ allocator_backend = ""
267
+ return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
268
+ else:
269
+ return "{}".format(device.type)
270
+ elif is_intel_xpu():
271
+ return "{} {}".format(device, torch.xpu.get_device_name(device))
272
+ else:
273
+ return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
274
+
275
+
276
+ try:
277
+ torch_device_name = get_torch_device_name(get_torch_device())
278
+ print("Device: {}".format(torch_device_name))
279
+ except:
280
+ torch_device_name = ''
281
+ print("Could not pick default device.")
282
+
283
+ if 'rtx' in torch_device_name.lower():
284
+ if not args.cuda_malloc:
285
+ print('Hint: your device supports --cuda-malloc for potential speed improvements.')
286
+
287
+
288
+ current_loaded_models = []
289
+
290
+
291
+ def state_dict_size(sd, exclude_device=None):
292
+ module_mem = 0
293
+ for k in sd:
294
+ t = sd[k]
295
+
296
+ if exclude_device is not None:
297
+ if t.device == exclude_device:
298
+ continue
299
+
300
+ module_mem += t.nelement() * t.element_size()
301
+ return module_mem
302
+
303
+
304
+ def state_dict_parameters(sd):
305
+ module_mem = 0
306
+ for k, v in sd.items():
307
+ module_mem += v.nelement()
308
+ return module_mem
309
+
310
+
311
+ def state_dict_dtype(state_dict):
312
+ for k, v in state_dict.items():
313
+ if hasattr(v, 'gguf_cls'):
314
+ return 'gguf'
315
+ if 'bitsandbytes__nf4' in k:
316
+ return 'nf4'
317
+ if 'bitsandbytes__fp4' in k:
318
+ return 'fp4'
319
+
320
+ dtype_counts = {}
321
+
322
+ for tensor in state_dict.values():
323
+ dtype = tensor.dtype
324
+ if dtype in dtype_counts:
325
+ dtype_counts[dtype] += 1
326
+ else:
327
+ dtype_counts[dtype] = 1
328
+
329
+ major_dtype = None
330
+ max_count = 0
331
+
332
+ for dtype, count in dtype_counts.items():
333
+ if count > max_count:
334
+ max_count = count
335
+ major_dtype = dtype
336
+
337
+ return major_dtype
338
+
339
+
340
+ def bake_gguf_model(model):
341
+ if getattr(model, 'gguf_baked', False):
342
+ return
343
+
344
+ for p in model.parameters():
345
+ gguf_cls = getattr(p, 'gguf_cls', None)
346
+ if gguf_cls is not None:
347
+ gguf_cls.bake(p)
348
+
349
+ global signal_empty_cache
350
+ signal_empty_cache = True
351
+
352
+ model.gguf_baked = True
353
+ return model
354
+
355
+
356
+ def module_size(module, exclude_device=None, include_device=None, return_split=False):
357
+ module_mem = 0
358
+ weight_mem = 0
359
+ weight_patterns = ['weight']
360
+
361
+ for k, p in module.named_parameters():
362
+ t = p.data
363
+
364
+ if exclude_device is not None:
365
+ if t.device == exclude_device:
366
+ continue
367
+
368
+ if include_device is not None:
369
+ if t.device != include_device:
370
+ continue
371
+
372
+ element_size = t.element_size()
373
+
374
+ if getattr(p, 'quant_type', None) in ['fp4', 'nf4']:
375
+ if element_size > 1:
376
+ # not quanted yet
377
+ element_size = 0.55 # a bit more than 0.5 because of quant state parameters
378
+ else:
379
+ # quanted
380
+ element_size = 1.1 # a bit more than 0.5 because of quant state parameters
381
+
382
+ module_mem += t.nelement() * element_size
383
+
384
+ if k in weight_patterns:
385
+ weight_mem += t.nelement() * element_size
386
+
387
+ if return_split:
388
+ return module_mem, weight_mem, module_mem - weight_mem
389
+
390
+ return module_mem
391
+
392
+
393
+ def module_move(module, device, recursive=True, excluded_pattens=[]):
394
+ if recursive:
395
+ return module.to(device=device)
396
+
397
+ for k, p in module.named_parameters(recurse=False, remove_duplicate=True):
398
+ if k in excluded_pattens:
399
+ continue
400
+ setattr(module, k, utils.tensor2parameter(p.to(device=device)))
401
+
402
+ return module
403
+
404
+
405
+ def build_module_profile(model, model_gpu_memory_when_using_cpu_swap):
406
+ all_modules = []
407
+ legacy_modules = []
408
+
409
+ for m in model.modules():
410
+ if hasattr(m, "parameters_manual_cast"):
411
+ m.total_mem, m.weight_mem, m.extra_mem = module_size(m, return_split=True)
412
+ all_modules.append(m)
413
+ elif hasattr(m, "weight"):
414
+ m.total_mem, m.weight_mem, m.extra_mem = module_size(m, return_split=True)
415
+ legacy_modules.append(m)
416
+
417
+ gpu_modules = []
418
+ gpu_modules_only_extras = []
419
+ mem_counter = 0
420
+
421
+ for m in legacy_modules.copy():
422
+ gpu_modules.append(m)
423
+ legacy_modules.remove(m)
424
+ mem_counter += m.total_mem
425
+
426
+ for m in sorted(all_modules, key=lambda x: x.extra_mem).copy():
427
+ if mem_counter + m.extra_mem < model_gpu_memory_when_using_cpu_swap:
428
+ gpu_modules_only_extras.append(m)
429
+ all_modules.remove(m)
430
+ mem_counter += m.extra_mem
431
+
432
+ cpu_modules = all_modules
433
+
434
+ for m in sorted(gpu_modules_only_extras, key=lambda x: x.weight_mem).copy():
435
+ if mem_counter + m.weight_mem < model_gpu_memory_when_using_cpu_swap:
436
+ gpu_modules.append(m)
437
+ gpu_modules_only_extras.remove(m)
438
+ mem_counter += m.weight_mem
439
+
440
+ return gpu_modules, gpu_modules_only_extras, cpu_modules
441
+
442
+
443
+ class LoadedModel:
444
+ def __init__(self, model):
445
+ self.model = model
446
+ self.model_accelerated = False
447
+ self.device = model.load_device
448
+ self.inclusive_memory = 0
449
+ self.exclusive_memory = 0
450
+
451
+ def compute_inclusive_exclusive_memory(self):
452
+ self.inclusive_memory = module_size(self.model.model, include_device=self.device)
453
+ self.exclusive_memory = module_size(self.model.model, exclude_device=self.device)
454
+ return
455
+
456
+ def model_load(self, model_gpu_memory_when_using_cpu_swap=-1):
457
+ patch_model_to = None
458
+ do_not_need_cpu_swap = model_gpu_memory_when_using_cpu_swap < 0
459
+
460
+ if do_not_need_cpu_swap:
461
+ patch_model_to = self.device
462
+
463
+ self.model.model_patches_to(self.device)
464
+ self.model.model_patches_to(self.model.model_dtype())
465
+
466
+ try:
467
+ self.real_model = self.model.forge_patch_model(patch_model_to)
468
+ self.model.current_device = self.model.load_device
469
+ except Exception as e:
470
+ self.model.forge_unpatch_model(self.model.offload_device)
471
+ self.model_unload()
472
+ raise e
473
+
474
+ if do_not_need_cpu_swap:
475
+ print('All loaded to GPU.')
476
+ else:
477
+ gpu_modules, gpu_modules_only_extras, cpu_modules = build_module_profile(self.real_model, model_gpu_memory_when_using_cpu_swap)
478
+ pin_memory = PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device)
479
+
480
+ mem_counter = 0
481
+ swap_counter = 0
482
+
483
+ for m in gpu_modules:
484
+ m.to(self.device)
485
+ mem_counter += m.total_mem
486
+
487
+ for m in cpu_modules:
488
+ m.prev_parameters_manual_cast = m.parameters_manual_cast
489
+ m.parameters_manual_cast = True
490
+ m.to(self.model.offload_device)
491
+ if pin_memory:
492
+ m._apply(lambda x: x.pin_memory())
493
+ swap_counter += m.total_mem
494
+
495
+ for m in gpu_modules_only_extras:
496
+ m.prev_parameters_manual_cast = m.parameters_manual_cast
497
+ m.parameters_manual_cast = True
498
+ module_move(m, device=self.device, recursive=False, excluded_pattens=['weight'])
499
+ if hasattr(m, 'weight') and m.weight is not None:
500
+ if pin_memory:
501
+ m.weight = utils.tensor2parameter(m.weight.to(self.model.offload_device).pin_memory())
502
+ else:
503
+ m.weight = utils.tensor2parameter(m.weight.to(self.model.offload_device))
504
+ mem_counter += m.extra_mem
505
+ swap_counter += m.weight_mem
506
+
507
+ swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
508
+ method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'
509
+ print(f"{swap_flag} Swap Loaded ({method_flag} method): {swap_counter / (1024 * 1024):.2f} MB, GPU Loaded: {mem_counter / (1024 * 1024):.2f} MB")
510
+
511
+ self.model_accelerated = True
512
+
513
+ global signal_empty_cache
514
+ signal_empty_cache = True
515
+
516
+ bake_gguf_model(self.real_model)
517
+
518
+ self.model.refresh_loras()
519
+
520
+ if is_intel_xpu() and not args.disable_ipex_hijack:
521
+ self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
522
+
523
+ return self.real_model
524
+
525
+ def model_unload(self, avoid_model_moving=False):
526
+ if self.model_accelerated:
527
+ for m in self.real_model.modules():
528
+ if hasattr(m, "prev_parameters_manual_cast"):
529
+ m.parameters_manual_cast = m.prev_parameters_manual_cast
530
+ del m.prev_parameters_manual_cast
531
+
532
+ self.model_accelerated = False
533
+
534
+ if avoid_model_moving:
535
+ self.model.forge_unpatch_model()
536
+ else:
537
+ self.model.forge_unpatch_model(self.model.offload_device)
538
+ self.model.model_patches_to(self.model.offload_device)
539
+
540
+ def __eq__(self, other):
541
+ return self.model is other.model # and self.memory_required == other.memory_required
542
+
543
+
544
+ current_inference_memory = 1024 * 1024 * 1024
545
+
546
+
547
+ def minimum_inference_memory():
548
+ global current_inference_memory
549
+ return current_inference_memory
550
+
551
+
552
+ def unload_model_clones(model):
553
+ to_unload = []
554
+ for i in range(len(current_loaded_models)):
555
+ if model.is_clone(current_loaded_models[i].model):
556
+ to_unload = [i] + to_unload
557
+
558
+ for i in to_unload:
559
+ current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
560
+
561
+
562
+ def free_memory(memory_required, device, keep_loaded=[], free_all=False):
563
+ # this check fully unloads any 'abandoned' models
564
+ for i in range(len(current_loaded_models) - 1, -1, -1):
565
+ if sys.getrefcount(current_loaded_models[i].model) <= 2:
566
+ current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
567
+
568
+ if free_all:
569
+ memory_required = 1e30
570
+ print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
571
+ else:
572
+ print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
573
+
574
+ offload_everything = ALWAYS_VRAM_OFFLOAD or vram_state == VRAMState.NO_VRAM
575
+ unloaded_model = False
576
+ for i in range(len(current_loaded_models) - 1, -1, -1):
577
+ if not offload_everything:
578
+ free_memory = get_free_memory(device)
579
+ print(f"Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ", end="")
580
+ if free_memory > memory_required:
581
+ break
582
+ shift_model = current_loaded_models[i]
583
+ if shift_model.device == device:
584
+ if shift_model not in keep_loaded:
585
+ m = current_loaded_models.pop(i)
586
+ print(f"Unload model {m.model.model.__class__.__name__} ", end="")
587
+ m.model_unload()
588
+ del m
589
+ unloaded_model = True
590
+
591
+ if unloaded_model:
592
+ soft_empty_cache()
593
+ else:
594
+ if vram_state != VRAMState.HIGH_VRAM:
595
+ mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
596
+ if mem_free_torch > mem_free_total * 0.25:
597
+ soft_empty_cache()
598
+
599
+ print('Done.')
600
+ return
601
+
602
+
603
+ def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
604
+ maximum_memory_available = current_free_mem - inference_memory
605
+
606
+ suggestion = max(
607
+ maximum_memory_available / 1.3,
608
+ maximum_memory_available - 1024 * 1024 * 1024 * 1.25
609
+ )
610
+
611
+ return int(max(0, suggestion))
612
+
613
+
614
+ def load_models_gpu(models, memory_required=0, hard_memory_preservation=0):
615
+ global vram_state
616
+
617
+ execution_start_time = time.perf_counter()
618
+ memory_to_free = max(minimum_inference_memory(), memory_required) + hard_memory_preservation
619
+ memory_for_inference = minimum_inference_memory() + hard_memory_preservation
620
+
621
+ models_to_load = []
622
+ models_already_loaded = []
623
+ for x in models:
624
+ loaded_model = LoadedModel(x)
625
+
626
+ if loaded_model in current_loaded_models:
627
+ index = current_loaded_models.index(loaded_model)
628
+ current_loaded_models.insert(0, current_loaded_models.pop(index))
629
+ models_already_loaded.append(loaded_model)
630
+ else:
631
+ models_to_load.append(loaded_model)
632
+
633
+ if len(models_to_load) == 0:
634
+ devs = set(map(lambda a: a.device, models_already_loaded))
635
+ for d in devs:
636
+ if d != torch.device("cpu"):
637
+ free_memory(memory_to_free, d, models_already_loaded)
638
+
639
+ moving_time = time.perf_counter() - execution_start_time
640
+ if moving_time > 0.1:
641
+ print(f'Memory cleanup has taken {moving_time:.2f} seconds')
642
+
643
+ return
644
+
645
+ for loaded_model in models_to_load:
646
+ unload_model_clones(loaded_model.model)
647
+
648
+ total_memory_required = {}
649
+ for loaded_model in models_to_load:
650
+ loaded_model.compute_inclusive_exclusive_memory()
651
+ total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.exclusive_memory + loaded_model.inclusive_memory * 0.25
652
+
653
+ for device in total_memory_required:
654
+ if device != torch.device("cpu"):
655
+ free_memory(total_memory_required[device] * 1.3 + memory_to_free, device, models_already_loaded)
656
+
657
+ for loaded_model in models_to_load:
658
+ model = loaded_model.model
659
+ torch_dev = model.load_device
660
+ if is_device_cpu(torch_dev):
661
+ vram_set_state = VRAMState.DISABLED
662
+ else:
663
+ vram_set_state = vram_state
664
+
665
+ model_gpu_memory_when_using_cpu_swap = -1
666
+
667
+ if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
668
+ model_require = loaded_model.exclusive_memory
669
+ previously_loaded = loaded_model.inclusive_memory
670
+ current_free_mem = get_free_memory(torch_dev)
671
+ estimated_remaining_memory = current_free_mem - model_require - memory_for_inference
672
+
673
+ print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_require / (1024 * 1024):.2f} MB, Previously Loaded: {previously_loaded / (1024 * 1024):.2f} MB, Inference Require: {memory_for_inference / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
674
+
675
+ if estimated_remaining_memory < 0:
676
+ vram_set_state = VRAMState.LOW_VRAM
677
+ model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, memory_for_inference)
678
+ if previously_loaded > 0:
679
+ model_gpu_memory_when_using_cpu_swap = previously_loaded
680
+
681
+ if vram_set_state == VRAMState.NO_VRAM:
682
+ model_gpu_memory_when_using_cpu_swap = 0
683
+
684
+ loaded_model.model_load(model_gpu_memory_when_using_cpu_swap)
685
+ current_loaded_models.insert(0, loaded_model)
686
+
687
+ moving_time = time.perf_counter() - execution_start_time
688
+ print(f'Moving model(s) has taken {moving_time:.2f} seconds')
689
+
690
+ return
691
+
692
+
693
+ def load_model_gpu(model):
694
+ return load_models_gpu([model])
695
+
696
+
697
+ def cleanup_models():
698
+ to_delete = []
699
+ for i in range(len(current_loaded_models)):
700
+ if sys.getrefcount(current_loaded_models[i].model) <= 2:
701
+ to_delete = [i] + to_delete
702
+
703
+ for i in to_delete:
704
+ x = current_loaded_models.pop(i)
705
+ x.model_unload()
706
+ del x
707
+
708
+
709
+ def dtype_size(dtype):
710
+ dtype_size = 4
711
+ if dtype == torch.float16 or dtype == torch.bfloat16:
712
+ dtype_size = 2
713
+ elif dtype == torch.float32:
714
+ dtype_size = 4
715
+ else:
716
+ try:
717
+ dtype_size = dtype.itemsize
718
+ except: # Old pytorch doesn't have .itemsize
719
+ pass
720
+ return dtype_size
721
+
722
+
723
+ def unet_offload_device():
724
+ if vram_state == VRAMState.HIGH_VRAM:
725
+ return get_torch_device()
726
+ else:
727
+ return torch.device("cpu")
728
+
729
+
730
+ def unet_inital_load_device(parameters, dtype):
731
+ torch_dev = get_torch_device()
732
+ if vram_state == VRAMState.HIGH_VRAM:
733
+ return torch_dev
734
+
735
+ cpu_dev = torch.device("cpu")
736
+ if ALWAYS_VRAM_OFFLOAD:
737
+ return cpu_dev
738
+
739
+ model_size = dtype_size(dtype) * parameters
740
+
741
+ mem_dev = get_free_memory(torch_dev)
742
+ mem_cpu = get_free_memory(cpu_dev)
743
+ if mem_dev > mem_cpu and model_size < mem_dev:
744
+ return torch_dev
745
+ else:
746
+ return cpu_dev
747
+
748
+
749
+ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
750
+ if args.unet_in_bf16:
751
+ return torch.bfloat16
752
+
753
+ if args.unet_in_fp16:
754
+ return torch.float16
755
+
756
+ if args.unet_in_fp8_e4m3fn:
757
+ return torch.float8_e4m3fn
758
+
759
+ if args.unet_in_fp8_e5m2:
760
+ return torch.float8_e5m2
761
+
762
+ for candidate in supported_dtypes:
763
+ if candidate == torch.float16:
764
+ if should_use_fp16(device, model_params=model_params, prioritize_performance=True, manual_cast=True):
765
+ return candidate
766
+ if candidate == torch.bfloat16:
767
+ if should_use_bf16(device, model_params=model_params, prioritize_performance=True, manual_cast=True):
768
+ return candidate
769
+
770
+ return torch.float32
771
+
772
+
773
+ def get_computation_dtype(inference_device, parameters=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
774
+ for candidate in supported_dtypes:
775
+ if candidate == torch.float16:
776
+ if should_use_fp16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False):
777
+ return candidate
778
+ if candidate == torch.bfloat16:
779
+ if should_use_bf16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False):
780
+ return candidate
781
+
782
+ return torch.float32
783
+
784
+
785
+ def text_encoder_offload_device():
786
+ if args.always_gpu:
787
+ return get_torch_device()
788
+ else:
789
+ return torch.device("cpu")
790
+
791
+
792
+ def text_encoder_device():
793
+ if args.always_gpu:
794
+ return get_torch_device()
795
+ elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
796
+ if should_use_fp16(prioritize_performance=False):
797
+ return get_torch_device()
798
+ else:
799
+ return torch.device("cpu")
800
+ else:
801
+ return torch.device("cpu")
802
+
803
+
804
+ def text_encoder_dtype(device=None):
805
+ if args.clip_in_fp8_e4m3fn:
806
+ return torch.float8_e4m3fn
807
+ elif args.clip_in_fp8_e5m2:
808
+ return torch.float8_e5m2
809
+ elif args.clip_in_fp16:
810
+ return torch.float16
811
+ elif args.clip_in_fp32:
812
+ return torch.float32
813
+
814
+ if is_device_cpu(device):
815
+ return torch.float16
816
+
817
+ return torch.float16
818
+
819
+
820
+ def intermediate_device():
821
+ if args.always_gpu:
822
+ return get_torch_device()
823
+ else:
824
+ return torch.device("cpu")
825
+
826
+
827
+ def vae_device():
828
+ if args.vae_in_cpu:
829
+ return torch.device("cpu")
830
+ return get_torch_device()
831
+
832
+
833
+ def vae_offload_device():
834
+ if args.always_gpu:
835
+ return get_torch_device()
836
+ else:
837
+ return torch.device("cpu")
838
+
839
+
840
+ def vae_dtype(device=None, allowed_dtypes=[]):
841
+ global VAE_DTYPES
842
+ if args.vae_in_fp16:
843
+ return torch.float16
844
+ elif args.vae_in_bf16:
845
+ return torch.bfloat16
846
+ elif args.vae_in_fp32:
847
+ return torch.float32
848
+
849
+ for d in allowed_dtypes:
850
+ if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
851
+ return d
852
+ if d in VAE_DTYPES:
853
+ return d
854
+
855
+ return VAE_DTYPES[0]
856
+
857
+
858
+ print(f"VAE dtype preferences: {VAE_DTYPES} -> {vae_dtype()}")
859
+
860
+
861
+ def get_autocast_device(dev):
862
+ if hasattr(dev, 'type'):
863
+ return dev.type
864
+ return "cuda"
865
+
866
+
867
+ def supports_dtype(device, dtype): # TODO
868
+ if dtype == torch.float32:
869
+ return True
870
+ if is_device_cpu(device):
871
+ return False
872
+ if dtype == torch.float16:
873
+ return True
874
+ if dtype == torch.bfloat16:
875
+ return True
876
+ return False
877
+
878
+
879
+ def supports_cast(device, dtype): # TODO
880
+ if dtype == torch.float32:
881
+ return True
882
+ if dtype == torch.float16:
883
+ return True
884
+ if directml_enabled: # TODO: test this
885
+ return False
886
+ if dtype == torch.bfloat16:
887
+ return True
888
+ if is_device_mps(device):
889
+ return False
890
+ if dtype == torch.float8_e4m3fn:
891
+ return True
892
+ if dtype == torch.float8_e5m2:
893
+ return True
894
+ return False
895
+
896
+
897
+ def pick_weight_dtype(dtype, fallback_dtype, device=None):
898
+ if dtype is None:
899
+ dtype = fallback_dtype
900
+ elif dtype_size(dtype) > dtype_size(fallback_dtype):
901
+ dtype = fallback_dtype
902
+
903
+ if not supports_cast(device, dtype):
904
+ dtype = fallback_dtype
905
+
906
+ return dtype
907
+
908
+
909
+ def device_supports_non_blocking(device):
910
+ if is_device_mps(device):
911
+ return False # pytorch bug? mps doesn't support non blocking
912
+ if is_intel_xpu():
913
+ return False
914
+ if args.pytorch_deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
915
+ return False
916
+ if directml_enabled:
917
+ return False
918
+ return True
919
+
920
+
921
+ def device_should_use_non_blocking(device):
922
+ if not device_supports_non_blocking(device):
923
+ return False
924
+ return False
925
+ # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
926
+
927
+
928
+ def force_channels_last():
929
+ if args.force_channels_last:
930
+ return True
931
+
932
+ # TODO
933
+ return False
934
+
935
+
936
+ def cast_to_device(tensor, device, dtype, copy=False):
937
+ device_supports_cast = False
938
+ if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
939
+ device_supports_cast = True
940
+ elif tensor.dtype == torch.bfloat16:
941
+ if hasattr(device, 'type') and device.type.startswith("cuda"):
942
+ device_supports_cast = True
943
+ elif is_intel_xpu():
944
+ device_supports_cast = True
945
+
946
+ non_blocking = device_should_use_non_blocking(device)
947
+
948
+ if device_supports_cast:
949
+ if copy:
950
+ if tensor.device == device:
951
+ return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
952
+ return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
953
+ else:
954
+ return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
955
+ else:
956
+ return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
957
+
958
+
959
+ def xformers_enabled():
960
+ global directml_enabled
961
+ global cpu_state
962
+ if cpu_state != CPUState.GPU:
963
+ return False
964
+ if is_intel_xpu():
965
+ return False
966
+ if directml_enabled:
967
+ return False
968
+ return XFORMERS_IS_AVAILABLE
969
+
970
+
971
+ def xformers_enabled_vae():
972
+ enabled = xformers_enabled()
973
+ if not enabled:
974
+ return False
975
+
976
+ return XFORMERS_ENABLED_VAE
977
+
978
+
979
+ def pytorch_attention_enabled():
980
+ global ENABLE_PYTORCH_ATTENTION
981
+ return ENABLE_PYTORCH_ATTENTION
982
+
983
+
984
+ def pytorch_attention_flash_attention():
985
+ global ENABLE_PYTORCH_ATTENTION
986
+ if ENABLE_PYTORCH_ATTENTION:
987
+ # TODO: more reliable way of checking for flash attention?
988
+ if is_nvidia(): # pytorch flash attention only works on Nvidia
989
+ return True
990
+ if is_intel_xpu():
991
+ return True
992
+ return False
993
+
994
+
995
+ def force_upcast_attention_dtype():
996
+ upcast = args.force_upcast_attention
997
+ try:
998
+ if platform.mac_ver()[0] in ['14.5']: # black image bug on OSX Sonoma 14.5
999
+ upcast = True
1000
+ except:
1001
+ pass
1002
+ if upcast:
1003
+ return torch.float32
1004
+ else:
1005
+ return None
1006
+
1007
+
1008
+ def get_free_memory(dev=None, torch_free_too=False):
1009
+ global directml_enabled
1010
+ if dev is None:
1011
+ dev = get_torch_device()
1012
+
1013
+ if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
1014
+ mem_free_total = psutil.virtual_memory().available
1015
+ mem_free_torch = mem_free_total
1016
+ else:
1017
+ if directml_enabled:
1018
+ mem_free_total = 1024 * 1024 * 1024
1019
+ mem_free_torch = mem_free_total
1020
+ elif is_intel_xpu():
1021
+ stats = torch.xpu.memory_stats(dev)
1022
+ mem_active = stats['active_bytes.all.current']
1023
+ mem_reserved = stats['reserved_bytes.all.current']
1024
+ mem_free_torch = mem_reserved - mem_active
1025
+ mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
1026
+ mem_free_total = mem_free_xpu + mem_free_torch
1027
+ else:
1028
+ stats = torch.cuda.memory_stats(dev)
1029
+ mem_active = stats['active_bytes.all.current']
1030
+ mem_reserved = stats['reserved_bytes.all.current']
1031
+ mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
1032
+ mem_free_torch = mem_reserved - mem_active
1033
+ mem_free_total = mem_free_cuda + mem_free_torch
1034
+
1035
+ if torch_free_too:
1036
+ return (mem_free_total, mem_free_torch)
1037
+ else:
1038
+ return mem_free_total
1039
+
1040
+
1041
+ def cpu_mode():
1042
+ global cpu_state
1043
+ return cpu_state == CPUState.CPU
1044
+
1045
+
1046
+ def mps_mode():
1047
+ global cpu_state
1048
+ return cpu_state == CPUState.MPS
1049
+
1050
+
1051
+ def is_device_type(device, type):
1052
+ if hasattr(device, 'type'):
1053
+ if (device.type == type):
1054
+ return True
1055
+ return False
1056
+
1057
+
1058
+ def is_device_cpu(device):
1059
+ return is_device_type(device, 'cpu')
1060
+
1061
+
1062
+ def is_device_mps(device):
1063
+ return is_device_type(device, 'mps')
1064
+
1065
+
1066
+ def is_device_cuda(device):
1067
+ return is_device_type(device, 'cuda')
1068
+
1069
+
1070
+ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
1071
+ global directml_enabled
1072
+
1073
+ if device is not None:
1074
+ if is_device_cpu(device):
1075
+ return False
1076
+
1077
+ if FORCE_FP16:
1078
+ return True
1079
+
1080
+ if device is not None:
1081
+ if is_device_mps(device):
1082
+ return True
1083
+
1084
+ if FORCE_FP32:
1085
+ return False
1086
+
1087
+ if directml_enabled:
1088
+ return False
1089
+
1090
+ if mps_mode():
1091
+ return True
1092
+
1093
+ if cpu_mode():
1094
+ return False
1095
+
1096
+ if is_intel_xpu():
1097
+ return True
1098
+
1099
+ if torch.version.hip:
1100
+ return True
1101
+
1102
+ props = torch.cuda.get_device_properties("cuda")
1103
+ if props.major >= 8:
1104
+ return True
1105
+
1106
+ if props.major < 6:
1107
+ return False
1108
+
1109
+ nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
1110
+ for x in nvidia_10_series:
1111
+ if x in props.name.lower():
1112
+ if manual_cast:
1113
+ # For storage dtype
1114
+ free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
1115
+ if (not prioritize_performance) or model_params * 4 > free_model_memory:
1116
+ return True
1117
+ else:
1118
+ # For computation dtype
1119
+ return False # Flux on 1080 can store model in fp16 to reduce swap, but computation must be fp32, otherwise super slow.
1120
+
1121
+ if props.major < 7:
1122
+ return False
1123
+
1124
+ # FP16 is just broken on these cards
1125
+ nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
1126
+ for x in nvidia_16_series:
1127
+ if x in props.name:
1128
+ return False
1129
+
1130
+ return True
1131
+
1132
+
1133
+ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
1134
+ if device is not None:
1135
+ if is_device_cpu(device): # TODO ? bf16 works on CPU but is extremely slow
1136
+ return False
1137
+
1138
+ if device is not None:
1139
+ if is_device_mps(device):
1140
+ return True
1141
+
1142
+ if FORCE_FP32:
1143
+ return False
1144
+
1145
+ if directml_enabled:
1146
+ return False
1147
+
1148
+ if mps_mode():
1149
+ return True
1150
+
1151
+ if cpu_mode():
1152
+ return False
1153
+
1154
+ if is_intel_xpu():
1155
+ return True
1156
+
1157
+ if device is None:
1158
+ device = torch.device("cuda")
1159
+
1160
+ props = torch.cuda.get_device_properties(device)
1161
+ if props.major >= 8:
1162
+ return True
1163
+
1164
+ if torch.cuda.is_bf16_supported():
1165
+ # This device is an old enough device but bf16 somewhat reports supported.
1166
+ # So in this case bf16 should only be used as storge dtype
1167
+ if manual_cast:
1168
+ # For storage dtype
1169
+ free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
1170
+ if (not prioritize_performance) or model_params * 4 > free_model_memory:
1171
+ return True
1172
+
1173
+ return False
1174
+
1175
+
1176
+ def can_install_bnb():
1177
+ try:
1178
+ if not torch.cuda.is_available():
1179
+ return False
1180
+
1181
+ cuda_version = tuple(int(x) for x in torch.version.cuda.split('.'))
1182
+
1183
+ if cuda_version >= (11, 7):
1184
+ return True
1185
+
1186
+ return False
1187
+ except:
1188
+ return False
1189
+
1190
+
1191
+ signal_empty_cache = False
1192
+
1193
+
1194
+ def soft_empty_cache(force=False):
1195
+ global cpu_state, signal_empty_cache
1196
+ if cpu_state == CPUState.MPS:
1197
+ torch.mps.empty_cache()
1198
+ elif is_intel_xpu():
1199
+ torch.xpu.empty_cache()
1200
+ elif torch.cuda.is_available():
1201
+ if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda
1202
+ torch.cuda.empty_cache()
1203
+ torch.cuda.ipc_collect()
1204
+ signal_empty_cache = False
1205
+ return
1206
+
1207
+
1208
+ def unload_all_models():
1209
+ free_memory(1e30, get_torch_device(), free_all=True)