prithivMLmods commited on
Commit
5946916
·
verified ·
1 Parent(s): eab7268

add app [fast_inferences] ✅

Browse files
Files changed (1) hide show
  1. app.py +266 -0
app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import gradio as gr
4
+ import numpy as np
5
+ import spaces
6
+ import torch
7
+ import random
8
+ from PIL import Image
9
+ from typing import Iterable
10
+ from gradio.themes import Soft
11
+ from gradio.themes.utils import colors, fonts, sizes
12
+
13
+ colors.orange_red = colors.Color(
14
+ name="orange_red",
15
+ c50="#FFF0E5",
16
+ c100="#FFE0CC",
17
+ c200="#FFC299",
18
+ c300="#FFA366",
19
+ c400="#FF8533",
20
+ c500="#FF4500",
21
+ c600="#E63E00",
22
+ c700="#CC3700",
23
+ c800="#B33000",
24
+ c900="#992900",
25
+ c950="#802200",
26
+ )
27
+
28
+ class OrangeRedTheme(Soft):
29
+ def __init__(
30
+ self,
31
+ *,
32
+ primary_hue: colors.Color | str = colors.gray,
33
+ secondary_hue: colors.Color | str = colors.orange_red,
34
+ neutral_hue: colors.Color | str = colors.slate,
35
+ text_size: sizes.Size | str = sizes.text_lg,
36
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
37
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
38
+ ),
39
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
40
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
41
+ ),
42
+ ):
43
+ super().__init__(
44
+ primary_hue=primary_hue,
45
+ secondary_hue=secondary_hue,
46
+ neutral_hue=neutral_hue,
47
+ text_size=text_size,
48
+ font=font,
49
+ font_mono=font_mono,
50
+ )
51
+ super().set(
52
+ background_fill_primary="*primary_50",
53
+ background_fill_primary_dark="*primary_900",
54
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
55
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
56
+ button_primary_text_color="white",
57
+ button_primary_text_color_hover="white",
58
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
59
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
60
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
61
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
62
+ button_secondary_text_color="black",
63
+ button_secondary_text_color_hover="white",
64
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
65
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
66
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
67
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
68
+ slider_color="*secondary_500",
69
+ slider_color_dark="*secondary_600",
70
+ block_title_text_weight="600",
71
+ block_border_width="3px",
72
+ block_shadow="*shadow_drop_lg",
73
+ button_primary_shadow="*shadow_drop_lg",
74
+ button_large_padding="11px",
75
+ color_accent_soft="*primary_100",
76
+ block_label_background_fill="*primary_200",
77
+ )
78
+
79
+ orange_red_theme = OrangeRedTheme()
80
+
81
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
+
83
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
84
+ print("torch.__version__ =", torch.__version__)
85
+ print("Using device:", device)
86
+
87
+ from diffusers import FlowMatchEulerDiscreteScheduler
88
+ from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
89
+ from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
90
+ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
91
+
92
+ dtype = torch.bfloat16
93
+
94
+ pipe = QwenImageEditPlusPipeline.from_pretrained(
95
+ "FireRedTeam/FireRed-Image-Edit-1.0",
96
+ transformer=QwenImageTransformer2DModel.from_pretrained(
97
+ "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19",
98
+ torch_dtype=dtype,
99
+ device_map='cuda'
100
+ ),
101
+ torch_dtype=dtype
102
+ ).to(device)
103
+
104
+ try:
105
+ pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
106
+ print("Flash Attention 3 Processor set successfully.")
107
+ except Exception as e:
108
+ print(f"Warning: Could not set FA3 processor: {e}")
109
+
110
+ MAX_SEED = np.iinfo(np.int32).max
111
+
112
+ def update_dimensions_on_upload(image):
113
+ if image is None:
114
+ return 1024, 1024
115
+
116
+ original_width, original_height = image.size
117
+
118
+ if original_width > original_height:
119
+ new_width = 1024
120
+ aspect_ratio = original_height / original_width
121
+ new_height = int(new_width * aspect_ratio)
122
+ else:
123
+ new_height = 1024
124
+ aspect_ratio = original_width / original_height
125
+ new_width = int(new_height * aspect_ratio)
126
+
127
+ new_width = (new_width // 8) * 8
128
+ new_height = (new_height // 8) * 8
129
+
130
+ return new_width, new_height
131
+
132
+ @spaces.GPU
133
+ def infer(
134
+ images,
135
+ prompt,
136
+ seed,
137
+ randomize_seed,
138
+ guidance_scale,
139
+ steps,
140
+ progress=gr.Progress(track_tqdm=True)
141
+ ):
142
+ gc.collect()
143
+ torch.cuda.empty_cache()
144
+
145
+ if not images:
146
+ raise gr.Error("Please upload at least one image to edit.")
147
+
148
+ pil_images = []
149
+ if images is not None:
150
+ for item in images:
151
+ try:
152
+ if isinstance(item, tuple) or isinstance(item, list):
153
+ path_or_img = item[0]
154
+ else:
155
+ path_or_img = item
156
+
157
+ if isinstance(path_or_img, str):
158
+ pil_images.append(Image.open(path_or_img).convert("RGB"))
159
+ elif isinstance(path_or_img, Image.Image):
160
+ pil_images.append(path_or_img.convert("RGB"))
161
+ else:
162
+ pil_images.append(Image.open(path_or_img.name).convert("RGB"))
163
+ except Exception as e:
164
+ print(f"Skipping invalid image item: {e}")
165
+ continue
166
+
167
+ if not pil_images:
168
+ raise gr.Error("Could not process uploaded images.")
169
+
170
+ if randomize_seed:
171
+ seed = random.randint(0, MAX_SEED)
172
+
173
+ generator = torch.Generator(device=device).manual_seed(seed)
174
+ negative_prompt = "worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
175
+
176
+ width, height = update_dimensions_on_upload(pil_images[0])
177
+
178
+ try:
179
+ result_image = pipe(
180
+ image=pil_images,
181
+ prompt=prompt,
182
+ negative_prompt=negative_prompt,
183
+ height=height,
184
+ width=width,
185
+ num_inference_steps=steps,
186
+ generator=generator,
187
+ true_cfg_scale=guidance_scale,
188
+ ).images[0]
189
+
190
+ return result_image, seed
191
+
192
+ except Exception as e:
193
+ raise e
194
+ finally:
195
+ gc.collect()
196
+ torch.cuda.empty_cache()
197
+
198
+ @spaces.GPU
199
+ def infer_example(images, prompt):
200
+ if not images:
201
+ return None, 0
202
+
203
+ if isinstance(images, str):
204
+ images_list = [images]
205
+ else:
206
+ images_list = images
207
+
208
+ result, seed = infer(
209
+ images=images_list,
210
+ prompt=prompt,
211
+ seed=0,
212
+ randomize_seed=True,
213
+ guidance_scale=1.0,
214
+ steps=4
215
+ )
216
+ return result, seed
217
+
218
+ css = """
219
+ #col-container {
220
+ margin: 0 auto;
221
+ max-width: 1000px;
222
+ }
223
+ #main-title h1 {font-size: 2.4em !important;}
224
+ """
225
+
226
+ with gr.Blocks() as demo:
227
+ with gr.Column(elem_id="col-container"):
228
+ gr.Markdown("# **Qwen-Image-Edit-2511 Fast**", elem_id="main-title")
229
+ gr.Markdown("Perform image edits using [Qwen-Image-Edit-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) with 4-step fast inference.")
230
+
231
+ with gr.Row(equal_height=True):
232
+ with gr.Column():
233
+ images = gr.Gallery(
234
+ label="Upload Images",
235
+ type="filepath",
236
+ columns=2,
237
+ rows=1,
238
+ height=300,
239
+ allow_preview=True
240
+ )
241
+
242
+ prompt = gr.Text(
243
+ label="Edit Prompt",
244
+ show_label=True,
245
+ placeholder="e.g., transform into anime, upscale, change lighting...",
246
+ )
247
+
248
+ run_button = gr.Button("Edit Image", variant="primary")
249
+
250
+ with gr.Column():
251
+ output_image = gr.Image(label="Output Image", interactive=False, format="png", height=365)
252
+
253
+ with gr.Accordion("Advanced Settings", open=False):
254
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
255
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
256
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
257
+ steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=4)
258
+
259
+ run_button.click(
260
+ fn=infer,
261
+ inputs=[images, prompt, seed, randomize_seed, guidance_scale, steps],
262
+ outputs=[output_image, seed]
263
+ )
264
+
265
+ if __name__ == "__main__":
266
+ demo.queue(max_size=30).launch(css=css, theme=orange_red_theme, mcp_server=True, ssr_mode=False, show_error=True)