|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import math |
|
|
import torch |
|
|
import gradio as gr |
|
|
from typing import List, Optional |
|
|
from PIL import Image |
|
|
from diffusers import ( |
|
|
DiffusionPipeline, |
|
|
StableDiffusionPipeline, |
|
|
AutoPipelineForText2Image, |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_CHOICES = { |
|
|
|
|
|
"Stable Diffusion 1.5 (runwayml/stable-diffusion-v1-5)": "runwayml/stable-diffusion-v1-5", |
|
|
|
|
|
"SDXL Turbo (stabilityai/sdxl-turbo)": "stabilityai/sdxl-turbo", |
|
|
} |
|
|
|
|
|
DEFAULT_MODEL_LABEL = "Stable Diffusion 1.5 (runwayml/stable-diffusion-v1-5)" |
|
|
|
|
|
|
|
|
DISABLE_SAFETY_DEFAULT = True |
|
|
|
|
|
|
|
|
def get_device() -> str: |
|
|
if torch.cuda.is_available(): |
|
|
return "cuda" |
|
|
|
|
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): |
|
|
return "mps" |
|
|
return "cpu" |
|
|
|
|
|
def nearest_multiple_of_8(x: int) -> int: |
|
|
if x < 64: |
|
|
return 64 |
|
|
return int(round(x / 8) * 8) |
|
|
|
|
|
|
|
|
_PIPE_CACHE = {} |
|
|
|
|
|
def load_pipe(model_id: str, device: str, fp16: bool) -> DiffusionPipeline: |
|
|
key = (model_id, device, fp16) |
|
|
if key in _PIPE_CACHE: |
|
|
return _PIPE_CACHE[key] |
|
|
|
|
|
dtype = torch.float16 if (fp16 and device == "cuda") else torch.float32 |
|
|
|
|
|
|
|
|
try: |
|
|
pipe = AutoPipelineForTextToImage.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=dtype, |
|
|
use_safetensors=True, |
|
|
trust_remote_code=False, |
|
|
) |
|
|
except Exception: |
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=dtype, |
|
|
use_safetensors=True, |
|
|
) |
|
|
|
|
|
|
|
|
pipe = pipe.to(device) |
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
try: |
|
|
pipe.enable_xformers_memory_efficient_attention() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
_PIPE_CACHE[key] = pipe |
|
|
return pipe |
|
|
|
|
|
|
|
|
def generate( |
|
|
prompt: str, |
|
|
negative: str, |
|
|
model_label: str, |
|
|
steps: int, |
|
|
guidance: float, |
|
|
width: int, |
|
|
height: int, |
|
|
seed: Optional[int], |
|
|
batch_size: int, |
|
|
disable_safety: bool, |
|
|
) -> List[Image.Image]: |
|
|
prompt = (prompt or "").strip() |
|
|
if not prompt: |
|
|
raise gr.Error("Enter a non-empty prompt.") |
|
|
|
|
|
model_id = MODEL_CHOICES[model_label] |
|
|
device = get_device() |
|
|
|
|
|
|
|
|
is_turbo = "sdxl-turbo" in model_id.lower() |
|
|
if is_turbo: |
|
|
steps = max(1, min(steps, 6)) |
|
|
guidance = 0.0 |
|
|
|
|
|
width = nearest_multiple_of_8(width) |
|
|
height = nearest_multiple_of_8(height) |
|
|
batch_size = max(1, min(batch_size, 8)) |
|
|
|
|
|
pipe = load_pipe(model_id, device, fp16=(device == "cuda")) |
|
|
|
|
|
|
|
|
if hasattr(pipe, "safety_checker"): |
|
|
pipe.safety_checker = None if disable_safety else pipe.safety_checker |
|
|
|
|
|
|
|
|
generator = None |
|
|
if seed is not None and seed != "": |
|
|
try: |
|
|
seed = int(seed) |
|
|
except ValueError: |
|
|
seed = None |
|
|
if seed is not None: |
|
|
if device == "cuda": |
|
|
generator = torch.Generator(device="cuda").manual_seed(seed) |
|
|
elif device == "mps": |
|
|
generator = torch.Generator(device="cpu").manual_seed(seed) |
|
|
else: |
|
|
generator = torch.Generator(device="cpu").manual_seed(seed) |
|
|
|
|
|
prompts = [prompt] * batch_size |
|
|
negative_prompts = [negative] * batch_size if negative else None |
|
|
|
|
|
|
|
|
with torch.autocast("cuda", enabled=(device == "cuda")): |
|
|
out = pipe( |
|
|
prompt=prompts, |
|
|
negative_prompt=negative_prompts, |
|
|
num_inference_steps=int(steps), |
|
|
guidance_scale=float(guidance), |
|
|
width=int(width), |
|
|
height=int(height), |
|
|
generator=generator, |
|
|
) |
|
|
|
|
|
images = out.images |
|
|
return images |
|
|
|
|
|
|
|
|
with gr.Blocks(css="footer {visibility: hidden}") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Text-to-Image (Diffusers) |
|
|
- **Models:** SD 1.5 and SDXL Turbo |
|
|
- **Tip:** SD 1.5 = better detail on CPU; Turbo = very fast on GPU, fewer steps. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
model_dd = gr.Dropdown( |
|
|
label="Model", |
|
|
choices=list(MODEL_CHOICES.keys()), |
|
|
value=DEFAULT_MODEL_LABEL, |
|
|
) |
|
|
steps = gr.Slider(1, 75, value=30, step=1, label="Steps") |
|
|
guidance = gr.Slider(0.0, 15.0, value=7.5, step=0.1, label="Guidance (CFG)") |
|
|
|
|
|
with gr.Row(): |
|
|
width = gr.Slider(256, 1024, value=768, step=8, label="Width (multiple of 8)") |
|
|
height = gr.Slider(256, 1024, value=768, step=8, label="Height (multiple of 8)") |
|
|
batch_size = gr.Slider(1, 4, value=1, step=1, label="Batch size") |
|
|
|
|
|
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="a cozy cabin at twilight beside a lake, cinematic lighting") |
|
|
negative = gr.Textbox(label="Negative Prompt", lines=1, placeholder="blurry, low quality, distorted") |
|
|
with gr.Row(): |
|
|
seed = gr.Textbox(label="Seed (optional integer)", value="") |
|
|
disable_safety = gr.Checkbox(label="Disable safety checker (you are responsible)", value=DISABLE_SAFETY_DEFAULT) |
|
|
|
|
|
run_btn = gr.Button("Generate", variant="primary") |
|
|
gallery = gr.Gallery(label="Results", columns=2, height=512, preview=True) |
|
|
|
|
|
def _on_change_model(label): |
|
|
|
|
|
if "Turbo" in label: |
|
|
return gr.update(value=4), gr.update(value=0.0) |
|
|
else: |
|
|
return gr.update(value=30), gr.update(value=7.5) |
|
|
|
|
|
model_dd.change(_on_change_model, inputs=model_dd, outputs=[steps, guidance]) |
|
|
|
|
|
run_btn.click( |
|
|
fn=generate, |
|
|
inputs=[prompt, negative, model_dd, steps, guidance, width, height, seed, batch_size, disable_safety], |
|
|
outputs=[gallery], |
|
|
api_name="generate", |
|
|
scroll_to_output=True, |
|
|
concurrency_limit=2, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True) |
|
|
|