| import torch | |
| from .. import shared | |
| class Embedding: | |
| def __init__(self, vec, name, step=None): | |
| self.vec = vec | |
| self.name = name | |
| self.step = step | |
| self.shape = None | |
| self.vectors = 0 | |
| self.cached_checksum = None | |
| self.sd_checkpoint = None | |
| self.sd_checkpoint_name = None | |
| self.optimizer_state_dict = None | |
| self.filename = None | |
| self.shape = vec.shape[-1] | |
| self.vectors = vec.shape[0] | |
| def save(self, filename): | |
| embedding_data = { | |
| "string_to_token": {"*": 265}, | |
| "string_to_param": {"*": self.vec}, | |
| "name": self.name, | |
| "step": self.step, | |
| "sd_checkpoint": self.sd_checkpoint, | |
| "sd_checkpoint_name": self.sd_checkpoint_name, | |
| } | |
| torch.save(embedding_data, filename) | |
| if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: | |
| optimizer_saved_dict = { | |
| 'hash': self.checksum(), | |
| 'optimizer_state_dict': self.optimizer_state_dict, | |
| } | |
| torch.save(optimizer_saved_dict, f"{filename}.optim") | |
| def checksum(self): | |
| if self.cached_checksum is not None: | |
| return self.cached_checksum | |
| def const_hash(a): | |
| r = 0 | |
| for v in a: | |
| r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF | |
| return r | |
| self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' | |
| return self.cached_checksum | |
| class EmbeddingDatabase: | |
| def __init__(self): | |
| self.ids_lookup = {} | |
| self.word_embeddings = {} | |
| self.skipped_embeddings = {} | |
| self.expected_shape = -1 | |
| self.embedding_dirs = {} | |
| self.previously_displayed_embeddings = () | |
| def register_embedding(self, embedding, model): | |
| self.word_embeddings[embedding.name] = embedding | |
| ids = model.tokenize([embedding.name])[0] | |
| first_id = ids[0] | |
| if first_id not in self.ids_lookup: | |
| self.ids_lookup[first_id] = [] | |
| self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True) | |
| return embedding | |
| def find_embedding_at_position(self, tokens, offset): | |
| token = tokens[offset] | |
| possible_matches = self.ids_lookup.get(token, None) | |
| if possible_matches is None: | |
| return None, None | |
| for ids, embedding in possible_matches: | |
| if tokens[offset:offset + len(ids)] == ids: | |
| return embedding, len(ids) | |
| return None, None |