Update app.py
Browse files
app.py
CHANGED
|
@@ -26,6 +26,7 @@ print(f"🖥️ Device: {device} | dtype: {dtype}")
|
|
| 26 |
|
| 27 |
# Lazy import (to avoid long startup if unused)
|
| 28 |
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionPipeline
|
|
|
|
| 29 |
from controlnet_aux import LineartDetector, LineartAnimeDetector
|
| 30 |
|
| 31 |
# Memory optimization
|
|
@@ -46,6 +47,22 @@ LINEART_ANIME_DETECTOR = None
|
|
| 46 |
CURRENT_T2I_PIPE = None
|
| 47 |
CURRENT_T2I_MODEL = None
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def get_pipeline(model_name: str, anime_model: bool = False):
|
| 50 |
"""Get or create a ControlNet pipeline for the given model and anime flag"""
|
| 51 |
global CURRENT_CONTROLNET_PIPE, CURRENT_CONTROLNET_KEY
|
|
@@ -57,6 +74,11 @@ def get_pipeline(model_name: str, anime_model: bool = False):
|
|
| 57 |
print(f"✅ Reusing existing ControlNet pipeline: {model_name}, anime: {anime_model}")
|
| 58 |
return CURRENT_CONTROLNET_PIPE
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# ถ้าเป็นโมเดลใหม่ ลบอันเก่าก่อน
|
| 61 |
if CURRENT_CONTROLNET_PIPE is not None:
|
| 62 |
print(f"🗑️ Unloading old ControlNet pipeline: {CURRENT_CONTROLNET_KEY}")
|
|
@@ -220,14 +242,30 @@ def load_t2i_model(model_name: str):
|
|
| 220 |
torch.cuda.empty_cache()
|
| 221 |
|
| 222 |
print(f"📥 Loading T2I model: {model_name}")
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
# Optimizations
|
| 233 |
CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
|
|
@@ -265,12 +303,20 @@ def load_t2i_model(model_name: str):
|
|
| 265 |
|
| 266 |
# Retry without use_safetensors
|
| 267 |
try:
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
|
| 276 |
if hasattr(CURRENT_T2I_PIPE, 'vae') and hasattr(CURRENT_T2I_PIPE.vae, 'enable_slicing'):
|
|
@@ -331,6 +377,22 @@ def resize_image(image, max_size=512):
|
|
| 331 |
# ===== Functions =====
|
| 332 |
def colorize(sketch, base_model, anime_model, prompt, seed, steps, scale, cn_weight):
|
| 333 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
# โหลด pipeline ที่เหมาะสม (จะลบอันเก่าออกอัตโนมัติถ้าเปลี่ยนโมเดล)
|
| 335 |
pipe = get_pipeline(base_model, anime_model)
|
| 336 |
|
|
@@ -373,14 +435,28 @@ def t2i(prompt, model, seed, steps, scale, w, h):
|
|
| 373 |
gen = torch.Generator(device=device).manual_seed(int(seed))
|
| 374 |
|
| 375 |
with torch.inference_mode():
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
if device.type == "cuda":
|
| 386 |
torch.cuda.empty_cache()
|
|
@@ -389,6 +465,14 @@ def t2i(prompt, model, seed, steps, scale, w, h):
|
|
| 389 |
except Exception as e:
|
| 390 |
print(f"❌ Error in t2i: {e}")
|
| 391 |
error_img = Image.new('RGB', (int(w), int(h)), color='red')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
return error_img
|
| 393 |
|
| 394 |
# ===== Function to unload all models =====
|
|
@@ -446,6 +530,7 @@ def unload_all_models():
|
|
| 446 |
with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Soft()) as demo:
|
| 447 |
gr.Markdown("# 🎨 Advanced Image Generation & Editing Suite")
|
| 448 |
gr.Markdown("### Powered by Stable Diffusion & ControlNet")
|
|
|
|
| 449 |
|
| 450 |
# Add system info
|
| 451 |
if torch.cuda.is_available():
|
|
@@ -464,7 +549,7 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
|
|
| 464 |
with gr.Tab("🎨 Colorize Sketch"):
|
| 465 |
gr.Markdown("""
|
| 466 |
### Convert your sketches to colored images using ControlNet
|
| 467 |
-
|
| 468 |
""")
|
| 469 |
|
| 470 |
with gr.Row():
|
|
@@ -476,16 +561,9 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
|
|
| 476 |
|
| 477 |
with gr.Row():
|
| 478 |
base_model = gr.Dropdown(
|
| 479 |
-
choices=
|
| 480 |
-
"digiplay/ChikMix_V3",
|
| 481 |
-
"digiplay/chilloutmix_NiPrunedFp16Fix",
|
| 482 |
-
"gsdf/Counterfeit-V2.5",
|
| 483 |
-
"stablediffusionapi/anything-v5",
|
| 484 |
-
"digiplay/CleanLinearMix_nsfw",
|
| 485 |
-
"Laxhar/noobai-XL-1.1" # เพิ่มโมเดลใหม่
|
| 486 |
-
],
|
| 487 |
value="digiplay/ChikMix_V3",
|
| 488 |
-
label="Base Model"
|
| 489 |
)
|
| 490 |
anime_chk = gr.Checkbox(label="Use Anime ControlNet", value=True)
|
| 491 |
|
|
@@ -512,7 +590,8 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
|
|
| 512 |
with gr.Tab("🖼️ Text-to-Image"):
|
| 513 |
gr.Markdown("""
|
| 514 |
### Generate images from text descriptions
|
| 515 |
-
|
|
|
|
| 516 |
""")
|
| 517 |
|
| 518 |
with gr.Row():
|
|
@@ -525,14 +604,7 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
|
|
| 525 |
placeholder="e.g., a beautiful landscape with mountains and a lake at sunset, highly detailed, 4k"
|
| 526 |
)
|
| 527 |
t2i_model = gr.Dropdown(
|
| 528 |
-
choices=
|
| 529 |
-
"digiplay/ChikMix_V3",
|
| 530 |
-
"digiplay/chilloutmix_NiPrunedFp16Fix",
|
| 531 |
-
"gsdf/Counterfeit-V2.5",
|
| 532 |
-
"stablediffusionapi/anything-v5",
|
| 533 |
-
"digiplay/CleanLinearMix_nsfw",
|
| 534 |
-
"Laxhar/noobai-XL-1.1" # เพิ่มโมเดลใหม่
|
| 535 |
-
],
|
| 536 |
value="digiplay/ChikMix_V3",
|
| 537 |
label="Model"
|
| 538 |
)
|
|
@@ -543,8 +615,9 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
|
|
| 543 |
t2i_scale = gr.Slider(1, 20, 7.5, step=0.5, label="CFG Scale")
|
| 544 |
|
| 545 |
with gr.Row():
|
| 546 |
-
|
| 547 |
-
|
|
|
|
| 548 |
|
| 549 |
gen_btn = gr.Button("🖼️ Generate", variant="primary")
|
| 550 |
gen_btn.click(
|
|
|
|
| 26 |
|
| 27 |
# Lazy import (to avoid long startup if unused)
|
| 28 |
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionPipeline
|
| 29 |
+
from diffusers import StableDiffusionXLPipeline # สำหรับ SDXL models
|
| 30 |
from controlnet_aux import LineartDetector, LineartAnimeDetector
|
| 31 |
|
| 32 |
# Memory optimization
|
|
|
|
| 47 |
CURRENT_T2I_PIPE = None
|
| 48 |
CURRENT_T2I_MODEL = None
|
| 49 |
|
| 50 |
+
# Define model types
|
| 51 |
+
SDXL_MODELS = ["Laxhar/noobai-XL-1.1"] # เพิ่ม SDXL models ตรงนี้
|
| 52 |
+
SD15_MODELS = [
|
| 53 |
+
"digiplay/ChikMix_V3",
|
| 54 |
+
"digiplay/chilloutmix_NiPrunedFp16Fix",
|
| 55 |
+
"gsdf/Counterfeit-V2.5",
|
| 56 |
+
"stablediffusionapi/anything-v5",
|
| 57 |
+
"digiplay/CleanLinearMix_nsfw"
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
ALL_MODELS = SD15_MODELS + SDXL_MODELS
|
| 61 |
+
|
| 62 |
+
def is_sdxl_model(model_name: str) -> bool:
|
| 63 |
+
"""ตรวจสอบว่าโมเดลเป็น SDXL หรือไม่"""
|
| 64 |
+
return model_name in SDXL_MODELS
|
| 65 |
+
|
| 66 |
def get_pipeline(model_name: str, anime_model: bool = False):
|
| 67 |
"""Get or create a ControlNet pipeline for the given model and anime flag"""
|
| 68 |
global CURRENT_CONTROLNET_PIPE, CURRENT_CONTROLNET_KEY
|
|
|
|
| 74 |
print(f"✅ Reusing existing ControlNet pipeline: {model_name}, anime: {anime_model}")
|
| 75 |
return CURRENT_CONTROLNET_PIPE
|
| 76 |
|
| 77 |
+
# ถ้าเป็น SDXL model ให้แจ้งเตือนว่าไม่รองรับ ControlNet
|
| 78 |
+
if is_sdxl_model(model_name):
|
| 79 |
+
print(f"⚠️ SDXL model {model_name} is not compatible with ControlNet")
|
| 80 |
+
raise ValueError(f"SDXL model {model_name} is not compatible with ControlNet. Please use SD1.5 models for ControlNet.")
|
| 81 |
+
|
| 82 |
# ถ้าเป็นโมเดลใหม่ ลบอันเก่าก่อน
|
| 83 |
if CURRENT_CONTROLNET_PIPE is not None:
|
| 84 |
print(f"🗑️ Unloading old ControlNet pipeline: {CURRENT_CONTROLNET_KEY}")
|
|
|
|
| 242 |
torch.cuda.empty_cache()
|
| 243 |
|
| 244 |
print(f"📥 Loading T2I model: {model_name}")
|
| 245 |
+
|
| 246 |
+
# ตรวจสอบว่าเป็น SDXL หรือ SD1.5 model
|
| 247 |
+
if is_sdxl_model(model_name):
|
| 248 |
+
# โหลด SDXL model
|
| 249 |
+
CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
|
| 250 |
+
model_name,
|
| 251 |
+
torch_dtype=dtype,
|
| 252 |
+
safety_checker=None,
|
| 253 |
+
requires_safety_checker=False,
|
| 254 |
+
use_safetensors=True,
|
| 255 |
+
variant="fp16" if dtype == torch.float16 else None
|
| 256 |
+
).to(device)
|
| 257 |
+
print(f"✅ Loaded SDXL model: {model_name}")
|
| 258 |
+
else:
|
| 259 |
+
# โหลด SD1.5 model
|
| 260 |
+
CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
|
| 261 |
+
model_name,
|
| 262 |
+
torch_dtype=dtype,
|
| 263 |
+
safety_checker=None,
|
| 264 |
+
requires_safety_checker=False,
|
| 265 |
+
use_safetensors=True,
|
| 266 |
+
variant="fp16" if dtype == torch.float16 else None
|
| 267 |
+
).to(device)
|
| 268 |
+
print(f"✅ Loaded SD1.5 model: {model_name}")
|
| 269 |
|
| 270 |
# Optimizations
|
| 271 |
CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
|
|
|
|
| 303 |
|
| 304 |
# Retry without use_safetensors
|
| 305 |
try:
|
| 306 |
+
if is_sdxl_model(model_name):
|
| 307 |
+
CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
|
| 308 |
+
model_name,
|
| 309 |
+
torch_dtype=dtype,
|
| 310 |
+
safety_checker=None,
|
| 311 |
+
requires_safety_checker=False
|
| 312 |
+
).to(device)
|
| 313 |
+
else:
|
| 314 |
+
CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
|
| 315 |
+
model_name,
|
| 316 |
+
torch_dtype=dtype,
|
| 317 |
+
safety_checker=None,
|
| 318 |
+
requires_safety_checker=False
|
| 319 |
+
).to(device)
|
| 320 |
|
| 321 |
CURRENT_T2I_PIPE.enable_attention_slicing(slice_size="max")
|
| 322 |
if hasattr(CURRENT_T2I_PIPE, 'vae') and hasattr(CURRENT_T2I_PIPE.vae, 'enable_slicing'):
|
|
|
|
| 377 |
# ===== Functions =====
|
| 378 |
def colorize(sketch, base_model, anime_model, prompt, seed, steps, scale, cn_weight):
|
| 379 |
try:
|
| 380 |
+
# ตรวจสอบว่าเป็น SDXL model หรือไม่
|
| 381 |
+
if is_sdxl_model(base_model):
|
| 382 |
+
error_img = Image.new('RGB', (512, 512), color='red')
|
| 383 |
+
error_msg_img = Image.new('RGB', (512, 512), color='yellow')
|
| 384 |
+
# สร้างภาพแสดงข้อความ error
|
| 385 |
+
from PIL import ImageDraw, ImageFont
|
| 386 |
+
draw = ImageDraw.Draw(error_msg_img)
|
| 387 |
+
try:
|
| 388 |
+
font = ImageFont.truetype("arial.ttf", 20)
|
| 389 |
+
except:
|
| 390 |
+
font = ImageFont.load_default()
|
| 391 |
+
draw.text((50, 200), f"SDXL model not compatible", fill="black", font=font)
|
| 392 |
+
draw.text((50, 230), f"with ControlNet", fill="black", font=font)
|
| 393 |
+
draw.text((50, 260), f"Use SD1.5 models instead", fill="black", font=font)
|
| 394 |
+
return error_img, error_msg_img
|
| 395 |
+
|
| 396 |
# โหลด pipeline ที่เหมาะสม (จะลบอันเก่าออกอัตโนมัติถ้าเปลี่ยนโมเดล)
|
| 397 |
pipe = get_pipeline(base_model, anime_model)
|
| 398 |
|
|
|
|
| 435 |
gen = torch.Generator(device=device).manual_seed(int(seed))
|
| 436 |
|
| 437 |
with torch.inference_mode():
|
| 438 |
+
# สำหรับ SDXL model ใช้ขนาดเริ่มต้นที่ใหญ่กว่า
|
| 439 |
+
if is_sdxl_model(model):
|
| 440 |
+
# SDXL ต้องการขนาดขั้นต่ำ 1024x1024 สำหรับผลลัพธ์ที่ดี
|
| 441 |
+
width = max(int(w), 512)
|
| 442 |
+
height = max(int(h), 512)
|
| 443 |
+
result = CURRENT_T2I_PIPE(
|
| 444 |
+
prompt,
|
| 445 |
+
width=width,
|
| 446 |
+
height=height,
|
| 447 |
+
num_inference_steps=int(steps),
|
| 448 |
+
guidance_scale=float(scale),
|
| 449 |
+
generator=gen
|
| 450 |
+
).images[0]
|
| 451 |
+
else:
|
| 452 |
+
result = CURRENT_T2I_PIPE(
|
| 453 |
+
prompt,
|
| 454 |
+
width=int(w),
|
| 455 |
+
height=int(h),
|
| 456 |
+
num_inference_steps=int(steps),
|
| 457 |
+
guidance_scale=float(scale),
|
| 458 |
+
generator=gen
|
| 459 |
+
).images[0]
|
| 460 |
|
| 461 |
if device.type == "cuda":
|
| 462 |
torch.cuda.empty_cache()
|
|
|
|
| 465 |
except Exception as e:
|
| 466 |
print(f"❌ Error in t2i: {e}")
|
| 467 |
error_img = Image.new('RGB', (int(w), int(h)), color='red')
|
| 468 |
+
# สร้างภาพแสดงข้อความ error
|
| 469 |
+
from PIL import ImageDraw, ImageFont
|
| 470 |
+
draw = ImageDraw.Draw(error_img)
|
| 471 |
+
try:
|
| 472 |
+
font = ImageFont.truetype("arial.ttf", 20)
|
| 473 |
+
except:
|
| 474 |
+
font = ImageFont.load_default()
|
| 475 |
+
draw.text((50, 50), f"Error: {str(e)[:50]}...", fill="white", font=font)
|
| 476 |
return error_img
|
| 477 |
|
| 478 |
# ===== Function to unload all models =====
|
|
|
|
| 530 |
with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Soft()) as demo:
|
| 531 |
gr.Markdown("# 🎨 Advanced Image Generation & Editing Suite")
|
| 532 |
gr.Markdown("### Powered by Stable Diffusion & ControlNet")
|
| 533 |
+
gr.Markdown("**Note:** SDXL models (noobai-XL) work only in Text-to-Image tab, not in ControlNet")
|
| 534 |
|
| 535 |
# Add system info
|
| 536 |
if torch.cuda.is_available():
|
|
|
|
| 549 |
with gr.Tab("🎨 Colorize Sketch"):
|
| 550 |
gr.Markdown("""
|
| 551 |
### Convert your sketches to colored images using ControlNet
|
| 552 |
+
**Note:** SDXL models are not compatible with ControlNet. Please use SD1.5 models only.
|
| 553 |
""")
|
| 554 |
|
| 555 |
with gr.Row():
|
|
|
|
| 561 |
|
| 562 |
with gr.Row():
|
| 563 |
base_model = gr.Dropdown(
|
| 564 |
+
choices=SD15_MODELS, # ใช้เฉพาะ SD1.5 models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
value="digiplay/ChikMix_V3",
|
| 566 |
+
label="Base Model (SD1.5 only)"
|
| 567 |
)
|
| 568 |
anime_chk = gr.Checkbox(label="Use Anime ControlNet", value=True)
|
| 569 |
|
|
|
|
| 590 |
with gr.Tab("🖼️ Text-to-Image"):
|
| 591 |
gr.Markdown("""
|
| 592 |
### Generate images from text descriptions
|
| 593 |
+
Supports both SD1.5 and SDXL models.
|
| 594 |
+
**Tip:** SDXL models produce higher quality but require more memory.
|
| 595 |
""")
|
| 596 |
|
| 597 |
with gr.Row():
|
|
|
|
| 604 |
placeholder="e.g., a beautiful landscape with mountains and a lake at sunset, highly detailed, 4k"
|
| 605 |
)
|
| 606 |
t2i_model = gr.Dropdown(
|
| 607 |
+
choices=ALL_MODELS, # ใช้ทั้ง SD1.5 และ SDXL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
value="digiplay/ChikMix_V3",
|
| 609 |
label="Model"
|
| 610 |
)
|
|
|
|
| 615 |
t2i_scale = gr.Slider(1, 20, 7.5, step=0.5, label="CFG Scale")
|
| 616 |
|
| 617 |
with gr.Row():
|
| 618 |
+
# สำหรับ SDXL ขอแนะนำขนาดที่ใหญ่กว่า
|
| 619 |
+
w = gr.Slider(256, 1536, 1024, step=64, label="Width")
|
| 620 |
+
h = gr.Slider(256, 1536, 1024, step=64, label="Height")
|
| 621 |
|
| 622 |
gen_btn = gr.Button("🖼️ Generate", variant="primary")
|
| 623 |
gen_btn.click(
|