| from functools import reduce |
| from modules import shared |
| from modules import extra_networks |
| from modules import prompt_parser |
| from modules import sd_hijack |
|
|
| |
| |
| def get_token_count(text, steps, is_positive: bool = True, return_tokens = False): |
| """ Get token count and max length for a given prompt text. If return_tokens is True, return the tokens as well. |
| Returns: |
| token_count: int - The total number of tokens in the prompt text |
| max_length: int - The maximum length of the prompt text |
| """ |
| try: |
| text, _ = extra_networks.parse_prompt(text) |
|
|
| if is_positive: |
| _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) |
| else: |
| prompt_flat_list = [text] |
|
|
| prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) |
|
|
| except Exception: |
| |
| |
| prompt_schedules = [[[steps, text]]] |
|
|
| flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) |
| prompts = [prompt_text for step, prompt_text in flat_prompts] |
|
|
| token_count, max_length = max([sd_hijack.model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) |
| return token_count, max_length |
|
|
|
|
| def tokenize_prompt(text): |
| """ Tokenize the given prompt text using the current clip model. |
| Arguments: |
| text: str - The prompt text to tokenize |
| |
| Returns: |
| tokens: list[int] - If return_tokens is True, return the tokenized prompt as well |
| |
| """ |
| if isinstance(text, str): |
| prompts = [text] |
|
|
| clip = getattr(sd_hijack.model_hijack, 'clip', None) |
| if clip is None: |
| return None, None |
| batch_chunks, token_count = clip.process_texts(prompts) |
| return batch_chunks, token_count |
|
|
|
|
| def decode_tokenized_prompt(tokens): |
| """ Decode the given tokenized prompt using the current clip model. |
| Arguments: |
| tokens: list[int] - The tokenized prompt to decode |
| Returns: |
| a list of tuples containing the token index, token, and decoded token |
| |
| """ |
| clip = getattr(sd_hijack.model_hijack, 'clip', None) |
| if clip is None: |
| return None |
| decoded_prompt = [ |
| [token_idx, token, clip.tokenizer.decoder[token]] for token_idx, token in enumerate(tokens) |
| ] |
| return decoded_prompt |