File size: 3,914 Bytes
6a659d3
1b9b37d
6a659d3
9dc50fa
6a659d3
 
 
9dc50fa
6a659d3
 
e50f297
6a659d3
9dc50fa
6a659d3
e50f297
 
6a659d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dc50fa
6a659d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b9b37d
9dc50fa
6a659d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# app.py
import os
import traceback
import gradio as gr
from PIL import Image
import torch

from diffusers import StableDiffusionPipeline
from transformers import logging
logging.set_verbosity_error()

# Config from environment
MODEL_ID = os.getenv("MODEL_ID", "stabilityai/stable-diffusion-2-1")
HF_TOKEN = os.getenv("HF_API_TOKEN")  # Secret in Spaces (optional)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def try_load(model_id, token=None):
    """
    Try to load a diffusers pipeline. Raises the original exception on fatal error.
    """
    kwargs = {}
    if token:
        kwargs["use_auth_token"] = token
    # choose dtype based on device
    torch_dtype = torch.float16 if DEVICE == "cuda" else torch.float32
    try:
        pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            revision="fp16" if DEVICE == "cuda" else None,
            torch_dtype=torch_dtype,
            **kwargs
        )
        if DEVICE == "cuda":
            pipe = pipe.to("cuda")
        else:
            pipe = pipe.to("cpu")
        return pipe
    except Exception:
        # re-raise to let caller decide (we'll handle fallback outside)
        raise

# Load pipeline with fallback logic
def load_pipeline_with_fallback():
    tried = []
    # first attempt: user-provided MODEL_ID with token if exists
    try:
        print(f"Attempting to load MODEL_ID='{MODEL_ID}' (token set: {'yes' if HF_TOKEN else 'no'}) on {DEVICE}")
        return try_load(MODEL_ID, token=HF_TOKEN)
    except Exception as e:
        tried.append((MODEL_ID, str(e)))
        print(f"Failed to load {MODEL_ID}: {e}")

    # fallback: try a known-public model
    fallback_model = "runwayml/stable-diffusion-v1-5"
    try:
        print(f"Attempting fallback model '{fallback_model}'")
        return try_load(fallback_model, token=None)
    except Exception as e:
        tried.append((fallback_model, str(e)))
        print(f"Failed to load fallback {fallback_model}: {e}")

    # if we get here, nothing could be loaded
    msg = "Failed to load any model. Tried: " + ", ".join([f"{m}: {err[:80]}" for m,err in tried])
    raise RuntimeError(msg)

# initialize
try:
    pipe = load_pipeline_with_fallback()
except Exception as e:
    # If pipeline can't be loaded, set pipe = None and keep running (UI will show error)
    pipe = None
    load_error = traceback.format_exc()
    print("MODEL LOAD ERROR:\n", load_error)

# Inference function
def generate_image(prompt: str, steps: int = 28, guidance: float = 7.5):
    if pipe is None:
        return None, "Model not loaded. Check Space Settings (MODEL_ID & HF_API_TOKEN). See server logs."
    if not prompt or not prompt.strip():
        return None, "Please provide a prompt."
    try:
        with torch.autocast("cuda") if DEVICE == "cuda" else torch.no_grad():
            out = pipe(prompt=prompt, guidance_scale=guidance, num_inference_steps=steps)
            img = out.images[0]
        return img, "OK"
    except Exception as e:
        print("Inference error:", e)
        return None, f"Inference error: {str(e)}"

# Gradio UI
with gr.Blocks(title="Prompt Image Editor") as demo:
    gr.Markdown("# Prompt Image Editor")
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(lines=3, label="Prompt")
            steps = gr.Slider(minimum=10, maximum=60, step=1, value=28, label="Steps")
            guidance = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=7.5, label="Guidance")
            run = gr.Button("Generate")
            status = gr.Textbox(label="Status")
        with gr.Column():
            out_img = gr.Image(label="Output", type="pil")

    def _run(prompt, steps, guidance):
        img, msg = generate_image(prompt, steps, guidance)
        return img, msg

    run.click(_run, inputs=[prompt, steps, guidance], outputs=[out_img, status])

if __name__ == "__main__":
    demo.launch()