primerz commited on
Commit
5dc2f53
Β·
verified Β·
1 Parent(s): 34256b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -20
app.py CHANGED
@@ -97,6 +97,20 @@ class FireRedTheme(Soft):
97
 
98
  theme = FireRedTheme()
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # ═══════════════════════════════════════════════════════════════════════
101
  # MODEL
102
  # ═══════════════════════════════════════════════════════════════════════
@@ -113,22 +127,75 @@ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
113
 
114
  dtype = torch.bfloat16
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  pipe = QwenImageEditPlusPipeline.from_pretrained(
117
  "FireRedTeam/FireRed-Image-Edit-1.1",
118
- transformer=QwenImageTransformer2DModel.from_pretrained(
119
- "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V23",
120
- torch_dtype=dtype,
121
- device_map="cuda",
122
- ),
123
  torch_dtype=dtype,
124
  ).to(device)
125
 
 
126
  try:
127
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
128
  print("Flash Attention 3 Processor set successfully.")
129
  except Exception as e:
130
  print(f"Warning: Could not set FA3 processor: {e}")
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  MAX_SEED = np.iinfo(np.int32).max
133
 
134
  DEFAULT_NEGATIVE_PROMPT = (
@@ -188,9 +255,7 @@ def infer(
188
  seed, randomize_seed, guidance_scale, steps,
189
  progress=gr.Progress(track_tqdm=True),
190
  ):
191
- gc.collect()
192
- torch.cuda.empty_cache()
193
-
194
  if not images:
195
  raise gr.Error("⚠️ Please upload at least one image.")
196
  if not prompt or not prompt.strip():
@@ -219,18 +284,24 @@ def infer(
219
  width, height = update_dimensions_on_upload(pil_images[0])
220
 
221
  try:
222
- result = pipe(
223
- image=pil_images,
224
- prompt=prompt,
225
- negative_prompt=negative_prompt,
226
- height=height,
227
- width=width,
228
- num_inference_steps=steps,
229
- generator=generator,
230
- true_cfg_scale=guidance_scale,
231
- ).images[0]
 
 
 
232
  return result, seed
233
  finally:
 
 
 
234
  gc.collect()
235
  torch.cuda.empty_cache()
236
 
@@ -589,7 +660,7 @@ with gr.Blocks(css=css, theme=theme, title="πŸ”₯ FireRed Image Edit") as demo:
589
  outputs=[images, prompt, output_image, info_box],
590
  )
591
 
592
- # Generate
593
  run_button.click(
594
  fn=infer,
595
  inputs=[
@@ -597,6 +668,7 @@ with gr.Blocks(css=css, theme=theme, title="πŸ”₯ FireRed Image Edit") as demo:
597
  seed, randomize_seed, guidance_scale, steps,
598
  ],
599
  outputs=[output_image, seed],
 
600
  ).then(
601
  fn=format_info,
602
  inputs=[seed, images],
@@ -608,7 +680,12 @@ with gr.Blocks(css=css, theme=theme, title="πŸ”₯ FireRed Image Edit") as demo:
608
  # ═══════════════════════════════════════════════════════════════════════
609
 
610
  if __name__ == "__main__":
611
- demo.queue(max_size=30).launch(
 
 
 
 
 
612
  mcp_server=True,
613
  ssr_mode=False,
614
  show_error=True,
 
97
 
98
  theme = FireRedTheme()
99
 
100
+ # ═══════════════════════════════════════════════════════════════════════
101
+ # GLOBAL CUDA OPTIMIZATIONS
102
+ # ═══════════════════════════════════════════════════════════════════════
103
+
104
+ # Enable cuDNN autotuner β€” finds the fastest convolution algorithms for
105
+ # the hardware and input sizes after a short warm-up.
106
+ torch.backends.cudnn.benchmark = True
107
+
108
+ # Allow TF32 on Ampere+ GPUs for ~3Γ— faster matmuls with negligible
109
+ # precision loss (already bf16 pipeline, so this is free perf).
110
+ torch.backends.cuda.matmul.allow_tf32 = True
111
+ torch.backends.cudnn.allow_tf32 = True
112
+ torch.set_float32_matmul_precision("high")
113
+
114
  # ═══════════════════════════════════════════════════════════════════════
115
  # MODEL
116
  # ═══════════════════════════════════════════════════════════════════════
 
127
 
128
  dtype = torch.bfloat16
129
 
130
+ # Load transformer separately so we can optimise it before plugging in
131
+ transformer = QwenImageTransformer2DModel.from_pretrained(
132
+ "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V23",
133
+ torch_dtype=dtype,
134
+ device_map="cuda",
135
+ )
136
+
137
+ # Attempt torch.compile for a fused-kernel speed-up on the denoising
138
+ # backbone. Falls back gracefully if the environment doesn't support it
139
+ # (older driver / torch version / dynamic-shape issues).
140
+ try:
141
+ transformer = torch.compile(transformer, mode="reduce-overhead")
142
+ print("torch.compile applied to transformer (reduce-overhead).")
143
+ except Exception as e:
144
+ print(f"torch.compile skipped: {e}")
145
+
146
  pipe = QwenImageEditPlusPipeline.from_pretrained(
147
  "FireRedTeam/FireRed-Image-Edit-1.1",
148
+ transformer=transformer,
 
 
 
 
149
  torch_dtype=dtype,
150
  ).to(device)
151
 
152
+ # Flash Attention 3 processor β€” fastest path when available
153
  try:
154
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
155
  print("Flash Attention 3 Processor set successfully.")
156
  except Exception as e:
157
  print(f"Warning: Could not set FA3 processor: {e}")
158
 
159
+ # VAE optimisations β€” process large images in tiles / slices so we
160
+ # never OOM on the decode step, and still stay fast for normal sizes.
161
+ try:
162
+ pipe.vae.enable_tiling()
163
+ print("VAE tiling enabled.")
164
+ except Exception:
165
+ pass
166
+
167
+ try:
168
+ pipe.vae.enable_slicing()
169
+ print("VAE slicing enabled.")
170
+ except Exception:
171
+ pass
172
+
173
+ # ── Warmup pass ─────────────────────────────────────────────────────
174
+ # The first inference is always slower (CUDA context init, cuDNN
175
+ # autotuner, torch.compile tracing). Run a tiny dummy forward so that
176
+ # cost is paid at startup, not on the first user request.
177
+ print("Running warmup inference …")
178
+ try:
179
+ _warmup_img = Image.new("RGB", (64, 64), color=(128, 128, 128))
180
+ _warmup_gen = torch.Generator(device=device).manual_seed(0)
181
+ with torch.inference_mode():
182
+ pipe(
183
+ image=[_warmup_img],
184
+ prompt="warmup",
185
+ negative_prompt="",
186
+ height=64,
187
+ width=64,
188
+ num_inference_steps=1,
189
+ generator=_warmup_gen,
190
+ true_cfg_scale=1.0,
191
+ )
192
+ del _warmup_img, _warmup_gen
193
+ gc.collect()
194
+ torch.cuda.empty_cache()
195
+ print("Warmup complete.")
196
+ except Exception as e:
197
+ print(f"Warmup skipped: {e}")
198
+
199
  MAX_SEED = np.iinfo(np.int32).max
200
 
201
  DEFAULT_NEGATIVE_PROMPT = (
 
255
  seed, randomize_seed, guidance_scale, steps,
256
  progress=gr.Progress(track_tqdm=True),
257
  ):
258
+ # ── Input validation (cheap, do first) ──────────────────────────
 
 
259
  if not images:
260
  raise gr.Error("⚠️ Please upload at least one image.")
261
  if not prompt or not prompt.strip():
 
284
  width, height = update_dimensions_on_upload(pil_images[0])
285
 
286
  try:
287
+ # torch.inference_mode is strictly faster than torch.no_grad β€”
288
+ # it also disables view-tracking and version-counter bumps.
289
+ with torch.inference_mode():
290
+ result = pipe(
291
+ image=pil_images,
292
+ prompt=prompt,
293
+ negative_prompt=negative_prompt,
294
+ height=height,
295
+ width=width,
296
+ num_inference_steps=steps,
297
+ generator=generator,
298
+ true_cfg_scale=guidance_scale,
299
+ ).images[0]
300
  return result, seed
301
  finally:
302
+ # GC *after* inference to reclaim any temporaries the pipeline
303
+ # allocated. Avoid gc.collect() + empty_cache() *before*
304
+ # inference β€” that stalls the CUDA stream for nothing.
305
  gc.collect()
306
  torch.cuda.empty_cache()
307
 
 
660
  outputs=[images, prompt, output_image, info_box],
661
  )
662
 
663
+ # Generate β€” with a public api_name so the endpoint is discoverable
664
  run_button.click(
665
  fn=infer,
666
  inputs=[
 
668
  seed, randomize_seed, guidance_scale, steps,
669
  ],
670
  outputs=[output_image, seed],
671
+ api_name="edit",
672
  ).then(
673
  fn=format_info,
674
  inputs=[seed, images],
 
680
  # ═══════════════════════════════════════════════════════════════════════
681
 
682
  if __name__ == "__main__":
683
+ demo.queue(
684
+ max_size=30,
685
+ default_concurrency_limit=2, # allow 2 concurrent GPU jobs
686
+ ).launch(
687
+ share=True, # ← public shareable link
688
+ show_api=True, # ← API docs visible at /docs
689
  mcp_server=True,
690
  ssr_mode=False,
691
  show_error=True,