IFMedTechdemo commited on
Commit
a5653b2
·
verified ·
1 Parent(s): d11a0ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -158
app.py CHANGED
@@ -1,179 +1,272 @@
1
  import gradio as gr
 
 
2
  import torch
3
- import os
4
  from PIL import Image
5
- from spaces import GPU
6
- from diffusers import QwenImageEditPipeline
7
- from diffusers.utils import load_image
 
 
 
8
 
9
- # Model configuration
10
- MODEL_ID = "Qwen/Qwen-Image-Edit"
 
11
 
12
- # Global pipeline variable
13
- pipeline = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def load_pipeline():
16
- """Load the Qwen Image Edit pipeline"""
17
- global pipeline
18
-
19
- try:
20
- pipeline = QwenImageEditPipeline.from_pretrained(
21
- MODEL_ID,
22
- torch_dtype=torch.bfloat16,
23
- device_map="auto",
24
- trust_remote_code=True
25
- )
26
- print(" Pipeline loaded successfully")
27
- except Exception as e:
28
- print(f"Error loading pipeline: {e}")
29
- raise
30
-
31
- @GPU
32
- def edit_image(image, prompt, negative_prompt="", num_steps=8, guidance_scale=2.5, seed=0):
33
- """Edit image based on text prompt using ZeroGPU"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- global pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- try:
38
- # Load pipeline on first call
39
- if pipeline is None:
40
- print("Loading Qwen Image Edit pipeline...")
41
- load_pipeline()
42
-
43
- # Ensure image is RGB and proper size
44
- if isinstance(image, str):
45
- img = load_image(image).convert("RGB")
46
- else:
47
- img = image.convert("RGB")
48
-
49
- # Resize to optimal size (model supports up to 1024x1024)
50
- max_size = 768
51
- img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
52
-
53
- # Set random seed for reproducibility
54
- generator = torch.manual_seed(seed)
55
-
56
- # Prepare inputs
57
- inputs = {
58
- "image": img,
59
- "prompt": prompt,
60
- "generator": generator,
61
- "guidance_scale": guidance_scale,
62
- "num_inference_steps": num_steps,
63
- }
64
-
65
- # Add negative prompt if provided
66
- if negative_prompt.strip():
67
- inputs["negative_prompt"] = negative_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Run inference
70
- with torch.no_grad():
71
- output = pipeline(**inputs)
72
- output_image = output.images[0]
73
 
74
- return output_image
75
-
76
- except Exception as e:
77
- print(f"Error during inference: {e}")
78
- raise
79
-
80
- # Create Gradio interface
81
- def create_interface():
82
- with gr.Blocks(title="Qwen Image Edit - ZeroGPU", theme=gr.themes.Soft()) as demo:
83
- gr.Markdown("# 🎨 Qwen Image Edit with ZeroGPU")
84
- gr.Markdown(
85
- "Edit images using natural language prompts powered by Qwen Image Edit on ZeroGPU."
86
- )
87
 
88
  with gr.Row():
89
  with gr.Column():
90
- # Input image
91
- input_image = gr.Image(
92
- label="Upload Image",
93
- type="pil",
94
- height=400
95
  )
96
-
97
- # Edit prompt
98
- prompt = gr.Textbox(
99
- label="Edit Instructions",
100
- placeholder="e.g., 'Add a red hat to the person' or 'Change background to sunset'",
101
- lines=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  )
103
 
104
- # Advanced options
105
- with gr.Accordion("Advanced Options", open=False):
106
- negative_prompt = gr.Textbox(
107
- label="Negative Prompt (Optional)",
108
- placeholder="What to avoid",
109
- lines=2
110
- )
111
-
112
- num_steps = gr.Slider(
113
- label="Inference Steps",
114
- minimum=4,
115
- maximum=20,
116
- value=8,
117
- step=1,
118
- info="More steps = better quality but slower"
119
- )
120
-
121
- guidance_scale = gr.Slider(
122
- label="Guidance Scale",
123
- minimum=1.0,
124
- maximum=7.5,
125
- value=2.5,
126
- step=0.1,
127
- info="Higher = stronger prompt adherence"
128
- )
129
-
130
- seed = gr.Slider(
131
- label="Seed",
132
- minimum=0,
133
- maximum=2147483647,
134
- value=0,
135
- step=1,
136
- info="For reproducible results"
137
- )
138
-
139
- # Submit button
140
- submit_btn = gr.Button(
141
- "🚀 Edit Image",
142
- variant="primary",
143
- scale=1
144
  )
145
-
146
- with gr.Column():
147
- # Output image
148
- output_image = gr.Image(
149
- label="Edited Image",
150
- type="pil",
151
- height=400
152
  )
153
-
154
- # Connect the function
155
- submit_btn.click(
156
- fn=edit_image,
157
- inputs=[input_image, prompt, negative_prompt, num_steps, guidance_scale, seed],
158
- outputs=output_image,
159
- show_progress=True
160
- )
161
-
162
- gr.Markdown("""
163
- ### 💡 Tips for best results:
164
- - Use clear, descriptive prompts
165
- - Start with 8 steps (Qwen Image Edit is optimized for low step counts)
166
- - Guidance scale 2.0-3.0 for subtle edits, 4.0-6.0 for stronger changes
167
- - Image editing takes ~30-60 seconds per inference on ZeroGPU
168
-
169
- ### ⚙️ Model Info:
170
- - **Model**: Qwen/Qwen-Image-Edit (20B parameters)
171
- - **Architecture**: Multi-modal Diffusion Transformer
172
- - **Input Resolution**: Up to 768x768 (optimized)
173
- """)
174
-
175
- return demo
176
 
177
  if __name__ == "__main__":
178
- demo = create_interface()
179
  demo.launch()
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import random
4
  import torch
5
+ import spaces
6
  from PIL import Image
7
+ from diffusers import FlowMatchEulerDiscreteScheduler
8
+ from optimization import optimize_pipeline_
9
+ from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
10
+ from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
11
+ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
12
+ import math
13
 
14
+ # --- Model Loading & Optimization ---
15
+ dtype = torch.bfloat16
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ # Scheduler configuration for Lightning
19
+ scheduler_config = {
20
+ "base_image_seq_len": 256,
21
+ "base_shift": math.log(3),
22
+ "invert_sigmas": False,
23
+ "max_image_seq_len": 8192,
24
+ "max_shift": math.log(3),
25
+ "num_train_timesteps": 1000,
26
+ "shift": 1.0,
27
+ "shift_terminal": None,
28
+ "stochastic_sampling": False,
29
+ "time_shift_type": "exponential",
30
+ "use_beta_sigmas": False,
31
+ "use_dynamic_shifting": True,
32
+ "use_exponential_sigmas": False,
33
+ "use_karras_sigmas": False,
34
+ }
35
 
36
+ # Initialize scheduler with Lightning config
37
+ scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
38
+
39
+ # Load the model pipeline
40
+ pipe = QwenImageEditPlusPipeline.from_pretrained(
41
+ "Qwen/Qwen-Image-Edit-2509",
42
+ scheduler=scheduler,
43
+ torch_dtype=dtype
44
+ ).to(device)
45
+
46
+ pipe.load_lora_weights(
47
+ "lightx2v/Qwen-Image-Lightning",
48
+ weight_name="Qwen-Image-Edit-2509/Qwen-Image-Edit-2509-Lightning-8steps-V1.0-bf16.safetensors"
49
+ )
50
+ pipe.fuse_lora()
51
+
52
+ # Apply optimizations
53
+ pipe.transformer.__class__ = QwenImageTransformer2DModel
54
+ pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
55
+
56
+ # Enable memory optimizations
57
+ pipe.enable_attention_slicing()
58
+
59
+ # Ahead-of-time compilation for faster subsequent runs
60
+ optimize_pipeline_(
61
+ pipe,
62
+ image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))],
63
+ prompt="remove acne marks and blemishes from the face"
64
+ )
65
+
66
+ # --- UI Constants ---
67
+ MAX_SEED = np.iinfo(np.int32).max
68
+
69
+ # Hardcoded prompt for acne removal
70
+ HARDCODED_PROMPT = "remove acne marks and blemishes from the face"
71
+ NEGATIVE_PROMPT = " "
72
+
73
+ # --- Main Inference Function (Optimized for Speed) ---
74
+ @spaces.GPU()
75
+ def infer(
76
+ images,
77
+ seed=42,
78
+ randomize_seed=False,
79
+ true_guidance_scale=1.0,
80
+ num_inference_steps=8,
81
+ height=1024,
82
+ width=1024,
83
+ progress=gr.Progress(track_tqdm=True),
84
+ ):
85
+ """
86
+ Optimized inference for acne removal with hardcoded prompt.
87
+ Removes prompt rewriting to save inference time.
88
+ """
89
+ if randomize_seed:
90
+ seed = random.randint(0, MAX_SEED)
91
+
92
+ # Set up generator for reproducibility
93
+ generator = torch.Generator(device=device).manual_seed(seed)
94
 
95
+ # Load and preprocess input images
96
+ pil_images = []
97
+ if images is not None:
98
+ for item in images:
99
+ try:
100
+ if isinstance(item[0], Image.Image):
101
+ img = item[0].convert("RGB")
102
+ # Resize to optimal inference size for speed
103
+ img.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
104
+ pil_images.append(img)
105
+ elif isinstance(item[0], str):
106
+ img = Image.open(item[0]).convert("RGB")
107
+ img.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
108
+ pil_images.append(img)
109
+ elif hasattr(item, "name"):
110
+ img = Image.open(item.name).convert("RGB")
111
+ img.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
112
+ pil_images.append(img)
113
+ except Exception as e:
114
+ print(f"Error loading image: {e}")
115
+ continue
116
+
117
+ print(f"Using hardcoded prompt: '{HARDCODED_PROMPT}'")
118
+ print(f"Seed: {seed}, Steps: {num_inference_steps}, Guidance: {true_guidance_scale}")
119
 
120
+ # Generate the image with optimized settings
121
+ with torch.inference_mode():
122
+ output = pipe(
123
+ image=pil_images if len(pil_images) > 0 else None,
124
+ prompt=HARDCODED_PROMPT,
125
+ height=height,
126
+ width=width,
127
+ negative_prompt=NEGATIVE_PROMPT,
128
+ num_inference_steps=num_inference_steps,
129
+ generator=generator,
130
+ true_cfg_scale=true_guidance_scale,
131
+ num_images_per_prompt=1,
132
+ ).images
133
+
134
+ return output, seed, gr.update(visible=True)
135
+
136
+
137
+ def use_output_as_input(output_images):
138
+ """Convert output images to input format for the gallery"""
139
+ if output_images is None or len(output_images) == 0:
140
+ return []
141
+ return output_images
142
+
143
+
144
+ # --- CSS Styling ---
145
+ css = """
146
+ #col-container {
147
+ margin: 0 auto;
148
+ max-width: 1024px;
149
+ }
150
+ #logo-title {
151
+ text-align: center;
152
+ }
153
+ #logo-title img {
154
+ width: 400px;
155
+ }
156
+ #edit_text {
157
+ margin-top: -62px !important;
158
+ }
159
+ """
160
+
161
+ # --- UI Layout ---
162
+ with gr.Blocks(css=css) as demo:
163
+ with gr.Column(elem_id="col-container"):
164
+ gr.HTML("""
165
+ <div id="logo-title">
166
+ <img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_edit_logo.png" alt="Qwen-Image Edit Logo" width="400" style="display: block; margin: 0 auto;">
167
+ <h2 style="font-style: italic;color: #5b47d1;margin-top: -27px !important;margin-left: 96px">[Acne Remover] Fast 8-step Lightning LoRA</h2>
168
+ </div>
169
+ """)
170
+ gr.Markdown("""
171
+ **Remove acne marks and blemishes** from facial images using Qwen-Image-Edit with Lightning LoRA optimization.
172
 
173
+ This demo uses [Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) with
174
+ [Qwen-Image-Lightning](https://huggingface.co/lightx2v/Qwen-Image-Lightning) + FA3 for ultra-fast inference.
 
 
175
 
176
+ [Learn more](https://github.com/QwenLM/Qwen-Image) | [Download model](https://huggingface.co/Qwen/Qwen-Image-Edit-2509)
177
+ """)
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  with gr.Row():
180
  with gr.Column():
181
+ input_images = gr.Gallery(
182
+ label="Upload facial image",
183
+ show_label=False,
184
+ type="pil",
185
+ interactive=True
186
  )
187
+
188
+ with gr.Column():
189
+ result = gr.Gallery(
190
+ label="Acne-removed result",
191
+ show_label=False,
192
+ type="pil"
193
+ )
194
+ use_output_btn = gr.Button(
195
+ "↗️ Use as input",
196
+ variant="secondary",
197
+ size="sm",
198
+ visible=False
199
+ )
200
+
201
+ with gr.Row():
202
+ run_button = gr.Button("Remove Acne!", variant="primary", size="lg")
203
+
204
+ with gr.Accordion("Advanced Settings", open=False):
205
+ seed = gr.Slider(
206
+ label="Seed",
207
+ minimum=0,
208
+ maximum=MAX_SEED,
209
+ step=1,
210
+ value=0,
211
+ )
212
+
213
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
214
+
215
+ with gr.Row():
216
+ true_guidance_scale = gr.Slider(
217
+ label="Guidance scale",
218
+ minimum=1.0,
219
+ maximum=10.0,
220
+ step=0.1,
221
+ value=1.0
222
+ )
223
+
224
+ num_inference_steps = gr.Slider(
225
+ label="Inference steps (fewer = faster)",
226
+ minimum=1,
227
+ maximum=40,
228
+ step=1,
229
+ value=8,
230
  )
231
 
232
+ with gr.Row():
233
+ height = gr.Slider(
234
+ label="Height",
235
+ minimum=512,
236
+ maximum=1024,
237
+ step=64,
238
+ value=1024,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  )
240
+
241
+ width = gr.Slider(
242
+ label="Width",
243
+ minimum=512,
244
+ maximum=1024,
245
+ step=64,
246
+ value=1024,
247
  )
248
+
249
+ # Event handlers
250
+ gr.on(
251
+ triggers=[run_button.click],
252
+ fn=infer,
253
+ inputs=[
254
+ input_images,
255
+ seed,
256
+ randomize_seed,
257
+ true_guidance_scale,
258
+ num_inference_steps,
259
+ height,
260
+ width,
261
+ ],
262
+ outputs=[result, seed, use_output_btn],
263
+ )
264
+
265
+ use_output_btn.click(
266
+ fn=use_output_as_input,
267
+ inputs=[result],
268
+ outputs=[input_images]
269
+ )
 
270
 
271
  if __name__ == "__main__":
 
272
  demo.launch()