Spaces:
Configuration error
Configuration error
| # app.py | |
| import os | |
| import io | |
| import re | |
| import random | |
| import asyncio | |
| from typing import List, Optional, Tuple | |
| from datetime import datetime | |
| import torch | |
| import gradio as gr | |
| from diffusers import ( | |
| StableDiffusionPipeline, | |
| StableDiffusionXLPipeline, | |
| ) | |
| from huggingface_hub import HfApi | |
| from PIL import Image | |
| # ---------------------- | |
| # Constants & Utilities | |
| # ---------------------- | |
| DEFAULT_MODELS = { | |
| "Stable Diffusion 1.5 (fastest)": "runwayml/stable-diffusion-v1-5", | |
| "Stable Diffusion XL Base 1.0": "stabilityai/stable-diffusion-xl-base-1.0", | |
| } | |
| # CPU-friendly defaults; auto-updated on model switch. | |
| DEFAULT_W_H = { | |
| "runwayml/stable-diffusion-v1-5": (512, 768), | |
| "stabilityai/stable-diffusion-xl-base-1.0": (768, 1024), | |
| } | |
| SCENE_HEADER = re.compile(r"^\s*Scene\s*\d+\s*[:\-–]", re.IGNORECASE | re.MULTILINE) | |
| PIPELINES = {} | |
| API = HfApi() | |
| HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| SPACE_ID = os.environ.get("SPACE_ID") or os.environ.get("SPACE_REPO") | |
| def get_pipeline(model_id: str): | |
| """Load & cache a pipeline for CPU usage.""" | |
| if model_id in PIPELINES: | |
| return PIPELINES[model_id] | |
| dtype = torch.float32 # CPU-safe | |
| if "stable-diffusion-xl" in model_id: | |
| pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=dtype) | |
| else: | |
| pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype) | |
| pipe = pipe.to("cpu") | |
| pipe.enable_attention_slicing() | |
| pipe.enable_vae_slicing() | |
| pipe.safety_checker = None # assuming safe usage/content policy is handled upstream | |
| PIPELINES[model_id] = pipe | |
| return pipe | |
| def split_into_scene_prompts(text: str) -> List[str]: | |
| """Split input script into up to 5 scene prompts. | |
| - If no explicit Scene headers are found, repeat the whole text to make 5 prompts. | |
| - If fewer than 5 scenes, pad with the last scene. | |
| - If more than 5, truncate to 5. | |
| """ | |
| text = (text or "").strip() | |
| if not text: | |
| return [] | |
| headers = list(SCENE_HEADER.finditer(text)) | |
| if not headers: | |
| return [text] * 5 | |
| ambience = text[: headers[0].start()].strip() | |
| blocks = [] | |
| for i, m in enumerate(headers): | |
| start = m.start() | |
| end = headers[i + 1].start() if i + 1 < len(headers) else len(text) | |
| block = text[start:end].strip() | |
| blocks.append(block) | |
| if len(blocks) < 5 and blocks: | |
| blocks += [blocks[-1]] * (5 - len(blocks)) | |
| elif len(blocks) > 5: | |
| blocks = blocks[:5] | |
| if ambience: | |
| blocks = [f"{ambience}\n\n{b}" for b in blocks] | |
| return blocks | |
| def clamp_size(model_id: str, width: int, height: int) -> Tuple[int, int]: | |
| """Keep sizes reasonable for CPU and aligned to multiples of 8.""" | |
| w, h = int(width), int(height) | |
| w -= w % 8 | |
| h -= h % 8 | |
| if "stable-diffusion-xl" in model_id: | |
| # SDXL works best with longer edge >= ~768; constrain for CPU | |
| w = max(640, min(w, 1152)) | |
| h = max(640, min(h, 1152)) | |
| else: | |
| # SD 1.5 sweet spot; keep safe caps for CPU | |
| w = max(384, min(w, 896)) | |
| h = max(384, min(h, 1152)) | |
| return w, h | |
| def _seed_everything(seed: Optional[int]): | |
| if seed is None or seed < 0: | |
| seed = random.randint(0, 2**32 - 1) | |
| generator = torch.Generator(device="cpu").manual_seed(seed) | |
| return seed, generator | |
| def _generate_one( | |
| prompt: str, | |
| negative_prompt: str, | |
| model_id: str, | |
| width: int, | |
| height: int, | |
| steps: int, | |
| guidance: float, | |
| seed: int, | |
| ) -> Image.Image: | |
| seed, generator = _seed_everything(seed) | |
| pipe = get_pipeline(model_id) | |
| with torch.inference_mode(): | |
| image = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt or None, | |
| width=width, | |
| height=height, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance, | |
| generator=generator, | |
| ).images[0] | |
| return image | |
| async def _generate_one_async(**kwargs) -> Image.Image: | |
| return await asyncio.to_thread(_generate_one, **kwargs) | |
| async def generate_per_scene( | |
| script_text: str, | |
| negative_prompt: str, | |
| model_id: str, | |
| width: int, | |
| height: int, | |
| steps: int, | |
| guidance: float, | |
| seed: int, | |
| ): | |
| """Sequential generation (CPU-friendly) with progress feedback.""" | |
| prompts = split_into_scene_prompts(script_text) | |
| if not prompts: | |
| raise gr.Error("Please enter a prompt or scene script.") | |
| images: List[Image.Image] = [] | |
| total = len(prompts) | |
| progress = gr.Progress(track_tqdm=True) | |
| for i, p in enumerate(prompts, start=1): | |
| progress(i / total, desc=f"Generating scene {i}/{total}") | |
| try: | |
| img = await _generate_one_async( | |
| prompt=p, | |
| negative_prompt=negative_prompt, | |
| model_id=model_id, | |
| width=width, | |
| height=height, | |
| steps=steps, | |
| guidance=guidance, | |
| seed=seed + (i - 1) if seed >= 0 else seed, | |
| ) | |
| except Exception as e: | |
| print(f"[error] scene {i} failed:", e) | |
| img = Image.new("RGB", (width, height), color=(220, 220, 220)) | |
| images.append(img) | |
| return images | |
| def _save_images_to_repo(imgs: List[Image.Image], subdir: str = "outputs") -> List[str]: | |
| """Save to the Space repo if HF_TOKEN & SPACE_ID are set. Returns repo paths.""" | |
| if not (HF_TOKEN and SPACE_ID): | |
| return [] | |
| ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S") | |
| paths = [] | |
| for idx, img in enumerate(imgs, start=1): | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| buf.seek(0) | |
| remote_path = f"{subdir}/{ts}_scene{idx}.png" | |
| API.upload_file( | |
| path_or_fileobj=buf, | |
| path_in_repo=remote_path, | |
| repo_id=SPACE_ID, | |
| repo_type="space", | |
| ) | |
| paths.append(remote_path) | |
| return paths | |
| def validate_inputs(script_text: str, steps: int, guidance: float): | |
| if not script_text or not script_text.strip(): | |
| raise gr.Error("Please enter a prompt or scene script.") | |
| if not (10 <= int(steps) <= 60): | |
| raise gr.Error("Steps must be between 10 and 60.") | |
| if not (1.0 <= float(guidance) <= 12.0): | |
| raise gr.Error("Guidance must be between 1.0 and 12.0.") | |
| with gr.Blocks(title="Loomvale Image Lab — CPU") as demo: | |
| gr.Markdown(""" | |
| # Loomvale Image Lab — CPU | |
| Enter a single prompt or a multi-scene script using headings like **Scene 1: ...**, **Scene 2: ...**. | |
| The app will generate up to **5** images (padding/truncating as needed). | |
| """) | |
| with gr.Row(): | |
| model = gr.Dropdown( | |
| label="Model", | |
| choices=list(DEFAULT_MODELS.keys()), | |
| value="Stable Diffusion 1.5 (fastest)", | |
| ) | |
| model_id_state = gr.State(DEFAULT_MODELS["Stable Diffusion 1.5 (fastest)"]) | |
| script = gr.Textbox( | |
| label="Prompt or Multi-Scene Script", | |
| lines=6, | |
| placeholder=( | |
| "Optional ambience on top...\n\n" | |
| "Scene 1: A cozy studio filled with soft morning light\n" | |
| "Scene 2: A minimalist desk with a steaming cup of tea\n" | |
| "Scene 3: ..." | |
| ), | |
| ) | |
| negative = gr.Textbox( | |
| label="Negative Prompt (optional)", | |
| placeholder="blurry, low quality, watermark, text, nsfw", | |
| value="blurry, low quality, watermark, text, worst quality, lowres", | |
| ) | |
| w = gr.Slider(384, 1024, value=512, step=8, label="Width") | |
| h = gr.Slider(512, 1280, value=768, step=8, label="Height") | |
| steps = gr.Slider(10, 60, value=28, step=1, label="Steps") | |
| guidance = gr.Slider(1.0, 12.0, value=7.0, step=0.1, label="Guidance Scale") | |
| seed = gr.Number(value=-1, label="Seed (-1 = random)") | |
| can_save = bool(HF_TOKEN and SPACE_ID) | |
| save_to_repo = gr.Checkbox( | |
| label=f"Save generated images to this Space repo ({SPACE_ID})", | |
| value=can_save, | |
| interactive=can_save, | |
| visible=True, | |
| ) | |
| btn = gr.Button("Generate Images", variant="primary") | |
| btn_clear = gr.Button("Clear") | |
| gallery = gr.Gallery(label="Images", columns=5, rows=1, height="auto", allow_preview=True) | |
| gallery.style(grid=5, preview=True, object_fit="contain") # keep layout tidy | |
| status = gr.Markdown(visible=True) | |
| # Examples for quick testing | |
| gr.Examples( | |
| examples=[ | |
| ["Ambient: gentle morning light\n\nScene 1: pastel living room\nScene 2: sunlight on linen curtains\nScene 3: ceramic mug on wooden table"], | |
| ["Scene 1: cyberpunk alley, neon reflections\nScene 2: rooftop garden at dusk\nScene 3: rainy crosswalk with umbrellas"], | |
| ], | |
| inputs=[script], | |
| label="Examples", | |
| ) | |
| def _sync_model_choice(choice): | |
| mid = DEFAULT_MODELS[choice] | |
| base_w, base_h = DEFAULT_W_H[mid] | |
| return mid, gr.update(value=base_w), gr.update(value=base_h) | |
| model.change(_sync_model_choice, inputs=model, outputs=[model_id_state, w, h]) | |
| async def _on_click( | |
| script_text, negative_prompt, _model_choice, _model_id, width, height, steps_, guidance_, seed_, save_flag | |
| ): | |
| validate_inputs(script_text, steps_, guidance_) | |
| w_clamped, h_clamped = clamp_size(_model_id, int(width), int(height)) | |
| imgs = await generate_per_scene( | |
| script_text=script_text, | |
| negative_prompt=negative_prompt, | |
| model_id=_model_id, | |
| width=w_clamped, | |
| height=h_clamped, | |
| steps=int(steps_), | |
| guidance=float(guidance_), | |
| seed=int(seed_), | |
| ) | |
| msg = f"✅ Generated {len(imgs)} image(s) at {w_clamped}×{h_clamped}." | |
| links = [] | |
| if save_flag: | |
| try: | |
| links = _save_images_to_repo(imgs) | |
| if links: | |
| saved_list = "\n".join(f"- {p}" for p in links) | |
| msg += f"\nSaved:\n{saved_list}" | |
| else: | |
| msg += "\nℹ️ Skipped saving (token/repo not configured)." | |
| except Exception as e: | |
| print("[save_error]", e) | |
| msg += "\n⚠️ Save failed (see logs)." | |
| return imgs, msg | |
| btn.click( | |
| _on_click, | |
| inputs=[script, negative, model, model_id_state, w, h, steps, guidance, seed, save_to_repo], | |
| outputs=[gallery, status], | |
| concurrency_limit=1, | |
| ) | |
| def _on_clear(): | |
| return None, "" | |
| btn_clear.click(_on_clear, outputs=[gallery, status]) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |