huchukato commited on
Commit
2095c8d
·
1 Parent(s): 4f90077
Files changed (7) hide show
  1. .DS_Store +0 -0
  2. app.py +444 -126
  3. config.py +33 -0
  4. config.toml +114 -0
  5. requirements.txt +8 -6
  6. style.css +212 -0
  7. utils.py +187 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py CHANGED
@@ -1,153 +1,471 @@
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "huchukato/pimp-my-pony" # Replace to the model you would like to use
 
 
 
 
 
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
- # @spaces.GPU #[uncomment to use ZeroGPU]
24
- def infer(
25
- prompt,
26
- negative_prompt,
27
- seed,
28
- randomize_seed,
29
- width,
30
- height,
31
- guidance_scale,
32
- num_inference_steps,
33
- progress=gr.Progress(track_tqdm=True),
34
- ):
35
- if randomize_seed:
36
- seed = random.randint(0, MAX_SEED)
37
-
38
- generator = torch.Generator().manual_seed(seed)
39
-
40
- image = pipe(
41
- prompt=prompt,
42
- negative_prompt=negative_prompt,
43
- guidance_scale=guidance_scale,
44
- num_inference_steps=num_inference_steps,
45
- width=width,
46
- height=height,
47
- generator=generator,
48
- ).images[0]
49
-
50
- return image, seed
51
-
52
-
53
- examples = [
54
- "source_9, source_8_up, source_7_up, source_6_up, in a loft, dramatic lighting, hard shadows, 1girl, 2b_\(nier:automata\), hairband, white hair, blindfold, covered eyes, black blindfold sleeveless dress, large breasts",
55
- "source_9, source_8_up, source_7_up, source_6_up, in a japanese house, at night, dramatic lighting, hard shadows, 1girl, gareth \(fate\), fate/grand order large breasts, solo, cleavage, covered nipples, arm under breast",
56
- "source_9, source_8_up, source_7_up, source_6_up, in a greek temple, at night, dramatic lighting, hard shadows, 1girl, kido saori, light purple hair, very long hair, bangs, aqua eyes, sleeveless dress, long dress, puffy dress, large breasts",
57
- ]
58
-
59
- css = """
60
- #col-container {
61
- margin: 0 auto;
62
- max-width: 640px;
63
- }
64
- """
65
-
66
- with gr.Blocks(css=css) as demo:
67
- with gr.Column(elem_id="col-container"):
68
- gr.Markdown(" # Text-to-Image PimpMyPony")
69
-
70
- with gr.Row():
71
- prompt = gr.Text(
72
- label="Prompt",
73
- show_label=False,
74
- max_lines=2,
75
- placeholder="Enter your prompt",
76
- container=False,
77
- )
78
 
79
- run_button = gr.Button("Run", scale=0, variant="primary")
 
 
 
80
 
81
- result = gr.Image(label="Result", show_label=False)
 
82
 
83
- with gr.Accordion("Advanced Settings", open=False):
84
- negative_prompt = gr.Text(
85
- label="Negative prompt",
86
- max_lines=1,
87
- placeholder="Enter a negative prompt",
88
- visible=False,
89
- )
90
 
91
- seed = gr.Slider(
92
- label="Seed",
93
- minimum=0,
94
- maximum=MAX_SEED,
95
- step=1,
96
- value=0,
97
- )
 
 
 
 
 
 
 
 
 
98
 
99
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
100
 
101
- with gr.Row():
102
- width = gr.Slider(
103
- label="Width",
104
- minimum=256,
105
- maximum=MAX_IMAGE_SIZE,
106
- step=32,
107
- value=832, # Replace with defaults that work for your model
108
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- height = gr.Slider(
111
- label="Height",
112
- minimum=256,
113
- maximum=MAX_IMAGE_SIZE,
114
- step=32,
115
- value=1216, # Replace with defaults that work for your model
116
- )
 
 
 
 
 
 
 
117
 
118
- with gr.Row():
119
- guidance_scale = gr.Slider(
120
- label="Guidance scale",
121
- minimum=0.0,
122
- maximum=10.0,
123
- step=0.1,
124
- value=6, # Replace with defaults that work for your model
125
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- num_inference_steps = gr.Slider(
128
- label="Number of inference steps",
129
- minimum=1,
130
- maximum=50,
131
- step=1,
132
- value=20, # Replace with defaults that work for your model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- gr.Examples(examples=examples, inputs=[prompt])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  gr.on(
137
- triggers=[run_button.click, prompt.submit],
138
- fn=infer,
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  inputs=[
140
  prompt,
141
  negative_prompt,
142
  seed,
143
- randomize_seed,
144
- width,
145
- height,
146
  guidance_scale,
147
  num_inference_steps,
 
 
 
 
 
 
 
148
  ],
149
- outputs=[result, seed],
 
 
 
150
  )
151
 
152
  if __name__ == "__main__":
153
- demo.launch()
 
1
+ import os
2
+ import gc
3
  import gradio as gr
4
  import numpy as np
 
 
 
 
5
  import torch
6
+ import json
7
+ import spaces
8
+ import config
9
+ import utils
10
+ import logging
11
+ from PIL import Image, PngImagePlugin
12
+ from datetime import datetime
13
+ from diffusers.models import AutoencoderKL
14
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
15
+ from config import (
16
+ MODEL,
17
+ MIN_IMAGE_SIZE,
18
+ MAX_IMAGE_SIZE,
19
+ USE_TORCH_COMPILE,
20
+ ENABLE_CPU_OFFLOAD,
21
+ OUTPUT_DIR,
22
+ DEFAULT_NEGATIVE_PROMPT,
23
+ DEFAULT_ASPECT_RATIO,
24
+ examples,
25
+ sampler_list,
26
+ aspect_ratios,
27
+ style_list,
28
+ )
29
+ import time
30
+ from typing import List, Dict, Tuple, Optional
31
 
32
+ # Enhanced logging configuration
33
+ logging.basicConfig(
34
+ level=logging.INFO,
35
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
36
+ datefmt='%Y-%m-%d %H:%M:%S'
37
+ )
38
+ logger = logging.getLogger(__name__)
39
 
40
+ # Constants
41
+ IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
42
+ HF_TOKEN = os.getenv("HF_TOKEN")
43
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ # PyTorch settings for better performance and determinism
46
+ torch.backends.cudnn.deterministic = True
47
+ torch.backends.cudnn.benchmark = False
48
+ torch.backends.cuda.matmul.allow_tf32 = True
49
 
50
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
51
+ logger.info(f"Using device: {device}")
52
 
53
+ class GenerationError(Exception):
54
+ """Custom exception for generation errors"""
55
+ pass
 
 
 
 
56
 
57
+ def validate_prompt(prompt: str) -> str:
58
+ """Validate and clean up the input prompt."""
59
+ if not isinstance(prompt, str):
60
+ raise GenerationError("Prompt must be a string")
61
+ try:
62
+ # Ensure proper UTF-8 encoding/decoding
63
+ prompt = prompt.encode('utf-8').decode('utf-8')
64
+ # Add space between ! and ,
65
+ prompt = prompt.replace("!,", "! ,")
66
+ except UnicodeError:
67
+ raise GenerationError("Invalid characters in prompt")
68
+
69
+ # Only check if the prompt is completely empty or only whitespace
70
+ if not prompt or prompt.isspace():
71
+ raise GenerationError("Prompt cannot be empty")
72
+ return prompt.strip()
73
 
74
+ def validate_dimensions(width: int, height: int) -> None:
75
+ """Validate image dimensions."""
76
+ if not MIN_IMAGE_SIZE <= width <= MAX_IMAGE_SIZE:
77
+ raise GenerationError(f"Width must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}")
78
+
79
+ if not MIN_IMAGE_SIZE <= height <= MAX_IMAGE_SIZE:
80
+ raise GenerationError(f"Height must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}")
81
 
82
+ @spaces.GPU
83
+ def generate(
84
+ prompt: str,
85
+ negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
86
+ seed: int = 0,
87
+ custom_width: int = 1024,
88
+ custom_height: int = 1024,
89
+ guidance_scale: float = 6.0,
90
+ num_inference_steps: int = 25,
91
+ sampler: str = "Euler a",
92
+ aspect_ratio_selector: str = DEFAULT_ASPECT_RATIO,
93
+ style_selector: str = "(None)",
94
+ use_upscaler: bool = False,
95
+ upscaler_strength: float = 0.55,
96
+ upscale_by: float = 1.5,
97
+ add_quality_tags: bool = True,
98
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
99
+ ) -> Tuple[List[str], Dict]:
100
+ """Generate images based on the given parameters."""
101
+ start_time = time.time()
102
+ upscaler_pipe = None
103
+ backup_scheduler = None
104
+
105
+ try:
106
+ # Memory management
107
+ torch.cuda.empty_cache()
108
+ gc.collect()
109
 
110
+ # Input validation
111
+ prompt = validate_prompt(prompt)
112
+ if negative_prompt:
113
+ negative_prompt = negative_prompt.encode('utf-8').decode('utf-8')
114
+
115
+ validate_dimensions(custom_width, custom_height)
116
+
117
+ # Set up generation
118
+ generator = utils.seed_everything(seed)
119
+ width, height = utils.aspect_ratio_handler(
120
+ aspect_ratio_selector,
121
+ custom_width,
122
+ custom_height,
123
+ )
124
 
125
+ # Process prompts
126
+ if add_quality_tags:
127
+ prompt = "masterpiece, high score, great score, absurdres, {prompt}".format(prompt=prompt)
128
+
129
+ prompt, negative_prompt = utils.preprocess_prompt(
130
+ styles, style_selector, prompt, negative_prompt
131
+ )
132
+
133
+ width, height = utils.preprocess_image_dimensions(width, height)
134
+
135
+ # Set up pipeline
136
+ backup_scheduler = pipe.scheduler
137
+ pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
138
+
139
+ if use_upscaler:
140
+ upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
141
+
142
+ # Prepare metadata
143
+ metadata = {
144
+ "prompt": prompt,
145
+ "negative_prompt": negative_prompt,
146
+ "resolution": f"{width} x {height}",
147
+ "guidance_scale": guidance_scale,
148
+ "num_inference_steps": num_inference_steps,
149
+ "style_preset": style_selector,
150
+ "seed": seed,
151
+ "sampler": sampler,
152
+ "Model": "PimpMyPony",
153
+ "Model hash": "e3c47aedb0",
154
+ }
155
+
156
+ if use_upscaler:
157
+ new_width = int(width * upscale_by)
158
+ new_height = int(height * upscale_by)
159
+ metadata["use_upscaler"] = {
160
+ "upscale_method": "nearest-exact",
161
+ "upscaler_strength": upscaler_strength,
162
+ "upscale_by": upscale_by,
163
+ "new_resolution": f"{new_width} x {new_height}",
164
+ }
165
+ else:
166
+ metadata["use_upscaler"] = None
167
+
168
+ logger.info(f"Starting generation with parameters: {json.dumps(metadata, indent=4)}")
169
+
170
+ # Generate images
171
+ if use_upscaler:
172
+ latents = pipe(
173
+ prompt=prompt,
174
+ negative_prompt=negative_prompt,
175
+ width=width,
176
+ height=height,
177
+ guidance_scale=guidance_scale,
178
+ num_inference_steps=num_inference_steps,
179
+ generator=generator,
180
+ output_type="latent",
181
+ ).images
182
+ upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
183
+ images = upscaler_pipe(
184
+ prompt=prompt,
185
+ negative_prompt=negative_prompt,
186
+ image=upscaled_latents,
187
+ guidance_scale=guidance_scale,
188
+ num_inference_steps=num_inference_steps,
189
+ strength=upscaler_strength,
190
+ generator=generator,
191
+ output_type="pil",
192
+ ).images
193
+ else:
194
+ images = pipe(
195
+ prompt=prompt,
196
+ negative_prompt=negative_prompt,
197
+ width=width,
198
+ height=height,
199
+ guidance_scale=guidance_scale,
200
+ num_inference_steps=num_inference_steps,
201
+ generator=generator,
202
+ output_type="pil",
203
+ ).images
204
+
205
+ # Save images
206
+ if images:
207
+ total = len(images)
208
+ image_paths = []
209
+ for idx, image in enumerate(images, 1):
210
+ progress(idx/total, desc="Saving images...")
211
+ path = utils.save_image(image, metadata, OUTPUT_DIR, IS_COLAB)
212
+ image_paths.append(path)
213
+ logger.info(f"Image {idx}/{total} saved as {path}")
214
 
215
+ generation_time = time.time() - start_time
216
+ logger.info(f"Generation completed successfully in {generation_time:.2f} seconds")
217
+ metadata["generation_time"] = f"{generation_time:.2f}s"
218
+
219
+ return image_paths, metadata
220
+
221
+ except GenerationError as e:
222
+ logger.warning(f"Generation validation error: {str(e)}")
223
+ raise gr.Error(str(e))
224
+ except Exception as e:
225
+ logger.exception("Unexpected error during generation")
226
+ raise gr.Error(f"Generation failed: {str(e)}")
227
+ finally:
228
+ # Cleanup
229
+ torch.cuda.empty_cache()
230
+ gc.collect()
231
+
232
+ if upscaler_pipe is not None:
233
+ del upscaler_pipe
234
+
235
+ if backup_scheduler is not None and pipe is not None:
236
+ pipe.scheduler = backup_scheduler
237
+
238
+ utils.free_memory()
239
+
240
+ # Model initialization
241
+ if torch.cuda.is_available():
242
+ try:
243
+ logger.info("Loading VAE and pipeline...")
244
+ vae = AutoencoderKL.from_pretrained(
245
+ "madebyollin/sdxl-vae-fp16-fix",
246
+ torch_dtype=torch.float16,
247
+ )
248
+ pipe = utils.load_pipeline(MODEL, device, vae=vae)
249
+ logger.info("Pipeline loaded successfully on GPU!")
250
+ except Exception as e:
251
+ logger.error(f"Error loading VAE, falling back to default: {e}")
252
+ pipe = utils.load_pipeline(MODEL, device)
253
+ else:
254
+ logger.warning("CUDA not available, running on CPU")
255
+ pipe = None
256
+
257
+ # Process styles
258
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
259
+
260
+ with gr.Blocks(css="style.css", theme="Nymbo/Nymbo_Theme_5") as demo:
261
+ gr.HTML(
262
+ """
263
+ <div class="header">
264
+ <div class="title">Pimp My Pony</div>
265
+ </div>
266
+ """,
267
+ )
268
+
269
+ with gr.Row():
270
+ with gr.Column(scale=2):
271
+ with gr.Group():
272
+ prompt = gr.Text(
273
+ label="Prompt",
274
+ max_lines=5,
275
+ placeholder="Describe what you want to generate",
276
+ info="Enter your image generation prompt here. Be specific and descriptive for better results.",
277
+ )
278
+ negative_prompt = gr.Text(
279
+ label="Negative Prompt",
280
+ max_lines=5,
281
+ placeholder="Describe what you want to avoid",
282
+ value=DEFAULT_NEGATIVE_PROMPT,
283
+ info="Specify elements you don't want in the image.",
284
  )
285
+ add_quality_tags = gr.Checkbox(
286
+ label="Quality Tags",
287
+ value=True,
288
+ info="Add quality-enhancing tags to your prompt automatically.",
289
+ )
290
+ with gr.Accordion(label="More Settings", open=False):
291
+ with gr.Group():
292
+ aspect_ratio_selector = gr.Radio(
293
+ label="Aspect Ratio",
294
+ choices=aspect_ratios,
295
+ value=DEFAULT_ASPECT_RATIO,
296
+ container=True,
297
+ info="Choose the dimensions of your image.",
298
+ )
299
+ with gr.Group(visible=False) as custom_resolution:
300
+ with gr.Row():
301
+ custom_width = gr.Slider(
302
+ label="Width",
303
+ minimum=MIN_IMAGE_SIZE,
304
+ maximum=MAX_IMAGE_SIZE,
305
+ step=8,
306
+ value=1024,
307
+ info=f"Image width (must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})",
308
+ )
309
+ custom_height = gr.Slider(
310
+ label="Height",
311
+ minimum=MIN_IMAGE_SIZE,
312
+ maximum=MAX_IMAGE_SIZE,
313
+ step=8,
314
+ value=1024,
315
+ info=f"Image height (must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})",
316
+ )
317
+ with gr.Group():
318
+ use_upscaler = gr.Checkbox(
319
+ label="Use Upscaler",
320
+ value=False,
321
+ info="Enable high-resolution upscaling.",
322
+ )
323
+ with gr.Row() as upscaler_row:
324
+ upscaler_strength = gr.Slider(
325
+ label="Strength",
326
+ minimum=0,
327
+ maximum=1,
328
+ step=0.05,
329
+ value=0.55,
330
+ visible=False,
331
+ info="Control how much the upscaler affects the final image.",
332
+ )
333
+ upscale_by = gr.Slider(
334
+ label="Upscale by",
335
+ minimum=1,
336
+ maximum=1.5,
337
+ step=0.1,
338
+ value=1.5,
339
+ visible=False,
340
+ info="Multiplier for the final image resolution.",
341
+ )
342
+ with gr.Accordion(label="Advanced Parameters", open=False):
343
+ with gr.Group():
344
+ style_selector = gr.Dropdown(
345
+ label="Style Preset",
346
+ interactive=True,
347
+ choices=list(styles.keys()),
348
+ value="(None)",
349
+ info="Apply a predefined style to your generation.",
350
+ )
351
+ with gr.Group():
352
+ sampler = gr.Dropdown(
353
+ label="Sampler",
354
+ choices=sampler_list,
355
+ interactive=True,
356
+ value="Euler a",
357
+ info="Different samplers can produce varying results.",
358
+ )
359
+ with gr.Group():
360
+ seed = gr.Slider(
361
+ label="Seed",
362
+ minimum=0,
363
+ maximum=utils.MAX_SEED,
364
+ step=1,
365
+ value=0,
366
+ info="Set a specific seed for reproducible results.",
367
+ )
368
+ randomize_seed = gr.Checkbox(
369
+ label="Randomize seed",
370
+ value=True,
371
+ info="Generate a new random seed for each image.",
372
+ )
373
+ with gr.Group():
374
+ with gr.Row():
375
+ guidance_scale = gr.Slider(
376
+ label="Guidance scale",
377
+ minimum=1,
378
+ maximum=12,
379
+ step=0.1,
380
+ value=6.0,
381
+ info="Higher values make the image more closely match your prompt.",
382
+ )
383
+ num_inference_steps = gr.Slider(
384
+ label="Number of inference steps",
385
+ minimum=1,
386
+ maximum=50,
387
+ step=1,
388
+ value=25,
389
+ info="More steps generally mean higher quality but slower generation.",
390
+ )
391
+
392
+ with gr.Column(scale=3):
393
+ with gr.Blocks():
394
+ run_button = gr.Button("Generate", variant="primary", elem_id="generate-button")
395
+ result = gr.Gallery(
396
+ label="Generated Images",
397
+ columns=1,
398
+ height='768px',
399
+ preview=True,
400
+ show_label=True,
401
+ )
402
+ with gr.Accordion(label="Generation Parameters", open=False):
403
+ gr_metadata = gr.JSON(
404
+ label="Image Metadata",
405
+ show_label=True,
406
+ )
407
+ gr.Examples(
408
+ examples=examples,
409
+ inputs=prompt,
410
+ outputs=[result, gr_metadata],
411
+ fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
412
+ cache_examples=CACHE_EXAMPLES,
413
+ )
414
+
415
 
416
+ use_upscaler.change(
417
+ fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
418
+ inputs=use_upscaler,
419
+ outputs=[upscaler_strength, upscale_by],
420
+ queue=False,
421
+ api_name=False,
422
+ )
423
+ aspect_ratio_selector.change(
424
+ fn=lambda x: gr.update(visible=x == "Custom"),
425
+ inputs=aspect_ratio_selector,
426
+ outputs=custom_resolution,
427
+ queue=False,
428
+ api_name=False,
429
+ )
430
+
431
+ # Combine all triggers including keyboard shortcuts
432
  gr.on(
433
+ triggers=[
434
+ prompt.submit,
435
+ negative_prompt.submit,
436
+ run_button.click,
437
+ ],
438
+ fn=utils.randomize_seed_fn,
439
+ inputs=[seed, randomize_seed],
440
+ outputs=seed,
441
+ queue=False,
442
+ api_name=False,
443
+ ).then(
444
+ fn=lambda: gr.update(interactive=False, value="Generating..."),
445
+ outputs=run_button,
446
+ ).then(
447
+ fn=generate,
448
  inputs=[
449
  prompt,
450
  negative_prompt,
451
  seed,
452
+ custom_width,
453
+ custom_height,
 
454
  guidance_scale,
455
  num_inference_steps,
456
+ sampler,
457
+ aspect_ratio_selector,
458
+ style_selector,
459
+ use_upscaler,
460
+ upscaler_strength,
461
+ upscale_by,
462
+ add_quality_tags,
463
  ],
464
+ outputs=[result, gr_metadata],
465
+ ).then(
466
+ fn=lambda: gr.update(interactive=True, value="Generate"),
467
+ outputs=run_button,
468
  )
469
 
470
  if __name__ == "__main__":
471
+ demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)
config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tomli
3
+ from typing import Dict, Any
4
+
5
+ def fix_escaping(text: str) -> str:
6
+ # When JSON is loaded, \\\\ becomes \\ automatically
7
+ # So we don't need to do any transformation
8
+ return text
9
+
10
+ def load_config() -> Dict[str, Any]:
11
+ config_path = os.path.join(os.path.dirname(__file__), 'config.toml')
12
+ with open(config_path, 'rb') as f:
13
+ config = tomli.load(f)
14
+ return config
15
+
16
+ # Load configuration
17
+ config = load_config()
18
+
19
+ # Export variables for backward compatibility
20
+ MODEL = os.getenv("MODEL", config['model']['path'])
21
+ MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", config['model']['min_image_size']))
22
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", config['model']['max_image_size']))
23
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", str(config['model']['use_torch_compile'])).lower() == "true"
24
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", str(config['model']['enable_cpu_offload'])).lower() == "true"
25
+ OUTPUT_DIR = os.getenv("OUTPUT_DIR", config['model']['output_dir'])
26
+
27
+ DEFAULT_NEGATIVE_PROMPT = config['prompts']['default_negative']
28
+ DEFAULT_ASPECT_RATIO = config['prompts']['default_aspect_ratio']
29
+
30
+ examples = config['prompts']['examples']
31
+ sampler_list = config['samplers']['list']
32
+ aspect_ratios = config['aspect_ratios']['list']
33
+ style_list = config['styles']
config.toml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [model]
2
+ path = "huchukato/pimp-my-pony"
3
+ min_image_size = 512
4
+ max_image_size = 2048
5
+ use_torch_compile = false
6
+ enable_cpu_offload = false
7
+ output_dir = "./outputs"
8
+
9
+ [prompts]
10
+ default_negative = "score_5, score_4, score_3, worst quality, low quality, ugly, malformed, bad anatomy, grayscale"
11
+ default_aspect_ratio = "832 x 1216"
12
+ examples = [
13
+ "source_9, source_8_up, source_7_up, 1girl, souryuu asuka langley, neon genesis evangelion, eyepatch, red plugsuit, sitting, on throne, crossed legs, head tilt, holding weapon, lance of longinus \\\\(evangelion\\\\), cowboy shot, depth of field, faux traditional media, painterly, impressionism, photo background",
14
+ "source_9, source_8_up, source_7_up, 1boy, vash the stampede, trigun stampede, red jacket, sunglasses, gun, hand on own hip, aiming, standing, looking at viewer, upper body, desert, cliff, cowboy shot",
15
+ "source_9, source_8_up, source_7_up, 1girl, vertin \\(reverse:1999\\), reverse:1999, black umbrella, headwear, suitcase, looking at viewer, rain, night, city, bridge, from side, dutch angle, upper body",
16
+ "source_9, source_8_up, source_7_up, 1girl, 1boy, c.c., lelouch vi britannia, year 2024, code geass, ascot, bare shoulders, black choker, black hair, blue flower, blue rose, breasts, bush, choker, closed mouth, collarbone, couple, dress, flower, frills, green hair, hetero, long hair, long sleeves, looking at another, medium breasts, off-shoulder dress, off shoulder, parted lips, purple ascot, purple eyes, rose, short hair, sitting, straight hair, yellow eyes",
17
+ "source_9, source_8_up, source_7_up, 1girl, hatsune miku, vocaloid, blue eyes, blue hair, bowl, can, chopsticks, collared shirt, detached sleeves, eating, elbow rest, fish \\(food\\), food, holding, holding chopsticks, katsudon \\(food\\), long hair, long sleeves, looking at viewer, meal, nail polish, necktie, noodles, onigiri, plate, ramen, sashimi, shirt, shrimp, shrimp tempura, sleeveless, sleeveless shirt, solo, spring onion, tempura, twintails",
18
+ "source_9, source_8_up, source_7_up, 4girls, multiple girls, gotoh hitori, ijichi nijika, kita ikuyo, yamada ryo, bocchi the rock!, ahoge, black shirt, blank eyes, blonde hair, blue eyes, blue hair, brown sweater, collared shirt, cube hair ornament, detached ahoge, empty eyes, green eyes, hair ornament, hairclip, kessoku band, long sleeves, looking at viewer, medium hair, mole, mole under eye, one side up, pink hair, pink track suit, red eyes, red hair, sailor collar, school uniform, serafuku, shirt, shuka high school uniform, side ahoge, side ponytail, sweater, sweater vest, track suit, white shirt, yellow eyes, painterly, impressionism, faux traditional media, v, double v, waving",
19
+ "source_9, source_8_up, source_7_up, 1other, solo, outdoors, sky, arm up, night, earth, helmet, outstretched arm, star \\(sky\\), night sky, full moon, floating, starry sky, reaching, jumping, space, cowboy shot, ambiguous gender, spacesuit, moonlight, space helmet, astronaut, horror, black and white, monochromatic, high contrast, abstract background, dutch angle, dark, depth of field, chromatic aberration, faux traditional media"
20
+ ]
21
+
22
+ [samplers]
23
+ list = [
24
+ "DPM++ 2M Karras",
25
+ "DPM++ SDE Karras",
26
+ "DPM++ 2M SDE Karras",
27
+ "Euler",
28
+ "Euler a",
29
+ "DDIM"
30
+ ]
31
+
32
+ [aspect_ratios]
33
+ list = [
34
+ "1024 x 1024",
35
+ "1152 x 896",
36
+ "896 x 1152",
37
+ "1216 x 832",
38
+ "832 x 1216",
39
+ "1344 x 768",
40
+ "768 x 1344",
41
+ "1536 x 640",
42
+ "640 x 1536",
43
+ "Custom"
44
+ ]
45
+
46
+ [[styles]]
47
+ name = "(None)"
48
+ prompt = "{prompt}"
49
+ negative_prompt = ""
50
+
51
+ [[styles]]
52
+ name = "Anim4gine"
53
+ prompt = "{prompt}, depth of field, faux traditional media, painterly, impressionism, photo background"
54
+ negative_prompt = ""
55
+
56
+ [[styles]]
57
+ name = "Painting"
58
+ prompt = "{prompt}, painterly, painting (medium)"
59
+ negative_prompt = ""
60
+
61
+ [[styles]]
62
+ name = "Pixel art"
63
+ prompt = "{prompt}, pixel art"
64
+ negative_prompt = ""
65
+
66
+ [[styles]]
67
+ name = "1980s"
68
+ prompt = "{prompt}, 1980s (style), retro artstyle"
69
+ negative_prompt = ""
70
+
71
+ [[styles]]
72
+ name = "1990s"
73
+ prompt = "{prompt}, 1990s (style), retro artstyle"
74
+ negative_prompt = ""
75
+
76
+ [[styles]]
77
+ name = "2000s"
78
+ prompt = "{prompt}, 2000s (style), retro artstyle"
79
+ negative_prompt = ""
80
+
81
+ [[styles]]
82
+ name = "Toon"
83
+ prompt = "{prompt}, toon (style)"
84
+ negative_prompt = ""
85
+
86
+ [[styles]]
87
+ name = "Lineart"
88
+ prompt = "{prompt}, lineart, thick lineart"
89
+ negative_prompt = ""
90
+
91
+ [[styles]]
92
+ name = "Art Nouveau"
93
+ prompt = "{prompt}, art nouveau"
94
+ negative_prompt = ""
95
+
96
+ [[styles]]
97
+ name = "Western Comics"
98
+ prompt = "{prompt}, western comics (style)"
99
+ negative_prompt = ""
100
+
101
+ [[styles]]
102
+ name = "3D"
103
+ prompt = "{prompt}, 3d"
104
+ negative_prompt = ""
105
+
106
+ [[styles]]
107
+ name = "Realistic"
108
+ prompt = "{prompt}, realistic, photorealistic"
109
+ negative_prompt = ""
110
+
111
+ [[styles]]
112
+ name = "Neonpunk"
113
+ prompt = "{prompt}, neonpunk"
114
+ negative_prompt = ""
requirements.txt CHANGED
@@ -1,6 +1,8 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
1
+ accelerate>=1.2.1
2
+ diffusers>=0.32.1
3
+ gradio==4.44.1
4
+ hf-transfer>=0.1.9
5
+ spaces>=0.32.0
6
+ torch>=2.4.0
7
+ transformers>=4.48.0
8
+ tomli>=2.0.1
style.css ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --title-font-size: clamp(1.5rem, 6vw, 3rem);
3
+ --subtitle-font-size: clamp(1rem, 2vw, 1.2rem);
4
+ --text-color: #fff;
5
+ --font-family: 'Helvetica Neue', sans-serif;
6
+ --gradient-primary: linear-gradient(45deg, #4EACEF, #b428b2);
7
+ --primary-color: #b61383;
8
+ --primary-hover: #910f63;
9
+ --discord-color: #910ed2;
10
+ --discord-hover: #600962;
11
+ --border-radius: 12px;
12
+ --box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05);
13
+ }
14
+
15
+ body {
16
+ font-family: var(--font-family);
17
+ color: var(--text-color);
18
+ margin: 0;
19
+ padding: 0;
20
+ min-height: 100vh;
21
+ background-color: #f5f5f5;
22
+ }
23
+
24
+ .header {
25
+ text-align: center;
26
+ padding: 1rem 0;
27
+ margin-bottom: 0.5rem;
28
+ }
29
+
30
+ .title {
31
+ font-size: var(--title-font-size);
32
+ font-weight: 700;
33
+ text-transform: uppercase;
34
+ margin-bottom: 0.25rem;
35
+ background-image: var(--gradient-primary);
36
+ -webkit-text-fill-color: transparent;
37
+ -webkit-background-clip: text;
38
+ background-clip: text;
39
+ display: inline-block;
40
+ }
41
+
42
+ .subtitle {
43
+ font-size: var(--subtitle-font-size);
44
+ color: #999;
45
+ margin-bottom: 0.5rem;
46
+ }
47
+
48
+ .status {
49
+ display: none;
50
+ }
51
+
52
+ #duplicate-button {
53
+ margin: 1.5rem auto 2.5rem;
54
+ color: #fff;
55
+ background: #1565c0;
56
+ border-radius: 100vh;
57
+ padding: 0.75rem 1.5rem;
58
+ font-weight: 500;
59
+ box-shadow: 0 2px 8px rgba(21, 101, 192, 0.25);
60
+ transition: all 0.2s ease;
61
+ display: block;
62
+ }
63
+
64
+ #duplicate-button:hover {
65
+ background: #1976d2;
66
+ box-shadow: 0 4px 12px rgba(21, 101, 192, 0.35);
67
+ transform: translateY(-1px);
68
+ }
69
+
70
+ .contain {
71
+ max-width: 80%;
72
+ margin: 2rem auto;
73
+ padding: 2rem 1.5rem;
74
+ }
75
+
76
+ /* Component styling */
77
+ .gr-box {
78
+ border-radius: var(--border-radius);
79
+ border: 1px solid #e0e0e0;
80
+ background: #ffffff;
81
+ box-shadow: var(--box-shadow);
82
+ transition: box-shadow 0.2s ease;
83
+ }
84
+
85
+ .gr-box:hover {
86
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
87
+ }
88
+
89
+ .gr-button.primary {
90
+ background: var(--primary-color);
91
+ border-radius: var(--border-radius);
92
+ padding: 0.8rem 2rem;
93
+ font-weight: 500;
94
+ box-shadow: 0 2px 8px rgba(21, 101, 192, 0.25);
95
+ transition: all 0.2s ease;
96
+ text-transform: uppercase;
97
+ letter-spacing: 0.5px;
98
+ }
99
+
100
+ .gr-button.primary:hover {
101
+ background: var(--primary-hover);
102
+ box-shadow: 0 4px 12px rgba(21, 101, 192, 0.35);
103
+ transform: translateY(-1px);
104
+ }
105
+
106
+ /* Form elements */
107
+ .gr-form {
108
+ background: #fff;
109
+ padding: 1.5rem;
110
+ border-radius: var(--border-radius);
111
+ box-shadow: var(--box-shadow);
112
+ }
113
+
114
+ .gr-input, .gr-textarea {
115
+ border: 1px solid #e0e0e0;
116
+ border-radius: 8px;
117
+ padding: 0.8rem;
118
+ transition: all 0.2s ease;
119
+ }
120
+
121
+ .gr-input:focus, .gr-textarea:focus {
122
+ border-color: var(--primary-color);
123
+ box-shadow: 0 0 0 2px rgba(21, 101, 192, 0.1);
124
+ }
125
+
126
+ /* Accordion styling */
127
+ .gr-accordion {
128
+ border: none;
129
+ margin: 1rem 0;
130
+ }
131
+
132
+ .gr-accordion-header {
133
+ background: #f8f9fa;
134
+ border-radius: var(--border-radius);
135
+ padding: 1rem;
136
+ font-weight: 500;
137
+ }
138
+
139
+ /* Gallery styling */
140
+ .gr-gallery {
141
+ background: #fff;
142
+ padding: 1rem;
143
+ border-radius: var(--border-radius);
144
+ box-shadow: var(--box-shadow);
145
+ }
146
+
147
+ /* Discord button */
148
+ .discord-btn {
149
+ display: inline-flex;
150
+ align-items: center;
151
+ justify-content: center;
152
+ background-color: var(--discord-color);
153
+ color: white !important;
154
+ text-decoration: none;
155
+ padding: 12px 24px;
156
+ border-radius: var(--border-radius);
157
+ transition: all 0.3s ease;
158
+ margin-top: 1rem;
159
+ font-size: 16px;
160
+ font-weight: 500;
161
+ width: 100%;
162
+ border: none;
163
+ cursor: pointer;
164
+ box-shadow: 0 2px 8px rgba(88, 101, 242, 0.25);
165
+ }
166
+
167
+ .discord-btn:hover {
168
+ background-color: var(--discord-hover);
169
+ transform: translateY(-2px);
170
+ box-shadow: 0 4px 12px rgba(88, 101, 242, 0.4);
171
+ }
172
+
173
+ .discord-icon {
174
+ width: 24px;
175
+ height: 24px;
176
+ margin-right: 12px;
177
+ }
178
+
179
+ .discord-text {
180
+ letter-spacing: 0.5px;
181
+ }
182
+
183
+ /* Tooltips */
184
+ .gr-form small {
185
+ color: #666;
186
+ font-size: 0.875rem;
187
+ margin-top: 0.25rem;
188
+ display: block;
189
+ }
190
+
191
+ /* Responsive layout */
192
+ @media (max-width: 768px) {
193
+ .contain {
194
+ max-width: 90%;
195
+ padding: 1rem;
196
+ }
197
+
198
+ .gr-box {
199
+ margin: 0.5rem 0;
200
+ }
201
+
202
+ .gr-button.primary {
203
+ width: 100%;
204
+ }
205
+ }
206
+
207
+ @media (min-width: 1200px) {
208
+ .contain {
209
+ max-width: 1400px;
210
+ padding: 2.5rem 2rem;
211
+ }
212
+ }
utils.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import json
6
+ import torch
7
+ import uuid
8
+ from PIL import Image, PngImagePlugin
9
+ from datetime import datetime
10
+ from dataclasses import dataclass
11
+ from typing import Callable, Dict, Optional, Tuple, Any, List
12
+ from diffusers import (
13
+ DDIMScheduler,
14
+ DPMSolverMultistepScheduler,
15
+ DPMSolverSinglestepScheduler,
16
+ EulerAncestralDiscreteScheduler,
17
+ EulerDiscreteScheduler,
18
+ AutoencoderKL,
19
+ StableDiffusionXLPipeline,
20
+ )
21
+ import logging
22
+
23
+ MAX_SEED = np.iinfo(np.int32).max
24
+
25
+
26
+ @dataclass
27
+ class StyleConfig:
28
+ prompt: str
29
+ negative_prompt: str
30
+
31
+
32
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
33
+ if randomize_seed:
34
+ seed = random.randint(0, MAX_SEED)
35
+ return seed
36
+
37
+
38
+ def seed_everything(seed: int) -> torch.Generator:
39
+ torch.manual_seed(seed)
40
+ torch.cuda.manual_seed_all(seed)
41
+ np.random.seed(seed)
42
+ generator = torch.Generator()
43
+ generator.manual_seed(seed)
44
+ return generator
45
+
46
+
47
+ def parse_aspect_ratio(aspect_ratio: str) -> Optional[Tuple[int, int]]:
48
+ if aspect_ratio == "Custom":
49
+ return None
50
+ width, height = aspect_ratio.split(" x ")
51
+ return int(width), int(height)
52
+
53
+
54
+ def aspect_ratio_handler(
55
+ aspect_ratio: str, custom_width: int, custom_height: int
56
+ ) -> Tuple[int, int]:
57
+ if aspect_ratio == "Custom":
58
+ return custom_width, custom_height
59
+ else:
60
+ width, height = parse_aspect_ratio(aspect_ratio)
61
+ return width, height
62
+
63
+
64
+ def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
65
+ scheduler_factory_map = {
66
+ "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
67
+ scheduler_config, use_karras_sigmas=True
68
+ ),
69
+ "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
70
+ scheduler_config, use_karras_sigmas=True
71
+ ),
72
+ "DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
73
+ scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
74
+ ),
75
+ "Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
76
+ "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(
77
+ scheduler_config
78
+ ),
79
+ "DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
80
+ }
81
+ return scheduler_factory_map.get(name, lambda: None)()
82
+
83
+
84
+ def free_memory() -> None:
85
+ """Free up GPU and system memory."""
86
+ if torch.cuda.is_available():
87
+ torch.cuda.empty_cache()
88
+ torch.cuda.ipc_collect()
89
+ gc.collect()
90
+
91
+
92
+ def preprocess_prompt(
93
+ style_dict,
94
+ style_name: str,
95
+ positive: str,
96
+ negative: str = "",
97
+ add_style: bool = True,
98
+ ) -> Tuple[str, str]:
99
+ p, n = style_dict.get(style_name, style_dict["(None)"])
100
+
101
+ if add_style and positive.strip():
102
+ formatted_positive = p.format(prompt=positive)
103
+ else:
104
+ formatted_positive = positive
105
+
106
+ combined_negative = n
107
+ if negative.strip():
108
+ if combined_negative:
109
+ combined_negative += ", " + negative
110
+ else:
111
+ combined_negative = negative
112
+
113
+ return formatted_positive, combined_negative
114
+
115
+
116
+ def common_upscale(
117
+ samples: torch.Tensor,
118
+ width: int,
119
+ height: int,
120
+ upscale_method: str,
121
+ ) -> torch.Tensor:
122
+ return torch.nn.functional.interpolate(
123
+ samples, size=(height, width), mode=upscale_method
124
+ )
125
+
126
+
127
+ def upscale(
128
+ samples: torch.Tensor, upscale_method: str, scale_by: float
129
+ ) -> torch.Tensor:
130
+ width = round(samples.shape[3] * scale_by)
131
+ height = round(samples.shape[2] * scale_by)
132
+ return common_upscale(samples, width, height, upscale_method)
133
+
134
+
135
+ def preprocess_image_dimensions(width, height):
136
+ if width % 8 != 0:
137
+ width = width - (width % 8)
138
+ if height % 8 != 0:
139
+ height = height - (height % 8)
140
+ return width, height
141
+
142
+
143
+ def save_image(image, metadata, output_dir, is_colab):
144
+ if is_colab:
145
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
146
+ filename = f"image_{current_time}.png"
147
+ else:
148
+ filename = str(uuid.uuid4()) + ".png"
149
+ os.makedirs(output_dir, exist_ok=True)
150
+ filepath = os.path.join(output_dir, filename)
151
+ metadata_str = json.dumps(metadata)
152
+ info = PngImagePlugin.PngInfo()
153
+ info.add_text("parameters", metadata_str)
154
+ image.save(filepath, "PNG", pnginfo=info)
155
+ return filepath
156
+
157
+
158
+ def is_google_colab():
159
+ try:
160
+ import google.colab
161
+ return True
162
+ except:
163
+ return False
164
+
165
+
166
+ def load_pipeline(model_name: str, device: torch.device, hf_token: Optional[str] = None, vae: Optional[AutoencoderKL] = None) -> Any:
167
+ """Load the Stable Diffusion pipeline."""
168
+ try:
169
+ pipeline = (
170
+ StableDiffusionXLPipeline.from_single_file
171
+ if model_name.endswith(".safetensors")
172
+ else StableDiffusionXLPipeline.from_pretrained
173
+ )
174
+
175
+ pipe = pipeline(
176
+ model_name,
177
+ vae=vae,
178
+ torch_dtype=torch.float16,
179
+ custom_pipeline="lpw_stable_diffusion_xl",
180
+ use_safetensors=True,
181
+ add_watermarker=False
182
+ )
183
+ pipe.to(device)
184
+ return pipe
185
+ except Exception as e:
186
+ logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True)
187
+ raise