Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import zipfile | |
| from datetime import datetime | |
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler | |
| # --------- Helper: load model ---------- | |
| def load_pipeline(model_id: str, torch_dtype=torch.float16, device=None): | |
| pipe = AutoPipelineForText2Image.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| use_safetensors=True | |
| ) | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe = pipe.to(device) | |
| # small memory tweaks | |
| if device == "cuda": | |
| pipe.enable_attention_slicing() | |
| pipe.enable_vae_slicing() | |
| return pipe, device | |
| # Cache of loaded models so switching is fast | |
| _PIPELINES = {} | |
| def get_pipe(model_id: str): | |
| if model_id not in _PIPELINES: | |
| _PIPELINES[model_id], _ = load_pipeline(model_id) | |
| return _PIPELINES[model_id] | |
| # --------- Core generation ---------- | |
| def parse_prompts(text: str): | |
| # Split by comma, strip whitespace, drop empties | |
| parts = [p.strip() for p in text.split(",")] | |
| return [p for p in parts if p] | |
| def generate_images( | |
| prompts_text, | |
| negative_prompt, | |
| model_id, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| batch_per_prompt, | |
| seed | |
| ): | |
| prompts = parse_prompts(prompts_text) | |
| if not prompts: | |
| return [], None, "Please enter at least one prompt (use commas to separate)." | |
| pipe = get_pipe(model_id) | |
| # Seeding | |
| if seed is None or str(seed).strip() == "": | |
| generator = torch.Generator(device=pipe.device).manual_seed(torch.seed()) | |
| else: | |
| try: | |
| seed_val = int(seed) | |
| except: | |
| seed_val = torch.seed() | |
| generator = torch.Generator(device=pipe.device).manual_seed(seed_val) | |
| all_images = [] | |
| names = [] | |
| for i, p in enumerate(prompts, start=1): | |
| images = pipe( | |
| prompt=p, | |
| negative_prompt=negative_prompt if negative_prompt else None, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| num_images_per_prompt=batch_per_prompt, | |
| generator=generator | |
| ).images | |
| # Collect and name | |
| for j, img in enumerate(images, start=1): | |
| all_images.append(img) | |
| safe_prompt = "".join(c for c in p[:40] if c.isalnum() or c in "-_ ").strip().replace(" ", "_") | |
| if not safe_prompt: | |
| safe_prompt = f"prompt_{i}" | |
| names.append(f"{safe_prompt}_{j}.png") | |
| # Build ZIP in-memory | |
| buf = io.BytesIO() | |
| with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: | |
| for img, name in zip(all_images, names): | |
| bio = io.BytesIO() | |
| img.save(bio, format="PNG") | |
| bio.seek(0) | |
| zf.writestr(name, bio.read()) | |
| buf.seek(0) | |
| zip_name = f"images_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip" | |
| return all_images, (zip_name, buf), f"Generated {len(all_images)} image(s) from {len(prompts)} prompt(s)." | |
| # --------- UI ---------- | |
| CSS = """ | |
| .gradio-container {max-width: 1100px !important} | |
| """ | |
| with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🖼️ Multi-Prompt Text-to-Image (Hugging Face Space)") | |
| gr.Markdown("Enter **comma-separated prompts** to generate multiple images at once. Choose size, batch count, and download all results as a ZIP.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompts_text = gr.Textbox( | |
| label="Prompts (comma-separated)", | |
| placeholder="A futuristic city at sunset, A cozy cabin in the woods, A portrait of a cyberpunk samurai", | |
| lines=4 | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative prompt (optional)", | |
| placeholder="blurry, low quality, distorted" | |
| ) | |
| model_id = gr.Dropdown( | |
| label="Model", | |
| value="stabilityai/sdxl-turbo", | |
| choices=[ | |
| "stabilityai/sdxl-turbo", # very fast SDXL | |
| "runwayml/stable-diffusion-v1-5", | |
| "stabilityai/stable-diffusion-2-1" | |
| ] | |
| ) | |
| size = gr.Dropdown( | |
| label="Image Size", | |
| value="1024x1024", | |
| choices=["512x512", "768x768", "1024x1024", "768x1024 (portrait)", "1024x768 (landscape)"] | |
| ) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider(0.0, 12.0, value=2.0, step=0.5, label="Guidance scale (SDXL-Turbo likes low)") | |
| steps = gr.Slider(2, 50, value=8, step=1, label="Steps") | |
| with gr.Row(): | |
| batch_per_prompt = gr.Slider(1, 6, value=2, step=1, label="Images per prompt") | |
| seed = gr.Textbox(label="Seed (optional, integer)") | |
| run_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| gallery = gr.Gallery(label="Results", show_label=True, columns=3, height=520) | |
| zip_file = gr.File(label="Download all images (.zip)") | |
| status = gr.Markdown("") | |
| def on_size_change(s): | |
| if "x" in s and s.count("x") == 1 and "(" not in s: | |
| w, h = s.split("x") | |
| return int(w), int(h) | |
| if s == "768x1024 (portrait)": | |
| return 768, 1024 | |
| if s == "1024x768 (landscape)": | |
| return 1024, 768 | |
| return 1024, 1024 | |
| width = gr.State(1024) | |
| height = gr.State(1024) | |
| size.change(fn=on_size_change, inputs=size, outputs=[width, height]) | |
| run_btn.click( | |
| fn=generate_images, | |
| inputs=[prompts_text, negative_prompt, model_id, width, height, guidance_scale, steps, batch_per_prompt, seed], | |
| outputs=[gallery, zip_file, status] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |