crenshaw / app.py
lainlives's picture
~
cdc565e
#import spaces
import random
import gradio as gr
import numpy as np
import PIL.Image
import torch
from compel import Compel, ReturnedEmbeddingsType
from diffusers import (EulerAncestralDiscreteScheduler,
StableDiffusionXLPipeline)
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEFAULT_BATCH_SIZE = 2
MAX_BATCH_SIZE = 12
# Make sure to use torch.float16 consistently throughout the pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
"dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
torch_dtype=torch.float16,
variant="fp16", # Explicitly use fp16 variant
use_safetensors=True # Use safetensors if available
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(device)
# Force all components to use the same dtype
pipe.text_encoder.to(torch.float16)
pipe.text_encoder_2.to(torch.float16)
pipe.vae.to(torch.float16)
pipe.unet.to(torch.float16)
# 追加: Initialize Compel for long prompt processing
compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True],
truncate_long_prompts=False
)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2000
# 追加: Simple long prompt processing function
def process_long_prompt(prompt, negative_prompt=""):
try:
conditioning, pooled = compel([prompt, negative_prompt])
return conditioning, pooled
except Exception as e:
print(f"Long prompt processing failed: {e}")
return None, None
def generate_image(
prompt,
negative_prompt,
generator,
width,
height,
guidance_scale,
num_inference_steps,
):
return pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
#@spaces.GPU
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
batch_size,
width,
height,
guidance_scale,
num_inference_steps,
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
prompts = [prompt] * batch_size
negative_prompts = [negative_prompt] * batch_size
generators = [
torch.Generator(device=device).manual_seed(seed - i)
for i in range(batch_size)
]
try:
images = pipe(
prompt=prompts,
negative_prompt=negative_prompts,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generators,
).images
return images
except RuntimeError as e:
print(f"Generation error: {e}")
return [Image.new("RGB", (width, height), (0, 0, 0))] * batch_size
css = """
#col-container {
margin: 0 auto;
max-width: 2000px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=5,
placeholder="Enter a negative prompt",
value=" ((unimaginative, dry, flat, static, stiff, uninspired), bad quality overexposed, too bright, washed out, high exposure, low resolution, artifact, compression artifacts, low poly, blocky, banding, color bleed, texture seams, oversaturation, fused fingers, malformed eyes, missing iris sclera, poorly drawn background)"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=20,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=600,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=800,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=30.0,
step=0.1,
value=7,
)
batch_size = gr.Slider(
label="Batch size",
minimum=1,
maximum=MAX_BATCH_SIZE,
step=1,
value=DEFAULT_BATCH_SIZE,
)
with gr.Row():
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=99,
step=1,
value=20,
)
with gr.Column(elem_id="col-container"):
with gr.Row():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=16,
placeholder="Enter your prompt (long prompts are automatically supported)",
value="Show me a glorious mountain range covered in colorful crystaline trees with a laser disco show from space.",
container=False,
)
results = [
gr.Image(format="png", show_label=False)
for _ in range(MAX_BATCH_SIZE)
]
run_button = gr.Button("Run", scale=0)
run_button.click(
concurrency_limit=25,
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
batch_size,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=results,
)
with gr.Column():
for _ in range(5):
with gr.Row():
for _ in range(2):
results.append(gr.Image(format="png", show_label=False, visible=True))
demo.queue(default_concurrency_limit=25).launch()