younginpiniti's picture
feat: ์‚ฌ์šฉ์ž์—๊ฒŒ ํ‘œ์‹œ๋˜๋Š” ๋ชจ๋“  ์ •๋ณด ๋ฉ”์‹œ์ง€๋ฅผ ํ•œ๊ตญ์–ด๋กœ ๋ฒˆ์—ญํ–ˆ์Šต๋‹ˆ๋‹ค.
a11abf8
"""
์Šคํ…Œ์ด๋ธ” ๋””ํ“จ์ „ WebUI - ํ—ˆ๊น…ํŽ˜์ด์Šค ์ŠคํŽ˜์ด์Šค์šฉ
Gradio ์ธํ„ฐํŽ˜์ด์Šค + REST API๋ฅผ ํ†ตํ•œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ (txt2img + img2img ์ง€์›)
"""
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler
from PIL import Image
import os
import gc
import io
import base64
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
# ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ๋ชฉ๋ก
MODELS = {
"๐ŸŽจ Mistoon Anime V3 (์นดํˆฐํ’ ์• ๋‹ˆ๋ฉ”์ด์…˜)": "stablediffusionapi/mistoonanime-v30",
"๐ŸŒธ Anything V5 (์• ๋‹ˆ๋ฉ”์ด์…˜)": "stablediffusionapi/anything-v5",
"๐Ÿ’œ Counterfeit V3 (๊ณ ํ’ˆ์งˆ ์• ๋‹ˆ๋ฉ”์ด์…˜)": "gsdf/Counterfeit-V3.0",
"โœจ DreamShaper V8 (๋‹ค๋ชฉ์ )": "Lykon/DreamShaper",
"๐ŸŽญ OpenJourney (Midjourney ์Šคํƒ€์ผ)": "prompthero/openjourney-v4",
"๐Ÿ–ผ๏ธ Stable Diffusion v1.5 (๊ธฐ๋ณธ)": "runwayml/stable-diffusion-v1-5",
"๐ŸŒŸ MeinaMix (์• ๋‹ˆ๋ฉ”์ด์…˜)": "Meina/MeinaMix_V11",
"๐Ÿ’ซ ReV Animated (์• ๋‹ˆ๋ฉ”์ด์…˜)": "stablediffusionapi/rev-animated",
}
# ๋””๋ฐ”์ด์Šค ์„ค์ •
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"๐Ÿš€ ๋””๋ฐ”์ด์Šค: {DEVICE}, ๋ฐ์ดํ„ฐ ํƒ€์ž…: {DTYPE}")
# ํ˜„์žฌ ๋กœ๋“œ๋œ ๋ชจ๋ธ ์ •๋ณด
current_model_id = None
current_pipeline_type = None # "txt2img" ๋˜๋Š” "img2img"
pipe = None
def clear_memory():
"""๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ"""
global pipe
if pipe is not None:
del pipe
pipe = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def load_model(model_name: str, pipeline_type: str = "txt2img"):
"""
๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
Args:
model_name: ๋ชจ๋ธ ์ด๋ฆ„
pipeline_type: "txt2img" ๋˜๋Š” "img2img"
"""
global pipe, current_model_id, current_pipeline_type
model_id = MODELS.get(model_name)
if model_id is None:
return None, f"โŒ ์•Œ ์ˆ˜ ์—†๋Š” ๋ชจ๋ธ: {model_name}"
# ์ด๋ฏธ ๊ฐ™์€ ๋ชจ๋ธ๊ณผ ํŒŒ์ดํ”„๋ผ์ธ ํƒ€์ž…์ด ๋กœ๋“œ๋˜์–ด ์žˆ์œผ๋ฉด ์Šคํ‚ต
if current_model_id == model_id and current_pipeline_type == pipeline_type and pipe is not None:
return pipe, f"โœ… {model_name} ์ด๋ฏธ ๋กœ๋“œ๋จ ({pipeline_type})"
# ๊ธฐ์กด ๋ชจ๋ธ ์ •๋ฆฌ
clear_memory()
print(f"๐Ÿ“ฅ ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘: {model_name} ({pipeline_type})...")
try:
# ํŒŒ์ดํ”„๋ผ์ธ ํƒ€์ž…์— ๋”ฐ๋ผ ๋‹ค๋ฅธ ํด๋ž˜์Šค ์‚ฌ์šฉ
if pipeline_type == "img2img":
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
torch_dtype=DTYPE,
safety_checker=None,
requires_safety_checker=False,
use_safetensors=False
)
else:
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=DTYPE,
safety_checker=None,
requires_safety_checker=False,
use_safetensors=False
)
# ๋น ๋ฅธ ์Šค์ผ€์ค„๋Ÿฌ ์‚ฌ์šฉ
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
# ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
pipe = pipe.to(DEVICE)
# ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”
pipe.enable_attention_slicing()
if hasattr(pipe, 'enable_vae_slicing'):
pipe.enable_vae_slicing()
current_model_id = model_id
current_pipeline_type = pipeline_type
print(f"โœ… ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ: {model_name} ({pipeline_type})")
return pipe, f"โœ… {model_name} ๋กœ๋”ฉ ์™„๋ฃŒ!"
except Exception as e:
print(f"โŒ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
return None, f"โŒ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {str(e)}"
# ๊ธฐ๋ณธ ๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ
DEFAULT_NEGATIVE = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
def generate_txt2img(
model_name: str,
prompt: str,
negative_prompt: str = "",
num_inference_steps: int = 25,
guidance_scale: float = 7.5,
width: int = 512,
height: int = 512,
seed: int = -1,
progress=gr.Progress()
):
"""ํ…์ŠคํŠธ โ†’ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜"""
global pipe
if not prompt.strip():
return None, "โš ๏ธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”!"
# ๋ชจ๋ธ ๋กœ๋“œ
progress(0.1, desc="๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
pipe, status = load_model(model_name, "txt2img")
if pipe is None:
return None, status
# ๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ ์„ค์ •
if negative_prompt.strip():
full_negative = f"{negative_prompt}, {DEFAULT_NEGATIVE}"
else:
full_negative = DEFAULT_NEGATIVE
# ์‹œ๋“œ ์„ค์ •
if seed == -1:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
try:
progress(0.3, desc="์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘...")
print(f"๐ŸŽจ [txt2img] ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘... ํ”„๋กฌํ”„ํŠธ: {prompt[:50]}...")
result = pipe(
prompt=prompt,
negative_prompt=full_negative,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
generator=generator
)
image = result.images[0]
progress(1.0, desc="์™„๋ฃŒ!")
print("โœ… [txt2img] ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์™„๋ฃŒ!")
return image, f"โœ… ์ƒ์„ฑ ์™„๋ฃŒ! (์‹œ๋“œ: {seed})"
except Exception as e:
print(f"โŒ [txt2img] ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: {e}")
return None, f"โŒ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: {str(e)}"
def generate_img2img(
model_name: str,
input_image: Image.Image,
prompt: str,
negative_prompt: str = "",
strength: float = 0.75,
num_inference_steps: int = 25,
guidance_scale: float = 7.5,
seed: int = -1,
progress=gr.Progress()
):
"""์ด๋ฏธ์ง€ โ†’ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ํ•จ์ˆ˜"""
global pipe
if input_image is None:
return None, "โš ๏ธ ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”!"
if not prompt.strip():
return None, "โš ๏ธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”!"
# ๋ชจ๋ธ ๋กœ๋“œ (img2img ํŒŒ์ดํ”„๋ผ์ธ)
progress(0.1, desc="๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
pipe, status = load_model(model_name, "img2img")
if pipe is None:
return None, status
# ๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ ์„ค์ •
if negative_prompt.strip():
full_negative = f"{negative_prompt}, {DEFAULT_NEGATIVE}"
else:
full_negative = DEFAULT_NEGATIVE
# ์‹œ๋“œ ์„ค์ •
if seed == -1:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
try:
progress(0.3, desc="์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ์ค‘...")
print(f"๐Ÿ–ผ๏ธ [img2img] ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ์ค‘... ํ”„๋กฌํ”„ํŠธ: {prompt[:50]}...")
# ์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ RGB๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ํฌ๊ธฐ ์กฐ์ •
input_image = input_image.convert("RGB")
# ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ 64์˜ ๋ฐฐ์ˆ˜๋กœ ์กฐ์ • (SD ์š”๊ตฌ์‚ฌํ•ญ)
w, h = input_image.size
w = (w // 64) * 64
h = (h // 64) * 64
if w == 0:
w = 512
if h == 0:
h = 512
input_image = input_image.resize((w, h), Image.LANCZOS)
result = pipe(
prompt=prompt,
image=input_image,
negative_prompt=full_negative,
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
)
image = result.images[0]
progress(1.0, desc="์™„๋ฃŒ!")
print("โœ… [img2img] ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ์™„๋ฃŒ!")
return image, f"โœ… ๋ณ€ํ™˜ ์™„๋ฃŒ! (์‹œ๋“œ: {seed}, ๊ฐ•๋„: {strength})"
except Exception as e:
print(f"โŒ [img2img] ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ์‹คํŒจ: {e}")
return None, f"โŒ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ์‹คํŒจ: {str(e)}"
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
def create_interface():
"""Gradio ์›น ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ"""
# ์ปค์Šคํ…€ CSS
custom_css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
max-width: 1200px !important;
}
.generate-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
color: white !important;
font-weight: bold !important;
font-size: 1.2em !important;
padding: 15px 30px !important;
border-radius: 10px !important;
transition: all 0.3s ease !important;
width: 100% !important;
}
.generate-btn:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.5) !important;
}
.title {
text-align: center;
background: linear-gradient(135deg, #ff6b9d 0%, #c44569 50%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 2.8em;
font-weight: bold;
margin-bottom: 5px;
}
.subtitle {
text-align: center;
color: #888;
font-size: 1.1em;
margin-bottom: 25px;
}
.model-dropdown {
border: 2px solid #764ba2 !important;
border-radius: 8px !important;
}
.output-image {
border-radius: 12px !important;
box-shadow: 0 4px 15px rgba(0,0,0,0.1) !important;
}
.status-box {
background: linear-gradient(135deg, #f5f7fa 0%, #e4e8eb 100%);
border-radius: 8px;
padding: 10px;
text-align: center;
}
"""
with gr.Blocks(css=custom_css, title="Stable Diffusion WebUI - Anime") as demo:
# ํ—ค๋”
gr.HTML("""
<div class="title">๐ŸŒธ Anime Diffusion WebUI</div>
<div class="subtitle">์• ๋‹ˆ๋ฉ”์ด์…˜ ์Šคํƒ€์ผ ์ด๋ฏธ์ง€ ์ƒ์„ฑ๊ธฐ | Text-to-Image & Image-to-Image</div>
""")
# ํƒญ์œผ๋กœ txt2img / img2img ๋ถ„๋ฆฌ
with gr.Tabs():
# ============================================
# ํƒญ 1: ํ…์ŠคํŠธ โ†’ ์ด๋ฏธ์ง€ (txt2img)
# ============================================
with gr.TabItem("๐ŸŽจ ํ…์ŠคํŠธ โ†’ ์ด๋ฏธ์ง€"):
with gr.Row():
# ์™ผ์ชฝ: ์ž…๋ ฅ ํŒจ๋„
with gr.Column(scale=1):
txt2img_model = gr.Dropdown(
label="๐Ÿค– ๋ชจ๋ธ ์„ ํƒ",
choices=list(MODELS.keys()),
value="๐ŸŽจ Mistoon Anime V3 (์นดํˆฐํ’ ์• ๋‹ˆ๋ฉ”์ด์…˜)",
elem_classes=["model-dropdown"]
)
txt2img_prompt = gr.Textbox(
label="๐Ÿ“ ํ”„๋กฌํ”„ํŠธ",
placeholder="1girl, anime style, beautiful, masterpiece, best quality",
lines=3
)
txt2img_negative = gr.Textbox(
label="๐Ÿšซ ๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ",
placeholder="์ถ”๊ฐ€ํ•  ๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ (๊ธฐ๋ณธ๊ฐ’์ด ์ž๋™ ์ ์šฉ๋ฉ๋‹ˆ๋‹ค)",
lines=2
)
with gr.Row():
txt2img_width = gr.Slider(label="๐Ÿ“ ๋„ˆ๋น„", minimum=256, maximum=768, value=512, step=64)
txt2img_height = gr.Slider(label="๐Ÿ“ ๋†’์ด", minimum=256, maximum=768, value=768, step=64)
with gr.Row():
txt2img_steps = gr.Slider(label="๐Ÿ”„ ์Šคํ… ์ˆ˜", minimum=10, maximum=50, value=25, step=1)
txt2img_guidance = gr.Slider(label="๐ŸŽฏ CFG ์Šค์ผ€์ผ", minimum=1.0, maximum=15.0, value=7.0, step=0.5)
txt2img_seed = gr.Number(label="๐ŸŽฒ ์‹œ๋“œ (-1 = ๋žœ๋ค)", value=-1, precision=0)
txt2img_btn = gr.Button("๐Ÿš€ ์ด๋ฏธ์ง€ ์ƒ์„ฑ", elem_classes=["generate-btn"])
txt2img_status = gr.Textbox(label="๐Ÿ“Š ์ƒํƒœ", value="ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”", interactive=False)
# ์˜ค๋ฅธ์ชฝ: ์ถœ๋ ฅ ํŒจ๋„
with gr.Column(scale=1):
txt2img_output = gr.Image(label="๐Ÿ–ผ๏ธ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€", type="pil", elem_classes=["output-image"])
# ์˜ˆ์ œ
gr.Examples(
examples=[
["๐ŸŽจ Mistoon Anime V3 (์นดํˆฐํ’ ์• ๋‹ˆ๋ฉ”์ด์…˜)", "1girl, solo, colorful, vibrant colors, cartoon style, school uniform, masterpiece", ""],
["๐ŸŒธ Anything V5 (์• ๋‹ˆ๋ฉ”์ด์…˜)", "1girl, solo, long blue hair, cherry blossoms, detailed, masterpiece, best quality", ""],
["๐Ÿ’œ Counterfeit V3 (๊ณ ํ’ˆ์งˆ ์• ๋‹ˆ๋ฉ”์ด์…˜)", "1girl, kimono, japanese garden, autumn leaves, ultra detailed, 8k", ""],
],
inputs=[txt2img_model, txt2img_prompt, txt2img_negative],
label="๐Ÿ’ก ์˜ˆ์ œ ํ”„๋กฌํ”„ํŠธ"
)
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
txt2img_btn.click(
fn=generate_txt2img,
inputs=[txt2img_model, txt2img_prompt, txt2img_negative, txt2img_steps, txt2img_guidance, txt2img_width, txt2img_height, txt2img_seed],
outputs=[txt2img_output, txt2img_status]
)
# ============================================
# ํƒญ 2: ์ด๋ฏธ์ง€ โ†’ ์ด๋ฏธ์ง€ (img2img)
# ============================================
with gr.TabItem("๐Ÿ–ผ๏ธ ์ด๋ฏธ์ง€ โ†’ ์ด๋ฏธ์ง€"):
with gr.Row():
# ์™ผ์ชฝ: ์ž…๋ ฅ ํŒจ๋„
with gr.Column(scale=1):
img2img_model = gr.Dropdown(
label="๐Ÿค– ๋ชจ๋ธ ์„ ํƒ",
choices=list(MODELS.keys()),
value="๐ŸŽจ Mistoon Anime V3 (์นดํˆฐํ’ ์• ๋‹ˆ๋ฉ”์ด์…˜)",
elem_classes=["model-dropdown"]
)
img2img_input = gr.Image(
label="๐Ÿ“ค ์ž…๋ ฅ ์ด๋ฏธ์ง€ (์‹ค์‚ฌ, ์Šค์ผ€์น˜ ๋“ฑ)",
type="pil",
height=200
)
img2img_prompt = gr.Textbox(
label="๐Ÿ“ ํ”„๋กฌํ”„ํŠธ (๋ณ€ํ™˜ํ•  ์Šคํƒ€์ผ)",
placeholder="anime style, colorful, masterpiece, best quality",
lines=3
)
img2img_negative = gr.Textbox(
label="๐Ÿšซ ๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ",
placeholder="์ถ”๊ฐ€ํ•  ๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ",
lines=2
)
img2img_strength = gr.Slider(
label="๐Ÿ’ช ๋ณ€ํ™˜ ๊ฐ•๋„ (0.0=์›๋ณธ ์œ ์ง€, 1.0=์™„์ „ ๋ณ€ํ™˜)",
minimum=0.1,
maximum=1.0,
value=0.75,
step=0.05
)
with gr.Row():
img2img_steps = gr.Slider(label="๐Ÿ”„ ์Šคํ… ์ˆ˜", minimum=10, maximum=50, value=25, step=1)
img2img_guidance = gr.Slider(label="๐ŸŽฏ CFG ์Šค์ผ€์ผ", minimum=1.0, maximum=15.0, value=7.0, step=0.5)
img2img_seed = gr.Number(label="๐ŸŽฒ ์‹œ๋“œ (-1 = ๋žœ๋ค)", value=-1, precision=0)
img2img_btn = gr.Button("๐Ÿš€ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜", elem_classes=["generate-btn"])
img2img_status = gr.Textbox(label="๐Ÿ“Š ์ƒํƒœ", value="์ด๋ฏธ์ง€์™€ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”", interactive=False)
# ์˜ค๋ฅธ์ชฝ: ์ถœ๋ ฅ ํŒจ๋„
with gr.Column(scale=1):
img2img_output = gr.Image(label="๐Ÿ–ผ๏ธ ๋ณ€ํ™˜๋œ ์ด๋ฏธ์ง€", type="pil", elem_classes=["output-image"])
# img2img ๊ฐ€์ด๋“œ
with gr.Accordion("๐Ÿ“– img2img ์‚ฌ์šฉ ๊ฐ€์ด๋“œ", open=False):
gr.Markdown("""
### ๐Ÿ–ผ๏ธ Image-to-Image ๋ณ€ํ™˜ ๊ฐ€์ด๋“œ
**์‚ฌ์šฉ ๋ฐฉ๋ฒ•**:
1. ๋ณ€ํ™˜ํ•˜๊ณ  ์‹ถ์€ ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค (์‹ค์‚ฌ ์‚ฌ์ง„, ์Šค์ผ€์น˜ ๋“ฑ)
2. ์›ํ•˜๋Š” ์Šคํƒ€์ผ์„ ํ”„๋กฌํ”„ํŠธ๋กœ ์ž…๋ ฅํ•ฉ๋‹ˆ๋‹ค
3. ๋ณ€ํ™˜ ๊ฐ•๋„๋ฅผ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค
**๋ณ€ํ™˜ ๊ฐ•๋„ (Strength) ์„ค๋ช…**:
- `0.3` - ์›๋ณธ ์ด๋ฏธ์ง€๋ฅผ ๋งŽ์ด ์œ ์ง€ (๋ฏธ์„ธํ•œ ์Šคํƒ€์ผ ๋ณ€ํ™”)
- `0.5` - ๊ท ํ˜• ์žกํžŒ ๋ณ€ํ™˜
- `0.75` - ํ”„๋กฌํ”„ํŠธ์— ๋” ์ถฉ์‹คํ•œ ๋ณ€ํ™˜ (๊ถŒ์žฅ)
- `1.0` - ๊ฑฐ์˜ ์ƒˆ๋กœ ์ƒ์„ฑ (์›๋ณธ ๋ฌด์‹œ)
**์ถ”์ฒœ ํ”„๋กฌํ”„ํŠธ**:
- ์‹ค์‚ฌ โ†’ ์• ๋‹ˆ๋ฉ”์ด์…˜: `anime style, colorful, masterpiece`
- ์Šค์ผ€์น˜ โ†’ ์™„์„ฑ๋ณธ: `detailed illustration, colored, vibrant`
- ์ง€๋ธŒ๋ฆฌ ์Šคํƒ€์ผ: `studio ghibli style, soft colors, fantasy`
""")
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
img2img_btn.click(
fn=generate_img2img,
inputs=[img2img_model, img2img_input, img2img_prompt, img2img_negative, img2img_strength, img2img_steps, img2img_guidance, img2img_seed],
outputs=[img2img_output, img2img_status]
)
# ํ‘ธํ„ฐ
gr.HTML("""
<div style="text-align: center; margin-top: 30px; padding: 20px; color: #888; border-top: 1px solid #eee;">
<p>๐ŸŒธ Powered by Diffusers & Gradio | ๐Ÿค— Hugging Face Spaces</p>
<p style="font-size: 0.9em;">โš ๏ธ CPU ๋ชจ๋“œ์—์„œ๋Š” ์ด๋ฏธ์ง€ ์ƒ์„ฑ์— 2-5๋ถ„ ์ •๋„ ์†Œ์š”๋ฉ๋‹ˆ๋‹ค.</p>
<p style="font-size: 0.85em; color: #aaa;">์ฒซ ์‹คํ–‰ ์‹œ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ๋กœ ์ธํ•ด ์‹œ๊ฐ„์ด ๋” ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.</p>
</div>
""")
return demo
# ================================
# REST API ์—”๋“œํฌ์ธํŠธ ์ •์˜
# ================================
# API ์š”์ฒญ/์‘๋‹ต ๋ชจ๋ธ ์ •์˜
class GenerateRequest(BaseModel):
"""ํ…์ŠคํŠธ โ†’ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์š”์ฒญ"""
prompt: str = Field(..., description="์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ”„๋กฌํ”„ํŠธ")
model_name: str = Field(default="๐ŸŽจ Mistoon Anime V3 (์นดํˆฐํ’ ์• ๋‹ˆ๋ฉ”์ด์…˜)", description="์‚ฌ์šฉํ•  ๋ชจ๋ธ ์ด๋ฆ„")
negative_prompt: str = Field(default="", description="๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ")
num_inference_steps: int = Field(default=25, ge=10, le=50, description="์ถ”๋ก  ์Šคํ… ์ˆ˜")
guidance_scale: float = Field(default=7.5, ge=1.0, le=15.0, description="CFG ์Šค์ผ€์ผ")
width: int = Field(default=512, ge=256, le=768, description="์ด๋ฏธ์ง€ ๋„ˆ๋น„")
height: int = Field(default=512, ge=256, le=768, description="์ด๋ฏธ์ง€ ๋†’์ด")
seed: int = Field(default=-1, description="์‹œ๋“œ ๊ฐ’ (-1์ด๋ฉด ๋žœ๋ค)")
class Img2ImgRequest(BaseModel):
"""์ด๋ฏธ์ง€ โ†’ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ์š”์ฒญ"""
image_base64: str = Field(..., description="์ž…๋ ฅ ์ด๋ฏธ์ง€ (Base64 ์ธ์ฝ”๋”ฉ)")
prompt: str = Field(..., description="๋ณ€ํ™˜ ํ”„๋กฌํ”„ํŠธ")
model_name: str = Field(default="๐ŸŽจ Mistoon Anime V3 (์นดํˆฐํ’ ์• ๋‹ˆ๋ฉ”์ด์…˜)", description="์‚ฌ์šฉํ•  ๋ชจ๋ธ ์ด๋ฆ„")
negative_prompt: str = Field(default="", description="๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ")
strength: float = Field(default=0.75, ge=0.1, le=1.0, description="๋ณ€ํ™˜ ๊ฐ•๋„")
num_inference_steps: int = Field(default=25, ge=10, le=50, description="์ถ”๋ก  ์Šคํ… ์ˆ˜")
guidance_scale: float = Field(default=7.5, ge=1.0, le=15.0, description="CFG ์Šค์ผ€์ผ")
seed: int = Field(default=-1, description="์‹œ๋“œ ๊ฐ’ (-1์ด๋ฉด ๋žœ๋ค)")
class GenerateResponse(BaseModel):
"""์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‘๋‹ต"""
success: bool
message: str
image_base64: Optional[str] = None
seed: Optional[int] = None
class ModelsResponse(BaseModel):
"""๋ชจ๋ธ ๋ชฉ๋ก ์‘๋‹ต"""
models: list[str]
# FastAPI ์•ฑ ์ƒ์„ฑ
api_app = FastAPI(
title="Anime Diffusion API",
description="์• ๋‹ˆ๋ฉ”์ด์…˜ ์Šคํƒ€์ผ ์ด๋ฏธ์ง€ ์ƒ์„ฑ REST API (txt2img + img2img)",
version="2.0.0"
)
# CORS ์„ค์ • ์ถ”๊ฐ€ (์™ธ๋ถ€ ํ˜ธ์ถœ ํ—ˆ์šฉ)
api_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@api_app.get("/api/models", response_model=ModelsResponse)
async def get_models():
"""์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ๋ชฉ๋ก ์กฐํšŒ"""
return ModelsResponse(models=list(MODELS.keys()))
@api_app.post("/api/generate", response_model=GenerateResponse)
async def api_generate_txt2img(request: GenerateRequest):
"""
ํ…์ŠคํŠธ โ†’ ์ด๋ฏธ์ง€ ์ƒ์„ฑ API
ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ „๋‹ฌํ•˜๋ฉด Base64๋กœ ์ธ์ฝ”๋”ฉ๋œ ์ด๋ฏธ์ง€๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
"""
global pipe
if not request.prompt.strip():
raise HTTPException(status_code=400, detail="ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”")
if request.model_name not in MODELS:
raise HTTPException(status_code=400, detail=f"์•Œ ์ˆ˜ ์—†๋Š” ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ: {list(MODELS.keys())}")
# ๋ชจ๋ธ ๋กœ๋“œ
pipe, status = load_model(request.model_name, "txt2img")
if pipe is None:
raise HTTPException(status_code=500, detail=status)
# ๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ ์„ค์ •
if request.negative_prompt.strip():
full_negative = f"{request.negative_prompt}, {DEFAULT_NEGATIVE}"
else:
full_negative = DEFAULT_NEGATIVE
# ์‹œ๋“œ ์„ค์ •
seed = request.seed
if seed == -1:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
try:
print(f"๐ŸŽจ [API txt2img] ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘... ํ”„๋กฌํ”„ํŠธ: {request.prompt[:50]}...")
result = pipe(
prompt=request.prompt,
negative_prompt=full_negative,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
width=request.width,
height=request.height,
generator=generator
)
image = result.images[0]
# ์ด๋ฏธ์ง€๋ฅผ Base64๋กœ ์ธ์ฝ”๋”ฉ
buffer = io.BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
print(f"โœ… [API txt2img] ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์™„๋ฃŒ! (์‹œ๋“œ: {seed})")
return GenerateResponse(
success=True,
message="์ด๋ฏธ์ง€ ์ƒ์„ฑ ์™„๋ฃŒ",
image_base64=image_base64,
seed=seed
)
except Exception as e:
print(f"โŒ [API txt2img] ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: {e}")
raise HTTPException(status_code=500, detail=str(e))
@api_app.post("/api/img2img", response_model=GenerateResponse)
async def api_generate_img2img(request: Img2ImgRequest):
"""
์ด๋ฏธ์ง€ โ†’ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ API
Base64 ์ด๋ฏธ์ง€์™€ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ „๋‹ฌํ•˜๋ฉด ๋ณ€ํ™˜๋œ ์ด๋ฏธ์ง€๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
"""
global pipe
if not request.prompt.strip():
raise HTTPException(status_code=400, detail="ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”")
if request.model_name not in MODELS:
raise HTTPException(status_code=400, detail=f"์•Œ ์ˆ˜ ์—†๋Š” ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ: {list(MODELS.keys())}")
# Base64 ์ด๋ฏธ์ง€ ๋””์ฝ”๋”ฉ
try:
image_data = base64.b64decode(request.image_base64)
input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail=f"์ด๋ฏธ์ง€ ๋””์ฝ”๋”ฉ ์‹คํŒจ: {str(e)}")
# ๋ชจ๋ธ ๋กœ๋“œ (img2img)
pipe, status = load_model(request.model_name, "img2img")
if pipe is None:
raise HTTPException(status_code=500, detail=status)
# ๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ ์„ค์ •
if request.negative_prompt.strip():
full_negative = f"{request.negative_prompt}, {DEFAULT_NEGATIVE}"
else:
full_negative = DEFAULT_NEGATIVE
# ์‹œ๋“œ ์„ค์ •
seed = request.seed
if seed == -1:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
try:
print(f"๐Ÿ–ผ๏ธ [API img2img] ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ์ค‘... ํ”„๋กฌํ”„ํŠธ: {request.prompt[:50]}...")
# ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ 64์˜ ๋ฐฐ์ˆ˜๋กœ ์กฐ์ •
w, h = input_image.size
w = (w // 64) * 64
h = (h // 64) * 64
if w == 0:
w = 512
if h == 0:
h = 512
input_image = input_image.resize((w, h), Image.LANCZOS)
result = pipe(
prompt=request.prompt,
image=input_image,
negative_prompt=full_negative,
strength=request.strength,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
generator=generator
)
image = result.images[0]
# ์ด๋ฏธ์ง€๋ฅผ Base64๋กœ ์ธ์ฝ”๋”ฉ
buffer = io.BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
print(f"โœ… [API img2img] ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ์™„๋ฃŒ! (์‹œ๋“œ: {seed})")
return GenerateResponse(
success=True,
message="์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ์™„๋ฃŒ",
image_base64=image_base64,
seed=seed
)
except Exception as e:
print(f"โŒ [API img2img] ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ์‹คํŒจ: {e}")
raise HTTPException(status_code=500, detail=str(e))
@api_app.get("/api/health")
async def health_check():
"""์„œ๋ฒ„ ์ƒํƒœ ํ™•์ธ"""
return {
"status": "healthy",
"device": DEVICE,
"model_loaded": current_model_id is not None,
"pipeline_type": current_pipeline_type
}
# ================================
# ๋ฉ”์ธ ์‹คํ–‰
# ================================
if __name__ == "__main__":
print("๐ŸŒธ Anime Diffusion WebUI + API ์‹œ์ž‘...")
print(" - txt2img: ํ…์ŠคํŠธ โ†’ ์ด๋ฏธ์ง€ ์ƒ์„ฑ")
print(" - img2img: ์ด๋ฏธ์ง€ โ†’ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜")
# Gradio ์•ฑ ์ƒ์„ฑ
demo = create_interface()
# FastAPI์— Gradio ๋งˆ์šดํŠธ
app = gr.mount_gradio_app(api_app, demo, path="/")
# uvicorn์œผ๋กœ ํ†ตํ•ฉ ์„œ๋ฒ„ ์‹คํ–‰
import uvicorn
print("๐Ÿ“ก API ๋ฌธ์„œ: http://localhost:7860/docs")
print("๐ŸŒ ์›น UI: http://localhost:7860/")
uvicorn.run(app, host="0.0.0.0", port=7860)