File size: 9,200 Bytes
f4a41d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""This module provides utility functions."""
from scripts.Logger import Logger
import os
import re
import requests
from typing import List

import modules
from modules.sd_models import read_state_dict
from modules.sd_models_config import (find_checkpoint_config, config_default, config_sd2, config_sd2v, config_sd2_inpainting,
                                      config_depth_model, config_unclip, config_unopenclip, config_inpainting, config_instruct_pix2pix, config_alt_diffusion)

import sys
sys.path.insert(0, os.path.join(os.path.dirname(
    os.path.abspath(__file__)), "scripts"))


class Utils():
    """
        methods that are needed in different classes
    """

    def __init__(self) -> None:
        self.logger = Logger()
        self.logger.debug = False
        script_path = os.path.dirname(
            os.path.dirname(os.path.abspath(__file__)))
        self.held_md_file_name = os.path.join(
            script_path, "HelpBatchCheckpointsPrompt.md")
        self.held_md_url = f"https://raw.githubusercontent.com/h43lb1t0/BatchCheckpointPrompt/main/{self.held_md_file_name}.md"

    def split_prompts(self, text: str) -> List[str]:
        """Split the prompts by the ; and remove empty strings and newlines

        Args:
            text (str): the input string
        Returns:
            List[str]: a list of prompts
        """
        prompt_list = text.split(";")
        return [prompt.replace('\n', '').strip(
        ) for prompt in prompt_list if not prompt.isspace() and prompt != '']


    def remove_index_from_string(self, input: str) -> str:
        """Remove the index from the string

        Args:
            input (str): the input string
        Returns:
            str: the string without the index
        """
        return re.sub(r"@index:\d+", "", input).strip()
    
    def remove_model_version_from_string(self, checkpoints_text: str) -> str:
        """Remove the model version from the string

        Args:
            input (str): the input string with all checkpoints
        Returns:
            str: the string without the model version
        """
        patterns = [
            '@version:sd1', 
            '@version:sd2', 
            '@version:sd2v', 
            '@version:sd2-inpainting',
            '@version:depth', 
            '@version:unclip', 
            '@version:unopenclip', 
            '@version:sd1-inpainting', 
            '@version:pix2pix',
            '@version:alt'
        ]
    
        # Iterate over the patterns and substitute them with an empty string
        for pattern in patterns:
            checkpoints_text = re.sub(pattern, '', checkpoints_text)

        return checkpoints_text

    def get_clean_checkpoint_path(self, checkpoint: str) -> str:
        """Remove the checkpoint hash from the filename

        Args:
            input (str): the input string with hash
        Returns:
            str: the string without the hash
        """
        return re.sub(r' \[.*?\]', '', checkpoint).strip()

    def getCheckpointListFromInput(self, checkpoints_text: str, clean: bool = True) -> List[str]:
        """Get a list of checkpoints from the input string

        Args:
            checkpoints_text (str): the input string with all checkpoints
            clean (bool): remove the index and hash from the string
        Returns:
            List[str]: a list of checkpoints
        """
        self.logger.debug_log(f"checkpoints: {checkpoints_text}")
        checkpoints_text = self.remove_model_version_from_string(checkpoints_text)
        if clean:
            checkpoints_text = self.remove_index_from_string(checkpoints_text)
            checkpoints_text = self.get_clean_checkpoint_path(checkpoints_text)
        checkpoints = checkpoints_text.split(",")
        checkpoints = [checkpoint.replace('\n', '').strip(
        ) for checkpoint in checkpoints if checkpoints if not checkpoint.isspace() and checkpoint != '']
        return checkpoints

    def get_help_md(self) -> str:
        """Gets the help md file. 
        If the file is not localy found downloads it from the github repository

        Returns:
            str: the help md file as a string
        """
        md = "could not get help file. Check Github for more information"
        if os.path.isfile(self.held_md_file_name):
            with open(self.held_md_file_name) as f:
                md = f.read()
        else:
            self.logger.debug_log("downloading help md")
            result = requests.get(self.held_md_url)
            if result.status_code == 200:
                with open(self.held_md_file_name, "wb") as file:
                    file.write(result.content)
                return self.get_help_md()
        return md

    def add_index_to_string(self, text: str, is_checkpoint: bool = True) -> str:
        """Add the index to the string

        Args:
            text (str): the input string
            is_checkpoint (bool): if the string is a checkpoint lits or a prompt list
        Returns:
            str: the string with the index
        """
        text_string = ""
        if is_checkpoint:
            checkpoint_List = self.getCheckpointListFromInput(text)
            for i, checkpoint in enumerate(checkpoint_List):
                text_string += f"{self.remove_index_from_string(checkpoint)} @index:{i},\n"
            return text_string
        else:
            prompt_list = self.split_prompts(text)
            for i, prompt in enumerate(prompt_list):
                text_string += f"{self.remove_index_from_string(prompt)} @index:{i};\n\n"
            return text_string

    def add_model_version_to_string(self, checkpoints_text: str) -> str:
        """Add the model version to the string.
        EXPERIMENTAL!

        Args:
            checkpoints_text (str): the input string with all checkpoints
        Returns:
            str: the string with the model version
        """
        text_string = ""
        checkpoints_not_cleaned = self.getCheckpointListFromInput(
            checkpoints_text, clean=False)
        checkpoints = self.getCheckpointListFromInput(checkpoints_text)
        for i, checkpoint in enumerate(checkpoints):
            info = modules.sd_models.get_closet_checkpoint_match(checkpoint)
            state_dict = read_state_dict(info.filename)
            version_string = find_checkpoint_config(state_dict, None)
            if version_string == config_default:
                version_string = "sd1"
            elif version_string == config_sd2:
                version_string = "sd2"
            elif version_string == config_sd2v:
                version_string = "sd2v"
            elif version_string == config_sd2_inpainting:
                version_string = "sd2-inpainting"
            elif version_string == config_depth_model:
                version_string = "depth"
            elif version_string == config_unclip:
                version_string = "unclip"
            elif version_string == config_unopenclip:
                version_string = "unopenclip"
            elif version_string == config_inpainting:
                version_string = "sd1-inpainting"
            elif version_string == config_instruct_pix2pix:
                version_string = "pix2pix"
            elif version_string == config_alt_diffusion:
                version_string = "alt"
            checkpoint_partly_cleaned = checkpoints_not_cleaned[i].replace(
                "\n", "").replace(",", "")
            text_string += f"{checkpoint_partly_cleaned} @version:{version_string},\n\n"
        return text_string

    def remove_element_at_index(self, checkpoints: str, prompts: str, index: List[int]) -> List[str]:
        """Remove the element at the given index from the string

        Args:
            checkpoints (str): the input string with all checkpoints
            prompts (str): the input string with all prompts
            index (List[int]): the indices to remove
        Returns:
            List[str]: a list with the new checkpoints and prompts
        """

        checkpoints_list = self.getCheckpointListFromInput(checkpoints)
        prompts_list = self.split_prompts(prompts)
        if (len(checkpoints_list) == len(prompts_list) or len(prompts_list) - len(index) <= 0 ):
            if max(index) <= len(checkpoints_list) -1:
                for i in index:
                    checkpoints_list.pop(i)
                    prompts_list.pop(i)
                checkpoints = ""
                for c in checkpoints_list:
                    checkpoints += f"{c},"
                prompts = ""
                for p in prompts_list:
                    prompts += f"{p};"
                result = [self.add_index_to_string(checkpoints, True), self.add_index_to_string(prompts, False)]
                self.logger.debug_log(f"result: {result}")
                return result
            else:
                self.logger.debug_log("index is out of range")
                return [checkpoints, prompts]
        else:
            self.logger.debug_log(
                f"checkpoints and prompts are not the same length cp: {len(checkpoints_list)} p: {len(prompts_list)}")
            return [checkpoints, prompts]