File size: 6,061 Bytes
6c41afa
 
 
 
 
27b8f0d
6c41afa
27b8f0d
 
6c41afa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27b8f0d
6c41afa
27b8f0d
 
 
 
6c41afa
 
27b8f0d
6c41afa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27b8f0d
 
6c41afa
 
 
27b8f0d
6c41afa
 
 
 
 
 
27b8f0d
6c41afa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27b8f0d
 
 
6c41afa
 
27b8f0d
 
6c41afa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27b8f0d
 
 
 
6c41afa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import io
import zipfile
from datetime import datetime

import gradio as gr
from PIL import Image

import torch
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler

# --------- Helper: load model ----------
@torch.inference_mode()
def load_pipeline(model_id: str, torch_dtype=torch.float16, device=None):
    pipe = AutoPipelineForText2Image.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        use_safetensors=True
    )
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe = pipe.to(device)
    # small memory tweaks
    if device == "cuda":
        pipe.enable_attention_slicing()
        pipe.enable_vae_slicing()
    return pipe, device

# Cache of loaded models so switching is fast
_PIPELINES = {}

def get_pipe(model_id: str):
    if model_id not in _PIPELINES:
        _PIPELINES[model_id], _ = load_pipeline(model_id)
    return _PIPELINES[model_id]

# --------- Core generation ----------
def parse_prompts(text: str):
    # Split by comma, strip whitespace, drop empties
    parts = [p.strip() for p in text.split(",")]
    return [p for p in parts if p]

def generate_images(
    prompts_text,
    negative_prompt,
    model_id,
    width,
    height,
    guidance_scale,
    num_inference_steps,
    batch_per_prompt,
    seed
):
    prompts = parse_prompts(prompts_text)
    if not prompts:
        return [], None, "Please enter at least one prompt (use commas to separate)."

    pipe = get_pipe(model_id)

    # Seeding
    if seed is None or str(seed).strip() == "":
        generator = torch.Generator(device=pipe.device).manual_seed(torch.seed())
    else:
        try:
            seed_val = int(seed)
        except:
            seed_val = torch.seed()
        generator = torch.Generator(device=pipe.device).manual_seed(seed_val)

    all_images = []
    names = []
    for i, p in enumerate(prompts, start=1):
        images = pipe(
            prompt=p,
            negative_prompt=negative_prompt if negative_prompt else None,
            width=width,
            height=height,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            num_images_per_prompt=batch_per_prompt,
            generator=generator
        ).images

        # Collect and name
        for j, img in enumerate(images, start=1):
            all_images.append(img)
            safe_prompt = "".join(c for c in p[:40] if c.isalnum() or c in "-_ ").strip().replace(" ", "_")
            if not safe_prompt:
                safe_prompt = f"prompt_{i}"
            names.append(f"{safe_prompt}_{j}.png")

    # Build ZIP in-memory
    buf = io.BytesIO()
    with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
        for img, name in zip(all_images, names):
            bio = io.BytesIO()
            img.save(bio, format="PNG")
            bio.seek(0)
            zf.writestr(name, bio.read())
    buf.seek(0)

    zip_name = f"images_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"
    return all_images, (zip_name, buf), f"Generated {len(all_images)} image(s) from {len(prompts)} prompt(s)."

# --------- UI ----------
CSS = """
.gradio-container {max-width: 1100px !important}
"""

with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🖼️ Multi-Prompt Text-to-Image (Hugging Face Space)")
    gr.Markdown("Enter **comma-separated prompts** to generate multiple images at once. Choose size, batch count, and download all results as a ZIP.")

    with gr.Row():
        with gr.Column():
            prompts_text = gr.Textbox(
                label="Prompts (comma-separated)",
                placeholder="A futuristic city at sunset, A cozy cabin in the woods, A portrait of a cyberpunk samurai",
                lines=4
            )
            negative_prompt = gr.Textbox(
                label="Negative prompt (optional)",
                placeholder="blurry, low quality, distorted"
            )
            model_id = gr.Dropdown(
                label="Model",
                value="stabilityai/sdxl-turbo",
                choices=[
                    "stabilityai/sdxl-turbo",       # very fast SDXL
                    "runwayml/stable-diffusion-v1-5",
                    "stabilityai/stable-diffusion-2-1"
                ]
            )
            size = gr.Dropdown(
                label="Image Size",
                value="1024x1024",
                choices=["512x512", "768x768", "1024x1024", "768x1024 (portrait)", "1024x768 (landscape)"]
            )

            with gr.Row():
                guidance_scale = gr.Slider(0.0, 12.0, value=2.0, step=0.5, label="Guidance scale (SDXL-Turbo likes low)")
                steps = gr.Slider(2, 50, value=8, step=1, label="Steps")

            with gr.Row():
                batch_per_prompt = gr.Slider(1, 6, value=2, step=1, label="Images per prompt")
                seed = gr.Textbox(label="Seed (optional, integer)")

            run_btn = gr.Button("Generate", variant="primary")

        with gr.Column():
            gallery = gr.Gallery(label="Results", show_label=True, columns=3, height=520)
            zip_file = gr.File(label="Download all images (.zip)")
            status = gr.Markdown("")

    def on_size_change(s):
        if "x" in s and s.count("x") == 1 and "(" not in s:
            w, h = s.split("x")
            return int(w), int(h)
        if s == "768x1024 (portrait)":
            return 768, 1024
        if s == "1024x768 (landscape)":
            return 1024, 768
        return 1024, 1024

    width = gr.State(1024)
    height = gr.State(1024)
    size.change(fn=on_size_change, inputs=size, outputs=[width, height])

    run_btn.click(
        fn=generate_images,
        inputs=[prompts_text, negative_prompt, model_id, width, height, guidance_scale, steps, batch_per_prompt, seed],
        outputs=[gallery, zip_file, status]
    )

if __name__ == "__main__":
    demo.launch()