Rosyad Almas commited on
Commit
6ea230c
Β·
verified Β·
1 Parent(s): 9227950

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -2
app.py CHANGED
@@ -1,5 +1,26 @@
1
  import os
2
  import gc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
  import numpy as np
5
  import spaces
@@ -21,6 +42,7 @@ print("torch.__version__ =", torch.__version__)
21
  print("Using device:", device)
22
 
23
  from diffusers import FlowMatchEulerDiscreteScheduler
 
24
  from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
25
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
26
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
@@ -28,7 +50,7 @@ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
28
  dtype = torch.bfloat16
29
 
30
  pipe = QwenImageEditPlusPipeline.from_pretrained(
31
- "Qwen/Qwen-Image-Edit-2511",
32
  transformer=QwenImageTransformer2DModel.from_pretrained(
33
  "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19",
34
  torch_dtype=dtype,
@@ -43,6 +65,92 @@ try:
43
  except Exception as e:
44
  print(f"Warning: Could not set FA3 processor: {e}")
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  EXAMPLES_CONFIG = [
47
  {
48
  "images": ["examples/1.jpg"],
@@ -195,11 +303,21 @@ def infer(images_b64_json, prompt, seed, randomize_seed, guidance_scale, steps,
195
  negative_prompt = "worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
196
  width, height = update_dimensions_on_upload(pil_images[0])
197
  try:
198
- result_image = pipe(
199
  image=pil_images, prompt=prompt, negative_prompt=negative_prompt,
200
  height=height, width=width, num_inference_steps=steps,
201
  generator=generator, true_cfg_scale=guidance_scale,
202
  ).images[0]
 
 
 
 
 
 
 
 
 
 
203
  return result_image, seed
204
  except Exception as e:
205
  raise e
 
1
  import os
2
  import gc
3
+ import subprocess
4
+
5
+ # ── DLSS 5 patch for Flux2KleinKV ──
6
+ def _apply_dlss_patch():
7
+ try:
8
+ import diffusers
9
+ site_packages = os.path.dirname(diffusers.__file__)
10
+ patch_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "flux2_klein_kv.patch")
11
+ if os.path.exists(patch_file):
12
+ result = subprocess.run(
13
+ ["patch", "-p2", "--forward", "--batch"],
14
+ cwd=os.path.dirname(site_packages),
15
+ stdin=open(patch_file),
16
+ capture_output=True, text=True,
17
+ )
18
+ print("DLSS patch:", result.stdout or "already applied")
19
+ except Exception as e:
20
+ print(f"DLSS patch warning: {e}")
21
+
22
+ _apply_dlss_patch()
23
+
24
  import gradio as gr
25
  import numpy as np
26
  import spaces
 
42
  print("Using device:", device)
43
 
44
  from diffusers import FlowMatchEulerDiscreteScheduler
45
+ from PIL import ImageDraw, ImageFont
46
  from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
47
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
48
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
 
50
  dtype = torch.bfloat16
51
 
52
  pipe = QwenImageEditPlusPipeline.from_pretrained(
53
+ "FireRedTeam/FireRed-Image-Edit-1.1",
54
  transformer=QwenImageTransformer2DModel.from_pretrained(
55
  "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19",
56
  torch_dtype=dtype,
 
65
  except Exception as e:
66
  print(f"Warning: Could not set FA3 processor: {e}")
67
 
68
+ # ── DLSS 5 model (Flux2KleinKV) ──
69
+ _dlss_pipe = None
70
+ def _get_dlss_pipe():
71
+ global _dlss_pipe
72
+ if _dlss_pipe is None:
73
+ try:
74
+ from diffusers.pipelines.flux2.pipeline_flux2_klein_kv import Flux2KleinKVPipeline
75
+ HF_TOKEN = os.environ.get("HF_TOKEN")
76
+ print("Loading DLSS 5 (FLUX.2-klein-9b-kv)...")
77
+ _dlss_pipe = Flux2KleinKVPipeline.from_pretrained(
78
+ "black-forest-labs/FLUX.2-klein-9b-kv",
79
+ torch_dtype=torch.bfloat16,
80
+ token=HF_TOKEN,
81
+ ).to(device)
82
+ print("DLSS 5 model loaded.")
83
+ except Exception as e:
84
+ print(f"DLSS model load failed: {e}")
85
+ _dlss_pipe = None
86
+ return _dlss_pipe
87
+
88
+ # ── DLSS font helper ──
89
+ _DLSS_FONT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "Inter-Bold.ttf")
90
+ def _get_dlss_font(size):
91
+ try:
92
+ return ImageFont.truetype(_DLSS_FONT_PATH, size)
93
+ except Exception:
94
+ return ImageFont.load_default(size=size)
95
+
96
+ def _dlss_comparison(original: "Image.Image", enhanced: "Image.Image") -> "Image.Image":
97
+ """Side-by-side comparison image with DLSS 5 Off/On labels."""
98
+ w, h = original.size
99
+ enhanced = enhanced.resize((w, h), LANCZOS)
100
+ canvas = Image.new("RGB", (w * 2, h))
101
+ canvas.paste(original, (0, 0))
102
+ canvas.paste(enhanced, (w, 0))
103
+ overlay = Image.new("RGBA", (w * 2, h), (0, 0, 0, 0))
104
+ draw = ImageDraw.Draw(overlay)
105
+ font_size = max(16, int(h * 0.06))
106
+ font = _get_dlss_font(font_size)
107
+ pad_x = int(font_size * 1.0)
108
+ pad_y = int(font_size * 0.55)
109
+ margin_bottom = int(h * 0.06)
110
+ for label, cx, dark, green_bar in [
111
+ ("Before Edit", w // 2, True, False),
112
+ ("DLSS 5 On", w + w // 2, False, True),
113
+ ]:
114
+ bbox = font.getbbox(label)
115
+ tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
116
+ lw, lh = tw + 2 * pad_x, th + 2 * pad_y
117
+ gh = max(4, int(lh * 0.13)) if green_bar else 0
118
+ x = cx - lw // 2
119
+ y = h - margin_bottom - lh - gh
120
+ if dark:
121
+ draw.rectangle([x, y, x+lw, y+lh], fill=(10,10,10,225), outline=(75,75,75,255), width=1)
122
+ draw.text((x + lw//2, y + lh//2), label, fill=(255,255,255,255), font=font, anchor="mm")
123
+ else:
124
+ draw.rectangle([x, y, x+lw, y+lh], fill=(255,255,255,255), outline=(190,190,190,255), width=1)
125
+ draw.text((x + lw//2, y + lh//2), label, fill=(0,0,0,255), font=font, anchor="mm")
126
+ draw.rectangle([x, y+lh, x+lw, y+lh+gh], fill=(118,185,0,255))
127
+ canvas = Image.alpha_composite(canvas.convert("RGBA"), overlay)
128
+ return canvas.convert("RGB")
129
+
130
+ def _run_dlss(image: "Image.Image", seed: int) -> "Image.Image":
131
+ """Run DLSS 5 upscale/enhance on an image. Returns enhanced image (same size)."""
132
+ dlss = _get_dlss_pipe()
133
+ if dlss is None:
134
+ return image # graceful fallback: return original if model didn't load
135
+ iw, ih = image.size
136
+ ar = iw / ih
137
+ if ar >= 1:
138
+ width = 1024; height = round(1024 / ar / 8) * 8
139
+ else:
140
+ height = 1024; width = round(1024 * ar / 8) * 8
141
+ width = max(256, min(1024, width))
142
+ height = max(256, min(1024, height))
143
+ generator = torch.Generator(device=device).manual_seed(seed)
144
+ result = dlss(
145
+ prompt="make it more realistic",
146
+ image=[image],
147
+ height=height,
148
+ width=width,
149
+ num_inference_steps=4,
150
+ generator=generator,
151
+ ).images[0]
152
+ return result
153
+
154
  EXAMPLES_CONFIG = [
155
  {
156
  "images": ["examples/1.jpg"],
 
303
  negative_prompt = "worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
304
  width, height = update_dimensions_on_upload(pil_images[0])
305
  try:
306
+ edited = pipe(
307
  image=pil_images, prompt=prompt, negative_prompt=negative_prompt,
308
  height=height, width=width, num_inference_steps=steps,
309
  generator=generator, true_cfg_scale=guidance_scale,
310
  ).images[0]
311
+
312
+ # Auto-DLSS: upscale/enhance the edited result
313
+ dlss_seed = seed + 1 if not randomize_seed else seed
314
+ try:
315
+ dlss_enhanced = _run_dlss(edited, dlss_seed)
316
+ result_image = _dlss_comparison(edited, dlss_enhanced)
317
+ except Exception as dlss_err:
318
+ print(f"DLSS step skipped: {dlss_err}")
319
+ result_image = edited # fallback to plain edited image
320
+
321
  return result_image, seed
322
  except Exception as e:
323
  raise e