| | import os |
| | import uuid |
| | import time |
| | import random |
| |
|
| | import spaces |
| | import gradio as gr |
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| | from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler |
| | from compel import Compel, ReturnedEmbeddingsType |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | pipe = StableDiffusionXLPipeline.from_pretrained( |
| | "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl", |
| | torch_dtype=torch.float16, |
| | variant="fp16", |
| | use_safetensors=True, |
| | ) |
| |
|
| | pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) |
| | pipe.to(device) |
| |
|
| | pipe.text_encoder.to(torch.float16) |
| | pipe.text_encoder_2.to(torch.float16) |
| | pipe.vae.to(torch.float16) |
| | pipe.unet.to(torch.float16) |
| |
|
| | compel = Compel( |
| | tokenizer=[pipe.tokenizer, pipe.tokenizer_2], |
| | text_encoder=[pipe.text_encoder, pipe.text_encoder_2], |
| | returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, |
| | requires_pooled=[False, True], |
| | truncate_long_prompts=False, |
| | ) |
| |
|
| | MAX_SEED = np.iinfo(np.int32).max |
| | MAX_IMAGE_SIZE = 1216 |
| |
|
| | OUTPUT_DIR = "/tmp/generated_images" |
| | os.makedirs(OUTPUT_DIR, exist_ok=True) |
| |
|
| |
|
| | def save_image_jpg(pil_image: Image.Image) -> str: |
| | if pil_image.mode != "RGB": |
| | pil_image = pil_image.convert("RGB") |
| | path = os.path.join(OUTPUT_DIR, f"{uuid.uuid4().hex}.jpg") |
| | pil_image.save(path, "JPEG", quality=95) |
| | return path |
| |
|
| |
|
| | @spaces.GPU(duration=15) |
| | def infer( |
| | prompt, |
| | negative_prompt, |
| | seed, |
| | randomize_seed, |
| | width, |
| | height, |
| | guidance_scale, |
| | num_inference_steps, |
| | ): |
| | if not prompt.strip(): |
| | raise gr.Error("Prompt cannot be empty.") |
| |
|
| | if randomize_seed: |
| | seed = random.randint(0, MAX_SEED) |
| |
|
| | generator = torch.Generator(device=device).manual_seed(seed) |
| |
|
| | try: |
| | conditioning, pooled = compel([prompt, negative_prompt]) |
| |
|
| | prompt_embeds = conditioning[0:1] |
| | pooled_prompt_embeds = pooled[0:1] |
| | negative_prompt_embeds = conditioning[1:2] |
| | negative_pooled_prompt_embeds = pooled[1:2] |
| |
|
| | image = pipe( |
| | prompt_embeds=prompt_embeds, |
| | pooled_prompt_embeds=pooled_prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
| | guidance_scale=guidance_scale, |
| | num_inference_steps=num_inference_steps, |
| | width=width, |
| | height=height, |
| | generator=generator, |
| | ).images[0] |
| |
|
| | image_path = save_image_jpg(image) |
| | return image_path, seed |
| |
|
| | except RuntimeError as e: |
| | print(f"Error during generation: {e}") |
| | blank_image = Image.new("RGB", (width, height), color=(0, 0, 0)) |
| | blank_path = save_image_jpg(blank_image) |
| | return blank_path, seed |
| |
|
| |
|
| | def generation_loop( |
| | prompt, |
| | negative_prompt, |
| | current_seed, |
| | randomize_seed, |
| | width, |
| | height, |
| | guidance_scale, |
| | num_inference_steps, |
| | interval_sec, |
| | ): |
| | if not prompt.strip(): |
| | raise gr.Error("Prompt cannot be empty to start consecutive generation.") |
| |
|
| | while True: |
| | try: |
| | image_path, new_seed = infer( |
| | prompt, |
| | negative_prompt, |
| | current_seed, |
| | True, |
| | width, |
| | height, |
| | guidance_scale, |
| | num_inference_steps, |
| | ) |
| |
|
| | yield {result: image_path, seed: new_seed} |
| | time.sleep(interval_sec) |
| |
|
| | except gr.exceptions.CancelledError: |
| | print("Generation loop cancelled by user.") |
| | break |
| |
|
| |
|
| | css = """ |
| | #col-container { |
| | margin: 0 auto; |
| | max-width: 1024px; |
| | } |
| | |
| | /* 完全透過(非表示だがクリック等は可能なまま) */ |
| | .transparent-btn, |
| | .transparent-btn * { |
| | opacity: 0 !important; |
| | } |
| | |
| | .transparent-btn button { |
| | background: transparent !important; |
| | border: 0 !important; |
| | box-shadow: none !important; |
| | } |
| | |
| | .transparent-btn button:focus, |
| | .transparent-btn button:focus-visible { |
| | outline: none !important; |
| | } |
| | """ |
| |
|
| | with gr.Blocks(css=css) as demo: |
| | with gr.Column(elem_id="col-container"): |
| | gr.Markdown("<br>" * 1) |
| | |
| | with gr.Row(equal_height=True): |
| | prompt = gr.Text( |
| | label="Prompt", |
| | show_label=False, |
| | max_lines=1, |
| | placeholder="Enter your prompt", |
| | value="", |
| | container=False, |
| | scale=1, |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| | result = gr.Image(format="jpeg", show_label=False, interactive=False, elem_id="result_image") |
| |
|
| | |
| | with gr.Row(equal_height=True): |
| | run_button = gr.Button("Generate", scale=0, interactive=False, elem_classes=["transparent-btn"]) |
| | consecutive_button = gr.Button("Consecutive", scale=0, interactive=False, elem_classes=["transparent-btn"]) |
| |
|
| | gr.Markdown("<br>" * 20) |
| |
|
| | |
| | with gr.Row(): |
| | stop_button = gr.Button("Stop", scale=0, visible=True, interactive=True) |
| | clear_button = gr.Button("Trash", scale=0, variant="secondary") |
| |
|
| | |
| | with gr.Row(equal_height=True): |
| | copy_button = gr.Button("Copy", scale=0, variant="secondary") |
| | image_url = gr.Textbox( |
| | label="Image URL", |
| | show_label=False, |
| | interactive=False, |
| | max_lines=2, |
| | placeholder="生成後、ここに外部URLが表示されます", |
| | scale=1, |
| | ) |
| |
|
| | with gr.Accordion("Advanced Settings", open=False): |
| | negative_prompt = gr.Text( |
| | label="Negative prompt", |
| | max_lines=1, |
| | placeholder="Enter a negative prompt", |
| | value="photoreal, bad quality, low quality, worst quality, worst detail, bad anatomy, extra hand, viewer's hand", |
| | ) |
| |
|
| | seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) |
| | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
| |
|
| | with gr.Row(): |
| | width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) |
| | height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) |
| |
|
| | with gr.Row(): |
| | guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=20.0, step=0.1, value=8) |
| | num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=28, step=1, value=25) |
| |
|
| | interval_seconds = gr.Slider(label="Interval (seconds)", minimum=1, maximum=60, step=1, value=1) |
| |
|
| | gr.Markdown("<br>" * 20) |
| |
|
| | gr.Examples( |
| | examples=[ |
| | ["masterpiece, solo, A little girl with blonde short side tails, red eyes, "], |
| | ], |
| | inputs=[prompt], |
| | label="Examples (Click to copy to prompt)", |
| | ) |
| |
|
| | |
| | prompt.input( |
| | fn=None, |
| | inputs=[prompt], |
| | outputs=[run_button, consecutive_button], |
| | js="(p) => { const interactive = p.trim().length > 0; return [{ interactive: interactive, '__type__': 'update' }, { interactive: interactive, '__type__': 'update' }]; }", |
| | ) |
| |
|
| | |
| | clear_button.click( |
| | fn=None, |
| | inputs=None, |
| | outputs=[prompt, run_button, consecutive_button, image_url], |
| | js=""" |
| | function() { |
| | return [ |
| | "", |
| | { "interactive": false, "__type__": "update" }, |
| | { "interactive": false, "__type__": "update" }, |
| | "" |
| | ]; |
| | } |
| | """, |
| | ) |
| |
|
| | |
| | run_button.click( |
| | fn=infer, |
| | inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], |
| | outputs=[result, seed], |
| | ) |
| |
|
| | |
| | gen_inputs = [ |
| | prompt, |
| | negative_prompt, |
| | seed, |
| | randomize_seed, |
| | width, |
| | height, |
| | guidance_scale, |
| | num_inference_steps, |
| | interval_seconds, |
| | ] |
| |
|
| | consecutive_event = consecutive_button.click( |
| | fn=generation_loop, |
| | inputs=gen_inputs, |
| | outputs=[result, seed], |
| | ) |
| |
|
| | |
| | stop_button.click( |
| | fn=None, |
| | inputs=None, |
| | outputs=None, |
| | cancels=[consecutive_event], |
| | ) |
| |
|
| | |
| | result.change( |
| | fn=None, |
| | inputs=None, |
| | outputs=[image_url], |
| | js=r""" |
| | () => { |
| | const img = document.querySelector("#result_image img"); |
| | if (!img || !img.src) return ""; |
| | return new URL(img.src, window.location.href).href; |
| | } |
| | """, |
| | ) |
| |
|
| | |
| | copy_button.click( |
| | fn=None, |
| | inputs=[image_url], |
| | outputs=None, |
| | js=r""" |
| | async (url) => { |
| | if (!url) return; |
| | try { |
| | await navigator.clipboard.writeText(url); |
| | console.log("URL copied"); |
| | } catch (e) { |
| | console.error("Copy failed", e); |
| | } |
| | } |
| | """, |
| | ) |
| |
|
| | demo.queue().launch() |