munio's picture
removed logo
344aa5f verified
import spaces
from diffusers import (
StableDiffusionXLPipeline,
EulerDiscreteScheduler,
UNet2DConditionModel,
AutoencoderTiny,
)
import torch
import os
from huggingface_hub import hf_hub_download
from compel import Compel, ReturnedEmbeddingsType
from gradio_promptweighting import PromptWeighting
from PIL import Image
import gradio as gr
import time
from safetensors.torch import load_file
import tempfile
from pathlib import Path
import openai
# Constants
BASE = "stabilityai/stable-diffusion-xl-base-1.0"
REPO = "ByteDance/SDXL-Lightning"
CHECKPOINT = "sdxl_lightning_2step_unet.safetensors"
taesd_model = "madebyollin/taesdxl"
SFAST_COMPILE = os.environ.get("SFAST_COMPILE", "0") == "1"
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
USE_TAESD = os.environ.get("USE_TAESD", "0") == "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_device = device
torch_dtype = torch.float16
print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
print(f"SFAST_COMPILE: {SFAST_COMPILE}")
print(f"USE_TAESD: {USE_TAESD}")
print(f"device: {device}")
unet = UNet2DConditionModel.from_config(BASE, subfolder="unet").to(
"cuda", torch.float16
)
unet.load_state_dict(load_file(hf_hub_download(REPO, CHECKPOINT), device="cuda"))
pipe = StableDiffusionXLPipeline.from_pretrained(
BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False
).to("cuda")
unet = unet.to(dtype=torch.float16)
compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True],
)
if USE_TAESD:
pipe.vae = AutoencoderTiny.from_pretrained(
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
pipe.set_progress_bar_config(disable=True)
if SAFETY_CHECKER:
from safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to(device)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
def check_nsfw_images(
images: list[Image.Image],
) -> tuple[list[Image.Image], list[bool]]:
safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
has_nsfw_concepts = safety_checker(
images=[images],
clip_input=safety_checker_input.pixel_values.to(torch_device),
)
return images, has_nsfw_concepts
if SFAST_COMPILE:
from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig
config = CompilationConfig.Default()
try:
import xformers
config.enable_xformers = True
except ImportError:
print("xformers not installed, skip")
try:
import triton
config.enable_triton = True
except ImportError:
print("Triton not installed, skip")
config.enable_cuda_graph = True
pipe = compile(pipe, config)
# AI Prompt setup
import requests
def generate_ai_prompt(base_prompt):
try:
api_key = os.getenv('GROQ_API_KEY')
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"messages": [
{
"role": "system",
"content": "You are a helpful assistant that generates detailed and crisp image prompts in under 50 words. Create vivid, specific descriptions that would work well with image generation AI. Focus on visual details, style, lighting, and composition."
},
{
"role": "user",
"content": f"Generate a detailed image prompt based on: {base_prompt}"
}
],
"model": "mixtral-8x7b-32768",
"temperature": 0.7,
"max_tokens": 150
}
response = requests.post(
"https://api.groq.com/openai/v1/chat/completions",
headers=headers,
json=payload,
timeout=30
)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"].strip()
except Exception as e:
print(f"Error generating AI prompt: {e}")
return base_prompt # Return the original prompt if there's an error
@spaces.GPU
def predict(prompt, prompt_w, seed=1231231, use_ai_prompt=False):
guidance_scale = 0.5
generated_prompt = ""
if use_ai_prompt:
generated_prompt = generate_ai_prompt(prompt)
prompt = generated_prompt
print(f"AI-generated prompt: {prompt}")
generator = torch.manual_seed(seed)
last_time = time.time()
prompt_w = " ".join(
[f"({p['prompt']}){p['scale']}" for p in prompt_w if p["prompt"]]
)
conditioning, pooled = compel([prompt + " " + prompt_w, ""])
results = pipe(
prompt_embeds=conditioning[0:1],
pooled_prompt_embeds=pooled[0:1],
negative_prompt_embeds=conditioning[1:2],
negative_pooled_prompt_embeds=pooled[1:2],
generator=generator,
num_inference_steps=2,
guidance_scale=guidance_scale,
output_type="pil",
)
print(f"Pipe took {time.time() - last_time} seconds")
if SAFETY_CHECKER:
images, has_nsfw_concepts = check_nsfw_images(results.images)
if any(has_nsfw_concepts):
gr.Warning("NSFW content detected.")
return Image.new("RGB", (512, 512)), generated_prompt
image = results.images[0]
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmpfile:
image.save(tmpfile, "JPEG", quality=80, optimize=True, progressive=True)
return Path(tmpfile.name), generated_prompt
LOGO_PATH = "logo.png"
css = """
#container {
margin: 0 auto;
max-width: 70rem;
padding: 2rem;
background-color: #f9f9f9;
border: 1px solid #e6e6e6;
border-radius: 10px;
box-shadow: 0 4px 10px rgba(0, 0, 0, 0.1);
}
#intro {
margin-bottom: 2rem;
}
#prompt {
border: 1px solid #ddd;
border-radius: 5px;
padding: 0.5rem;
}
#generate-button {
background-color: #007bff;
color: white;
border-radius: 5px;
padding: 0.8rem;
width: 100%;
border: none;
font-size: 1rem;
transition: all 0.3s ease-in-out;
}
#generate-button:hover {
background-color: #0056b3;
cursor: pointer;
}
#output-image {
max-height: 400px;
border: 1px solid #ddd;
border-radius: 5px;
padding: 0.5rem;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="container"):
# gr.Image(LOGO_PATH, elem_id="logo", show_label=False)
gr.Markdown(
"""
<style>
body {
background: linear-gradient(135deg, #89CFF0, #6A5ACD);
font-family: Arial, sans-serif;
}
h1 {
color: #fff;
text-align: center;
margin-top: 20px;
font-size: 2.5rem;
}
p {
color: #ddd;
text-align: center;
font-size: 1.2rem;
margin-bottom: 20px;
}
.gr-row {
justify-content: center;
padding: 20px;
}
.gr-textbox, .gr-slider, .gr-checkbox {
background-color: #f8f9fa;
border-radius: 8px;
border: 1px solid #ddd;
box-shadow: 0px 2px 5px rgba(0, 0, 0, 0.1);
padding: 10px;
margin-bottom: 10px;
}
.gr-image {
border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
}
.gr-checkbox span {
font-weight: bold;
color: #fff;
}
</style>
<h1>Info-TypeToArt</h1>
<p>Type a creative prompt below and watch images come to life in real time!</p>
""",
elem_id="intro",
)
with gr.Row():
with gr.Column():
with gr.Group():
prompt = gr.Textbox(
placeholder="Insert your prompt here:",
max_lines=1,
label="Prompt",
)
use_ai_prompt = gr.Checkbox(label="Generate AI Prompt")
generated_prompt_display = gr.Textbox(
label="AI-Generated Prompt",
interactive=False,
)
prompt_w = PromptWeighting(
min=0,
max=3,
step=0.005,
show_label=False,
info="Drag up and down to adjust the weight of each prompt.",
)
with gr.Accordion("Advanced options", open=True):
seed = gr.Slider(
minimum=0,
maximum=12013012031030,
label="Seed",
step=1,
)
# generate_bt = gr.Button("Generate")
with gr.Column():
image = gr.Image(type="filepath")
inputs = [
prompt,
prompt_w,
seed,
use_ai_prompt,
]
outputs = [image, generated_prompt_display]
gr.on(
triggers=[
prompt.input,
prompt_w.input,
# generate_bt.click,
seed.input,
use_ai_prompt.change,
],
fn=predict,
inputs=inputs,
outputs=outputs,
show_progress="hidden",
show_api=False,
trigger_mode="always_last",
)
demo.queue(api_open=False)
demo.launch()