RioShiina commited on
Commit
992e30a
·
verified ·
1 Parent(s): ae50e07

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +9 -13
  2. app.py +515 -0
  3. requirements.txt +11 -0
README.md CHANGED
@@ -1,13 +1,9 @@
1
- ---
2
- title: Animated SDXL T2I With LoRAs
3
- emoji: 🐠
4
- colorFrom: green
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.42.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Animated-SDXL-T2I-with-LoRAs
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: true
9
+ ---
 
 
 
 
app.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import numpy as np
4
+ import PIL.Image
5
+ from PIL import Image, PngImagePlugin
6
+ import random
7
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, DDIMScheduler, UniPCMultistepScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler
8
+ import torch
9
+ from compel import Compel, ReturnedEmbeddingsType
10
+ import requests
11
+ import os
12
+ import re
13
+ import gc
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ # This dummy function is required to pass the Hugging Face Spaces startup check for GPU apps.
17
+ @spaces.GPU(duration=60)
18
+ def dummy_gpu_for_startup():
19
+ print("Dummy function for startup check executed. This is normal.")
20
+ return "Startup check passed."
21
+
22
+ # --- Constants ---
23
+ MAX_LORAS = 5
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ MAX_SEED = np.iinfo(np.int64).max
26
+ MAX_IMAGE_SIZE = 1216
27
+ SAMPLER_MAP = {
28
+ "Euler a": EulerAncestralDiscreteScheduler,
29
+ "Euler": EulerDiscreteScheduler,
30
+ "DPM++ 2M Karras": DPMSolverMultistepScheduler,
31
+ "DDIM": DDIMScheduler,
32
+ "UniPC": UniPCMultistepScheduler,
33
+ "Heun": HeunDiscreteScheduler,
34
+ "LMS": LMSDiscreteScheduler,
35
+ }
36
+ SCHEDULE_TYPE_MAP = ["Default", "Karras", "Uniform", "SGM Uniform"]
37
+ DEFAULT_SCHEDULE_TYPE = "Default"
38
+ DEFAULT_SAMPLER = "Euler a"
39
+ DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn,"
40
+ DOWNLOAD_DIR = "/tmp/loras"
41
+ os.makedirs(DOWNLOAD_DIR, exist_ok=True)
42
+
43
+ # --- Model Lists ---
44
+ MODEL_LIST = [
45
+ "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
46
+ "Laxhar/noobai-XL-Vpred-1.0",
47
+ "John6666/hassaku-xl-illustrious-v30-sdxl",
48
+ "RedRayz/hikari_noob_v-pred_1.2.2",
49
+ "bluepen5805/noob_v_pencil-XL",
50
+ "Laxhar/noobai-XL-1.1"
51
+ ]
52
+
53
+ # --- List of V-Prediction Models ---
54
+ V_PREDICTION_MODELS = [
55
+ "Laxhar/noobai-XL-Vpred-1.0",
56
+ "RedRayz/hikari_noob_v-pred_1.2.2",
57
+ "bluepen5805/noob_v_pencil-XL"
58
+ ]
59
+
60
+ # --- Dictionary for single-file models now stores the filename ---
61
+ SINGLE_FILE_MODELS = {
62
+ "bluepen5805/noob_v_pencil-XL": "noob_v_pencil-XL-v3.0.0.safetensors"
63
+ }
64
+
65
+ # --- Model Hash to Name Mapping ---
66
+ HASH_TO_MODEL_MAP = {
67
+ "bdb59bac77": "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
68
+ "ea349eeae8": "Laxhar/noobai-XL-Vpred-1.0",
69
+ "b4fb5f829a": "John6666/hassaku-xl-illustrious-v30-sdxl",
70
+ "6681e8e4b1": "Laxhar/noobai-XL-1.1",
71
+ "90b7911a78": "bluepen5805/noob_v_pencil-XL",
72
+ "874170688a": "RedRayz/hikari_noob_v-pred_1.2.2"
73
+ }
74
+
75
+ def get_civitai_file_info(version_id):
76
+ """Gets the file metadata for a model version via the Civitai API."""
77
+ api_url = f"https://civitai.com/api/v1/model-versions/{version_id}"
78
+ try:
79
+ response = requests.get(api_url)
80
+ response.raise_for_status()
81
+ data = response.json()
82
+ for file_data in data.get('files', []):
83
+ if file_data['name'].endswith('.safetensors'):
84
+ return file_data
85
+ if data.get('files'):
86
+ return data['files'][0]
87
+ return None
88
+ except Exception as e:
89
+ print(f"Could not get file info from Civitai API: {e}")
90
+ return None
91
+
92
+ def download_file(url, save_path, api_key=None, progress=None, desc=""):
93
+ """Downloads a file, skipping if it already exists."""
94
+ if os.path.exists(save_path):
95
+ return f"File already exists: {os.path.basename(save_path)}"
96
+
97
+ headers = {}
98
+ if api_key and api_key.strip():
99
+ headers['Authorization'] = f'Bearer {api_key}'
100
+
101
+ try:
102
+ if progress: progress(0, desc=desc)
103
+ response = requests.get(url, stream=True, headers=headers)
104
+ response.raise_for_status()
105
+
106
+ total_size = int(response.headers.get('content-length', 0))
107
+
108
+ with open(save_path, "wb") as f:
109
+ downloaded = 0
110
+ for chunk in response.iter_content(chunk_size=8192):
111
+ f.write(chunk)
112
+ if progress and total_size > 0:
113
+ downloaded += len(chunk)
114
+ progress(downloaded / total_size, desc=desc)
115
+
116
+ return f"Successfully downloaded: {os.path.basename(save_path)}"
117
+ except Exception as e:
118
+ if os.path.exists(save_path): os.remove(save_path)
119
+ return f"Download failed for {os.path.basename(save_path)}: {e}"
120
+
121
+ def process_long_prompt(compel_proc, prompt, negative_prompt=""):
122
+ try:
123
+ conditioning, pooled = compel_proc([prompt, negative_prompt])
124
+ return conditioning, pooled
125
+ except Exception:
126
+ return None, None
127
+
128
+ def pre_download_base_model(model_name, progress=gr.Progress(track_tqdm=True)):
129
+ if not model_name:
130
+ return "Please select a base model to download."
131
+
132
+ status_log = []
133
+ try:
134
+ progress(0, desc=f"Starting download for: {model_name}")
135
+
136
+ if model_name in SINGLE_FILE_MODELS:
137
+ filename = SINGLE_FILE_MODELS[model_name]
138
+ print(f"Pre-downloading single file: {filename} from repo: {model_name}")
139
+ local_path = hf_hub_download(repo_id=model_name, filename=filename)
140
+ pipe = StableDiffusionXLPipeline.from_single_file(
141
+ local_path,
142
+ torch_dtype=torch.float16,
143
+ use_safetensors=True
144
+ )
145
+ else:
146
+ print(f"Pre-downloading diffusers model: {model_name}")
147
+ pipe = StableDiffusionXLPipeline.from_pretrained(
148
+ model_name,
149
+ torch_dtype=torch.float16,
150
+ use_safetensors=True
151
+ )
152
+
153
+ status_log.append(f"✅ Successfully downloaded {model_name}")
154
+ del pipe
155
+ except Exception as e:
156
+ status_log.append(f"❌ Failed to download {model_name}: {e}")
157
+ finally:
158
+ gc.collect()
159
+ if torch.cuda.is_available():
160
+ torch.cuda.empty_cache()
161
+
162
+ return "\n".join(status_log)
163
+
164
+ def pre_download_loras(civitai_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)):
165
+ civitai_ids = lora_data[0::2]
166
+ status_log = []
167
+
168
+ active_lora_ids = [cid for cid in civitai_ids if cid and cid.strip()]
169
+ if not active_lora_ids:
170
+ return "No LoRA IDs provided to download."
171
+
172
+ for i, civitai_id in enumerate(active_lora_ids):
173
+ version_id = civitai_id.strip()
174
+ progress(i / len(active_lora_ids), desc=f"Getting URL for LoRA ID: {version_id}")
175
+
176
+ local_lora_path = os.path.join(DOWNLOAD_DIR, f"civitai_{version_id}.safetensors")
177
+
178
+ file_info = get_civitai_file_info(version_id)
179
+ if not file_info:
180
+ status_log.append(f"* LoRA ID {version_id}: Could not get file info from Civitai.")
181
+ continue
182
+
183
+ download_url = file_info.get('downloadUrl')
184
+ if not download_url:
185
+ status_log.append(f"* LoRA ID {version_id}: Could not get download link.")
186
+ continue
187
+
188
+ status = download_file(
189
+ download_url,
190
+ local_lora_path,
191
+ api_key=civitai_api_key,
192
+ progress=progress,
193
+ desc=f"Downloading LoRA ID: {version_id}"
194
+ )
195
+ status_log.append(f"* LoRA ID {version_id}: {status}")
196
+
197
+ return "\n".join(status_log)
198
+
199
+ def _infer_logic(base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
200
+ sampler, schedule_type,
201
+ civitai_api_key,
202
+ *lora_data,
203
+ progress=gr.Progress(track_tqdm=True)):
204
+
205
+ pipe = None
206
+ try:
207
+ progress(0, desc=f"Loading model: {base_model_name}")
208
+
209
+ if base_model_name in SINGLE_FILE_MODELS:
210
+ filename = SINGLE_FILE_MODELS[base_model_name]
211
+ print(f"Loading single file: {filename} from repo: {base_model_name}")
212
+ local_path = hf_hub_download(repo_id=base_model_name, filename=filename)
213
+ pipe = StableDiffusionXLPipeline.from_single_file(
214
+ local_path,
215
+ torch_dtype=torch.float16,
216
+ use_safetensors=True
217
+ )
218
+ else:
219
+ print(f"Loading diffusers model: {base_model_name}")
220
+ pipe = StableDiffusionXLPipeline.from_pretrained(
221
+ base_model_name,
222
+ torch_dtype=torch.float16,
223
+ use_safetensors=True
224
+ )
225
+ pipe.to(device)
226
+
227
+ batch_size = int(batch_size)
228
+ seed = int(seed)
229
+
230
+ pipe.unload_lora_weights()
231
+
232
+ scheduler_class = SAMPLER_MAP.get(sampler, EulerAncestralDiscreteScheduler)
233
+ scheduler_config = pipe.scheduler.config
234
+
235
+ if base_model_name in V_PREDICTION_MODELS:
236
+ scheduler_config['prediction_type'] = 'v_prediction'
237
+ else:
238
+ scheduler_config['prediction_type'] = 'epsilon'
239
+
240
+ scheduler_kwargs = {}
241
+ if schedule_type == "Default" and sampler == "DPM++ 2M Karras":
242
+ scheduler_kwargs['use_karras_sigmas'] = True
243
+ elif schedule_type == "Karras":
244
+ scheduler_kwargs['use_karras_sigmas'] = True
245
+ elif schedule_type == "Uniform":
246
+ scheduler_kwargs['use_karras_sigmas'] = False
247
+ elif schedule_type == "SGM Uniform":
248
+ scheduler_kwargs['algorithm_type'] = 'sgm_uniform'
249
+
250
+ pipe.scheduler = scheduler_class.from_config(scheduler_config, **scheduler_kwargs)
251
+
252
+ compel_type = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
253
+ compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2], text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
254
+ returned_embeddings_type=compel_type, requires_pooled=[False, True], truncate_long_prompts=False)
255
+
256
+ civitai_ids, lora_scales = lora_data[0::2], lora_data[1::2]
257
+ lora_params = list(zip(civitai_ids, lora_scales))
258
+ active_loras, active_lora_names_for_meta = [], []
259
+
260
+ for i, (civitai_id, lora_scale) in enumerate(lora_params):
261
+ if civitai_id and civitai_id.strip() and lora_scale > 0:
262
+ version_id = civitai_id.strip()
263
+ local_lora_path = os.path.join(DOWNLOAD_DIR, f"civitai_{version_id}.safetensors")
264
+
265
+ if not os.path.exists(local_lora_path):
266
+ file_info = get_civitai_file_info(version_id)
267
+ if not file_info:
268
+ print(f"Could not get file info for Civitai ID {version_id}, skipping.")
269
+ continue
270
+
271
+ download_url = file_info.get('downloadUrl')
272
+ if download_url:
273
+ download_file(download_url, local_lora_path, api_key=civitai_api_key, progress=progress, desc=f"Downloading LoRA ID {version_id}")
274
+ else:
275
+ print(f"Could not get download link for Civitai ID {version_id} during inference, skipping."); continue
276
+
277
+ if not os.path.exists(local_lora_path): print(f"LoRA file for ID {version_id} not found, skipping."); continue
278
+
279
+ adapter_name = f"lora_{i+1}"
280
+ progress((i * 0.1) + 0.05, desc=f"Loading LoRA (ID: {version_id})")
281
+ pipe.load_lora_weights(local_lora_path, adapter_name=adapter_name)
282
+ active_loras.append((adapter_name, lora_scale))
283
+ active_lora_names_for_meta.append(f"LoRA {i+1} (ID: {version_id}, Weight: {lora_scale})")
284
+
285
+ if active_loras:
286
+ adapter_names, adapter_weights = zip(*active_loras); pipe.set_adapters(list(adapter_names), list(adapter_weights))
287
+
288
+ conditioning, pooled = process_long_prompt(compel, prompt, negative_prompt)
289
+
290
+ pipe_args = {
291
+ "guidance_scale": guidance_scale,
292
+ "num_inference_steps": num_inference_steps,
293
+ "width": width,
294
+ "height": height,
295
+ }
296
+
297
+ output_images = []
298
+ loras_string = f"LoRAs: [{', '.join(active_lora_names_for_meta)}]" if active_lora_names_for_meta else ""
299
+
300
+ for i in range(batch_size):
301
+ progress(i / batch_size, desc=f"Generating image {i+1}/{batch_size}")
302
+
303
+ if i == 0 and seed != -1:
304
+ current_seed = seed
305
+ else:
306
+ current_seed = random.randint(0, MAX_SEED)
307
+
308
+ generator = torch.Generator(device=device).manual_seed(current_seed)
309
+ pipe_args["generator"] = generator
310
+
311
+ if conditioning is not None:
312
+ image = pipe(prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2], **pipe_args).images[0]
313
+ else:
314
+ image = pipe(prompt=prompt, negative_prompt=negative_prompt, **pipe_args).images[0]
315
+
316
+ params_string = f"{prompt}\nNegative prompt: {negative_prompt}\n"
317
+ params_string += f"Steps: {num_inference_steps}, Sampler: {sampler}, Schedule type: {schedule_type}, CFG scale: {guidance_scale}, Seed: {current_seed}, Size: {width}x{height}, Base Model: {base_model_name}, {loras_string}".strip()
318
+ image.info = {'parameters': params_string}
319
+ output_images.append(image)
320
+
321
+ return output_images
322
+
323
+ except Exception as e:
324
+ print(f"An error occurred during generation: {e}"); raise gr.Error(f"Generation failed: {e}")
325
+ finally:
326
+ if pipe is not None:
327
+ pipe.disable_lora()
328
+ del pipe
329
+ gc.collect()
330
+ if torch.cuda.is_available():
331
+ torch.cuda.empty_cache()
332
+
333
+ def infer(base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
334
+ sampler, schedule_type,
335
+ civitai_api_key,
336
+ zero_gpu_duration,
337
+ *lora_data,
338
+ progress=gr.Progress(track_tqdm=True)):
339
+
340
+ duration = 60
341
+ if zero_gpu_duration and int(zero_gpu_duration) > 0:
342
+ duration = int(zero_gpu_duration)
343
+
344
+ print(f"Using ZeroGPU duration: {duration} seconds")
345
+
346
+ decorated_infer_logic = spaces.GPU(duration=duration)(_infer_logic)
347
+
348
+ return decorated_infer_logic(
349
+ base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
350
+ sampler, schedule_type, civitai_api_key, *lora_data, progress=progress
351
+ )
352
+
353
+ def _parse_parameters(params_text):
354
+ data = {'lora_ids': [''] * MAX_LORAS, 'lora_scales': [0.0] * MAX_LORAS}
355
+ lines = params_text.strip().split('\n')
356
+ data['prompt'] = lines[0]
357
+ data['negative_prompt'] = lines[1].replace("Negative prompt:", "").strip() if len(lines) > 1 and lines[1].startswith("Negative prompt:") else ""
358
+ params_line = lines[2] if len(lines) > 2 else ""
359
+
360
+ def find_param(key, default, cast_type=str):
361
+ match = re.search(fr"\b{key}: ([^,]+?)(,|$)", params_line)
362
+ if match:
363
+ try:
364
+ return cast_type(match.group(1).strip())
365
+ except (ValueError, TypeError):
366
+ return default
367
+ return default
368
+
369
+ data['steps'] = find_param("Steps", 28, int)
370
+ data['sampler'] = find_param("Sampler", DEFAULT_SAMPLER)
371
+ data['schedule_type'] = find_param("Schedule type", DEFAULT_SCHEDULE_TYPE)
372
+ data['cfg_scale'] = find_param("CFG scale", 7.0, float)
373
+ data['seed'] = find_param("Seed", -1, int)
374
+ data['base_model'] = find_param("Base Model", MODEL_LIST[0])
375
+ data['model_hash'] = find_param("Model hash", None)
376
+
377
+ size_match = re.search(r"Size: (\d+)x(\d+)", params_line); data['width'], data['height'] = (int(size_match.group(1)), int(size_match.group(2))) if size_match else (1024, 1024)
378
+ if loras_match := re.search(r"LoRAs: \[(.+?)\]", params_line):
379
+ for i, (lora_id, lora_scale) in enumerate(re.findall(r"ID: (\d+), Weight: ([\d.]+)", loras_match.group(1))):
380
+ if i < MAX_LORAS: data['lora_ids'][i] = lora_id; data['lora_scales'][i] = float(lora_scale)
381
+ return data
382
+
383
+ def get_png_info(image):
384
+ if image is None: return "", "", "Please upload an image first."
385
+ params = image.info.get('parameters', None)
386
+ if not params: return "", "", "No metadata found in the image."
387
+ try:
388
+ parsed_data = _parse_parameters(params)
389
+ lines = params.strip().split('\n')
390
+ other_params_text = lines[2] if len(lines) > 2 else ""
391
+ other_params_display = "\n".join([p.strip() for p in other_params_text.split(',')])
392
+
393
+ return parsed_data.get('prompt', ''), parsed_data.get('negative_prompt', ''), other_params_display
394
+ except Exception as e:
395
+ return "", "", f"Error parsing metadata: {e}\n\nRaw metadata:\n{params}"
396
+
397
+ def send_info_to_txt2img(image):
398
+ if image is None or not (params := image.info.get('parameters', '')):
399
+ return [gr.update()] * (12 + MAX_LORAS * 2 + 1)
400
+
401
+ data = _parse_parameters(params)
402
+
403
+ model_from_hash = HASH_TO_MODEL_MAP.get(data.get('model_hash'))
404
+ final_base_model = model_from_hash if model_from_hash else data.get('base_model', MODEL_LIST[0])
405
+
406
+ sampler_from_png = data.get('sampler', DEFAULT_SAMPLER)
407
+ final_sampler = sampler_from_png if sampler_from_png in SAMPLER_MAP else DEFAULT_SAMPLER
408
+
409
+ schedule_from_png = data.get('schedule_type', DEFAULT_SCHEDULE_TYPE)
410
+ final_schedule_type = schedule_from_png if schedule_from_png in SCHEDULE_TYPE_MAP else DEFAULT_SCHEDULE_TYPE
411
+
412
+ updates = [final_base_model, data['prompt'], data['negative_prompt'], data['seed'], gr.update(), gr.update(), data['width'], data['height'],
413
+ data['cfg_scale'], data['steps'], final_sampler, final_schedule_type]
414
+
415
+ for i in range(MAX_LORAS): updates.extend([data['lora_ids'][i], data['lora_scales'][i]])
416
+ updates.append(gr.Tabs(selected=0))
417
+ return updates
418
+
419
+ with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo:
420
+ gr.Markdown("# Animated SDXL T2I with LoRAs")
421
+ with gr.Tabs(elem_id="tabs_container") as tabs:
422
+ with gr.TabItem("txt2img", id=0):
423
+ gr.Markdown("<div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>💡 <b>Tip:</b> Pre-downloading the base model and LoRAs before clicking 'Run' can maximize your ZeroGPU time.</div>")
424
+ with gr.Column(elem_id="col-container"):
425
+ with gr.Row():
426
+ with gr.Column(scale=3):
427
+ base_model_name = gr.Dropdown(label="Base Model", choices=MODEL_LIST, value="Laxhar/noobai-XL-Vpred-1.0")
428
+ with gr.Column(scale=2):
429
+ with gr.Row():
430
+ predownload_base_model_button = gr.Button("Pre-download Base Model")
431
+ predownload_lora_button = gr.Button("Pre-download LoRAs")
432
+ with gr.Column(scale=1, min_width=100):
433
+ run_button = gr.Button("Run", variant="primary")
434
+
435
+ predownload_status = gr.Markdown("")
436
+ prompt = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt")
437
+ negative_prompt = gr.Text(label="Negative prompt", lines=3, placeholder="Enter a negative prompt", value=DEFAULT_NEGATIVE_PROMPT)
438
+
439
+ # --- UI Layout ---
440
+ with gr.Row():
441
+ with gr.Column(scale=2):
442
+ with gr.Row():
443
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
444
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
445
+ with gr.Row():
446
+ sampler = gr.Dropdown(label="Sampling method", choices=list(SAMPLER_MAP.keys()), value=DEFAULT_SAMPLER)
447
+ schedule_type = gr.Dropdown(label="Schedule type", choices=SCHEDULE_TYPE_MAP, value=DEFAULT_SCHEDULE_TYPE)
448
+ with gr.Row():
449
+ guidance_scale = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7)
450
+ num_inference_steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=28)
451
+
452
+ with gr.Column(scale=1):
453
+ result = gr.Gallery(label="Result", show_label=False, elem_id="result_gallery", columns=2, object_fit="contain", height="auto")
454
+
455
+ with gr.Row():
456
+ seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
457
+ batch_size = gr.Slider(label="Batch size", minimum=1, maximum=8, step=1, value=1)
458
+ zero_gpu_duration = gr.Number(
459
+ label="ZeroGPU Duration (s)",
460
+ value=None,
461
+ placeholder="Default: 60s",
462
+ info="Optional: Leave empty for default (60s), max to 120"
463
+ )
464
+
465
+ with gr.Accordion("LoRA Settings", open=False):
466
+ gr.Markdown("⚠️ **Responsible Use Notice:** Please avoid excessive, rapid, or automated (scripted) use of the pre-download LoRA feature. Overt misuse may lead to service disruption. Thank you for your cooperation.")
467
+ civitai_api_key = gr.Textbox(label="Optional Civitai API Key", info="Get from your Civitai account settings...", placeholder="Enter your Civitai API Key here", type="password", show_label=True)
468
+ gr.Markdown("Find the Model Version ID in the LoRA page URL (e.g., `modelVersionId=12345`) and fill it in below.")
469
+ lora_rows, lora_civitai_id_inputs, lora_scale_inputs = [], [], []
470
+ for i in range(MAX_LORAS):
471
+ with gr.Row(visible=(i == 0)) as row:
472
+ lora_civitai_id = gr.Textbox(label=f"LoRA {i+1} - Civitai Model Version ID", placeholder="e.g.: 1834914")
473
+ lora_scale = gr.Slider(label=f"Weight {i+1}", minimum=0.0, maximum=2.0, step=0.05, value=0.0)
474
+ lora_rows.append(row); lora_civitai_id_inputs.append(lora_civitai_id); lora_scale_inputs.append(lora_scale)
475
+ with gr.Row():
476
+ add_lora_button = gr.Button("✚ Add LoRA", variant="secondary")
477
+ lora_count_state = gr.State(value=1)
478
+ all_lora_inputs = [item for pair in zip(lora_civitai_id_inputs, lora_scale_inputs) for item in pair]
479
+
480
+ with gr.TabItem("PNG Info", id=1):
481
+ with gr.Column(elem_id="col-container"):
482
+ gr.Markdown("Upload a generated image to view its generation data.")
483
+ info_image_input = gr.Image(type="pil", label="Upload Image")
484
+ with gr.Row():
485
+ info_get_button = gr.Button("Get Info", variant="secondary")
486
+ send_to_txt2img_button = gr.Button("Send to Txt-to-Image", variant="primary")
487
+ gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
488
+ gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
489
+ gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False)
490
+
491
+ gr.Markdown("<div style='text-align: center; margin-top: 20px;'>Made by RioShiina with ❤</div>")
492
+
493
+ def add_lora_row(current_count):
494
+ current_count = int(current_count)
495
+ if current_count < MAX_LORAS:
496
+ updates = {lora_count_state: current_count + 1, lora_rows[current_count]: gr.Row(visible=True)}
497
+ if current_count + 1 == MAX_LORAS: updates[add_lora_button] = gr.Button(visible=False)
498
+ return updates
499
+ return {lora_count_state: current_count}
500
+
501
+ add_lora_button.click(fn=add_lora_row, inputs=[lora_count_state], outputs=[lora_count_state, add_lora_button] + lora_rows)
502
+
503
+ predownload_base_model_button.click(fn=pre_download_base_model, inputs=[base_model_name], outputs=[predownload_status])
504
+ predownload_lora_button.click(fn=pre_download_loras, inputs=[civitai_api_key, *all_lora_inputs], outputs=[predownload_status])
505
+
506
+ run_button.click(fn=infer,
507
+ inputs=[base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps, sampler, schedule_type, civitai_api_key, zero_gpu_duration, *all_lora_inputs],
508
+ outputs=[result])
509
+
510
+ info_get_button.click(fn=get_png_info, inputs=[info_image_input], outputs=[info_prompt_output, info_neg_prompt_output, info_params_output])
511
+
512
+ txt2img_outputs = [base_model_name, prompt, negative_prompt, seed, batch_size, zero_gpu_duration, width, height, guidance_scale, num_inference_steps, sampler, schedule_type, *all_lora_inputs, tabs]
513
+ send_to_txt2img_button.click(fn=send_info_to_txt2img, inputs=[info_image_input], outputs=txt2img_outputs)
514
+
515
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ invisible_watermark
4
+ torch
5
+ transformers
6
+ xformers
7
+ compel
8
+ pydantic==2.10.6
9
+ gradio==5.12.0
10
+ requests
11
+ peft