Anand Gupta commited on
Commit
886392e
·
verified ·
1 Parent(s): 0d0fd57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +336 -125
app.py CHANGED
@@ -1,142 +1,353 @@
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
- #import spaces #[uncomment to use ZeroGPU]
5
- from diffusers import DiffusionPipeline
6
  import torch
 
 
 
7
 
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_repo_id = "stabilityai/sdxl-turbo" #Replace to the model you would like to use
10
 
11
- if torch.cuda.is_available():
12
- torch_dtype = torch.float16
13
- else:
14
- torch_dtype = torch.float32
15
-
16
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
17
- pipe = pipe.to(device)
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
- MAX_IMAGE_SIZE = 1024
21
-
22
- #@spaces.GPU #[uncomment to use ZeroGPU]
23
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
24
-
25
- if randomize_seed:
26
- seed = random.randint(0, MAX_SEED)
27
-
28
- generator = torch.Generator().manual_seed(seed)
29
-
30
- image = pipe(
31
- prompt = prompt,
32
- negative_prompt = negative_prompt,
33
- guidance_scale = guidance_scale,
34
- num_inference_steps = num_inference_steps,
35
- width = width,
36
- height = height,
37
- generator = generator
38
- ).images[0]
39
-
40
- return image, seed
41
-
42
- examples = [
43
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
44
- "An astronaut riding a green horse",
45
- "A delicious ceviche cheesecake slice",
 
 
 
 
 
 
 
 
 
 
 
 
46
  ]
47
 
48
- css="""
49
- #col-container {
50
- margin: 0 auto;
51
- max-width: 640px;
52
- }
53
- """
54
 
55
- with gr.Blocks(css=css) as demo:
56
-
57
- with gr.Column(elem_id="col-container"):
58
- gr.Markdown(f"""
59
- # Text-to-Image Gradio Template
60
- """)
61
-
62
- with gr.Row():
63
-
64
- prompt = gr.Text(
65
- label="Prompt",
66
- show_label=False,
67
- max_lines=1,
68
- placeholder="Enter your prompt",
69
- container=False,
70
- )
71
-
72
- run_button = gr.Button("Run", scale=0)
73
-
74
- result = gr.Image(label="Result", show_label=False)
75
-
76
- with gr.Accordion("Advanced Settings", open=False):
77
-
78
- negative_prompt = gr.Text(
79
- label="Negative prompt",
80
- max_lines=1,
81
- placeholder="Enter a negative prompt",
82
- visible=False,
83
- )
84
-
85
- seed = gr.Slider(
86
- label="Seed",
87
- minimum=0,
88
- maximum=MAX_SEED,
89
- step=1,
90
- value=0,
91
- )
92
-
93
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
94
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  with gr.Row():
96
-
97
- width = gr.Slider(
98
- label="Width",
99
- minimum=256,
100
- maximum=MAX_IMAGE_SIZE,
101
- step=32,
102
- value=1024, #Replace with defaults that work for your model
103
- )
104
-
105
- height = gr.Slider(
106
- label="Height",
107
- minimum=256,
108
- maximum=MAX_IMAGE_SIZE,
109
- step=32,
110
- value=1024, #Replace with defaults that work for your model
111
  )
112
-
113
- with gr.Row():
114
-
115
- guidance_scale = gr.Slider(
116
- label="Guidance scale",
117
- minimum=0.0,
118
- maximum=10.0,
119
- step=0.1,
120
- value=0.0, #Replace with defaults that work for your model
 
121
  )
122
-
123
- num_inference_steps = gr.Slider(
124
- label="Number of inference steps",
125
- minimum=1,
126
- maximum=50,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  step=1,
128
- value=2, #Replace with defaults that work for your model
129
  )
130
-
131
- gr.Examples(
132
- examples = examples,
133
- inputs = [prompt]
134
- )
135
- gr.on(
136
- triggers=[run_button.click, prompt.submit],
137
- fn = infer,
138
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
139
- outputs = [result, seed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  )
 
141
 
142
- demo.queue().launch()
 
1
+ from functools import partial
2
+
3
+ import cv2
4
+ import random
5
+ from typing import Tuple, Optional
6
+
7
  import gradio as gr
8
  import numpy as np
9
+ import requests
10
+ import spaces
 
11
  import torch
12
+ from PIL import Image, ImageFilter
13
+ from diffusers import FluxInpaintPipeline
14
+ from gradio_client import Client, handle_file
15
 
16
+ MARKDOWN = """
17
+ # FLUX.1 Inpainting 🔥
18
 
19
+ Shoutout to [Black Forest Labs](https://huggingface.co/black-forest-labs) team for
20
+ creating this amazing model, and a big thanks to [Gothos](https://github.com/Gothos)
21
+ for taking it to the next level by enabling inpainting with the FLUX.
22
+ """
 
 
 
23
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
+ IMAGE_SIZE = 1024
26
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+ PIPE = FluxInpaintPipeline.from_pretrained(
28
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
29
+ CLIENT = Client("SkalskiP/florence-sam-masking")
30
+
31
+
32
+ EXAMPLES = [
33
+ [
34
+ {
35
+ "background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
36
+ "layers": [Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-2-removebg.png", stream=True).raw)],
37
+ "composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-2.png", stream=True).raw),
38
+ },
39
+ "little lion",
40
+ "",
41
+ 5,
42
+ 5,
43
+ 42,
44
+ False,
45
+ 0.85,
46
+ 20
47
+ ],
48
+ [
49
+ {
50
+ "background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-5.jpeg", stream=True).raw),
51
+ "layers": None,
52
+ "composite": None
53
+ },
54
+ "big blue eyes",
55
+ "eyes",
56
+ 10,
57
+ 5,
58
+ 42,
59
+ False,
60
+ 0.9,
61
+ 20
62
+ ]
63
  ]
64
 
 
 
 
 
 
 
65
 
66
+ def calculate_image_dimensions_for_flux(
67
+ original_resolution_wh: Tuple[int, int],
68
+ maximum_dimension: int = IMAGE_SIZE
69
+ ) -> Tuple[int, int]:
70
+ width, height = original_resolution_wh
71
+
72
+ if width > height:
73
+ scaling_factor = maximum_dimension / width
74
+ else:
75
+ scaling_factor = maximum_dimension / height
76
+
77
+ new_width = int(width * scaling_factor)
78
+ new_height = int(height * scaling_factor)
79
+
80
+ new_width = new_width - (new_width % 32)
81
+ new_height = new_height - (new_height % 32)
82
+
83
+ return new_width, new_height
84
+
85
+
86
+ def is_mask_empty(image: Image.Image) -> bool:
87
+ gray_img = image.convert("L")
88
+ pixels = list(gray_img.getdata())
89
+ return all(pixel == 0 for pixel in pixels)
90
+
91
+
92
+ def process_mask(
93
+ mask: Image.Image,
94
+ mask_inflation: Optional[int] = None,
95
+ mask_blur: Optional[int] = None
96
+ ) -> Image.Image:
97
+ """
98
+ Inflates and blurs the white regions of a mask.
99
+
100
+ Args:
101
+ mask (Image.Image): The input mask image.
102
+ mask_inflation (Optional[int]): The number of pixels to inflate the mask by.
103
+ mask_blur (Optional[int]): The radius of the Gaussian blur to apply.
104
+
105
+ Returns:
106
+ Image.Image: The processed mask with inflated and/or blurred regions.
107
+ """
108
+ if mask_inflation and mask_inflation > 0:
109
+ mask_array = np.array(mask)
110
+ kernel = np.ones((mask_inflation, mask_inflation), np.uint8)
111
+ mask_array = cv2.dilate(mask_array, kernel, iterations=1)
112
+ mask = Image.fromarray(mask_array)
113
+
114
+ if mask_blur and mask_blur > 0:
115
+ mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur))
116
+
117
+ return mask
118
+
119
+
120
+ def set_client_for_session(request: gr.Request):
121
+ try:
122
+ x_ip_token = request.headers['x-ip-token']
123
+ return Client("SkalskiP/florence-sam-masking", headers={"X-IP-Token": x_ip_token})
124
+ except:
125
+ return CLIENT
126
+
127
+
128
+ @spaces.GPU(duration=50)
129
+ def run_flux(
130
+ image: Image.Image,
131
+ mask: Image.Image,
132
+ prompt: str,
133
+ seed_slicer: int,
134
+ randomize_seed_checkbox: bool,
135
+ strength_slider: float,
136
+ num_inference_steps_slider: int,
137
+ resolution_wh: Tuple[int, int],
138
+ ) -> Image.Image:
139
+ print("Running FLUX...")
140
+ width, height = resolution_wh
141
+ if randomize_seed_checkbox:
142
+ seed_slicer = random.randint(0, MAX_SEED)
143
+ generator = torch.Generator().manual_seed(seed_slicer)
144
+ return PIPE(
145
+ prompt=prompt,
146
+ image=image,
147
+ mask_image=mask,
148
+ width=width,
149
+ height=height,
150
+ strength=strength_slider,
151
+ generator=generator,
152
+ num_inference_steps=num_inference_steps_slider
153
+ ).images[0]
154
+
155
+
156
+ def process(
157
+ client,
158
+ input_image_editor: dict,
159
+ inpainting_prompt_text: str,
160
+ masking_prompt_text: str,
161
+ mask_inflation_slider: int,
162
+ mask_blur_slider: int,
163
+ seed_slicer: int,
164
+ randomize_seed_checkbox: bool,
165
+ strength_slider: float,
166
+ num_inference_steps_slider: int
167
+ ):
168
+ if not inpainting_prompt_text:
169
+ gr.Info("Please enter inpainting text prompt.")
170
+ return None, None
171
+
172
+ image_path = input_image_editor['background']
173
+ mask_path = input_image_editor['layers'][0]
174
+
175
+ image = Image.open(image_path)
176
+ mask = Image.open(mask_path)
177
+
178
+ if not image:
179
+ gr.Info("Please upload an image.")
180
+ return None, None
181
+
182
+ if is_mask_empty(mask) and not masking_prompt_text:
183
+ gr.Info("Please draw a mask or enter a masking prompt.")
184
+ return None, None
185
+
186
+ if not is_mask_empty(mask) and masking_prompt_text:
187
+ gr.Info("Both mask and masking prompt are provided. Please provide only one.")
188
+ return None, None
189
+
190
+ if is_mask_empty(mask):
191
+ print("Generating mask...")
192
+ mask = client.predict(
193
+ image_input=handle_file(image_path),
194
+ text_input=masking_prompt_text,
195
+ api_name="/process_image")
196
+ mask = Image.open(mask)
197
+ print("Mask generated.")
198
+
199
+ width, height = calculate_image_dimensions_for_flux(original_resolution_wh=image.size)
200
+ image = image.resize((width, height), Image.LANCZOS)
201
+ mask = mask.resize((width, height), Image.LANCZOS)
202
+ mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
203
+ image = run_flux(
204
+ image=image,
205
+ mask=mask,
206
+ prompt=inpainting_prompt_text,
207
+ seed_slicer=seed_slicer,
208
+ randomize_seed_checkbox=randomize_seed_checkbox,
209
+ strength_slider=strength_slider,
210
+ num_inference_steps_slider=num_inference_steps_slider,
211
+ resolution_wh=(width, height)
212
+ )
213
+ return image, mask
214
+
215
+
216
+ process_example = partial(process, client=CLIENT)
217
+
218
+
219
+ with gr.Blocks() as demo:
220
+ client_component = gr.State()
221
+ gr.Markdown(MARKDOWN)
222
+ with gr.Row():
223
+ with gr.Column():
224
+ input_image_editor_component = gr.ImageEditor(
225
+ label='Image',
226
+ type='filepath',
227
+ sources=["upload", "webcam"],
228
+ image_mode='RGB',
229
+ layers=False,
230
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
231
+
232
  with gr.Row():
233
+ inpainting_prompt_text_component = gr.Text(
234
+ label="Inpainting prompt",
235
+ show_label=False,
236
+ max_lines=1,
237
+ placeholder="Enter text to generate inpainting",
238
+ container=False,
 
 
 
 
 
 
 
 
 
239
  )
240
+ submit_button_component = gr.Button(
241
+ value='Submit', variant='primary', scale=0)
242
+
243
+ with gr.Accordion("Advanced Settings", open=False):
244
+ masking_prompt_text_component = gr.Text(
245
+ label="Masking prompt",
246
+ show_label=False,
247
+ max_lines=1,
248
+ placeholder="Enter text to generate masking",
249
+ container=False,
250
  )
251
+
252
+ with gr.Row():
253
+ mask_inflation_slider_component = gr.Slider(
254
+ label="Mask inflation",
255
+ info="Adjusts the amount of mask edge expansion before "
256
+ "inpainting.",
257
+ minimum=0,
258
+ maximum=20,
259
+ step=1,
260
+ value=5,
261
+ )
262
+
263
+ mask_blur_slider_component = gr.Slider(
264
+ label="Mask blur",
265
+ info="Controls the intensity of the Gaussian blur applied to "
266
+ "the mask edges.",
267
+ minimum=0,
268
+ maximum=20,
269
+ step=1,
270
+ value=5,
271
+ )
272
+
273
+ seed_slicer_component = gr.Slider(
274
+ label="Seed",
275
+ minimum=0,
276
+ maximum=MAX_SEED,
277
  step=1,
278
+ value=42,
279
  )
280
+
281
+ randomize_seed_checkbox_component = gr.Checkbox(
282
+ label="Randomize seed", value=True)
283
+
284
+ with gr.Row():
285
+ strength_slider_component = gr.Slider(
286
+ label="Strength",
287
+ info="Indicates extent to transform the reference `image`. "
288
+ "Must be between 0 and 1. `image` is used as a starting "
289
+ "point and more noise is added the higher the `strength`.",
290
+ minimum=0,
291
+ maximum=1,
292
+ step=0.01,
293
+ value=0.85,
294
+ )
295
+
296
+ num_inference_steps_slider_component = gr.Slider(
297
+ label="Number of inference steps",
298
+ info="The number of denoising steps. More denoising steps "
299
+ "usually lead to a higher quality image at the",
300
+ minimum=1,
301
+ maximum=50,
302
+ step=1,
303
+ value=20,
304
+ )
305
+ with gr.Column():
306
+ output_image_component = gr.Image(
307
+ type='pil', image_mode='RGB', label='Generated image', format="png")
308
+ with gr.Accordion("Debug", open=False):
309
+ output_mask_component = gr.Image(
310
+ type='pil', image_mode='RGB', label='Input mask', format="png")
311
+ gr.Examples(
312
+ fn=process_example,
313
+ examples=EXAMPLES,
314
+ inputs=[
315
+ input_image_editor_component,
316
+ inpainting_prompt_text_component,
317
+ masking_prompt_text_component,
318
+ mask_inflation_slider_component,
319
+ mask_blur_slider_component,
320
+ seed_slicer_component,
321
+ randomize_seed_checkbox_component,
322
+ strength_slider_component,
323
+ num_inference_steps_slider_component
324
+ ],
325
+ outputs=[
326
+ output_image_component,
327
+ output_mask_component
328
+ ],
329
+ run_on_click=False
330
+ )
331
+
332
+ submit_button_component.click(
333
+ fn=process,
334
+ inputs=[
335
+ client_component,
336
+ input_image_editor_component,
337
+ inpainting_prompt_text_component,
338
+ masking_prompt_text_component,
339
+ mask_inflation_slider_component,
340
+ mask_blur_slider_component,
341
+ seed_slicer_component,
342
+ randomize_seed_checkbox_component,
343
+ strength_slider_component,
344
+ num_inference_steps_slider_component
345
+ ],
346
+ outputs=[
347
+ output_image_component,
348
+ output_mask_component
349
+ ]
350
  )
351
+ demo.load(set_client_for_session, None, client_component)
352
 
353
+ demo.launch(debug=False, show_error=True)