| import spaces |
| from diffusers import ( |
| StableDiffusionXLPipeline, |
| EulerDiscreteScheduler, |
| UNet2DConditionModel, |
| AutoencoderTiny, |
| ) |
| import torch |
| import os |
| from huggingface_hub import hf_hub_download |
| from compel import Compel, ReturnedEmbeddingsType |
| from gradio_promptweighting import PromptWeighting |
|
|
| from PIL import Image |
| import gradio as gr |
| import time |
| from safetensors.torch import load_file |
| import tempfile |
| from pathlib import Path |
| import openai |
|
|
| |
| BASE = "stabilityai/stable-diffusion-xl-base-1.0" |
| REPO = "ByteDance/SDXL-Lightning" |
| CHECKPOINT = "sdxl_lightning_2step_unet.safetensors" |
| taesd_model = "madebyollin/taesdxl" |
|
|
| SFAST_COMPILE = os.environ.get("SFAST_COMPILE", "0") == "1" |
| SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1" |
| USE_TAESD = os.environ.get("USE_TAESD", "0") == "1" |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| torch_device = device |
| torch_dtype = torch.float16 |
|
|
| print(f"SAFETY_CHECKER: {SAFETY_CHECKER}") |
| print(f"SFAST_COMPILE: {SFAST_COMPILE}") |
| print(f"USE_TAESD: {USE_TAESD}") |
| print(f"device: {device}") |
|
|
| unet = UNet2DConditionModel.from_config(BASE, subfolder="unet").to( |
| "cuda", torch.float16 |
| ) |
| unet.load_state_dict(load_file(hf_hub_download(REPO, CHECKPOINT), device="cuda")) |
| pipe = StableDiffusionXLPipeline.from_pretrained( |
| BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False |
| ).to("cuda") |
| unet = unet.to(dtype=torch.float16) |
|
|
| compel = Compel( |
| tokenizer=[pipe.tokenizer, pipe.tokenizer_2], |
| text_encoder=[pipe.text_encoder, pipe.text_encoder_2], |
| returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, |
| requires_pooled=[False, True], |
| ) |
|
|
| if USE_TAESD: |
| pipe.vae = AutoencoderTiny.from_pretrained( |
| taesd_model, torch_dtype=torch_dtype, use_safetensors=True |
| ).to(device) |
|
|
| pipe.scheduler = EulerDiscreteScheduler.from_config( |
| pipe.scheduler.config, timestep_spacing="trailing" |
| ) |
| pipe.set_progress_bar_config(disable=True) |
|
|
| if SAFETY_CHECKER: |
| from safety_checker import StableDiffusionSafetyChecker |
| from transformers import CLIPFeatureExtractor |
|
|
| safety_checker = StableDiffusionSafetyChecker.from_pretrained( |
| "CompVis/stable-diffusion-safety-checker" |
| ).to(device) |
| feature_extractor = CLIPFeatureExtractor.from_pretrained( |
| "openai/clip-vit-base-patch32" |
| ) |
|
|
| def check_nsfw_images( |
| images: list[Image.Image], |
| ) -> tuple[list[Image.Image], list[bool]]: |
| safety_checker_input = feature_extractor(images, return_tensors="pt").to(device) |
| has_nsfw_concepts = safety_checker( |
| images=[images], |
| clip_input=safety_checker_input.pixel_values.to(torch_device), |
| ) |
| return images, has_nsfw_concepts |
|
|
| if SFAST_COMPILE: |
| from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig |
|
|
| config = CompilationConfig.Default() |
| try: |
| import xformers |
| config.enable_xformers = True |
| except ImportError: |
| print("xformers not installed, skip") |
| try: |
| import triton |
| config.enable_triton = True |
| except ImportError: |
| print("Triton not installed, skip") |
| config.enable_cuda_graph = True |
|
|
| pipe = compile(pipe, config) |
|
|
|
|
| |
| import requests |
|
|
| def generate_ai_prompt(base_prompt): |
| try: |
| api_key = os.getenv('GROQ_API_KEY') |
| |
| headers = { |
| "Authorization": f"Bearer {api_key}", |
| "Content-Type": "application/json" |
| } |
| |
| payload = { |
| "messages": [ |
| { |
| "role": "system", |
| "content": "You are a helpful assistant that generates detailed and crisp image prompts in under 50 words. Create vivid, specific descriptions that would work well with image generation AI. Focus on visual details, style, lighting, and composition." |
| }, |
| { |
| "role": "user", |
| "content": f"Generate a detailed image prompt based on: {base_prompt}" |
| } |
| ], |
| "model": "mixtral-8x7b-32768", |
| "temperature": 0.7, |
| "max_tokens": 150 |
| } |
| |
| response = requests.post( |
| "https://api.groq.com/openai/v1/chat/completions", |
| headers=headers, |
| json=payload, |
| timeout=30 |
| ) |
| response.raise_for_status() |
| |
| result = response.json() |
| return result["choices"][0]["message"]["content"].strip() |
| |
| except Exception as e: |
| print(f"Error generating AI prompt: {e}") |
| return base_prompt |
|
|
| @spaces.GPU |
| def predict(prompt, prompt_w, seed=1231231, use_ai_prompt=False): |
| guidance_scale = 0.5 |
| generated_prompt = "" |
| if use_ai_prompt: |
| generated_prompt = generate_ai_prompt(prompt) |
| prompt = generated_prompt |
| print(f"AI-generated prompt: {prompt}") |
| |
| generator = torch.manual_seed(seed) |
| last_time = time.time() |
| prompt_w = " ".join( |
| [f"({p['prompt']}){p['scale']}" for p in prompt_w if p["prompt"]] |
| ) |
|
|
| conditioning, pooled = compel([prompt + " " + prompt_w, ""]) |
|
|
| results = pipe( |
| prompt_embeds=conditioning[0:1], |
| pooled_prompt_embeds=pooled[0:1], |
| negative_prompt_embeds=conditioning[1:2], |
| negative_pooled_prompt_embeds=pooled[1:2], |
| generator=generator, |
| num_inference_steps=2, |
| guidance_scale=guidance_scale, |
| output_type="pil", |
| ) |
| print(f"Pipe took {time.time() - last_time} seconds") |
| if SAFETY_CHECKER: |
| images, has_nsfw_concepts = check_nsfw_images(results.images) |
| if any(has_nsfw_concepts): |
| gr.Warning("NSFW content detected.") |
| return Image.new("RGB", (512, 512)), generated_prompt |
| image = results.images[0] |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmpfile: |
| image.save(tmpfile, "JPEG", quality=80, optimize=True, progressive=True) |
| return Path(tmpfile.name), generated_prompt |
|
|
| LOGO_PATH = "logo.png" |
|
|
| css = """ |
| #container { |
| margin: 0 auto; |
| max-width: 70rem; |
| padding: 2rem; |
| background-color: #f9f9f9; |
| border: 1px solid #e6e6e6; |
| border-radius: 10px; |
| box-shadow: 0 4px 10px rgba(0, 0, 0, 0.1); |
| } |
| #intro { |
| margin-bottom: 2rem; |
| } |
| #prompt { |
| border: 1px solid #ddd; |
| border-radius: 5px; |
| padding: 0.5rem; |
| } |
| #generate-button { |
| background-color: #007bff; |
| color: white; |
| border-radius: 5px; |
| padding: 0.8rem; |
| width: 100%; |
| border: none; |
| font-size: 1rem; |
| transition: all 0.3s ease-in-out; |
| } |
| #generate-button:hover { |
| background-color: #0056b3; |
| cursor: pointer; |
| } |
| #output-image { |
| max-height: 400px; |
| border: 1px solid #ddd; |
| border-radius: 5px; |
| padding: 0.5rem; |
| } |
| """ |
|
|
| with gr.Blocks(css=css) as demo: |
| with gr.Column(elem_id="container"): |
| |
| gr.Markdown( |
| """ |
| <style> |
| body { |
| background: linear-gradient(135deg, #89CFF0, #6A5ACD); |
| font-family: Arial, sans-serif; |
| } |
| h1 { |
| color: #fff; |
| text-align: center; |
| margin-top: 20px; |
| font-size: 2.5rem; |
| } |
| p { |
| color: #ddd; |
| text-align: center; |
| font-size: 1.2rem; |
| margin-bottom: 20px; |
| } |
| .gr-row { |
| justify-content: center; |
| padding: 20px; |
| } |
| .gr-textbox, .gr-slider, .gr-checkbox { |
| background-color: #f8f9fa; |
| border-radius: 8px; |
| border: 1px solid #ddd; |
| box-shadow: 0px 2px 5px rgba(0, 0, 0, 0.1); |
| padding: 10px; |
| margin-bottom: 10px; |
| } |
| .gr-image { |
| border-radius: 10px; |
| box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2); |
| } |
| .gr-checkbox span { |
| font-weight: bold; |
| color: #fff; |
| } |
| </style> |
| <h1>Info-TypeToArt</h1> |
| <p>Type a creative prompt below and watch images come to life in real time!</p> |
| """, |
| elem_id="intro", |
| ) |
| with gr.Row(): |
| with gr.Column(): |
| with gr.Group(): |
| prompt = gr.Textbox( |
| placeholder="Insert your prompt here:", |
| max_lines=1, |
| label="Prompt", |
| ) |
| use_ai_prompt = gr.Checkbox(label="Generate AI Prompt") |
| generated_prompt_display = gr.Textbox( |
| label="AI-Generated Prompt", |
| interactive=False, |
| ) |
| prompt_w = PromptWeighting( |
| min=0, |
| max=3, |
| step=0.005, |
| show_label=False, |
| info="Drag up and down to adjust the weight of each prompt.", |
| ) |
|
|
| with gr.Accordion("Advanced options", open=True): |
| seed = gr.Slider( |
| minimum=0, |
| maximum=12013012031030, |
| label="Seed", |
| step=1, |
| ) |
| |
| |
| |
| with gr.Column(): |
| image = gr.Image(type="filepath") |
|
|
| inputs = [ |
| prompt, |
| prompt_w, |
| seed, |
| use_ai_prompt, |
| ] |
| outputs = [image, generated_prompt_display] |
|
|
| gr.on( |
| triggers=[ |
| prompt.input, |
| prompt_w.input, |
| |
| seed.input, |
| use_ai_prompt.change, |
| ], |
| fn=predict, |
| inputs=inputs, |
| outputs=outputs, |
| show_progress="hidden", |
| show_api=False, |
| trigger_mode="always_last", |
| ) |
|
|
| demo.queue(api_open=False) |
| demo.launch() |