File size: 7,100 Bytes
de9947c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
Text-to-Image Generator
Dataset: rhli/genarena  |  Model: runwayml/stable-diffusion-v1-5
Deploy on: Hugging Face Spaces (Gradio SDK)
"""

import torch
import gradio as gr
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from datasets import load_dataset
import random

# ── Model setup ───────────────────────────────────────────────────────────────
MODEL_ID = "runwayml/stable-diffusion-v1-5"
DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE    = torch.float16 if torch.cuda.is_available() else torch.float32

print(f"Loading model on {DEVICE} …")
pipe = StableDiffusionPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    safety_checker=None,
    requires_safety_checker=False,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(DEVICE)
if DEVICE == "cuda":
    pipe.enable_attention_slicing()
print("Model ready βœ“")

# ── Dataset: pull example prompts for the "Surprise me" button ───────────────
PROMPT_COLUMN = "prompt"   # adjust if the column name differs
try:
    _ds = load_dataset("rhli/genarena", split="train")
    DATASET_PROMPTS = [_ds[i][PROMPT_COLUMN] for i in range(min(200, len(_ds)))]
    print(f"Loaded {len(DATASET_PROMPTS)} prompts from rhli/genarena βœ“")
except Exception as e:
    print(f"Could not load dataset prompts: {e}")
    DATASET_PROMPTS = [
        "a futuristic city at sunset",
        "a cozy cottage in a misty forest",
        "a robot painting a watercolor",
        "an astronaut on a purple alien planet",
    ]

# ── Inference function ────────────────────────────────────────────────────────
def generate_image(
    prompt: str,
    negative_prompt: str,
    num_steps: int,
    guidance_scale: float,
    seed: int,
):
    if not prompt.strip():
        return None, "⚠️  Please enter a prompt."

    generator = torch.Generator(DEVICE).manual_seed(int(seed))

    try:
        with torch.autocast(DEVICE) if DEVICE == "cuda" else torch.no_grad():
            result = pipe(
                prompt,
                negative_prompt=negative_prompt or None,
                num_inference_steps=int(num_steps),
                guidance_scale=float(guidance_scale),
                generator=generator,
                height=512,
                width=512,
            )
        image = result.images[0]
        return image, f"βœ… Generated with seed {seed}"
    except Exception as e:
        return None, f"❌ Error: {e}"

def random_prompt():
    """Pick a random prompt from the dataset."""
    return random.choice(DATASET_PROMPTS)

def random_seed():
    return random.randint(0, 2**31 - 1)

# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(title="Text-to-Image Generator", theme=gr.themes.Soft()) as demo:

    gr.Markdown(
        """
        # 🎨 Text-to-Image Generator
        Powered by **Stable Diffusion v1.5** Β· Prompts sourced from [`rhli/genarena`](https://huggingface.co/datasets/rhli/genarena)
        """
    )

    with gr.Row():
        # ── Left column: inputs ──────────────────────────────────────────────
        with gr.Column(scale=1):
            prompt_box = gr.Textbox(
                label="Prompt",
                placeholder="Describe the image you want to generate…",
                lines=3,
            )

            surprise_btn = gr.Button("🎲 Surprise me (dataset prompt)", variant="secondary", size="sm")

            neg_prompt_box = gr.Textbox(
                label="Negative prompt (optional)",
                placeholder="blurry, low quality, ugly, distorted",
                value="blurry, low quality, ugly, distorted",
                lines=2,
            )

            with gr.Accordion("βš™οΈ Advanced settings", open=False):
                steps_slider = gr.Slider(
                    minimum=10, maximum=50, value=20, step=1,
                    label="Inference steps (more = slower but better)",
                )
                guidance_slider = gr.Slider(
                    minimum=1.0, maximum=20.0, value=7.5, step=0.5,
                    label="Guidance scale (how closely to follow the prompt)",
                )
                with gr.Row():
                    seed_box = gr.Number(label="Seed", value=42, precision=0)
                    rand_seed_btn = gr.Button("πŸ”€", size="sm")

            generate_btn = gr.Button("✨ Generate", variant="primary")

        # ── Right column: output ─────────────────────────────────────────────
        with gr.Column(scale=1):
            output_image = gr.Image(label="Generated image", type="pil")
            status_text  = gr.Markdown("")

    # ── Event wiring ─────────────────────────────────────────────────────────
    generate_btn.click(
        fn=generate_image,
        inputs=[prompt_box, neg_prompt_box, steps_slider, guidance_slider, seed_box],
        outputs=[output_image, status_text],
    )

    surprise_btn.click(fn=random_prompt, outputs=prompt_box)
    rand_seed_btn.click(fn=random_seed, outputs=seed_box)

    gr.Examples(
        examples=[
            ["a golden sunset over a calm ocean, photorealistic", "blurry, low quality", 20, 7.5, 42],
            ["a watercolor painting of a Japanese cherry blossom garden", "", 25, 8.0, 7],
            ["a futuristic robot chef cooking in a neon-lit kitchen", "low quality", 20, 7.5, 99],
            ["an ancient library filled with glowing magical books", "", 20, 9.0, 12],
        ],
        inputs=[prompt_box, neg_prompt_box, steps_slider, guidance_slider, seed_box],
        outputs=[output_image, status_text],
        fn=generate_image,
        cache_examples=False,
    )

    gr.Markdown(
        """
        ---
        **Tips**
        - Keep prompts descriptive: add style words like *oil painting*, *photorealistic*, *anime*, *watercolor*.
        - Use the negative prompt to suppress unwanted features.
        - Higher guidance (> 9) locks tightly to the prompt but can look over-saturated.
        - The "Surprise me" button pulls real prompts from the `rhli/genarena` dataset.
        """
    )

# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
    demo.launch()