Wai_0 / app.py
mcuo's picture
Upload app.py
c411baa verified
import os
import uuid
import time
import random
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
from compel import Compel, ReturnedEmbeddingsType
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe = StableDiffusionXLPipeline.from_pretrained(
"dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(device)
pipe.text_encoder.to(torch.float16)
pipe.text_encoder_2.to(torch.float16)
pipe.vae.to(torch.float16)
pipe.unet.to(torch.float16)
compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True],
truncate_long_prompts=False,
)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1216
OUTPUT_DIR = "/tmp/generated_images"
os.makedirs(OUTPUT_DIR, exist_ok=True)
def save_image_jpg(pil_image: Image.Image) -> str:
if pil_image.mode != "RGB":
pil_image = pil_image.convert("RGB")
path = os.path.join(OUTPUT_DIR, f"{uuid.uuid4().hex}.jpg")
pil_image.save(path, "JPEG", quality=95)
return path
@spaces.GPU(duration=15)
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
):
if not prompt.strip():
raise gr.Error("Prompt cannot be empty.")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
try:
conditioning, pooled = compel([prompt, negative_prompt])
prompt_embeds = conditioning[0:1]
pooled_prompt_embeds = pooled[0:1]
negative_prompt_embeds = conditioning[1:2]
negative_pooled_prompt_embeds = pooled[1:2]
image = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
image_path = save_image_jpg(image)
return image_path, seed
except RuntimeError as e:
print(f"Error during generation: {e}")
blank_image = Image.new("RGB", (width, height), color=(0, 0, 0))
blank_path = save_image_jpg(blank_image)
return blank_path, seed
def generation_loop(
prompt,
negative_prompt,
current_seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
interval_sec,
):
if not prompt.strip():
raise gr.Error("Prompt cannot be empty to start consecutive generation.")
while True:
try:
image_path, new_seed = infer(
prompt,
negative_prompt,
current_seed,
True, # 連続生成は毎回seedを変える
width,
height,
guidance_scale,
num_inference_steps,
)
yield {result: image_path, seed: new_seed}
time.sleep(interval_sec)
except gr.exceptions.CancelledError:
print("Generation loop cancelled by user.")
break
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
/* 完全透過(非表示だがクリック等は可能なまま) */
.transparent-btn,
.transparent-btn * {
opacity: 0 !important;
}
.transparent-btn button {
background: transparent !important;
border: 0 !important;
box-shadow: none !important;
}
.transparent-btn button:focus,
.transparent-btn button:focus-visible {
outline: none !important;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("<br>" * 1)
# Prompt(右にGenerateは置かない)
with gr.Row(equal_height=True):
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
value="",
container=False,
scale=1,
)
# 画像表示
with gr.Row():
result = gr.Image(format="jpeg", show_label=False, interactive=False, elem_id="result_image")
# 画像表示欄のすぐ下、20行改行のすぐ上:GenerateとConsecutiveを横並び、完全透過
with gr.Row(equal_height=True):
run_button = gr.Button("Generate", scale=0, interactive=False, elem_classes=["transparent-btn"])
consecutive_button = gr.Button("Consecutive", scale=0, interactive=False, elem_classes=["transparent-btn"])
gr.Markdown("<br>" * 20)
# 停止/クリア
with gr.Row():
stop_button = gr.Button("Stop", scale=0, visible=True, interactive=True)
clear_button = gr.Button("Trash", scale=0, variant="secondary")
# Copy の右に URL欄(Trash と Advanced Settings の間)
with gr.Row(equal_height=True):
copy_button = gr.Button("Copy", scale=0, variant="secondary")
image_url = gr.Textbox(
label="Image URL",
show_label=False,
interactive=False,
max_lines=2,
placeholder="生成後、ここに外部URLが表示されます",
scale=1,
)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
value="photoreal, bad quality, low quality, worst quality, worst detail, bad anatomy, extra hand, viewer's hand",
)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
with gr.Row():
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=20.0, step=0.1, value=8)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=28, step=1, value=25)
interval_seconds = gr.Slider(label="Interval (seconds)", minimum=1, maximum=60, step=1, value=1)
gr.Markdown("<br>" * 20)
gr.Examples(
examples=[
["masterpiece, solo, A little girl with blonde short side tails, red eyes, "],
],
inputs=[prompt],
label="Examples (Click to copy to prompt)",
)
# Promptが空でなければボタンを押せるようにする
prompt.input(
fn=None,
inputs=[prompt],
outputs=[run_button, consecutive_button],
js="(p) => { const interactive = p.trim().length > 0; return [{ interactive: interactive, '__type__': 'update' }, { interactive: interactive, '__type__': 'update' }]; }",
)
# クリア:promptを空にしてボタン無効、URL欄も空にする
clear_button.click(
fn=None,
inputs=None,
outputs=[prompt, run_button, consecutive_button, image_url],
js="""
function() {
return [
"",
{ "interactive": false, "__type__": "update" },
{ "interactive": false, "__type__": "update" },
""
];
}
""",
)
# 生成
run_button.click(
fn=infer,
inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result, seed],
)
# 連続生成
gen_inputs = [
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
interval_seconds,
]
consecutive_event = consecutive_button.click(
fn=generation_loop,
inputs=gen_inputs,
outputs=[result, seed],
)
# 停止
stop_button.click(
fn=None,
inputs=None,
outputs=None,
cancels=[consecutive_event],
)
# resultが更新されたら、表示中のimg.src(= /file=...)を拾って表示
result.change(
fn=None,
inputs=None,
outputs=[image_url],
js=r"""
() => {
const img = document.querySelector("#result_image img");
if (!img || !img.src) return "";
return new URL(img.src, window.location.href).href;
}
""",
)
# Copyボタン:URL文字列をコピー
copy_button.click(
fn=None,
inputs=[image_url],
outputs=None,
js=r"""
async (url) => {
if (!url) return;
try {
await navigator.clipboard.writeText(url);
console.log("URL copied");
} catch (e) {
console.error("Copy failed", e);
}
}
""",
)
demo.queue().launch()