feat: ์ฌ์ฉ์์๊ฒ ํ์๋๋ ๋ชจ๋ ์ ๋ณด ๋ฉ์์ง๋ฅผ ํ๊ตญ์ด๋ก ๋ฒ์ญํ์ต๋๋ค.
a11abf8
| """ | |
| ์คํ ์ด๋ธ ๋ํจ์ WebUI - ํ๊น ํ์ด์ค ์คํ์ด์ค์ฉ | |
| Gradio ์ธํฐํ์ด์ค + REST API๋ฅผ ํตํ ์ด๋ฏธ์ง ์์ฑ (txt2img + img2img ์ง์) | |
| """ | |
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler | |
| from PIL import Image | |
| import os | |
| import gc | |
| import io | |
| import base64 | |
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| # ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ ๋ชฉ๋ก | |
| MODELS = { | |
| "๐จ Mistoon Anime V3 (์นดํฐํ ์ ๋๋ฉ์ด์ )": "stablediffusionapi/mistoonanime-v30", | |
| "๐ธ Anything V5 (์ ๋๋ฉ์ด์ )": "stablediffusionapi/anything-v5", | |
| "๐ Counterfeit V3 (๊ณ ํ์ง ์ ๋๋ฉ์ด์ )": "gsdf/Counterfeit-V3.0", | |
| "โจ DreamShaper V8 (๋ค๋ชฉ์ )": "Lykon/DreamShaper", | |
| "๐ญ OpenJourney (Midjourney ์คํ์ผ)": "prompthero/openjourney-v4", | |
| "๐ผ๏ธ Stable Diffusion v1.5 (๊ธฐ๋ณธ)": "runwayml/stable-diffusion-v1-5", | |
| "๐ MeinaMix (์ ๋๋ฉ์ด์ )": "Meina/MeinaMix_V11", | |
| "๐ซ ReV Animated (์ ๋๋ฉ์ด์ )": "stablediffusionapi/rev-animated", | |
| } | |
| # ๋๋ฐ์ด์ค ์ค์ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| print(f"๐ ๋๋ฐ์ด์ค: {DEVICE}, ๋ฐ์ดํฐ ํ์ : {DTYPE}") | |
| # ํ์ฌ ๋ก๋๋ ๋ชจ๋ธ ์ ๋ณด | |
| current_model_id = None | |
| current_pipeline_type = None # "txt2img" ๋๋ "img2img" | |
| pipe = None | |
| def clear_memory(): | |
| """๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ""" | |
| global pipe | |
| if pipe is not None: | |
| del pipe | |
| pipe = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def load_model(model_name: str, pipeline_type: str = "txt2img"): | |
| """ | |
| ๋ชจ๋ธ ๋ก๋ ํจ์ | |
| Args: | |
| model_name: ๋ชจ๋ธ ์ด๋ฆ | |
| pipeline_type: "txt2img" ๋๋ "img2img" | |
| """ | |
| global pipe, current_model_id, current_pipeline_type | |
| model_id = MODELS.get(model_name) | |
| if model_id is None: | |
| return None, f"โ ์ ์ ์๋ ๋ชจ๋ธ: {model_name}" | |
| # ์ด๋ฏธ ๊ฐ์ ๋ชจ๋ธ๊ณผ ํ์ดํ๋ผ์ธ ํ์ ์ด ๋ก๋๋์ด ์์ผ๋ฉด ์คํต | |
| if current_model_id == model_id and current_pipeline_type == pipeline_type and pipe is not None: | |
| return pipe, f"โ {model_name} ์ด๋ฏธ ๋ก๋๋จ ({pipeline_type})" | |
| # ๊ธฐ์กด ๋ชจ๋ธ ์ ๋ฆฌ | |
| clear_memory() | |
| print(f"๐ฅ ๋ชจ๋ธ ๋ก๋ฉ ์ค: {model_name} ({pipeline_type})...") | |
| try: | |
| # ํ์ดํ๋ผ์ธ ํ์ ์ ๋ฐ๋ผ ๋ค๋ฅธ ํด๋์ค ์ฌ์ฉ | |
| if pipeline_type == "img2img": | |
| pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=DTYPE, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| use_safetensors=False | |
| ) | |
| else: | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=DTYPE, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| use_safetensors=False | |
| ) | |
| # ๋น ๋ฅธ ์ค์ผ์ค๋ฌ ์ฌ์ฉ | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| # ๋๋ฐ์ด์ค๋ก ์ด๋ | |
| pipe = pipe.to(DEVICE) | |
| # ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ | |
| pipe.enable_attention_slicing() | |
| if hasattr(pipe, 'enable_vae_slicing'): | |
| pipe.enable_vae_slicing() | |
| current_model_id = model_id | |
| current_pipeline_type = pipeline_type | |
| print(f"โ ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ: {model_name} ({pipeline_type})") | |
| return pipe, f"โ {model_name} ๋ก๋ฉ ์๋ฃ!" | |
| except Exception as e: | |
| print(f"โ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}") | |
| return None, f"โ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {str(e)}" | |
| # ๊ธฐ๋ณธ ๋ค๊ฑฐํฐ๋ธ ํ๋กฌํํธ | |
| DEFAULT_NEGATIVE = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" | |
| def generate_txt2img( | |
| model_name: str, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| num_inference_steps: int = 25, | |
| guidance_scale: float = 7.5, | |
| width: int = 512, | |
| height: int = 512, | |
| seed: int = -1, | |
| progress=gr.Progress() | |
| ): | |
| """ํ ์คํธ โ ์ด๋ฏธ์ง ์์ฑ ํจ์""" | |
| global pipe | |
| if not prompt.strip(): | |
| return None, "โ ๏ธ ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์!" | |
| # ๋ชจ๋ธ ๋ก๋ | |
| progress(0.1, desc="๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| pipe, status = load_model(model_name, "txt2img") | |
| if pipe is None: | |
| return None, status | |
| # ๋ค๊ฑฐํฐ๋ธ ํ๋กฌํํธ ์ค์ | |
| if negative_prompt.strip(): | |
| full_negative = f"{negative_prompt}, {DEFAULT_NEGATIVE}" | |
| else: | |
| full_negative = DEFAULT_NEGATIVE | |
| # ์๋ ์ค์ | |
| if seed == -1: | |
| seed = torch.randint(0, 2**32 - 1, (1,)).item() | |
| generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| try: | |
| progress(0.3, desc="์ด๋ฏธ์ง ์์ฑ ์ค...") | |
| print(f"๐จ [txt2img] ์ด๋ฏธ์ง ์์ฑ ์ค... ํ๋กฌํํธ: {prompt[:50]}...") | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt=full_negative, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| width=width, | |
| height=height, | |
| generator=generator | |
| ) | |
| image = result.images[0] | |
| progress(1.0, desc="์๋ฃ!") | |
| print("โ [txt2img] ์ด๋ฏธ์ง ์์ฑ ์๋ฃ!") | |
| return image, f"โ ์์ฑ ์๋ฃ! (์๋: {seed})" | |
| except Exception as e: | |
| print(f"โ [txt2img] ์ด๋ฏธ์ง ์์ฑ ์คํจ: {e}") | |
| return None, f"โ ์ด๋ฏธ์ง ์์ฑ ์คํจ: {str(e)}" | |
| def generate_img2img( | |
| model_name: str, | |
| input_image: Image.Image, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| strength: float = 0.75, | |
| num_inference_steps: int = 25, | |
| guidance_scale: float = 7.5, | |
| seed: int = -1, | |
| progress=gr.Progress() | |
| ): | |
| """์ด๋ฏธ์ง โ ์ด๋ฏธ์ง ๋ณํ ํจ์""" | |
| global pipe | |
| if input_image is None: | |
| return None, "โ ๏ธ ์ด๋ฏธ์ง๋ฅผ ์ ๋ก๋ํด์ฃผ์ธ์!" | |
| if not prompt.strip(): | |
| return None, "โ ๏ธ ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์!" | |
| # ๋ชจ๋ธ ๋ก๋ (img2img ํ์ดํ๋ผ์ธ) | |
| progress(0.1, desc="๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| pipe, status = load_model(model_name, "img2img") | |
| if pipe is None: | |
| return None, status | |
| # ๋ค๊ฑฐํฐ๋ธ ํ๋กฌํํธ ์ค์ | |
| if negative_prompt.strip(): | |
| full_negative = f"{negative_prompt}, {DEFAULT_NEGATIVE}" | |
| else: | |
| full_negative = DEFAULT_NEGATIVE | |
| # ์๋ ์ค์ | |
| if seed == -1: | |
| seed = torch.randint(0, 2**32 - 1, (1,)).item() | |
| generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| try: | |
| progress(0.3, desc="์ด๋ฏธ์ง ๋ณํ ์ค...") | |
| print(f"๐ผ๏ธ [img2img] ์ด๋ฏธ์ง ๋ณํ ์ค... ํ๋กฌํํธ: {prompt[:50]}...") | |
| # ์ ๋ ฅ ์ด๋ฏธ์ง๋ฅผ RGB๋ก ๋ณํํ๊ณ ํฌ๊ธฐ ์กฐ์ | |
| input_image = input_image.convert("RGB") | |
| # ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ 64์ ๋ฐฐ์๋ก ์กฐ์ (SD ์๊ตฌ์ฌํญ) | |
| w, h = input_image.size | |
| w = (w // 64) * 64 | |
| h = (h // 64) * 64 | |
| if w == 0: | |
| w = 512 | |
| if h == 0: | |
| h = 512 | |
| input_image = input_image.resize((w, h), Image.LANCZOS) | |
| result = pipe( | |
| prompt=prompt, | |
| image=input_image, | |
| negative_prompt=full_negative, | |
| strength=strength, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator | |
| ) | |
| image = result.images[0] | |
| progress(1.0, desc="์๋ฃ!") | |
| print("โ [img2img] ์ด๋ฏธ์ง ๋ณํ ์๋ฃ!") | |
| return image, f"โ ๋ณํ ์๋ฃ! (์๋: {seed}, ๊ฐ๋: {strength})" | |
| except Exception as e: | |
| print(f"โ [img2img] ์ด๋ฏธ์ง ๋ณํ ์คํจ: {e}") | |
| return None, f"โ ์ด๋ฏธ์ง ๋ณํ ์คํจ: {str(e)}" | |
| # Gradio ์ธํฐํ์ด์ค ์์ฑ | |
| def create_interface(): | |
| """Gradio ์น ์ธํฐํ์ด์ค ์์ฑ""" | |
| # ์ปค์คํ CSS | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| max-width: 1200px !important; | |
| } | |
| .generate-btn { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| border: none !important; | |
| color: white !important; | |
| font-weight: bold !important; | |
| font-size: 1.2em !important; | |
| padding: 15px 30px !important; | |
| border-radius: 10px !important; | |
| transition: all 0.3s ease !important; | |
| width: 100% !important; | |
| } | |
| .generate-btn:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 20px rgba(102, 126, 234, 0.5) !important; | |
| } | |
| .title { | |
| text-align: center; | |
| background: linear-gradient(135deg, #ff6b9d 0%, #c44569 50%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-size: 2.8em; | |
| font-weight: bold; | |
| margin-bottom: 5px; | |
| } | |
| .subtitle { | |
| text-align: center; | |
| color: #888; | |
| font-size: 1.1em; | |
| margin-bottom: 25px; | |
| } | |
| .model-dropdown { | |
| border: 2px solid #764ba2 !important; | |
| border-radius: 8px !important; | |
| } | |
| .output-image { | |
| border-radius: 12px !important; | |
| box-shadow: 0 4px 15px rgba(0,0,0,0.1) !important; | |
| } | |
| .status-box { | |
| background: linear-gradient(135deg, #f5f7fa 0%, #e4e8eb 100%); | |
| border-radius: 8px; | |
| padding: 10px; | |
| text-align: center; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css, title="Stable Diffusion WebUI - Anime") as demo: | |
| # ํค๋ | |
| gr.HTML(""" | |
| <div class="title">๐ธ Anime Diffusion WebUI</div> | |
| <div class="subtitle">์ ๋๋ฉ์ด์ ์คํ์ผ ์ด๋ฏธ์ง ์์ฑ๊ธฐ | Text-to-Image & Image-to-Image</div> | |
| """) | |
| # ํญ์ผ๋ก txt2img / img2img ๋ถ๋ฆฌ | |
| with gr.Tabs(): | |
| # ============================================ | |
| # ํญ 1: ํ ์คํธ โ ์ด๋ฏธ์ง (txt2img) | |
| # ============================================ | |
| with gr.TabItem("๐จ ํ ์คํธ โ ์ด๋ฏธ์ง"): | |
| with gr.Row(): | |
| # ์ผ์ชฝ: ์ ๋ ฅ ํจ๋ | |
| with gr.Column(scale=1): | |
| txt2img_model = gr.Dropdown( | |
| label="๐ค ๋ชจ๋ธ ์ ํ", | |
| choices=list(MODELS.keys()), | |
| value="๐จ Mistoon Anime V3 (์นดํฐํ ์ ๋๋ฉ์ด์ )", | |
| elem_classes=["model-dropdown"] | |
| ) | |
| txt2img_prompt = gr.Textbox( | |
| label="๐ ํ๋กฌํํธ", | |
| placeholder="1girl, anime style, beautiful, masterpiece, best quality", | |
| lines=3 | |
| ) | |
| txt2img_negative = gr.Textbox( | |
| label="๐ซ ๋ค๊ฑฐํฐ๋ธ ํ๋กฌํํธ", | |
| placeholder="์ถ๊ฐํ ๋ค๊ฑฐํฐ๋ธ ํ๋กฌํํธ (๊ธฐ๋ณธ๊ฐ์ด ์๋ ์ ์ฉ๋ฉ๋๋ค)", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| txt2img_width = gr.Slider(label="๐ ๋๋น", minimum=256, maximum=768, value=512, step=64) | |
| txt2img_height = gr.Slider(label="๐ ๋์ด", minimum=256, maximum=768, value=768, step=64) | |
| with gr.Row(): | |
| txt2img_steps = gr.Slider(label="๐ ์คํ ์", minimum=10, maximum=50, value=25, step=1) | |
| txt2img_guidance = gr.Slider(label="๐ฏ CFG ์ค์ผ์ผ", minimum=1.0, maximum=15.0, value=7.0, step=0.5) | |
| txt2img_seed = gr.Number(label="๐ฒ ์๋ (-1 = ๋๋ค)", value=-1, precision=0) | |
| txt2img_btn = gr.Button("๐ ์ด๋ฏธ์ง ์์ฑ", elem_classes=["generate-btn"]) | |
| txt2img_status = gr.Textbox(label="๐ ์ํ", value="ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์", interactive=False) | |
| # ์ค๋ฅธ์ชฝ: ์ถ๋ ฅ ํจ๋ | |
| with gr.Column(scale=1): | |
| txt2img_output = gr.Image(label="๐ผ๏ธ ์์ฑ๋ ์ด๋ฏธ์ง", type="pil", elem_classes=["output-image"]) | |
| # ์์ | |
| gr.Examples( | |
| examples=[ | |
| ["๐จ Mistoon Anime V3 (์นดํฐํ ์ ๋๋ฉ์ด์ )", "1girl, solo, colorful, vibrant colors, cartoon style, school uniform, masterpiece", ""], | |
| ["๐ธ Anything V5 (์ ๋๋ฉ์ด์ )", "1girl, solo, long blue hair, cherry blossoms, detailed, masterpiece, best quality", ""], | |
| ["๐ Counterfeit V3 (๊ณ ํ์ง ์ ๋๋ฉ์ด์ )", "1girl, kimono, japanese garden, autumn leaves, ultra detailed, 8k", ""], | |
| ], | |
| inputs=[txt2img_model, txt2img_prompt, txt2img_negative], | |
| label="๐ก ์์ ํ๋กฌํํธ" | |
| ) | |
| # ์ด๋ฒคํธ ํธ๋ค๋ฌ | |
| txt2img_btn.click( | |
| fn=generate_txt2img, | |
| inputs=[txt2img_model, txt2img_prompt, txt2img_negative, txt2img_steps, txt2img_guidance, txt2img_width, txt2img_height, txt2img_seed], | |
| outputs=[txt2img_output, txt2img_status] | |
| ) | |
| # ============================================ | |
| # ํญ 2: ์ด๋ฏธ์ง โ ์ด๋ฏธ์ง (img2img) | |
| # ============================================ | |
| with gr.TabItem("๐ผ๏ธ ์ด๋ฏธ์ง โ ์ด๋ฏธ์ง"): | |
| with gr.Row(): | |
| # ์ผ์ชฝ: ์ ๋ ฅ ํจ๋ | |
| with gr.Column(scale=1): | |
| img2img_model = gr.Dropdown( | |
| label="๐ค ๋ชจ๋ธ ์ ํ", | |
| choices=list(MODELS.keys()), | |
| value="๐จ Mistoon Anime V3 (์นดํฐํ ์ ๋๋ฉ์ด์ )", | |
| elem_classes=["model-dropdown"] | |
| ) | |
| img2img_input = gr.Image( | |
| label="๐ค ์ ๋ ฅ ์ด๋ฏธ์ง (์ค์ฌ, ์ค์ผ์น ๋ฑ)", | |
| type="pil", | |
| height=200 | |
| ) | |
| img2img_prompt = gr.Textbox( | |
| label="๐ ํ๋กฌํํธ (๋ณํํ ์คํ์ผ)", | |
| placeholder="anime style, colorful, masterpiece, best quality", | |
| lines=3 | |
| ) | |
| img2img_negative = gr.Textbox( | |
| label="๐ซ ๋ค๊ฑฐํฐ๋ธ ํ๋กฌํํธ", | |
| placeholder="์ถ๊ฐํ ๋ค๊ฑฐํฐ๋ธ ํ๋กฌํํธ", | |
| lines=2 | |
| ) | |
| img2img_strength = gr.Slider( | |
| label="๐ช ๋ณํ ๊ฐ๋ (0.0=์๋ณธ ์ ์ง, 1.0=์์ ๋ณํ)", | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.75, | |
| step=0.05 | |
| ) | |
| with gr.Row(): | |
| img2img_steps = gr.Slider(label="๐ ์คํ ์", minimum=10, maximum=50, value=25, step=1) | |
| img2img_guidance = gr.Slider(label="๐ฏ CFG ์ค์ผ์ผ", minimum=1.0, maximum=15.0, value=7.0, step=0.5) | |
| img2img_seed = gr.Number(label="๐ฒ ์๋ (-1 = ๋๋ค)", value=-1, precision=0) | |
| img2img_btn = gr.Button("๐ ์ด๋ฏธ์ง ๋ณํ", elem_classes=["generate-btn"]) | |
| img2img_status = gr.Textbox(label="๐ ์ํ", value="์ด๋ฏธ์ง์ ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์", interactive=False) | |
| # ์ค๋ฅธ์ชฝ: ์ถ๋ ฅ ํจ๋ | |
| with gr.Column(scale=1): | |
| img2img_output = gr.Image(label="๐ผ๏ธ ๋ณํ๋ ์ด๋ฏธ์ง", type="pil", elem_classes=["output-image"]) | |
| # img2img ๊ฐ์ด๋ | |
| with gr.Accordion("๐ img2img ์ฌ์ฉ ๊ฐ์ด๋", open=False): | |
| gr.Markdown(""" | |
| ### ๐ผ๏ธ Image-to-Image ๋ณํ ๊ฐ์ด๋ | |
| **์ฌ์ฉ ๋ฐฉ๋ฒ**: | |
| 1. ๋ณํํ๊ณ ์ถ์ ์ด๋ฏธ์ง๋ฅผ ์ ๋ก๋ํฉ๋๋ค (์ค์ฌ ์ฌ์ง, ์ค์ผ์น ๋ฑ) | |
| 2. ์ํ๋ ์คํ์ผ์ ํ๋กฌํํธ๋ก ์ ๋ ฅํฉ๋๋ค | |
| 3. ๋ณํ ๊ฐ๋๋ฅผ ์กฐ์ ํฉ๋๋ค | |
| **๋ณํ ๊ฐ๋ (Strength) ์ค๋ช **: | |
| - `0.3` - ์๋ณธ ์ด๋ฏธ์ง๋ฅผ ๋ง์ด ์ ์ง (๋ฏธ์ธํ ์คํ์ผ ๋ณํ) | |
| - `0.5` - ๊ท ํ ์กํ ๋ณํ | |
| - `0.75` - ํ๋กฌํํธ์ ๋ ์ถฉ์คํ ๋ณํ (๊ถ์ฅ) | |
| - `1.0` - ๊ฑฐ์ ์๋ก ์์ฑ (์๋ณธ ๋ฌด์) | |
| **์ถ์ฒ ํ๋กฌํํธ**: | |
| - ์ค์ฌ โ ์ ๋๋ฉ์ด์ : `anime style, colorful, masterpiece` | |
| - ์ค์ผ์น โ ์์ฑ๋ณธ: `detailed illustration, colored, vibrant` | |
| - ์ง๋ธ๋ฆฌ ์คํ์ผ: `studio ghibli style, soft colors, fantasy` | |
| """) | |
| # ์ด๋ฒคํธ ํธ๋ค๋ฌ | |
| img2img_btn.click( | |
| fn=generate_img2img, | |
| inputs=[img2img_model, img2img_input, img2img_prompt, img2img_negative, img2img_strength, img2img_steps, img2img_guidance, img2img_seed], | |
| outputs=[img2img_output, img2img_status] | |
| ) | |
| # ํธํฐ | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 30px; padding: 20px; color: #888; border-top: 1px solid #eee;"> | |
| <p>๐ธ Powered by Diffusers & Gradio | ๐ค Hugging Face Spaces</p> | |
| <p style="font-size: 0.9em;">โ ๏ธ CPU ๋ชจ๋์์๋ ์ด๋ฏธ์ง ์์ฑ์ 2-5๋ถ ์ ๋ ์์๋ฉ๋๋ค.</p> | |
| <p style="font-size: 0.85em; color: #aaa;">์ฒซ ์คํ ์ ๋ชจ๋ธ ๋ค์ด๋ก๋๋ก ์ธํด ์๊ฐ์ด ๋ ๊ฑธ๋ฆด ์ ์์ต๋๋ค.</p> | |
| </div> | |
| """) | |
| return demo | |
| # ================================ | |
| # REST API ์๋ํฌ์ธํธ ์ ์ | |
| # ================================ | |
| # API ์์ฒญ/์๋ต ๋ชจ๋ธ ์ ์ | |
| class GenerateRequest(BaseModel): | |
| """ํ ์คํธ โ ์ด๋ฏธ์ง ์์ฑ ์์ฒญ""" | |
| prompt: str = Field(..., description="์ด๋ฏธ์ง ์์ฑ ํ๋กฌํํธ") | |
| model_name: str = Field(default="๐จ Mistoon Anime V3 (์นดํฐํ ์ ๋๋ฉ์ด์ )", description="์ฌ์ฉํ ๋ชจ๋ธ ์ด๋ฆ") | |
| negative_prompt: str = Field(default="", description="๋ค๊ฑฐํฐ๋ธ ํ๋กฌํํธ") | |
| num_inference_steps: int = Field(default=25, ge=10, le=50, description="์ถ๋ก ์คํ ์") | |
| guidance_scale: float = Field(default=7.5, ge=1.0, le=15.0, description="CFG ์ค์ผ์ผ") | |
| width: int = Field(default=512, ge=256, le=768, description="์ด๋ฏธ์ง ๋๋น") | |
| height: int = Field(default=512, ge=256, le=768, description="์ด๋ฏธ์ง ๋์ด") | |
| seed: int = Field(default=-1, description="์๋ ๊ฐ (-1์ด๋ฉด ๋๋ค)") | |
| class Img2ImgRequest(BaseModel): | |
| """์ด๋ฏธ์ง โ ์ด๋ฏธ์ง ๋ณํ ์์ฒญ""" | |
| image_base64: str = Field(..., description="์ ๋ ฅ ์ด๋ฏธ์ง (Base64 ์ธ์ฝ๋ฉ)") | |
| prompt: str = Field(..., description="๋ณํ ํ๋กฌํํธ") | |
| model_name: str = Field(default="๐จ Mistoon Anime V3 (์นดํฐํ ์ ๋๋ฉ์ด์ )", description="์ฌ์ฉํ ๋ชจ๋ธ ์ด๋ฆ") | |
| negative_prompt: str = Field(default="", description="๋ค๊ฑฐํฐ๋ธ ํ๋กฌํํธ") | |
| strength: float = Field(default=0.75, ge=0.1, le=1.0, description="๋ณํ ๊ฐ๋") | |
| num_inference_steps: int = Field(default=25, ge=10, le=50, description="์ถ๋ก ์คํ ์") | |
| guidance_scale: float = Field(default=7.5, ge=1.0, le=15.0, description="CFG ์ค์ผ์ผ") | |
| seed: int = Field(default=-1, description="์๋ ๊ฐ (-1์ด๋ฉด ๋๋ค)") | |
| class GenerateResponse(BaseModel): | |
| """์ด๋ฏธ์ง ์์ฑ ์๋ต""" | |
| success: bool | |
| message: str | |
| image_base64: Optional[str] = None | |
| seed: Optional[int] = None | |
| class ModelsResponse(BaseModel): | |
| """๋ชจ๋ธ ๋ชฉ๋ก ์๋ต""" | |
| models: list[str] | |
| # FastAPI ์ฑ ์์ฑ | |
| api_app = FastAPI( | |
| title="Anime Diffusion API", | |
| description="์ ๋๋ฉ์ด์ ์คํ์ผ ์ด๋ฏธ์ง ์์ฑ REST API (txt2img + img2img)", | |
| version="2.0.0" | |
| ) | |
| # CORS ์ค์ ์ถ๊ฐ (์ธ๋ถ ํธ์ถ ํ์ฉ) | |
| api_app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def get_models(): | |
| """์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ ๋ชฉ๋ก ์กฐํ""" | |
| return ModelsResponse(models=list(MODELS.keys())) | |
| async def api_generate_txt2img(request: GenerateRequest): | |
| """ | |
| ํ ์คํธ โ ์ด๋ฏธ์ง ์์ฑ API | |
| ํ๋กฌํํธ๋ฅผ ์ ๋ฌํ๋ฉด Base64๋ก ์ธ์ฝ๋ฉ๋ ์ด๋ฏธ์ง๋ฅผ ๋ฐํํฉ๋๋ค. | |
| """ | |
| global pipe | |
| if not request.prompt.strip(): | |
| raise HTTPException(status_code=400, detail="ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์") | |
| if request.model_name not in MODELS: | |
| raise HTTPException(status_code=400, detail=f"์ ์ ์๋ ๋ชจ๋ธ์ ๋๋ค. ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ: {list(MODELS.keys())}") | |
| # ๋ชจ๋ธ ๋ก๋ | |
| pipe, status = load_model(request.model_name, "txt2img") | |
| if pipe is None: | |
| raise HTTPException(status_code=500, detail=status) | |
| # ๋ค๊ฑฐํฐ๋ธ ํ๋กฌํํธ ์ค์ | |
| if request.negative_prompt.strip(): | |
| full_negative = f"{request.negative_prompt}, {DEFAULT_NEGATIVE}" | |
| else: | |
| full_negative = DEFAULT_NEGATIVE | |
| # ์๋ ์ค์ | |
| seed = request.seed | |
| if seed == -1: | |
| seed = torch.randint(0, 2**32 - 1, (1,)).item() | |
| generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| try: | |
| print(f"๐จ [API txt2img] ์ด๋ฏธ์ง ์์ฑ ์ค... ํ๋กฌํํธ: {request.prompt[:50]}...") | |
| result = pipe( | |
| prompt=request.prompt, | |
| negative_prompt=full_negative, | |
| num_inference_steps=request.num_inference_steps, | |
| guidance_scale=request.guidance_scale, | |
| width=request.width, | |
| height=request.height, | |
| generator=generator | |
| ) | |
| image = result.images[0] | |
| # ์ด๋ฏธ์ง๋ฅผ Base64๋ก ์ธ์ฝ๋ฉ | |
| buffer = io.BytesIO() | |
| image.save(buffer, format="PNG") | |
| buffer.seek(0) | |
| image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| print(f"โ [API txt2img] ์ด๋ฏธ์ง ์์ฑ ์๋ฃ! (์๋: {seed})") | |
| return GenerateResponse( | |
| success=True, | |
| message="์ด๋ฏธ์ง ์์ฑ ์๋ฃ", | |
| image_base64=image_base64, | |
| seed=seed | |
| ) | |
| except Exception as e: | |
| print(f"โ [API txt2img] ์ด๋ฏธ์ง ์์ฑ ์คํจ: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def api_generate_img2img(request: Img2ImgRequest): | |
| """ | |
| ์ด๋ฏธ์ง โ ์ด๋ฏธ์ง ๋ณํ API | |
| Base64 ์ด๋ฏธ์ง์ ํ๋กฌํํธ๋ฅผ ์ ๋ฌํ๋ฉด ๋ณํ๋ ์ด๋ฏธ์ง๋ฅผ ๋ฐํํฉ๋๋ค. | |
| """ | |
| global pipe | |
| if not request.prompt.strip(): | |
| raise HTTPException(status_code=400, detail="ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์") | |
| if request.model_name not in MODELS: | |
| raise HTTPException(status_code=400, detail=f"์ ์ ์๋ ๋ชจ๋ธ์ ๋๋ค. ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ: {list(MODELS.keys())}") | |
| # Base64 ์ด๋ฏธ์ง ๋์ฝ๋ฉ | |
| try: | |
| image_data = base64.b64decode(request.image_base64) | |
| input_image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"์ด๋ฏธ์ง ๋์ฝ๋ฉ ์คํจ: {str(e)}") | |
| # ๋ชจ๋ธ ๋ก๋ (img2img) | |
| pipe, status = load_model(request.model_name, "img2img") | |
| if pipe is None: | |
| raise HTTPException(status_code=500, detail=status) | |
| # ๋ค๊ฑฐํฐ๋ธ ํ๋กฌํํธ ์ค์ | |
| if request.negative_prompt.strip(): | |
| full_negative = f"{request.negative_prompt}, {DEFAULT_NEGATIVE}" | |
| else: | |
| full_negative = DEFAULT_NEGATIVE | |
| # ์๋ ์ค์ | |
| seed = request.seed | |
| if seed == -1: | |
| seed = torch.randint(0, 2**32 - 1, (1,)).item() | |
| generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| try: | |
| print(f"๐ผ๏ธ [API img2img] ์ด๋ฏธ์ง ๋ณํ ์ค... ํ๋กฌํํธ: {request.prompt[:50]}...") | |
| # ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ 64์ ๋ฐฐ์๋ก ์กฐ์ | |
| w, h = input_image.size | |
| w = (w // 64) * 64 | |
| h = (h // 64) * 64 | |
| if w == 0: | |
| w = 512 | |
| if h == 0: | |
| h = 512 | |
| input_image = input_image.resize((w, h), Image.LANCZOS) | |
| result = pipe( | |
| prompt=request.prompt, | |
| image=input_image, | |
| negative_prompt=full_negative, | |
| strength=request.strength, | |
| num_inference_steps=request.num_inference_steps, | |
| guidance_scale=request.guidance_scale, | |
| generator=generator | |
| ) | |
| image = result.images[0] | |
| # ์ด๋ฏธ์ง๋ฅผ Base64๋ก ์ธ์ฝ๋ฉ | |
| buffer = io.BytesIO() | |
| image.save(buffer, format="PNG") | |
| buffer.seek(0) | |
| image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| print(f"โ [API img2img] ์ด๋ฏธ์ง ๋ณํ ์๋ฃ! (์๋: {seed})") | |
| return GenerateResponse( | |
| success=True, | |
| message="์ด๋ฏธ์ง ๋ณํ ์๋ฃ", | |
| image_base64=image_base64, | |
| seed=seed | |
| ) | |
| except Exception as e: | |
| print(f"โ [API img2img] ์ด๋ฏธ์ง ๋ณํ ์คํจ: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| """์๋ฒ ์ํ ํ์ธ""" | |
| return { | |
| "status": "healthy", | |
| "device": DEVICE, | |
| "model_loaded": current_model_id is not None, | |
| "pipeline_type": current_pipeline_type | |
| } | |
| # ================================ | |
| # ๋ฉ์ธ ์คํ | |
| # ================================ | |
| if __name__ == "__main__": | |
| print("๐ธ Anime Diffusion WebUI + API ์์...") | |
| print(" - txt2img: ํ ์คํธ โ ์ด๋ฏธ์ง ์์ฑ") | |
| print(" - img2img: ์ด๋ฏธ์ง โ ์ด๋ฏธ์ง ๋ณํ") | |
| # Gradio ์ฑ ์์ฑ | |
| demo = create_interface() | |
| # FastAPI์ Gradio ๋ง์ดํธ | |
| app = gr.mount_gradio_app(api_app, demo, path="/") | |
| # uvicorn์ผ๋ก ํตํฉ ์๋ฒ ์คํ | |
| import uvicorn | |
| print("๐ก API ๋ฌธ์: http://localhost:7860/docs") | |
| print("๐ ์น UI: http://localhost:7860/") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |