medallo commited on
Commit
b6ff6dc
·
verified ·
1 Parent(s): d49415d

Upload loraiterate.py

Browse files
Files changed (1) hide show
  1. loraiterate.py +275 -0
loraiterate.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import copy
4
+ import random
5
+ import math
6
+
7
+ import gradio as gr
8
+ from PIL import Image, ImageDraw, ImageFont
9
+
10
+ from modules import sd_samplers, errors, scripts, images, sd_models
11
+ from modules.paths_internal import roboto_ttf_file
12
+ from modules.processing import Processed, process_images
13
+ from modules.shared import state, cmd_opts, opts
14
+ from pathlib import Path
15
+
16
+ lora_dir = Path(cmd_opts.lora_dir).resolve()
17
+
18
+
19
+ def allowed_path(path):
20
+ return Path(path).resolve().is_relative_to(lora_dir)
21
+
22
+
23
+ def get_base_path(is_use_custom_path, custom_path):
24
+ return lora_dir.joinpath(custom_path) if is_use_custom_path else lora_dir
25
+
26
+
27
+ def is_directory_contain_lora(path):
28
+ try:
29
+ if allowed_path(path):
30
+ safetensor_files = [f for f in os.listdir(
31
+ path) if f.endswith('.safetensors')]
32
+ return len(safetensor_files) > 0
33
+ except FileNotFoundError:
34
+ pass
35
+ except Exception as e:
36
+ print(e)
37
+ return False
38
+
39
+
40
+ def get_directories(base_path, include_root=True):
41
+ directories = ["/"] if include_root else []
42
+ try:
43
+ if allowed_path(base_path):
44
+ for entry in os.listdir(base_path):
45
+ full_path = os.path.join(base_path, entry)
46
+ if os.path.isdir(full_path):
47
+ if is_directory_contain_lora(full_path):
48
+ directories.append(entry)
49
+ nested_directories = get_directories(
50
+ full_path, include_root=False)
51
+ directories.extend([os.path.join(entry, d)
52
+ for d in nested_directories])
53
+ except FileNotFoundError:
54
+ pass
55
+ except Exception as e:
56
+ print(e)
57
+ return directories
58
+
59
+
60
+ def read_json_file(file_path):
61
+ with open(file_path, 'r') as file:
62
+ return json.load(file)
63
+
64
+
65
+ def get_lora_name(lora_path):
66
+ if opts.lora_preferred_name == "Filename":
67
+ lora_name = lora_path.stem
68
+ else:
69
+ metadata = sd_models.read_metadata_from_safetensors(lora_path)
70
+ lora_name = metadata.get('ss_output_name', lora_path.stem)
71
+ return lora_name
72
+
73
+
74
+ def get_lora_prompt(lora_path, json_path):
75
+ with open(json_path, 'r', encoding='utf-8') as file:
76
+ data = json.load(file)
77
+ preferred_weight = data.get("preferred weight", 1)
78
+ activation_text = data.get("activation text", "")
79
+ try:
80
+ if float(preferred_weight) == 0:
81
+ preferred_weight = 1
82
+ except:
83
+ preferred_weight = 1
84
+ lora_name = get_lora_name(lora_path)
85
+ return f"<lora:{lora_name}:{preferred_weight}>, {activation_text},"
86
+
87
+
88
+ def image_grid_with_text(imgs, texts, rows=None, cols=None, font_path=None, font_size=20, text_color="#FFFFFF", stroke_color="#000000", stroke_width=2, add_text=True):
89
+ if rows is None:
90
+ rows = round(math.sqrt(len(imgs)))
91
+ cols = math.ceil(len(imgs) / rows) if cols is None else cols
92
+ w, h = imgs[0].size
93
+ grid = Image.new('RGB', (cols * w, rows * h), 'black')
94
+ for i, img in enumerate(imgs):
95
+ grid.paste(img, (i % cols * w, i // cols * h))
96
+ if add_text:
97
+ draw = ImageDraw.Draw(grid)
98
+ try:
99
+ font = ImageFont.truetype(font_path, font_size) if font_path and os.path.exists(
100
+ font_path) else ImageFont.truetype(roboto_ttf_file, font_size)
101
+ except:
102
+ font = ImageFont.truetype(roboto_ttf_file, font_size)
103
+ for i, text in enumerate(texts):
104
+ x = (i % cols) * w
105
+ y = (i // cols) * h
106
+ for dx, dy in [(j, k) for j in range(-stroke_width, stroke_width+1) for k in range(-stroke_width, stroke_width+1)]:
107
+ draw.text((x+5+dx, y+5+dy), text, font=font, fill=stroke_color)
108
+ draw.text((x+5, y+5), text, font=font, fill=text_color)
109
+ return grid
110
+
111
+
112
+ class Script(scripts.Script):
113
+ def title(self):
114
+ return "Apply on every Lora"
115
+
116
+ def ui(self, is_img2img):
117
+ def build_lora_tree(base_path):
118
+ tree = {"__root__": {"name": base_path.name, "children": {}}}
119
+ for root, dirs, files in os.walk(base_path):
120
+ rel_path = os.path.relpath(root, base_path)
121
+ current_node = tree["__root__"]
122
+ if rel_path != ".":
123
+ for part in rel_path.split(os.sep):
124
+ current_node = current_node["children"].setdefault(
125
+ part, {"name": part, "children": {}, "loras": []})
126
+
127
+ loras = [f[:-12] for f in files if f.endswith(".safetensors")]
128
+ current_node["loras"] = loras
129
+ return tree["__root__"]
130
+
131
+ def update_tree(is_use_custom, custom_path):
132
+ base_path = get_base_path(is_use_custom, custom_path)
133
+ return gr.Tree.update(value=build_lora_tree(base_path))
134
+
135
+ with gr.Column():
136
+ base_dir_checkbox = gr.Checkbox(
137
+ label="Use Custom Lora path", value=False)
138
+ base_dir_textbox = gr.Textbox(
139
+ label="Lora directory", visible=False)
140
+ with gr.Row():
141
+ lora_dir_dropdown = gr.Dropdown(
142
+ label="LORA Directory",
143
+ choices=["/"] + get_directories(lora_dir),
144
+ value="/",
145
+ interactive=True
146
+ )
147
+ refresh_btn = gr.Button("🔄", variant="tool")
148
+
149
+ lora_checkboxes = gr.CheckboxGroup(
150
+ label="Select LoRAs",
151
+ interactive=True
152
+ )
153
+
154
+ def update_directory(current_dir):
155
+ base_path = lora_dir.joinpath(current_dir.lstrip('/'))
156
+ loras = []
157
+ if allowed_path(base_path):
158
+ for root, _, files in os.walk(base_path):
159
+ for file in files:
160
+ if file.endswith(('.safetensors', '.pt')):
161
+ rel_path = os.path.relpath(root, lora_dir)
162
+ loras.append(
163
+ f"{rel_path}/{file}" if rel_path != '.' else file)
164
+ return gr.CheckboxGroup.update(choices=loras)
165
+
166
+ def scan_loras(current_dir):
167
+ return update_directory(current_dir)
168
+
169
+ lora_dir_dropdown.change(
170
+ fn=scan_loras,
171
+ inputs=[lora_dir_dropdown],
172
+ outputs=lora_checkboxes
173
+ )
174
+ refresh_btn.click(
175
+ fn=lambda: scan_loras(lora_dir_dropdown.value),
176
+ outputs=lora_checkboxes
177
+ )
178
+ prompt_lines = gr.Textbox(label="Prompts (one per line)", lines=5)
179
+ lora_tags_position_radio = gr.Radio(
180
+ ["Prepend", "Append"], value="Prepend", label="LoRA Tags Position")
181
+ checkbox_save_grid = gr.Checkbox(
182
+ label="Save grid image", value=True)
183
+ font_path = gr.Textbox(label="Custom Font Path")
184
+
185
+ with gr.Row():
186
+ use_random_seed = gr.Checkbox(
187
+ label="Random seed", value=True)
188
+ use_fixed_seed = gr.Checkbox(label="Fixed seed", value=False)
189
+
190
+ file_upload = gr.File(
191
+ label="Load prompts from file", file_types=[".txt"], type='binary')
192
+
193
+ def load_prompt_file(file, current_prompts):
194
+ if file is None:
195
+ return None, current_prompts, gr.update()
196
+ lines = [x.strip() for x in file.decode(
197
+ 'utf8', errors='ignore').split("\n")]
198
+ return None, "\n".join(lines), gr.update(lines=max(7, len(lines)))
199
+
200
+ file_upload.change(
201
+ fn=load_prompt_file,
202
+ inputs=[file_upload, prompt_lines],
203
+ outputs=[file_upload, prompt_lines, prompt_lines],
204
+ show_progress=False
205
+ )
206
+
207
+ base_dir_checkbox.change(
208
+ fn=lambda is_use, path: get_base_path(is_use, path),
209
+ inputs=[base_dir_checkbox, base_dir_textbox],
210
+ outputs=lora_dir_dropdown
211
+ )
212
+
213
+ return [base_dir_checkbox, base_dir_textbox, lora_checkboxes, prompt_lines, lora_tags_position_radio, checkbox_save_grid, font_path]
214
+
215
+ def run(self, p, is_use_custom_path, custom_path, lora_checkboxes, prompt_lines, lora_tags_position, is_save_grid, font_path):
216
+ selected_loras = [
217
+ str(lora_dir.joinpath(lora))
218
+ for lora in lora_checkboxes
219
+ if lora.endswith(('.safetensors', '.pt'))
220
+ ]
221
+
222
+ if not selected_loras or not prompt_lines:
223
+ return Processed(p, [], p.seed, "No LoRAs or prompts selected")
224
+
225
+ prompts = [line.strip()
226
+ for line in prompt_lines.splitlines() if line.strip()]
227
+ combinations = [(lora, prompt)
228
+ for lora in selected_loras for prompt in prompts]
229
+
230
+ state.job_count = len(combinations)
231
+ result_images = []
232
+ all_prompts = []
233
+ infotexts = []
234
+ grid_texts = []
235
+
236
+ for lora_path, prompt in combinations:
237
+ if state.interrupted:
238
+ break
239
+
240
+ current_p = copy.copy(p)
241
+ lora_file = Path(lora_path)
242
+ json_file = lora_file.with_suffix('.json')
243
+
244
+ try:
245
+ lora_tags = get_lora_prompt(
246
+ lora_file, json_file) if json_file.exists() else f"<lora:{lora_file.stem}:1>,"
247
+ except Exception as e:
248
+ print(f"Error loading Lora {lora_file}: {str(e)}")
249
+ continue
250
+
251
+ final_prompt = f"{lora_tags} {prompt}" if lora_tags_position == "Prepend" else f"{prompt} {lora_tags}"
252
+ current_p.prompt = final_prompt
253
+
254
+ proc = process_images(current_p)
255
+ result_images.extend(proc.images)
256
+ all_prompts.extend(proc.all_prompts)
257
+ infotexts.extend(proc.infotexts)
258
+ grid_texts.extend(
259
+ [f"{lora_file.stem}\n{prompt}"] * len(proc.images))
260
+
261
+ if is_save_grid and len(result_images) > 1:
262
+ rows = round(math.sqrt(len(result_images)))
263
+ grid_image = image_grid_with_text(
264
+ result_images, grid_texts,
265
+ rows=rows,
266
+ font_path=font_path,
267
+ text_color="#FFFFFF",
268
+ stroke_color="#000000",
269
+ stroke_width=2
270
+ )
271
+ images.save_image(grid_image, p.outpath_grids,
272
+ "grid", grid=True, p=p)
273
+ result_images.insert(0, grid_image)
274
+
275
+ return Processed(p, result_images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)