ARIN460-ELE / app.py
etchisone's picture
Upload 2 files
de9947c verified
"""
Text-to-Image Generator
Dataset: rhli/genarena | Model: runwayml/stable-diffusion-v1-5
Deploy on: Hugging Face Spaces (Gradio SDK)
"""
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from datasets import load_dataset
import random
# ── Model setup ───────────────────────────────────────────────────────────────
MODEL_ID = "runwayml/stable-diffusion-v1-5"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"Loading model on {DEVICE} …")
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
safety_checker=None,
requires_safety_checker=False,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(DEVICE)
if DEVICE == "cuda":
pipe.enable_attention_slicing()
print("Model ready βœ“")
# ── Dataset: pull example prompts for the "Surprise me" button ───────────────
PROMPT_COLUMN = "prompt" # adjust if the column name differs
try:
_ds = load_dataset("rhli/genarena", split="train")
DATASET_PROMPTS = [_ds[i][PROMPT_COLUMN] for i in range(min(200, len(_ds)))]
print(f"Loaded {len(DATASET_PROMPTS)} prompts from rhli/genarena βœ“")
except Exception as e:
print(f"Could not load dataset prompts: {e}")
DATASET_PROMPTS = [
"a futuristic city at sunset",
"a cozy cottage in a misty forest",
"a robot painting a watercolor",
"an astronaut on a purple alien planet",
]
# ── Inference function ────────────────────────────────────────────────────────
def generate_image(
prompt: str,
negative_prompt: str,
num_steps: int,
guidance_scale: float,
seed: int,
):
if not prompt.strip():
return None, "⚠️ Please enter a prompt."
generator = torch.Generator(DEVICE).manual_seed(int(seed))
try:
with torch.autocast(DEVICE) if DEVICE == "cuda" else torch.no_grad():
result = pipe(
prompt,
negative_prompt=negative_prompt or None,
num_inference_steps=int(num_steps),
guidance_scale=float(guidance_scale),
generator=generator,
height=512,
width=512,
)
image = result.images[0]
return image, f"βœ… Generated with seed {seed}"
except Exception as e:
return None, f"❌ Error: {e}"
def random_prompt():
"""Pick a random prompt from the dataset."""
return random.choice(DATASET_PROMPTS)
def random_seed():
return random.randint(0, 2**31 - 1)
# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(title="Text-to-Image Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🎨 Text-to-Image Generator
Powered by **Stable Diffusion v1.5** Β· Prompts sourced from [`rhli/genarena`](https://huggingface.co/datasets/rhli/genarena)
"""
)
with gr.Row():
# ── Left column: inputs ──────────────────────────────────────────────
with gr.Column(scale=1):
prompt_box = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate…",
lines=3,
)
surprise_btn = gr.Button("🎲 Surprise me (dataset prompt)", variant="secondary", size="sm")
neg_prompt_box = gr.Textbox(
label="Negative prompt (optional)",
placeholder="blurry, low quality, ugly, distorted",
value="blurry, low quality, ugly, distorted",
lines=2,
)
with gr.Accordion("βš™οΈ Advanced settings", open=False):
steps_slider = gr.Slider(
minimum=10, maximum=50, value=20, step=1,
label="Inference steps (more = slower but better)",
)
guidance_slider = gr.Slider(
minimum=1.0, maximum=20.0, value=7.5, step=0.5,
label="Guidance scale (how closely to follow the prompt)",
)
with gr.Row():
seed_box = gr.Number(label="Seed", value=42, precision=0)
rand_seed_btn = gr.Button("πŸ”€", size="sm")
generate_btn = gr.Button("✨ Generate", variant="primary")
# ── Right column: output ─────────────────────────────────────────────
with gr.Column(scale=1):
output_image = gr.Image(label="Generated image", type="pil")
status_text = gr.Markdown("")
# ── Event wiring ─────────────────────────────────────────────────────────
generate_btn.click(
fn=generate_image,
inputs=[prompt_box, neg_prompt_box, steps_slider, guidance_slider, seed_box],
outputs=[output_image, status_text],
)
surprise_btn.click(fn=random_prompt, outputs=prompt_box)
rand_seed_btn.click(fn=random_seed, outputs=seed_box)
gr.Examples(
examples=[
["a golden sunset over a calm ocean, photorealistic", "blurry, low quality", 20, 7.5, 42],
["a watercolor painting of a Japanese cherry blossom garden", "", 25, 8.0, 7],
["a futuristic robot chef cooking in a neon-lit kitchen", "low quality", 20, 7.5, 99],
["an ancient library filled with glowing magical books", "", 20, 9.0, 12],
],
inputs=[prompt_box, neg_prompt_box, steps_slider, guidance_slider, seed_box],
outputs=[output_image, status_text],
fn=generate_image,
cache_examples=False,
)
gr.Markdown(
"""
---
**Tips**
- Keep prompts descriptive: add style words like *oil painting*, *photorealistic*, *anime*, *watercolor*.
- Use the negative prompt to suppress unwanted features.
- Higher guidance (> 9) locks tightly to the prompt but can look over-saturated.
- The "Surprise me" button pulls real prompts from the `rhli/genarena` dataset.
"""
)
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
demo.launch()