os-odyssey's picture
Update app.py
6a659d3 verified
raw
history blame
3.91 kB
# app.py
import os
import traceback
import gradio as gr
from PIL import Image
import torch
from diffusers import StableDiffusionPipeline
from transformers import logging
logging.set_verbosity_error()
# Config from environment
MODEL_ID = os.getenv("MODEL_ID", "stabilityai/stable-diffusion-2-1")
HF_TOKEN = os.getenv("HF_API_TOKEN") # Secret in Spaces (optional)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def try_load(model_id, token=None):
"""
Try to load a diffusers pipeline. Raises the original exception on fatal error.
"""
kwargs = {}
if token:
kwargs["use_auth_token"] = token
# choose dtype based on device
torch_dtype = torch.float16 if DEVICE == "cuda" else torch.float32
try:
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
revision="fp16" if DEVICE == "cuda" else None,
torch_dtype=torch_dtype,
**kwargs
)
if DEVICE == "cuda":
pipe = pipe.to("cuda")
else:
pipe = pipe.to("cpu")
return pipe
except Exception:
# re-raise to let caller decide (we'll handle fallback outside)
raise
# Load pipeline with fallback logic
def load_pipeline_with_fallback():
tried = []
# first attempt: user-provided MODEL_ID with token if exists
try:
print(f"Attempting to load MODEL_ID='{MODEL_ID}' (token set: {'yes' if HF_TOKEN else 'no'}) on {DEVICE}")
return try_load(MODEL_ID, token=HF_TOKEN)
except Exception as e:
tried.append((MODEL_ID, str(e)))
print(f"Failed to load {MODEL_ID}: {e}")
# fallback: try a known-public model
fallback_model = "runwayml/stable-diffusion-v1-5"
try:
print(f"Attempting fallback model '{fallback_model}'")
return try_load(fallback_model, token=None)
except Exception as e:
tried.append((fallback_model, str(e)))
print(f"Failed to load fallback {fallback_model}: {e}")
# if we get here, nothing could be loaded
msg = "Failed to load any model. Tried: " + ", ".join([f"{m}: {err[:80]}" for m,err in tried])
raise RuntimeError(msg)
# initialize
try:
pipe = load_pipeline_with_fallback()
except Exception as e:
# If pipeline can't be loaded, set pipe = None and keep running (UI will show error)
pipe = None
load_error = traceback.format_exc()
print("MODEL LOAD ERROR:\n", load_error)
# Inference function
def generate_image(prompt: str, steps: int = 28, guidance: float = 7.5):
if pipe is None:
return None, "Model not loaded. Check Space Settings (MODEL_ID & HF_API_TOKEN). See server logs."
if not prompt or not prompt.strip():
return None, "Please provide a prompt."
try:
with torch.autocast("cuda") if DEVICE == "cuda" else torch.no_grad():
out = pipe(prompt=prompt, guidance_scale=guidance, num_inference_steps=steps)
img = out.images[0]
return img, "OK"
except Exception as e:
print("Inference error:", e)
return None, f"Inference error: {str(e)}"
# Gradio UI
with gr.Blocks(title="Prompt Image Editor") as demo:
gr.Markdown("# Prompt Image Editor")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=3, label="Prompt")
steps = gr.Slider(minimum=10, maximum=60, step=1, value=28, label="Steps")
guidance = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=7.5, label="Guidance")
run = gr.Button("Generate")
status = gr.Textbox(label="Status")
with gr.Column():
out_img = gr.Image(label="Output", type="pil")
def _run(prompt, steps, guidance):
img, msg = generate_image(prompt, steps, guidance)
return img, msg
run.click(_run, inputs=[prompt, steps, guidance], outputs=[out_img, status])
if __name__ == "__main__":
demo.launch()