Spaces:
Runtime error
Runtime error
| import os | |
| from collections import namedtuple | |
| from contextlib import closing | |
| import torch | |
| import tqdm | |
| import html | |
| import datetime | |
| import csv | |
| import safetensors.torch | |
| import numpy as np | |
| from PIL import Image, PngImagePlugin | |
| from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes | |
| from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay | |
| TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"]) | |
| textual_inversion_templates = {} | |
| def list_textual_inversion_templates(): | |
| textual_inversion_templates.clear() | |
| for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): | |
| for fn in fns: | |
| path = os.path.join(root, fn) | |
| textual_inversion_templates[fn] = TextualInversionTemplate(fn, path) | |
| return textual_inversion_templates | |
| class Embedding: | |
| def __init__(self, vec, name, step=None): | |
| self.vec = vec | |
| self.name = name | |
| self.step = step | |
| self.shape = None | |
| self.vectors = 0 | |
| self.cached_checksum = None | |
| self.sd_checkpoint = None | |
| self.sd_checkpoint_name = None | |
| self.optimizer_state_dict = None | |
| self.filename = None | |
| self.hash = None | |
| self.shorthash = None | |
| def save(self, filename): | |
| embedding_data = { | |
| "string_to_token": {"*": 265}, | |
| "string_to_param": {"*": self.vec}, | |
| "name": self.name, | |
| "step": self.step, | |
| "sd_checkpoint": self.sd_checkpoint, | |
| "sd_checkpoint_name": self.sd_checkpoint_name, | |
| } | |
| torch.save(embedding_data, filename) | |
| if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: | |
| optimizer_saved_dict = { | |
| 'hash': self.checksum(), | |
| 'optimizer_state_dict': self.optimizer_state_dict, | |
| } | |
| torch.save(optimizer_saved_dict, f"{filename}.optim") | |
| def checksum(self): | |
| if self.cached_checksum is not None: | |
| return self.cached_checksum | |
| def const_hash(a): | |
| r = 0 | |
| for v in a: | |
| r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF | |
| return r | |
| self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' | |
| return self.cached_checksum | |
| def set_hash(self, v): | |
| self.hash = v | |
| self.shorthash = self.hash[0:12] | |
| class DirWithTextualInversionEmbeddings: | |
| def __init__(self, path): | |
| self.path = path | |
| self.mtime = None | |
| def has_changed(self): | |
| if not os.path.isdir(self.path): | |
| return False | |
| mt = os.path.getmtime(self.path) | |
| if self.mtime is None or mt > self.mtime: | |
| return True | |
| def update(self): | |
| if not os.path.isdir(self.path): | |
| return | |
| self.mtime = os.path.getmtime(self.path) | |
| class EmbeddingDatabase: | |
| def __init__(self): | |
| self.ids_lookup = {} | |
| self.word_embeddings = {} | |
| self.skipped_embeddings = {} | |
| self.expected_shape = -1 | |
| self.embedding_dirs = {} | |
| self.previously_displayed_embeddings = () | |
| def add_embedding_dir(self, path): | |
| self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) | |
| def clear_embedding_dirs(self): | |
| self.embedding_dirs.clear() | |
| def register_embedding(self, embedding, model): | |
| return self.register_embedding_by_name(embedding, model, embedding.name) | |
| def register_embedding_by_name(self, embedding, model, name): | |
| ids = [0, 0, 0] # model.cond_stage_model.tokenize([name])[0] | |
| first_id = ids[0] | |
| if first_id not in self.ids_lookup: | |
| self.ids_lookup[first_id] = [] | |
| if name in self.word_embeddings: | |
| # remove old one from the lookup list | |
| lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name] | |
| else: | |
| lookup = self.ids_lookup[first_id] | |
| if embedding is not None: | |
| lookup += [(ids, embedding)] | |
| self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True) | |
| if embedding is None: | |
| # unregister embedding with specified name | |
| if name in self.word_embeddings: | |
| del self.word_embeddings[name] | |
| if len(self.ids_lookup[first_id])==0: | |
| del self.ids_lookup[first_id] | |
| return None | |
| self.word_embeddings[name] = embedding | |
| return embedding | |
| def get_expected_shape(self): | |
| devices.torch_npu_set_device() | |
| vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) | |
| return vec.shape[1] | |
| def load_from_file(self, path, filename): | |
| name, ext = os.path.splitext(filename) | |
| ext = ext.upper() | |
| if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: | |
| _, second_ext = os.path.splitext(name) | |
| if second_ext.upper() == '.PREVIEW': | |
| return | |
| embed_image = Image.open(path) | |
| if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: | |
| data = embedding_from_b64(embed_image.text['sd-ti-embedding']) | |
| name = data.get('name', name) | |
| else: | |
| data = extract_image_data_embed(embed_image) | |
| if data: | |
| name = data.get('name', name) | |
| else: | |
| # if data is None, means this is not an embedding, just a preview image | |
| return | |
| elif ext in ['.BIN', '.PT']: | |
| data = torch.load(path, map_location="cpu") | |
| elif ext in ['.SAFETENSORS']: | |
| data = safetensors.torch.load_file(path, device="cpu") | |
| else: | |
| return | |
| if data is not None: | |
| embedding = create_embedding_from_data(data, name, filename=filename, filepath=path) | |
| self.register_embedding(embedding, None) | |
| else: | |
| print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.") | |
| def load_from_dir(self, embdir): | |
| if not os.path.isdir(embdir.path): | |
| return | |
| for root, _, fns in os.walk(embdir.path, followlinks=True): | |
| for fn in fns: | |
| try: | |
| fullfn = os.path.join(root, fn) | |
| if os.stat(fullfn).st_size == 0: | |
| continue | |
| self.load_from_file(fullfn, fn) | |
| except Exception: | |
| errors.report(f"Error loading embedding {fn}", exc_info=True) | |
| continue | |
| def load_textual_inversion_embeddings(self, force_reload=False, sync_with_sd_model=True): | |
| if not force_reload: | |
| need_reload = False | |
| for embdir in self.embedding_dirs.values(): | |
| if embdir.has_changed(): | |
| need_reload = True | |
| break | |
| if not need_reload: | |
| return | |
| self.ids_lookup.clear() | |
| self.word_embeddings.clear() | |
| self.skipped_embeddings.clear() | |
| if sync_with_sd_model: | |
| self.expected_shape = self.get_expected_shape() | |
| for embdir in self.embedding_dirs.values(): | |
| self.load_from_dir(embdir) | |
| embdir.update() | |
| # re-sort word_embeddings because load_from_dir may not load in alphabetic order. | |
| # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it. | |
| sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())} | |
| self.word_embeddings.clear() | |
| self.word_embeddings.update(sorted_word_embeddings) | |
| displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) | |
| if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings: | |
| self.previously_displayed_embeddings = displayed_embeddings | |
| print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") | |
| if self.skipped_embeddings: | |
| print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") | |
| def find_embedding_at_position(self, tokens, offset): | |
| token = tokens[offset] | |
| possible_matches = self.ids_lookup.get(token, None) | |
| if possible_matches is None: | |
| return None, None | |
| for ids, embedding in possible_matches: | |
| if tokens[offset:offset + len(ids)] == ids: | |
| return embedding, len(ids) | |
| return None, None | |
| def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): | |
| cond_model = shared.sd_model.cond_stage_model | |
| with devices.autocast(): | |
| cond_model([""]) # will send cond model to GPU if lowvram/medvram is active | |
| #cond_model expects at least some text, so we provide '*' as backup. | |
| embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token) | |
| vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) | |
| #Only copy if we provided an init_text, otherwise keep vectors as zeros | |
| if init_text: | |
| for i in range(num_vectors_per_token): | |
| vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] | |
| # Remove illegal characters from name. | |
| name = "".join( x for x in name if (x.isalnum() or x in "._- ")) | |
| fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") | |
| if not overwrite_old: | |
| assert not os.path.exists(fn), f"file {fn} already exists" | |
| embedding = Embedding(vec, name) | |
| embedding.step = 0 | |
| embedding.save(fn) | |
| return fn | |
| def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None): | |
| if 'string_to_param' in data: # textual inversion embeddings | |
| param_dict = data['string_to_param'] | |
| param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 | |
| assert len(param_dict) == 1, 'embedding file has multiple terms in it' | |
| emb = next(iter(param_dict.items()))[1] | |
| vec = emb.detach().to(devices.device, dtype=torch.float32) | |
| shape = vec.shape[-1] | |
| vectors = vec.shape[0] | |
| elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding | |
| vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} | |
| shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] | |
| vectors = data['clip_g'].shape[0] | |
| elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts | |
| assert len(data.keys()) == 1, 'embedding file has multiple terms in it' | |
| emb = next(iter(data.values())) | |
| if len(emb.shape) == 1: | |
| emb = emb.unsqueeze(0) | |
| vec = emb.detach().to(devices.device, dtype=torch.float32) | |
| shape = vec.shape[-1] | |
| vectors = vec.shape[0] | |
| else: | |
| raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") | |
| embedding = Embedding(vec, name) | |
| embedding.step = data.get('step', None) | |
| embedding.sd_checkpoint = data.get('sd_checkpoint', None) | |
| embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) | |
| embedding.vectors = vectors | |
| embedding.shape = shape | |
| if filepath: | |
| embedding.filename = filepath | |
| embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '') | |
| return embedding | |