Spaces:
Sleeping
Sleeping
File size: 6,061 Bytes
6c41afa 27b8f0d 6c41afa 27b8f0d 6c41afa 27b8f0d 6c41afa 27b8f0d 6c41afa 27b8f0d 6c41afa 27b8f0d 6c41afa 27b8f0d 6c41afa 27b8f0d 6c41afa 27b8f0d 6c41afa 27b8f0d 6c41afa 27b8f0d 6c41afa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | import os
import io
import zipfile
from datetime import datetime
import gradio as gr
from PIL import Image
import torch
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler
# --------- Helper: load model ----------
@torch.inference_mode()
def load_pipeline(model_id: str, torch_dtype=torch.float16, device=None):
pipe = AutoPipelineForText2Image.from_pretrained(
model_id,
torch_dtype=torch_dtype,
use_safetensors=True
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(device)
# small memory tweaks
if device == "cuda":
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
return pipe, device
# Cache of loaded models so switching is fast
_PIPELINES = {}
def get_pipe(model_id: str):
if model_id not in _PIPELINES:
_PIPELINES[model_id], _ = load_pipeline(model_id)
return _PIPELINES[model_id]
# --------- Core generation ----------
def parse_prompts(text: str):
# Split by comma, strip whitespace, drop empties
parts = [p.strip() for p in text.split(",")]
return [p for p in parts if p]
def generate_images(
prompts_text,
negative_prompt,
model_id,
width,
height,
guidance_scale,
num_inference_steps,
batch_per_prompt,
seed
):
prompts = parse_prompts(prompts_text)
if not prompts:
return [], None, "Please enter at least one prompt (use commas to separate)."
pipe = get_pipe(model_id)
# Seeding
if seed is None or str(seed).strip() == "":
generator = torch.Generator(device=pipe.device).manual_seed(torch.seed())
else:
try:
seed_val = int(seed)
except:
seed_val = torch.seed()
generator = torch.Generator(device=pipe.device).manual_seed(seed_val)
all_images = []
names = []
for i, p in enumerate(prompts, start=1):
images = pipe(
prompt=p,
negative_prompt=negative_prompt if negative_prompt else None,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=batch_per_prompt,
generator=generator
).images
# Collect and name
for j, img in enumerate(images, start=1):
all_images.append(img)
safe_prompt = "".join(c for c in p[:40] if c.isalnum() or c in "-_ ").strip().replace(" ", "_")
if not safe_prompt:
safe_prompt = f"prompt_{i}"
names.append(f"{safe_prompt}_{j}.png")
# Build ZIP in-memory
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
for img, name in zip(all_images, names):
bio = io.BytesIO()
img.save(bio, format="PNG")
bio.seek(0)
zf.writestr(name, bio.read())
buf.seek(0)
zip_name = f"images_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"
return all_images, (zip_name, buf), f"Generated {len(all_images)} image(s) from {len(prompts)} prompt(s)."
# --------- UI ----------
CSS = """
.gradio-container {max-width: 1100px !important}
"""
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🖼️ Multi-Prompt Text-to-Image (Hugging Face Space)")
gr.Markdown("Enter **comma-separated prompts** to generate multiple images at once. Choose size, batch count, and download all results as a ZIP.")
with gr.Row():
with gr.Column():
prompts_text = gr.Textbox(
label="Prompts (comma-separated)",
placeholder="A futuristic city at sunset, A cozy cabin in the woods, A portrait of a cyberpunk samurai",
lines=4
)
negative_prompt = gr.Textbox(
label="Negative prompt (optional)",
placeholder="blurry, low quality, distorted"
)
model_id = gr.Dropdown(
label="Model",
value="stabilityai/sdxl-turbo",
choices=[
"stabilityai/sdxl-turbo", # very fast SDXL
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1"
]
)
size = gr.Dropdown(
label="Image Size",
value="1024x1024",
choices=["512x512", "768x768", "1024x1024", "768x1024 (portrait)", "1024x768 (landscape)"]
)
with gr.Row():
guidance_scale = gr.Slider(0.0, 12.0, value=2.0, step=0.5, label="Guidance scale (SDXL-Turbo likes low)")
steps = gr.Slider(2, 50, value=8, step=1, label="Steps")
with gr.Row():
batch_per_prompt = gr.Slider(1, 6, value=2, step=1, label="Images per prompt")
seed = gr.Textbox(label="Seed (optional, integer)")
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
gallery = gr.Gallery(label="Results", show_label=True, columns=3, height=520)
zip_file = gr.File(label="Download all images (.zip)")
status = gr.Markdown("")
def on_size_change(s):
if "x" in s and s.count("x") == 1 and "(" not in s:
w, h = s.split("x")
return int(w), int(h)
if s == "768x1024 (portrait)":
return 768, 1024
if s == "1024x768 (landscape)":
return 1024, 768
return 1024, 1024
width = gr.State(1024)
height = gr.State(1024)
size.change(fn=on_size_change, inputs=size, outputs=[width, height])
run_btn.click(
fn=generate_images,
inputs=[prompts_text, negative_prompt, model_id, width, height, guidance_scale, steps, batch_per_prompt, seed],
outputs=[gallery, zip_file, status]
)
if __name__ == "__main__":
demo.launch()
|