Theloomvale's picture
Update app.py
0371a09 verified
# app.py
import os
import io
import re
import random
import asyncio
from typing import List, Optional, Tuple
from datetime import datetime
import torch
import gradio as gr
from diffusers import (
StableDiffusionPipeline,
StableDiffusionXLPipeline,
)
from huggingface_hub import HfApi
from PIL import Image
# ----------------------
# Constants & Utilities
# ----------------------
DEFAULT_MODELS = {
"Stable Diffusion 1.5 (fastest)": "runwayml/stable-diffusion-v1-5",
"Stable Diffusion XL Base 1.0": "stabilityai/stable-diffusion-xl-base-1.0",
}
# CPU-friendly defaults; auto-updated on model switch.
DEFAULT_W_H = {
"runwayml/stable-diffusion-v1-5": (512, 768),
"stabilityai/stable-diffusion-xl-base-1.0": (768, 1024),
}
SCENE_HEADER = re.compile(r"^\s*Scene\s*\d+\s*[:\-–]", re.IGNORECASE | re.MULTILINE)
PIPELINES = {}
API = HfApi()
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
SPACE_ID = os.environ.get("SPACE_ID") or os.environ.get("SPACE_REPO")
def get_pipeline(model_id: str):
"""Load & cache a pipeline for CPU usage."""
if model_id in PIPELINES:
return PIPELINES[model_id]
dtype = torch.float32 # CPU-safe
if "stable-diffusion-xl" in model_id:
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=dtype)
else:
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe = pipe.to("cpu")
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.safety_checker = None # assuming safe usage/content policy is handled upstream
PIPELINES[model_id] = pipe
return pipe
def split_into_scene_prompts(text: str) -> List[str]:
"""Split input script into up to 5 scene prompts.
- If no explicit Scene headers are found, repeat the whole text to make 5 prompts.
- If fewer than 5 scenes, pad with the last scene.
- If more than 5, truncate to 5.
"""
text = (text or "").strip()
if not text:
return []
headers = list(SCENE_HEADER.finditer(text))
if not headers:
return [text] * 5
ambience = text[: headers[0].start()].strip()
blocks = []
for i, m in enumerate(headers):
start = m.start()
end = headers[i + 1].start() if i + 1 < len(headers) else len(text)
block = text[start:end].strip()
blocks.append(block)
if len(blocks) < 5 and blocks:
blocks += [blocks[-1]] * (5 - len(blocks))
elif len(blocks) > 5:
blocks = blocks[:5]
if ambience:
blocks = [f"{ambience}\n\n{b}" for b in blocks]
return blocks
def clamp_size(model_id: str, width: int, height: int) -> Tuple[int, int]:
"""Keep sizes reasonable for CPU and aligned to multiples of 8."""
w, h = int(width), int(height)
w -= w % 8
h -= h % 8
if "stable-diffusion-xl" in model_id:
# SDXL works best with longer edge >= ~768; constrain for CPU
w = max(640, min(w, 1152))
h = max(640, min(h, 1152))
else:
# SD 1.5 sweet spot; keep safe caps for CPU
w = max(384, min(w, 896))
h = max(384, min(h, 1152))
return w, h
def _seed_everything(seed: Optional[int]):
if seed is None or seed < 0:
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator(device="cpu").manual_seed(seed)
return seed, generator
def _generate_one(
prompt: str,
negative_prompt: str,
model_id: str,
width: int,
height: int,
steps: int,
guidance: float,
seed: int,
) -> Image.Image:
seed, generator = _seed_everything(seed)
pipe = get_pipeline(model_id)
with torch.inference_mode():
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt or None,
width=width,
height=height,
num_inference_steps=steps,
guidance_scale=guidance,
generator=generator,
).images[0]
return image
async def _generate_one_async(**kwargs) -> Image.Image:
return await asyncio.to_thread(_generate_one, **kwargs)
async def generate_per_scene(
script_text: str,
negative_prompt: str,
model_id: str,
width: int,
height: int,
steps: int,
guidance: float,
seed: int,
):
"""Sequential generation (CPU-friendly) with progress feedback."""
prompts = split_into_scene_prompts(script_text)
if not prompts:
raise gr.Error("Please enter a prompt or scene script.")
images: List[Image.Image] = []
total = len(prompts)
progress = gr.Progress(track_tqdm=True)
for i, p in enumerate(prompts, start=1):
progress(i / total, desc=f"Generating scene {i}/{total}")
try:
img = await _generate_one_async(
prompt=p,
negative_prompt=negative_prompt,
model_id=model_id,
width=width,
height=height,
steps=steps,
guidance=guidance,
seed=seed + (i - 1) if seed >= 0 else seed,
)
except Exception as e:
print(f"[error] scene {i} failed:", e)
img = Image.new("RGB", (width, height), color=(220, 220, 220))
images.append(img)
return images
def _save_images_to_repo(imgs: List[Image.Image], subdir: str = "outputs") -> List[str]:
"""Save to the Space repo if HF_TOKEN & SPACE_ID are set. Returns repo paths."""
if not (HF_TOKEN and SPACE_ID):
return []
ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
paths = []
for idx, img in enumerate(imgs, start=1):
buf = io.BytesIO()
img.save(buf, format="PNG")
buf.seek(0)
remote_path = f"{subdir}/{ts}_scene{idx}.png"
API.upload_file(
path_or_fileobj=buf,
path_in_repo=remote_path,
repo_id=SPACE_ID,
repo_type="space",
)
paths.append(remote_path)
return paths
def validate_inputs(script_text: str, steps: int, guidance: float):
if not script_text or not script_text.strip():
raise gr.Error("Please enter a prompt or scene script.")
if not (10 <= int(steps) <= 60):
raise gr.Error("Steps must be between 10 and 60.")
if not (1.0 <= float(guidance) <= 12.0):
raise gr.Error("Guidance must be between 1.0 and 12.0.")
with gr.Blocks(title="Loomvale Image Lab — CPU") as demo:
gr.Markdown("""
# Loomvale Image Lab — CPU
Enter a single prompt or a multi-scene script using headings like **Scene 1: ...**, **Scene 2: ...**.
The app will generate up to **5** images (padding/truncating as needed).
""")
with gr.Row():
model = gr.Dropdown(
label="Model",
choices=list(DEFAULT_MODELS.keys()),
value="Stable Diffusion 1.5 (fastest)",
)
model_id_state = gr.State(DEFAULT_MODELS["Stable Diffusion 1.5 (fastest)"])
script = gr.Textbox(
label="Prompt or Multi-Scene Script",
lines=6,
placeholder=(
"Optional ambience on top...\n\n"
"Scene 1: A cozy studio filled with soft morning light\n"
"Scene 2: A minimalist desk with a steaming cup of tea\n"
"Scene 3: ..."
),
)
negative = gr.Textbox(
label="Negative Prompt (optional)",
placeholder="blurry, low quality, watermark, text, nsfw",
value="blurry, low quality, watermark, text, worst quality, lowres",
)
w = gr.Slider(384, 1024, value=512, step=8, label="Width")
h = gr.Slider(512, 1280, value=768, step=8, label="Height")
steps = gr.Slider(10, 60, value=28, step=1, label="Steps")
guidance = gr.Slider(1.0, 12.0, value=7.0, step=0.1, label="Guidance Scale")
seed = gr.Number(value=-1, label="Seed (-1 = random)")
can_save = bool(HF_TOKEN and SPACE_ID)
save_to_repo = gr.Checkbox(
label=f"Save generated images to this Space repo ({SPACE_ID})",
value=can_save,
interactive=can_save,
visible=True,
)
btn = gr.Button("Generate Images", variant="primary")
btn_clear = gr.Button("Clear")
gallery = gr.Gallery(label="Images", columns=5, rows=1, height="auto", allow_preview=True)
gallery.style(grid=5, preview=True, object_fit="contain") # keep layout tidy
status = gr.Markdown(visible=True)
# Examples for quick testing
gr.Examples(
examples=[
["Ambient: gentle morning light\n\nScene 1: pastel living room\nScene 2: sunlight on linen curtains\nScene 3: ceramic mug on wooden table"],
["Scene 1: cyberpunk alley, neon reflections\nScene 2: rooftop garden at dusk\nScene 3: rainy crosswalk with umbrellas"],
],
inputs=[script],
label="Examples",
)
def _sync_model_choice(choice):
mid = DEFAULT_MODELS[choice]
base_w, base_h = DEFAULT_W_H[mid]
return mid, gr.update(value=base_w), gr.update(value=base_h)
model.change(_sync_model_choice, inputs=model, outputs=[model_id_state, w, h])
async def _on_click(
script_text, negative_prompt, _model_choice, _model_id, width, height, steps_, guidance_, seed_, save_flag
):
validate_inputs(script_text, steps_, guidance_)
w_clamped, h_clamped = clamp_size(_model_id, int(width), int(height))
imgs = await generate_per_scene(
script_text=script_text,
negative_prompt=negative_prompt,
model_id=_model_id,
width=w_clamped,
height=h_clamped,
steps=int(steps_),
guidance=float(guidance_),
seed=int(seed_),
)
msg = f"✅ Generated {len(imgs)} image(s) at {w_clamped}×{h_clamped}."
links = []
if save_flag:
try:
links = _save_images_to_repo(imgs)
if links:
saved_list = "\n".join(f"- {p}" for p in links)
msg += f"\nSaved:\n{saved_list}"
else:
msg += "\nℹ️ Skipped saving (token/repo not configured)."
except Exception as e:
print("[save_error]", e)
msg += "\n⚠️ Save failed (see logs)."
return imgs, msg
btn.click(
_on_click,
inputs=[script, negative, model, model_id_state, w, h, steps, guidance, seed, save_to_repo],
outputs=[gallery, status],
concurrency_limit=1,
)
def _on_clear():
return None, ""
btn_clear.click(_on_clear, outputs=[gallery, status])
if __name__ == "__main__":
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))