00Boobs00 commited on
Commit
0dccb9c
·
verified ·
1 Parent(s): 5c25205

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -63
app.py CHANGED
@@ -1,72 +1,233 @@
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
  import spaces
5
  import torch
6
- from diffusers import DiffusionPipeline
7
 
8
- dtype = torch.bfloat16
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
 
 
 
11
 
 
 
12
 
13
- # Load the model pipeline
14
- pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
 
 
 
15
 
16
- torch.cuda.empty_cache()
 
 
 
17
 
18
- MAX_SEED = np.iinfo(np.int32).max
19
- MAX_IMAGE_SIZE = 2048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
 
 
 
 
 
 
22
 
23
- @spaces.GPU()
24
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4, num_inference_steps=28, lora_id=None, lora_scale=0.95, progress=gr.Progress(track_tqdm=True)):
25
- if randomize_seed:
26
- seed = random.randint(0, MAX_SEED)
27
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
-
30
- if lora_id and lora_id.strip() != "":
31
- pipe.unload_lora_weights()
32
- pipe.load_lora_weights(lora_id.strip())
33
-
34
  try:
35
- image = pipe(
36
- prompt=prompt,
37
- negative_prompt="",
38
- width=width,
39
- height=height,
40
- num_inference_steps=num_inference_steps,
41
- generator=generator,
42
- true_cfg_scale=guidance_scale,
43
- guidance_scale=1.0 # Use a fixed default for distilled guidance
44
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  return image, seed
 
46
  finally:
47
- # Unload LoRA weights if they were loaded
48
- if lora_id:
49
- pipe.unload_lora_weights()
50
-
 
 
 
 
 
 
 
 
 
51
  examples = [
52
  "a tiny astronaut hatching from an egg on the moon",
53
  "a cat holding a sign that says hello world",
54
  "an anime illustration of a wiener schnitzel",
55
  ]
56
-
57
  css = """
58
  #col-container {
59
- margin: 0 auto;
60
- max-width: 960px;
61
  }
62
  .generate-btn {
63
- background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
64
- border: none !important;
65
- color: white !important;
66
  }
67
  .generate-btn:hover {
68
- transform: translateY(-2px);
69
- box-shadow: 0 5px 15px rgba(0,0,0,0.2);
70
  }
71
  """
72
 
@@ -76,9 +237,18 @@ with gr.Blocks(css=css) as app:
76
  with gr.Row():
77
  with gr.Column():
78
  with gr.Row():
79
- text_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here", lines=3, elem_id="prompt-text-input")
 
 
 
 
 
80
  with gr.Row():
81
- custom_lora = gr.Textbox(label="Custom LoRA (optional)", info="LoRA Hugging Face path", placeholder="flymy-ai/qwen-image-realism-lora")
 
 
 
 
82
  with gr.Row():
83
  with gr.Accordion("Advanced Settings", open=False):
84
  lora_scale = gr.Slider(
@@ -89,36 +259,86 @@ with gr.Blocks(css=css) as app:
89
  value=1,
90
  )
91
  with gr.Row():
92
- width = gr.Slider(label="Width", value=1024, minimum=64, maximum=2048, step=16)
93
- height = gr.Slider(label="Height", value=1024, minimum=64, maximum=2048, step=16)
94
- seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=4294967296, step=1)
95
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  with gr.Row():
97
- steps = gr.Slider(label="Inference steps steps", value=28, minimum=1, maximum=100, step=1)
98
- cfg = gr.Slider(label="Guidance Scale", value=4, minimum=1, maximum=20, step=0.5)
99
- # method = gr.Radio(label="Sampling method", value="DPM++ 2M Karras", choices=["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"])
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  with gr.Row():
102
- # text_button = gr.Button("Run", variant='primary', elem_id="gen-button")
103
- text_button = gr.Button("✨ Generate Image", variant='primary', elem_classes=["generate-btn"])
 
 
 
 
104
  with gr.Column():
105
  with gr.Row():
106
- image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
107
-
108
- # gr.Markdown(article_text)
 
 
 
109
  with gr.Column():
110
  gr.Examples(
111
- examples = examples,
112
- inputs = [text_prompt],
113
  )
 
 
114
  gr.on(
115
  triggers=[text_button.click, text_prompt.submit],
116
- fn = infer,
117
- inputs=[text_prompt, seed, randomize_seed, width, height, cfg, steps, custom_lora, lora_scale],
118
- outputs=[image_output, seed]
 
 
 
 
 
 
 
 
 
 
119
  )
120
-
121
- # text_button.click(query, inputs=[custom_lora, text_prompt, steps, cfg, randomize_seed, seed, width, height], outputs=[image_output,seed_output, seed])
122
- # text_button.click(infer, inputs=[text_prompt, seed, randomize_seed, width, height, cfg, steps, custom_lora, lora_scale], outputs=[image_output,seed_output, seed])
123
 
124
- app.launch(share=True)
 
 
 
1
+ import os
2
+ import traceback
3
  import gradio as gr
4
  import numpy as np
5
  import random
6
  import spaces
7
  import torch
8
+ from diffusers import DiffusionPipeline
9
 
10
+ # --------------------
11
+ # Global config
12
+ # --------------------
13
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ MAX_SEED = np.iinfo(np.int32).max
17
+ MAX_IMAGE_SIZE = 2048
18
+ MIN_IMAGE_SIZE = 64
19
 
20
+ # Optional: environment override for model name
21
+ MODEL_ID = os.getenv("QWEN_IMAGE_MODEL_ID", "Qwen/Qwen-Image")
22
 
23
+ # --------------------
24
+ # Pipeline load with guard
25
+ # --------------------
26
+ pipe = None
27
+ pipe_load_error = None
28
 
29
+ def _load_pipeline():
30
+ global pipe, pipe_load_error
31
+ if pipe is not None:
32
+ return pipe
33
 
34
+ try:
35
+ pipe = DiffusionPipeline.from_pretrained(
36
+ MODEL_ID,
37
+ torch_dtype=dtype
38
+ )
39
+ pipe = pipe.to(device)
40
+ torch.cuda.empty_cache()
41
+ except Exception as e:
42
+ pipe_load_error = f"Failed to load model '{MODEL_ID}': {repr(e)}"
43
+ traceback.print_exc()
44
+ return pipe
45
+
46
+ _load_pipeline() # eager load on startup
47
+
48
+
49
+ def _safe_clamp_size(width: int, height: int):
50
+ """
51
+ Clamp image dimensions to safe boundaries and keep them multiples of 8/16.
52
+ """
53
+ def _round_to_16(x):
54
+ return int(max(MIN_IMAGE_SIZE, min(MAX_IMAGE_SIZE, x)) // 16 * 16)
55
+
56
+ w = _round_to_16(width)
57
+ h = _round_to_16(height)
58
+ return w, h
59
+
60
+
61
+ def _normalize_seed(seed, randomize_seed: bool):
62
+ """
63
+ Normalize seed: if -1 or None, or randomize_seed=True, draw a fresh seed.
64
+ """
65
+ if randomize_seed or seed is None or int(seed) < 0:
66
+ return random.randint(0, MAX_SEED)
67
+ return int(seed) % (MAX_SEED + 1)
68
+
69
+
70
+ def _maybe_load_lora(lora_id: str, lora_scale: float):
71
+ """
72
+ Load LoRA if provided. Returns (loaded: bool, message: str | None).
73
+ """
74
+ if not lora_id or lora_id.strip() == "":
75
+ return False, None
76
+
77
+ lora_id = lora_id.strip()
78
+ try:
79
+ # Best-effort unload previous LoRA if supported
80
+ if hasattr(pipe, "unload_lora_weights"):
81
+ pipe.unload_lora_weights()
82
+
83
+ if hasattr(pipe, "load_lora_weights"):
84
+ pipe.load_lora_weights(lora_id, adapter_name="default", weight_name=None)
85
+ else:
86
+ return False, f"LoRA support not available in this pipeline. (Tried: {lora_id})"
87
+
88
+ # Some pipelines support setting a scale attribute or passing scale in call.
89
+ # Here we just report scale; the actual use depends on the underlying pipeline.
90
+ return True, None
91
+ except Exception as e:
92
+ traceback.print_exc()
93
+ return False, f"Failed to load LoRA '{lora_id}': {repr(e)}"
94
+
95
+
96
+ def _maybe_unload_lora():
97
+ try:
98
+ if hasattr(pipe, "unload_lora_weights"):
99
+ pipe.unload_lora_weights()
100
+ except Exception:
101
+ traceback.print_exc()
102
+
103
+
104
+ # --------------------
105
+ # Inference function with robust error handling
106
+ # --------------------
107
+ @spaces.GPU(duration=120)
108
+ def infer(
109
+ prompt: str,
110
+ seed: int = 42,
111
+ randomize_seed: bool = False,
112
+ width: int = 1024,
113
+ height: int = 1024,
114
+ guidance_scale: float = 4.0,
115
+ num_inference_steps: int = 28,
116
+ lora_id: str = None,
117
+ lora_scale: float = 0.95,
118
+ progress=gr.Progress(track_tqdm=True),
119
+ ):
120
+ """
121
+ Main inference entrypoint for Gradio.
122
 
123
+ Returns:
124
+ - on success: (PIL.Image, seed)
125
+ - on failure: (None, seed or -1) with a user-friendly error via gr.Error
126
+ """
127
+ # Basic validation
128
+ if not prompt or prompt.strip() == "":
129
+ raise gr.Error("Prompt is empty. Please provide a text prompt.")
130
 
131
+ # If model failed to load at startup, fail fast
132
+ if pipe_load_error is not None:
133
+ raise gr.Error(
134
+ f"Model failed to load on startup: {pipe_load_error}
135
+ "
136
+ "Try restarting the Space or check the logs."
137
+ )
138
+
139
+ # Clamp dimensions
140
+ width, height = _safe_clamp_size(width, height)
141
+
142
+ # Normalize seed
143
+ seed = _normalize_seed(seed, randomize_seed)
144
+ generator = torch.Generator(device=device).manual_seed(seed)
145
+
146
+ lora_loaded = False
147
+ lora_warning = None
148
 
 
 
 
 
 
149
  try:
150
+ # LoRA loading
151
+ if lora_id and lora_id.strip() != "":
152
+ lora_loaded, lora_warning = _maybe_load_lora(lora_id, lora_scale)
153
+
154
+ progress(0.1, desc="Running generation...")
155
+
156
+ # Core pipeline call
157
+ # true_cfg_scale enables Qwen-style CFG; keep guidance_scale fixed / unused.
158
+ try:
159
+ result = pipe(
160
+ prompt=prompt,
161
+ negative_prompt="", # required even if empty for true_cfg_scale CFG
162
+ width=width,
163
+ height=height,
164
+ num_inference_steps=int(num_inference_steps),
165
+ generator=generator,
166
+ true_cfg_scale=float(guidance_scale),
167
+ guidance_scale=None, # unused for this pipeline
168
+ )
169
+ except torch.cuda.OutOfMemoryError:
170
+ torch.cuda.empty_cache()
171
+ raise gr.Error(
172
+ "CUDA out-of-memory during generation. Try reducing image size or steps."
173
+ )
174
+ except Exception as e:
175
+ traceback.print_exc()
176
+ raise gr.Error(
177
+ f"Inference failed with an internal error: {repr(e)}
178
+ "
179
+ "Please try again with smaller dimensions or fewer steps."
180
+ )
181
+
182
+ if not hasattr(result, "images") or not result.images:
183
+ raise gr.Error(
184
+ "Pipeline returned no images. This may indicate a model or configuration issue."
185
+ )
186
+
187
+ image = result.images[0]
188
+
189
+ # If there was a LoRA warning, surface it as a non-fatal message
190
+ if lora_warning:
191
+ # Use print for logs; Gradio will show the main output, not this text.
192
+ print(lora_warning)
193
+
194
+ progress(1.0, desc="Done")
195
+
196
  return image, seed
197
+
198
  finally:
199
+ # Ensure we always try to clean up LoRA & memory even on errors
200
+ if lora_loaded:
201
+ _maybe_unload_lora()
202
+ if device == "cuda":
203
+ try:
204
+ torch.cuda.empty_cache()
205
+ except Exception:
206
+ pass
207
+
208
+
209
+ # --------------------
210
+ # UI
211
+ # --------------------
212
  examples = [
213
  "a tiny astronaut hatching from an egg on the moon",
214
  "a cat holding a sign that says hello world",
215
  "an anime illustration of a wiener schnitzel",
216
  ]
217
+
218
  css = """
219
  #col-container {
220
+ margin: 0 auto;
221
+ max-width: 960px;
222
  }
223
  .generate-btn {
224
+ background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
225
+ border: none !important;
226
+ color: white !important;
227
  }
228
  .generate-btn:hover {
229
+ transform: translateY(-2px);
230
+ box-shadow: 0 5px 15px rgba(0,0,0,0.2);
231
  }
232
  """
233
 
 
237
  with gr.Row():
238
  with gr.Column():
239
  with gr.Row():
240
+ text_prompt = gr.Textbox(
241
+ label="Prompt",
242
+ placeholder="Enter a prompt here",
243
+ lines=3,
244
+ elem_id="prompt-text-input",
245
+ )
246
  with gr.Row():
247
+ custom_lora = gr.Textbox(
248
+ label="Custom LoRA (optional)",
249
+ info="LoRA Hugging Face path (e.g. flymy-ai/qwen-image-realism-lora)",
250
+ placeholder="flymy-ai/qwen-image-realism-lora",
251
+ )
252
  with gr.Row():
253
  with gr.Accordion("Advanced Settings", open=False):
254
  lora_scale = gr.Slider(
 
259
  value=1,
260
  )
261
  with gr.Row():
262
+ width = gr.Slider(
263
+ label="Width",
264
+ value=1024,
265
+ minimum=MIN_IMAGE_SIZE,
266
+ maximum=MAX_IMAGE_SIZE,
267
+ step=16,
268
+ )
269
+ height = gr.Slider(
270
+ label="Height",
271
+ value=1024,
272
+ minimum=MIN_IMAGE_SIZE,
273
+ maximum=MAX_IMAGE_SIZE,
274
+ step=16,
275
+ )
276
+ seed = gr.Slider(
277
+ label="Seed (-1 = random)",
278
+ value=-1,
279
+ minimum=-1,
280
+ maximum=MAX_SEED,
281
+ step=1,
282
+ )
283
+ randomize_seed = gr.Checkbox(
284
+ label="Randomize seed each run",
285
+ value=True,
286
+ )
287
  with gr.Row():
288
+ steps = gr.Slider(
289
+ label="Inference steps",
290
+ value=28,
291
+ minimum=1,
292
+ maximum=100,
293
+ step=1,
294
+ )
295
+ cfg = gr.Slider(
296
+ label="Guidance Scale (true_cfg_scale)",
297
+ value=4,
298
+ minimum=1,
299
+ maximum=20,
300
+ step=0.5,
301
+ )
302
 
303
  with gr.Row():
304
+ text_button = gr.Button(
305
+ "✨ Generate Image",
306
+ variant="primary",
307
+ elem_classes=["generate-btn"],
308
+ )
309
+
310
  with gr.Column():
311
  with gr.Row():
312
+ image_output = gr.Image(
313
+ type="pil",
314
+ label="Image Output",
315
+ elem_id="gallery",
316
+ )
317
+
318
  with gr.Column():
319
  gr.Examples(
320
+ examples=examples,
321
+ inputs=[text_prompt],
322
  )
323
+
324
+ # Shared handler for button click and prompt submit
325
  gr.on(
326
  triggers=[text_button.click, text_prompt.submit],
327
+ fn=infer,
328
+ inputs=[
329
+ text_prompt,
330
+ seed,
331
+ randomize_seed,
332
+ width,
333
+ height,
334
+ cfg,
335
+ steps,
336
+ custom_lora,
337
+ lora_scale,
338
+ ],
339
+ outputs=[image_output, seed],
340
  )
 
 
 
341
 
342
+ if __name__ == "__main__":
343
+ # In Spaces, HF will call app.launch() implicitly, but keeping this for local dev.
344
+ app.launch(share=False)899492