K1Z3M1112 commited on
Commit
fdb4331
·
verified ·
1 Parent(s): ff79d24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -43
app.py CHANGED
@@ -26,6 +26,7 @@ print(f"🖥️ Device: {device} | dtype: {dtype}")
26
 
27
  # Lazy import (to avoid long startup if unused)
28
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionPipeline
 
29
  from controlnet_aux import LineartDetector, LineartAnimeDetector
30
 
31
  # Memory optimization
@@ -46,6 +47,22 @@ LINEART_ANIME_DETECTOR = None
46
  CURRENT_T2I_PIPE = None
47
  CURRENT_T2I_MODEL = None
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def get_pipeline(model_name: str, anime_model: bool = False):
50
  """Get or create a ControlNet pipeline for the given model and anime flag"""
51
  global CURRENT_CONTROLNET_PIPE, CURRENT_CONTROLNET_KEY
@@ -57,6 +74,11 @@ def get_pipeline(model_name: str, anime_model: bool = False):
57
  print(f"✅ Reusing existing ControlNet pipeline: {model_name}, anime: {anime_model}")
58
  return CURRENT_CONTROLNET_PIPE
59
 
 
 
 
 
 
60
  # ถ้าเป็นโมเดลใหม่ ลบอันเก่าก่อน
61
  if CURRENT_CONTROLNET_PIPE is not None:
62
  print(f"🗑️ Unloading old ControlNet pipeline: {CURRENT_CONTROLNET_KEY}")
@@ -220,14 +242,30 @@ def load_t2i_model(model_name: str):
220
  torch.cuda.empty_cache()
221
 
222
  print(f"📥 Loading T2I model: {model_name}")
223
- CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
224
- model_name,
225
- torch_dtype=dtype,
226
- safety_checker=None,
227
- requires_safety_checker=False,
228
- use_safetensors=True,
229
- variant="fp16" if dtype == torch.float16 else None
230
- ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  # Optimizations
233
  CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
@@ -265,12 +303,20 @@ def load_t2i_model(model_name: str):
265
 
266
  # Retry without use_safetensors
267
  try:
268
- CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
269
- model_name,
270
- torch_dtype=dtype,
271
- safety_checker=None,
272
- requires_safety_checker=False
273
- ).to(device)
 
 
 
 
 
 
 
 
274
 
275
  CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
276
  if hasattr(CURRENT_T2I_PIPE, 'vae') and hasattr(CURRENT_T2I_PIPE.vae, 'enable_slicing'):
@@ -331,6 +377,22 @@ def resize_image(image, max_size=512):
331
  # ===== Functions =====
332
  def colorize(sketch, base_model, anime_model, prompt, seed, steps, scale, cn_weight):
333
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  # โหลด pipeline ที่เหมาะสม (จะลบอันเก่าออกอัตโนมัติถ้าเปลี่ยนโมเดล)
335
  pipe = get_pipeline(base_model, anime_model)
336
 
@@ -373,14 +435,28 @@ def t2i(prompt, model, seed, steps, scale, w, h):
373
  gen = torch.Generator(device=device).manual_seed(int(seed))
374
 
375
  with torch.inference_mode():
376
- result = CURRENT_T2I_PIPE(
377
- prompt,
378
- width=int(w),
379
- height=int(h),
380
- num_inference_steps=int(steps),
381
- guidance_scale=float(scale),
382
- generator=gen
383
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
  if device.type == "cuda":
386
  torch.cuda.empty_cache()
@@ -389,6 +465,14 @@ def t2i(prompt, model, seed, steps, scale, w, h):
389
  except Exception as e:
390
  print(f"❌ Error in t2i: {e}")
391
  error_img = Image.new('RGB', (int(w), int(h)), color='red')
 
 
 
 
 
 
 
 
392
  return error_img
393
 
394
  # ===== Function to unload all models =====
@@ -446,6 +530,7 @@ def unload_all_models():
446
  with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Soft()) as demo:
447
  gr.Markdown("# 🎨 Advanced Image Generation & Editing Suite")
448
  gr.Markdown("### Powered by Stable Diffusion & ControlNet")
 
449
 
450
  # Add system info
451
  if torch.cuda.is_available():
@@ -464,7 +549,7 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
464
  with gr.Tab("🎨 Colorize Sketch"):
465
  gr.Markdown("""
466
  ### Convert your sketches to colored images using ControlNet
467
- Upload a sketch or line art, and the AI will automatically colorize it based on your prompt.
468
  """)
469
 
470
  with gr.Row():
@@ -476,16 +561,9 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
476
 
477
  with gr.Row():
478
  base_model = gr.Dropdown(
479
- choices=[
480
- "digiplay/ChikMix_V3",
481
- "digiplay/chilloutmix_NiPrunedFp16Fix",
482
- "gsdf/Counterfeit-V2.5",
483
- "stablediffusionapi/anything-v5",
484
- "digiplay/CleanLinearMix_nsfw",
485
- "Laxhar/noobai-XL-1.1" # เพิ่มโมเดลใหม่
486
- ],
487
  value="digiplay/ChikMix_V3",
488
- label="Base Model"
489
  )
490
  anime_chk = gr.Checkbox(label="Use Anime ControlNet", value=True)
491
 
@@ -512,7 +590,8 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
512
  with gr.Tab("🖼️ Text-to-Image"):
513
  gr.Markdown("""
514
  ### Generate images from text descriptions
515
- Describe what you want to see, and the AI will create it for you.
 
516
  """)
517
 
518
  with gr.Row():
@@ -525,14 +604,7 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
525
  placeholder="e.g., a beautiful landscape with mountains and a lake at sunset, highly detailed, 4k"
526
  )
527
  t2i_model = gr.Dropdown(
528
- choices=[
529
- "digiplay/ChikMix_V3",
530
- "digiplay/chilloutmix_NiPrunedFp16Fix",
531
- "gsdf/Counterfeit-V2.5",
532
- "stablediffusionapi/anything-v5",
533
- "digiplay/CleanLinearMix_nsfw",
534
- "Laxhar/noobai-XL-1.1" # เพิ่มโมเดลใหม่
535
- ],
536
  value="digiplay/ChikMix_V3",
537
  label="Model"
538
  )
@@ -543,8 +615,9 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
543
  t2i_scale = gr.Slider(1, 20, 7.5, step=0.5, label="CFG Scale")
544
 
545
  with gr.Row():
546
- w = gr.Slider(256, 1024, 512, step=64, label="Width")
547
- h = gr.Slider(256, 1024, 768, step=64, label="Height")
 
548
 
549
  gen_btn = gr.Button("🖼️ Generate", variant="primary")
550
  gen_btn.click(
 
26
 
27
  # Lazy import (to avoid long startup if unused)
28
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionPipeline
29
+ from diffusers import StableDiffusionXLPipeline # สำหรับ SDXL models
30
  from controlnet_aux import LineartDetector, LineartAnimeDetector
31
 
32
  # Memory optimization
 
47
  CURRENT_T2I_PIPE = None
48
  CURRENT_T2I_MODEL = None
49
 
50
+ # Define model types
51
+ SDXL_MODELS = ["Laxhar/noobai-XL-1.1"] # เพิ่ม SDXL models ตรงนี้
52
+ SD15_MODELS = [
53
+ "digiplay/ChikMix_V3",
54
+ "digiplay/chilloutmix_NiPrunedFp16Fix",
55
+ "gsdf/Counterfeit-V2.5",
56
+ "stablediffusionapi/anything-v5",
57
+ "digiplay/CleanLinearMix_nsfw"
58
+ ]
59
+
60
+ ALL_MODELS = SD15_MODELS + SDXL_MODELS
61
+
62
+ def is_sdxl_model(model_name: str) -> bool:
63
+ """ตรวจสอบว่าโมเดลเป็น SDXL หรือไม่"""
64
+ return model_name in SDXL_MODELS
65
+
66
  def get_pipeline(model_name: str, anime_model: bool = False):
67
  """Get or create a ControlNet pipeline for the given model and anime flag"""
68
  global CURRENT_CONTROLNET_PIPE, CURRENT_CONTROLNET_KEY
 
74
  print(f"✅ Reusing existing ControlNet pipeline: {model_name}, anime: {anime_model}")
75
  return CURRENT_CONTROLNET_PIPE
76
 
77
+ # ถ้าเป็น SDXL model ให้แจ้งเตือนว่าไม่รองรับ ControlNet
78
+ if is_sdxl_model(model_name):
79
+ print(f"⚠️ SDXL model {model_name} is not compatible with ControlNet")
80
+ raise ValueError(f"SDXL model {model_name} is not compatible with ControlNet. Please use SD1.5 models for ControlNet.")
81
+
82
  # ถ้าเป็นโมเดลใหม่ ลบอันเก่าก่อน
83
  if CURRENT_CONTROLNET_PIPE is not None:
84
  print(f"🗑️ Unloading old ControlNet pipeline: {CURRENT_CONTROLNET_KEY}")
 
242
  torch.cuda.empty_cache()
243
 
244
  print(f"📥 Loading T2I model: {model_name}")
245
+
246
+ # ตรวจสอบว่าเป็น SDXL หรือ SD1.5 model
247
+ if is_sdxl_model(model_name):
248
+ # โหลด SDXL model
249
+ CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
250
+ model_name,
251
+ torch_dtype=dtype,
252
+ safety_checker=None,
253
+ requires_safety_checker=False,
254
+ use_safetensors=True,
255
+ variant="fp16" if dtype == torch.float16 else None
256
+ ).to(device)
257
+ print(f"✅ Loaded SDXL model: {model_name}")
258
+ else:
259
+ # โหลด SD1.5 model
260
+ CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
261
+ model_name,
262
+ torch_dtype=dtype,
263
+ safety_checker=None,
264
+ requires_safety_checker=False,
265
+ use_safetensors=True,
266
+ variant="fp16" if dtype == torch.float16 else None
267
+ ).to(device)
268
+ print(f"✅ Loaded SD1.5 model: {model_name}")
269
 
270
  # Optimizations
271
  CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
 
303
 
304
  # Retry without use_safetensors
305
  try:
306
+ if is_sdxl_model(model_name):
307
+ CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
308
+ model_name,
309
+ torch_dtype=dtype,
310
+ safety_checker=None,
311
+ requires_safety_checker=False
312
+ ).to(device)
313
+ else:
314
+ CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
315
+ model_name,
316
+ torch_dtype=dtype,
317
+ safety_checker=None,
318
+ requires_safety_checker=False
319
+ ).to(device)
320
 
321
  CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
322
  if hasattr(CURRENT_T2I_PIPE, 'vae') and hasattr(CURRENT_T2I_PIPE.vae, 'enable_slicing'):
 
377
  # ===== Functions =====
378
  def colorize(sketch, base_model, anime_model, prompt, seed, steps, scale, cn_weight):
379
  try:
380
+ # ตรวจสอบว่าเป็น SDXL model หรือไม่
381
+ if is_sdxl_model(base_model):
382
+ error_img = Image.new('RGB', (512, 512), color='red')
383
+ error_msg_img = Image.new('RGB', (512, 512), color='yellow')
384
+ # สร้างภาพแสดงข้อความ error
385
+ from PIL import ImageDraw, ImageFont
386
+ draw = ImageDraw.Draw(error_msg_img)
387
+ try:
388
+ font = ImageFont.truetype("arial.ttf", 20)
389
+ except:
390
+ font = ImageFont.load_default()
391
+ draw.text((50, 200), f"SDXL model not compatible", fill="black", font=font)
392
+ draw.text((50, 230), f"with ControlNet", fill="black", font=font)
393
+ draw.text((50, 260), f"Use SD1.5 models instead", fill="black", font=font)
394
+ return error_img, error_msg_img
395
+
396
  # โหลด pipeline ที่เหมาะสม (จะลบอันเก่าออกอัตโนมัติถ้าเปลี่ยนโมเดล)
397
  pipe = get_pipeline(base_model, anime_model)
398
 
 
435
  gen = torch.Generator(device=device).manual_seed(int(seed))
436
 
437
  with torch.inference_mode():
438
+ # สำหรับ SDXL model ใช้ขนาดเริ่มต้นที่ใหญ่กว่า
439
+ if is_sdxl_model(model):
440
+ # SDXL ต้องการขนาดขั้นต่ำ 1024x1024 สำหรับผลลัพธ์ที่ดี
441
+ width = max(int(w), 512)
442
+ height = max(int(h), 512)
443
+ result = CURRENT_T2I_PIPE(
444
+ prompt,
445
+ width=width,
446
+ height=height,
447
+ num_inference_steps=int(steps),
448
+ guidance_scale=float(scale),
449
+ generator=gen
450
+ ).images[0]
451
+ else:
452
+ result = CURRENT_T2I_PIPE(
453
+ prompt,
454
+ width=int(w),
455
+ height=int(h),
456
+ num_inference_steps=int(steps),
457
+ guidance_scale=float(scale),
458
+ generator=gen
459
+ ).images[0]
460
 
461
  if device.type == "cuda":
462
  torch.cuda.empty_cache()
 
465
  except Exception as e:
466
  print(f"❌ Error in t2i: {e}")
467
  error_img = Image.new('RGB', (int(w), int(h)), color='red')
468
+ # สร้างภาพแสดงข้อความ error
469
+ from PIL import ImageDraw, ImageFont
470
+ draw = ImageDraw.Draw(error_img)
471
+ try:
472
+ font = ImageFont.truetype("arial.ttf", 20)
473
+ except:
474
+ font = ImageFont.load_default()
475
+ draw.text((50, 50), f"Error: {str(e)[:50]}...", fill="white", font=font)
476
  return error_img
477
 
478
  # ===== Function to unload all models =====
 
530
  with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Soft()) as demo:
531
  gr.Markdown("# 🎨 Advanced Image Generation & Editing Suite")
532
  gr.Markdown("### Powered by Stable Diffusion & ControlNet")
533
+ gr.Markdown("**Note:** SDXL models (noobai-XL) work only in Text-to-Image tab, not in ControlNet")
534
 
535
  # Add system info
536
  if torch.cuda.is_available():
 
549
  with gr.Tab("🎨 Colorize Sketch"):
550
  gr.Markdown("""
551
  ### Convert your sketches to colored images using ControlNet
552
+ **Note:** SDXL models are not compatible with ControlNet. Please use SD1.5 models only.
553
  """)
554
 
555
  with gr.Row():
 
561
 
562
  with gr.Row():
563
  base_model = gr.Dropdown(
564
+ choices=SD15_MODELS, # ใช้เฉพาะ SD1.5 models
 
 
 
 
 
 
 
565
  value="digiplay/ChikMix_V3",
566
+ label="Base Model (SD1.5 only)"
567
  )
568
  anime_chk = gr.Checkbox(label="Use Anime ControlNet", value=True)
569
 
 
590
  with gr.Tab("🖼️ Text-to-Image"):
591
  gr.Markdown("""
592
  ### Generate images from text descriptions
593
+ Supports both SD1.5 and SDXL models.
594
+ **Tip:** SDXL models produce higher quality but require more memory.
595
  """)
596
 
597
  with gr.Row():
 
604
  placeholder="e.g., a beautiful landscape with mountains and a lake at sunset, highly detailed, 4k"
605
  )
606
  t2i_model = gr.Dropdown(
607
+ choices=ALL_MODELS, # ใช้ทั้ง SD1.5 และ SDXL
 
 
 
 
 
 
 
608
  value="digiplay/ChikMix_V3",
609
  label="Model"
610
  )
 
615
  t2i_scale = gr.Slider(1, 20, 7.5, step=0.5, label="CFG Scale")
616
 
617
  with gr.Row():
618
+ # สำหรับ SDXL ขอแนะนำขนาดที่ใหญ่กว่า
619
+ w = gr.Slider(256, 1536, 1024, step=64, label="Width")
620
+ h = gr.Slider(256, 1536, 1024, step=64, label="Height")
621
 
622
  gen_btn = gr.Button("🖼️ Generate", variant="primary")
623
  gen_btn.click(