os-odyssey commited on
Commit
6a659d3
·
verified ·
1 Parent(s): 9dc50fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -43
app.py CHANGED
@@ -1,53 +1,110 @@
 
1
  import os
2
- import torch
3
  import gradio as gr
 
 
 
4
  from diffusers import StableDiffusionPipeline
 
 
5
 
 
6
  MODEL_ID = os.getenv("MODEL_ID", "stabilityai/stable-diffusion-2-1")
 
7
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- # -------------------------
10
- # Load Model
11
- # -------------------------
12
- def load_pipeline():
13
- print(f"Loading model: {MODEL_ID} on {DEVICE}")
14
- pipe = StableDiffusionPipeline.from_pretrained(
15
- MODEL_ID,
16
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
17
- )
18
- pipe = pipe.to(DEVICE)
19
- return pipe
20
-
21
- pipe = load_pipeline()
22
-
23
- # -------------------------
24
- # Inference Function
25
- # -------------------------
26
- def generate(prompt):
27
- if not prompt or prompt.strip() == "":
28
- return "Please enter a valid prompt.", None
29
-
30
- print("Running inference...")
31
-
32
- result = pipe(
33
- prompt=prompt,
34
- num_inference_steps=25,
35
- guidance_scale=7.5
36
- )
37
-
38
- image = result.images[0]
39
- return f"Generated image for: {prompt}", image
40
-
41
- # -------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # Gradio UI
43
- # -------------------------
44
- interface = gr.Interface(
45
- fn=generate,
46
- inputs=gr.Textbox(label="Prompt", placeholder="Enter your image prompt..."),
47
- outputs=[gr.Textbox(label="Status"), gr.Image(label="Generated Image")],
48
- title="Prompt Image Editor",
49
- description="Generate AI images using text prompts.",
50
- )
 
 
 
 
 
 
 
 
 
51
 
52
  if __name__ == "__main__":
53
- interface.launch()
 
1
+ # app.py
2
  import os
3
+ import traceback
4
  import gradio as gr
5
+ from PIL import Image
6
+ import torch
7
+
8
  from diffusers import StableDiffusionPipeline
9
+ from transformers import logging
10
+ logging.set_verbosity_error()
11
 
12
+ # Config from environment
13
  MODEL_ID = os.getenv("MODEL_ID", "stabilityai/stable-diffusion-2-1")
14
+ HF_TOKEN = os.getenv("HF_API_TOKEN") # Secret in Spaces (optional)
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
+ def try_load(model_id, token=None):
18
+ """
19
+ Try to load a diffusers pipeline. Raises the original exception on fatal error.
20
+ """
21
+ kwargs = {}
22
+ if token:
23
+ kwargs["use_auth_token"] = token
24
+ # choose dtype based on device
25
+ torch_dtype = torch.float16 if DEVICE == "cuda" else torch.float32
26
+ try:
27
+ pipe = StableDiffusionPipeline.from_pretrained(
28
+ model_id,
29
+ revision="fp16" if DEVICE == "cuda" else None,
30
+ torch_dtype=torch_dtype,
31
+ **kwargs
32
+ )
33
+ if DEVICE == "cuda":
34
+ pipe = pipe.to("cuda")
35
+ else:
36
+ pipe = pipe.to("cpu")
37
+ return pipe
38
+ except Exception:
39
+ # re-raise to let caller decide (we'll handle fallback outside)
40
+ raise
41
+
42
+ # Load pipeline with fallback logic
43
+ def load_pipeline_with_fallback():
44
+ tried = []
45
+ # first attempt: user-provided MODEL_ID with token if exists
46
+ try:
47
+ print(f"Attempting to load MODEL_ID='{MODEL_ID}' (token set: {'yes' if HF_TOKEN else 'no'}) on {DEVICE}")
48
+ return try_load(MODEL_ID, token=HF_TOKEN)
49
+ except Exception as e:
50
+ tried.append((MODEL_ID, str(e)))
51
+ print(f"Failed to load {MODEL_ID}: {e}")
52
+
53
+ # fallback: try a known-public model
54
+ fallback_model = "runwayml/stable-diffusion-v1-5"
55
+ try:
56
+ print(f"Attempting fallback model '{fallback_model}'")
57
+ return try_load(fallback_model, token=None)
58
+ except Exception as e:
59
+ tried.append((fallback_model, str(e)))
60
+ print(f"Failed to load fallback {fallback_model}: {e}")
61
+
62
+ # if we get here, nothing could be loaded
63
+ msg = "Failed to load any model. Tried: " + ", ".join([f"{m}: {err[:80]}" for m,err in tried])
64
+ raise RuntimeError(msg)
65
+
66
+ # initialize
67
+ try:
68
+ pipe = load_pipeline_with_fallback()
69
+ except Exception as e:
70
+ # If pipeline can't be loaded, set pipe = None and keep running (UI will show error)
71
+ pipe = None
72
+ load_error = traceback.format_exc()
73
+ print("MODEL LOAD ERROR:\n", load_error)
74
+
75
+ # Inference function
76
+ def generate_image(prompt: str, steps: int = 28, guidance: float = 7.5):
77
+ if pipe is None:
78
+ return None, "Model not loaded. Check Space Settings (MODEL_ID & HF_API_TOKEN). See server logs."
79
+ if not prompt or not prompt.strip():
80
+ return None, "Please provide a prompt."
81
+ try:
82
+ with torch.autocast("cuda") if DEVICE == "cuda" else torch.no_grad():
83
+ out = pipe(prompt=prompt, guidance_scale=guidance, num_inference_steps=steps)
84
+ img = out.images[0]
85
+ return img, "OK"
86
+ except Exception as e:
87
+ print("Inference error:", e)
88
+ return None, f"Inference error: {str(e)}"
89
+
90
  # Gradio UI
91
+ with gr.Blocks(title="Prompt Image Editor") as demo:
92
+ gr.Markdown("# Prompt Image Editor")
93
+ with gr.Row():
94
+ with gr.Column():
95
+ prompt = gr.Textbox(lines=3, label="Prompt")
96
+ steps = gr.Slider(minimum=10, maximum=60, step=1, value=28, label="Steps")
97
+ guidance = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=7.5, label="Guidance")
98
+ run = gr.Button("Generate")
99
+ status = gr.Textbox(label="Status")
100
+ with gr.Column():
101
+ out_img = gr.Image(label="Output", type="pil")
102
+
103
+ def _run(prompt, steps, guidance):
104
+ img, msg = generate_image(prompt, steps, guidance)
105
+ return img, msg
106
+
107
+ run.click(_run, inputs=[prompt, steps, guidance], outputs=[out_img, status])
108
 
109
  if __name__ == "__main__":
110
+ demo.launch()