| | """ |
| | The main idea for this code is to provide a way for users to not need to bother with the hassle of multiple tokens for a concept by typing |
| | a photo of <concept>_0 <concept>_1 ... and so on |
| | and instead just do |
| | a photo of <concept> |
| | which gets translated to the above. This needs to work for both inference and training. |
| | For inference, |
| | the tokenizer encodes the text. So, we would want logic for our tokenizer to replace the placeholder token with |
| | it's underlying vectors |
| | For training, |
| | we would want to abstract away some logic like |
| | 1. Adding tokens |
| | 2. Updating gradient mask |
| | 3. Saving embeddings |
| | to our Util class here. |
| | so |
| | TODO: |
| | 1. have tokenizer keep track of concept, multiconcept pairs and replace during encode call x |
| | 2. have mechanism for adding tokens x |
| | 3. have mech for saving emebeddings x |
| | 4. get mask to update x |
| | 5. Loading tokens from embedding x |
| | 6. Integrate to training x |
| | 7. Test |
| | """ |
| |
|
| | import copy |
| | import random |
| |
|
| | from transformers import CLIPTokenizer |
| |
|
| |
|
| | class MultiTokenCLIPTokenizer(CLIPTokenizer): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.token_map = {} |
| |
|
| | def try_adding_tokens(self, placeholder_token, *args, **kwargs): |
| | num_added_tokens = super().add_tokens(placeholder_token, *args, **kwargs) |
| | if num_added_tokens == 0: |
| | raise ValueError( |
| | f"The tokenizer already contains the token {placeholder_token}. Please pass a different" |
| | " `placeholder_token` that is not already in the tokenizer." |
| | ) |
| |
|
| | def add_placeholder_tokens(self, placeholder_token, *args, num_vec_per_token=1, **kwargs): |
| | output = [] |
| | if num_vec_per_token == 1: |
| | self.try_adding_tokens(placeholder_token, *args, **kwargs) |
| | output.append(placeholder_token) |
| | else: |
| | output = [] |
| | for i in range(num_vec_per_token): |
| | ith_token = placeholder_token + f"_{i}" |
| | self.try_adding_tokens(ith_token, *args, **kwargs) |
| | output.append(ith_token) |
| | |
| | for token in self.token_map: |
| | if token in placeholder_token: |
| | raise ValueError( |
| | f"The tokenizer already has placeholder token {token} that can get confused with" |
| | f" {placeholder_token}keep placeholder tokens independent" |
| | ) |
| | self.token_map[placeholder_token] = output |
| |
|
| | def replace_placeholder_tokens_in_text(self, text, vector_shuffle=False, prop_tokens_to_load=1.0): |
| | """ |
| | Here, we replace the placeholder tokens in text recorded in token_map so that the text_encoder |
| | can encode them |
| | vector_shuffle was inspired by https://github.com/rinongal/textual_inversion/pull/119 |
| | where shuffling tokens were found to force the model to learn the concepts more descriptively. |
| | """ |
| | if isinstance(text, list): |
| | output = [] |
| | for i in range(len(text)): |
| | output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle)) |
| | return output |
| | for placeholder_token in self.token_map: |
| | if placeholder_token in text: |
| | tokens = self.token_map[placeholder_token] |
| | tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)] |
| | if vector_shuffle: |
| | tokens = copy.copy(tokens) |
| | random.shuffle(tokens) |
| | text = text.replace(placeholder_token, " ".join(tokens)) |
| | return text |
| |
|
| | def __call__(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs): |
| | return super().__call__( |
| | self.replace_placeholder_tokens_in_text( |
| | text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load |
| | ), |
| | *args, |
| | **kwargs, |
| | ) |
| |
|
| | def encode(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs): |
| | return super().encode( |
| | self.replace_placeholder_tokens_in_text( |
| | text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load |
| | ), |
| | *args, |
| | **kwargs, |
| | ) |
| |
|