Spaces:
Running
Running
| # app.py | |
| # Robust Hugging Face Space: load a Diffusers model with safe fallbacks | |
| # No branding in source — ready to publish under any HF account | |
| import os | |
| import time | |
| import traceback | |
| import logging | |
| from typing import Optional, Tuple | |
| import torch | |
| from PIL import Image | |
| import gradio as gr | |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
| from huggingface_hub import hf_hub_download, HfApi | |
| from transformers import logging as trf_logging | |
| # ------------------------- | |
| # Logging | |
| # ------------------------- | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(levelname)s — %(message)s") | |
| logger = logging.getLogger("prompt-image-editor") | |
| trf_logging.set_verbosity_error() | |
| # ------------------------- | |
| # Config via environment | |
| # ------------------------- | |
| MODEL_ID = os.getenv("MODEL_ID", "runwayml/stable-diffusion-v1-5") # recommended default | |
| HF_TOKEN = os.getenv("HF_API_TOKEN") # optional, put as Secret if needed | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| RETRY_COUNT = int(os.getenv("MODEL_LOAD_RETRIES", "3")) | |
| RETRY_WAIT_SECONDS = float(os.getenv("MODEL_LOAD_RETRY_WAIT", "2.0")) | |
| # Optional: switch to inference-api mode instead of loading the model in-process | |
| USE_INFERENCE_API = os.getenv("USE_INFERENCE_API", "false").lower() in ("1", "true", "yes") | |
| # ------------------------- | |
| # Utilities | |
| # ------------------------- | |
| def safe_from_pretrained(model_id: str, token: Optional[str] = None): | |
| """ | |
| Load a diffusers pipeline with safe options (dtype/device_map when available). | |
| Raise exception to caller on failure. | |
| """ | |
| kwargs = {} | |
| if token: | |
| kwargs["use_auth_token"] = token | |
| # Use float16 on CUDA for memory saving; else float32 | |
| torch_dtype = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| # Try to create pipeline with recommended scheduler | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| **kwargs | |
| ) | |
| # set scheduler if desired (optional improvement) | |
| try: | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| except Exception: | |
| # ignore if incompatible | |
| pass | |
| if DEVICE == "cuda": | |
| pipe = pipe.to("cuda") | |
| else: | |
| pipe = pipe.to("cpu") | |
| # enable VAE tiling or RAM optimizations if needed (user can extend) | |
| return pipe | |
| def load_pipeline_with_retries(model_id: str, token: Optional[str], retries: int = 3, wait: float = 2.0): | |
| """ | |
| Attempt to load model with retries and fallback logic. | |
| Returns (pipeline or None, error_message or None) | |
| """ | |
| last_err = None | |
| for attempt in range(1, retries + 1): | |
| try: | |
| logger.info(f"[load] Attempt {attempt}/{retries} to load model '{model_id}' (token set: {'yes' if token else 'no'}).") | |
| pipe = safe_from_pretrained(model_id, token) | |
| logger.info(f"[load] Successfully loaded model '{model_id}'.") | |
| return pipe, None | |
| except Exception as e: | |
| last_err = traceback.format_exc() | |
| logger.warning(f"[load] Failed attempt {attempt}: {e}") | |
| if attempt < retries: | |
| time.sleep(wait * attempt) # exponential-ish backoff | |
| # fallback attempt to a known public model if initial failed | |
| fallback = "runwayml/stable-diffusion-v1-5" | |
| if model_id != fallback: | |
| try: | |
| logger.info(f"[load] Trying fallback model '{fallback}'.") | |
| pipe = safe_from_pretrained(fallback, None) | |
| logger.info(f"[load] Successfully loaded fallback '{fallback}'.") | |
| return pipe, None | |
| except Exception as e: | |
| last_err = traceback.format_exc() | |
| logger.error(f"[load] Fallback also failed: {e}") | |
| return None, last_err | |
| # ------------------------- | |
| # Pipeline init | |
| # ------------------------- | |
| pipe = None | |
| load_error = None | |
| if USE_INFERENCE_API: | |
| logger.info("Configured to use Inference API mode. The app will not load local models.") | |
| else: | |
| try: | |
| pipe, load_error = load_pipeline_with_retries(MODEL_ID, HF_TOKEN, retries=RETRY_COUNT, wait=RETRY_WAIT_SECONDS) | |
| except Exception as e: | |
| pipe = None | |
| load_error = traceback.format_exc() | |
| logger.error("Unexpected error during model load:\n" + load_error) | |
| # ------------------------- | |
| # Inference function | |
| # ------------------------- | |
| def generate_image(prompt: str, steps: int = 28, guidance: float = 7.5) -> Tuple[Optional[Image.Image], str]: | |
| """ | |
| Returns (PIL.Image or None, status message) | |
| """ | |
| if USE_INFERENCE_API: | |
| return None, "Inference API mode enabled — implement API call flow or disable USE_INFERENCE_API." | |
| if pipe is None: | |
| return None, "Model is not loaded. Check Space Settings (MODEL_ID & HF_API_TOKEN) and server logs." | |
| if not prompt or not prompt.strip(): | |
| return None, "Please enter a valid prompt." | |
| try: | |
| # autocast only on CUDA | |
| if DEVICE == "cuda": | |
| with torch.autocast("cuda"): | |
| out = pipe(prompt=prompt, guidance_scale=guidance, num_inference_steps=int(steps)) | |
| else: | |
| out = pipe(prompt=prompt, guidance_scale=guidance, num_inference_steps=int(steps)) | |
| img = out.images[0] | |
| return img, "OK" | |
| except Exception as e: | |
| logger.exception("Inference failed") | |
| return None, f"Inference error: {str(e)}" | |
| # ------------------------- | |
| # Gradio UI | |
| # ------------------------- | |
| title = "Prompt Image Editor" | |
| description = "Generate or edit images using a Diffusers-compatible model. Configure MODEL_ID and HF_API_TOKEN in Settings → Variables & Secrets." | |
| with gr.Blocks(title=title) as demo: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt = gr.Textbox(lines=4, label="Prompt", placeholder="e.g. A portrait of an astronaut riding a horse, cinematic lighting") | |
| steps = gr.Slider(minimum=10, maximum=60, step=1, value=28, label="Steps") | |
| guidance = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, value=7.5, label="Guidance scale") | |
| run_btn = gr.Button("Generate") | |
| status = gr.Textbox(label="Status", interactive=False, value="Model loaded." if pipe else "Model not loaded. Check settings.") | |
| with gr.Column(scale=3): | |
| out_img = gr.Image(label="Output image", type="pil") | |
| def _on_generate(prompt_text, steps_val, guidance_val): | |
| img, msg = generate_image(prompt_text, steps_val, guidance_val) | |
| return img, msg | |
| run_btn.click(_on_generate, inputs=[prompt, steps, guidance], outputs=[out_img, status]) | |
| if __name__ == "__main__": | |
| demo.launch() | |