test / app.py
Wasi8's picture
Update app.py
6c41afa verified
import os
import io
import zipfile
from datetime import datetime
import gradio as gr
from PIL import Image
import torch
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler
# --------- Helper: load model ----------
@torch.inference_mode()
def load_pipeline(model_id: str, torch_dtype=torch.float16, device=None):
pipe = AutoPipelineForText2Image.from_pretrained(
model_id,
torch_dtype=torch_dtype,
use_safetensors=True
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(device)
# small memory tweaks
if device == "cuda":
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
return pipe, device
# Cache of loaded models so switching is fast
_PIPELINES = {}
def get_pipe(model_id: str):
if model_id not in _PIPELINES:
_PIPELINES[model_id], _ = load_pipeline(model_id)
return _PIPELINES[model_id]
# --------- Core generation ----------
def parse_prompts(text: str):
# Split by comma, strip whitespace, drop empties
parts = [p.strip() for p in text.split(",")]
return [p for p in parts if p]
def generate_images(
prompts_text,
negative_prompt,
model_id,
width,
height,
guidance_scale,
num_inference_steps,
batch_per_prompt,
seed
):
prompts = parse_prompts(prompts_text)
if not prompts:
return [], None, "Please enter at least one prompt (use commas to separate)."
pipe = get_pipe(model_id)
# Seeding
if seed is None or str(seed).strip() == "":
generator = torch.Generator(device=pipe.device).manual_seed(torch.seed())
else:
try:
seed_val = int(seed)
except:
seed_val = torch.seed()
generator = torch.Generator(device=pipe.device).manual_seed(seed_val)
all_images = []
names = []
for i, p in enumerate(prompts, start=1):
images = pipe(
prompt=p,
negative_prompt=negative_prompt if negative_prompt else None,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=batch_per_prompt,
generator=generator
).images
# Collect and name
for j, img in enumerate(images, start=1):
all_images.append(img)
safe_prompt = "".join(c for c in p[:40] if c.isalnum() or c in "-_ ").strip().replace(" ", "_")
if not safe_prompt:
safe_prompt = f"prompt_{i}"
names.append(f"{safe_prompt}_{j}.png")
# Build ZIP in-memory
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
for img, name in zip(all_images, names):
bio = io.BytesIO()
img.save(bio, format="PNG")
bio.seek(0)
zf.writestr(name, bio.read())
buf.seek(0)
zip_name = f"images_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"
return all_images, (zip_name, buf), f"Generated {len(all_images)} image(s) from {len(prompts)} prompt(s)."
# --------- UI ----------
CSS = """
.gradio-container {max-width: 1100px !important}
"""
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🖼️ Multi-Prompt Text-to-Image (Hugging Face Space)")
gr.Markdown("Enter **comma-separated prompts** to generate multiple images at once. Choose size, batch count, and download all results as a ZIP.")
with gr.Row():
with gr.Column():
prompts_text = gr.Textbox(
label="Prompts (comma-separated)",
placeholder="A futuristic city at sunset, A cozy cabin in the woods, A portrait of a cyberpunk samurai",
lines=4
)
negative_prompt = gr.Textbox(
label="Negative prompt (optional)",
placeholder="blurry, low quality, distorted"
)
model_id = gr.Dropdown(
label="Model",
value="stabilityai/sdxl-turbo",
choices=[
"stabilityai/sdxl-turbo", # very fast SDXL
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1"
]
)
size = gr.Dropdown(
label="Image Size",
value="1024x1024",
choices=["512x512", "768x768", "1024x1024", "768x1024 (portrait)", "1024x768 (landscape)"]
)
with gr.Row():
guidance_scale = gr.Slider(0.0, 12.0, value=2.0, step=0.5, label="Guidance scale (SDXL-Turbo likes low)")
steps = gr.Slider(2, 50, value=8, step=1, label="Steps")
with gr.Row():
batch_per_prompt = gr.Slider(1, 6, value=2, step=1, label="Images per prompt")
seed = gr.Textbox(label="Seed (optional, integer)")
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
gallery = gr.Gallery(label="Results", show_label=True, columns=3, height=520)
zip_file = gr.File(label="Download all images (.zip)")
status = gr.Markdown("")
def on_size_change(s):
if "x" in s and s.count("x") == 1 and "(" not in s:
w, h = s.split("x")
return int(w), int(h)
if s == "768x1024 (portrait)":
return 768, 1024
if s == "1024x768 (landscape)":
return 1024, 768
return 1024, 1024
width = gr.State(1024)
height = gr.State(1024)
size.change(fn=on_size_change, inputs=size, outputs=[width, height])
run_btn.click(
fn=generate_images,
inputs=[prompts_text, negative_prompt, model_id, width, height, guidance_scale, steps, batch_per_prompt, seed],
outputs=[gallery, zip_file, status]
)
if __name__ == "__main__":
demo.launch()