Gemini899 commited on
Commit
a5f87e0
·
verified ·
1 Parent(s): f936547

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -235
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import spaces
3
  import torch
@@ -12,6 +13,15 @@ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
12
  from PIL import Image, ImageDraw
13
  import numpy as np
14
 
 
 
 
 
 
 
 
 
 
15
  config_file = hf_hub_download(
16
  "xinsir/controlnet-union-sdxl-1.0",
17
  filename="config_promax.json",
@@ -20,26 +30,19 @@ config_file = hf_hub_download(
20
  config = ControlNetModel_Union.load_config(config_file)
21
  controlnet_model = ControlNetModel_Union.from_config(config)
22
 
23
- # Load the state dictionary
24
  model_file = hf_hub_download(
25
  "xinsir/controlnet-union-sdxl-1.0",
26
  filename="diffusion_pytorch_model_promax.safetensors",
27
  )
28
  state_dict = load_state_dict(model_file)
29
-
30
- # Extract the keys from the state_dict
31
  loaded_keys = list(state_dict.keys())
32
 
33
- # Call the method and store all returns in a variable
34
  result = ControlNetModel_Union._load_pretrained_model(
35
  controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0", loaded_keys
36
  )
37
-
38
- # Use the first element from the result
39
  model = result[0]
40
  model = model.to(device="cuda", dtype=torch.float16)
41
 
42
-
43
  vae = AutoencoderKL.from_pretrained(
44
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
45
  ).to("cuda")
@@ -55,8 +58,10 @@ pipe = StableDiffusionXLFillPipeline.from_pretrained(
55
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
56
 
57
 
 
 
 
58
  def can_expand(source_width, source_height, target_width, target_height, alignment):
59
- """Checks if the image can be expanded based on the alignment."""
60
  if alignment in ("Left", "Right") and source_width >= target_width:
61
  return False
62
  if alignment in ("Top", "Bottom") and source_height >= target_height:
@@ -66,15 +71,13 @@ def can_expand(source_width, source_height, target_width, target_height, alignme
66
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
67
  target_size = (width, height)
68
 
69
- # Calculate the scaling factor to fit the image within the target size
70
  scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
71
  new_width = int(image.width * scale_factor)
72
  new_height = int(image.height * scale_factor)
73
-
74
- # Resize the source image to fit within target size
75
  source = image.resize((new_width, new_height), Image.LANCZOS)
76
 
77
- # Apply resize option using percentages
78
  if resize_option == "Full":
79
  resize_percentage = 100
80
  elif resize_option == "50%":
@@ -83,66 +86,44 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
83
  resize_percentage = 33
84
  elif resize_option == "25%":
85
  resize_percentage = 25
86
- else: # Custom
87
  resize_percentage = custom_resize_percentage
88
 
89
- # Calculate new dimensions based on percentage
90
- resize_factor = resize_percentage / 100
91
- new_width = int(source.width * resize_factor)
92
- new_height = int(source.height * resize_factor)
93
-
94
- # Ensure minimum size of 64 pixels
95
- new_width = max(new_width, 64)
96
- new_height = max(new_height, 64)
97
-
98
- # Resize the image
99
  source = source.resize((new_width, new_height), Image.LANCZOS)
100
 
101
- # Calculate the overlap in pixels based on the percentage
102
- overlap_x = int(new_width * (overlap_percentage / 100))
103
- overlap_y = int(new_height * (overlap_percentage / 100))
104
 
105
- # Ensure minimum overlap of 1 pixel
106
- overlap_x = max(overlap_x, 1)
107
- overlap_y = max(overlap_y, 1)
108
-
109
- # Calculate margins based on alignment
110
  if alignment == "Middle":
111
  margin_x = (target_size[0] - new_width) // 2
112
  margin_y = (target_size[1] - new_height) // 2
113
  elif alignment == "Left":
114
- margin_x = 0
115
- margin_y = (target_size[1] - new_height) // 2
116
  elif alignment == "Right":
117
- margin_x = target_size[0] - new_width
118
- margin_y = (target_size[1] - new_height) // 2
119
  elif alignment == "Top":
120
- margin_x = (target_size[0] - new_width) // 2
121
- margin_y = 0
122
  elif alignment == "Bottom":
123
- margin_x = (target_size[0] - new_width) // 2
124
- margin_y = target_size[1] - new_height
125
 
126
- # Adjust margins to eliminate gaps
127
  margin_x = max(0, min(margin_x, target_size[0] - new_width))
128
  margin_y = max(0, min(margin_y, target_size[1] - new_height))
129
 
130
- # Create a new background image and paste the resized source image
131
  background = Image.new('RGB', target_size, (255, 255, 255))
132
  background.paste(source, (margin_x, margin_y))
133
 
134
- # Create the mask
135
  mask = Image.new('L', target_size, 255)
136
  mask_draw = ImageDraw.Draw(mask)
137
 
138
- # Calculate overlap areas
139
  white_gaps_patch = 2
140
-
141
  left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
142
  right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
143
  top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
144
  bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
145
-
146
  if alignment == "Left":
147
  left_overlap = margin_x + overlap_x if overlap_left else margin_x
148
  elif alignment == "Right":
@@ -152,37 +133,19 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
152
  elif alignment == "Bottom":
153
  bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
154
 
155
-
156
- # Draw the mask
157
- mask_draw.rectangle([
158
- (left_overlap, top_overlap),
159
- (right_overlap, bottom_overlap)
160
- ], fill=0)
161
-
162
  return background, mask
163
 
164
- def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
165
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
166
-
167
- # Create a preview image showing the mask
168
- preview = background.copy().convert('RGBA')
169
-
170
- # Create a semi-transparent red overlay
171
- red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64)) # Reduced alpha to 64 (25% opacity)
172
-
173
- # Convert black pixels in the mask to semi-transparent red
174
- red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
175
- red_mask.paste(red_overlay, (0, 0), mask)
176
-
177
- # Overlay the red mask on the background
178
- preview = Image.alpha_composite(preview, red_mask)
179
-
180
- return preview
181
-
182
  @spaces.GPU(duration=24)
183
- def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
184
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
185
-
 
 
 
 
 
186
  if not can_expand(background.width, background.height, width, height, alignment):
187
  alignment = "Middle"
188
 
@@ -191,7 +154,6 @@ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_
191
 
192
  final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
193
 
194
- # Use with torch.autocast to ensure consistent dtype
195
  with torch.autocast(device_type="cuda", dtype=torch.float16):
196
  (
197
  prompt_embeds,
@@ -200,7 +162,8 @@ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_
200
  negative_pooled_prompt_embeds,
201
  ) = pipe.encode_prompt(final_prompt, "cuda", True)
202
 
203
- for image in pipe(
 
204
  prompt_embeds=prompt_embeds,
205
  negative_prompt_embeds=negative_prompt_embeds,
206
  pooled_prompt_embeds=pooled_prompt_embeds,
@@ -208,31 +171,54 @@ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_
208
  image=cnet_image,
209
  num_inference_steps=num_inference_steps
210
  ):
211
- yield cnet_image, image
 
 
 
212
 
213
- image = image.convert("RGBA")
214
- cnet_image.paste(image, (0, 0), mask)
 
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  yield background, cnet_image
217
 
218
  def clear_result():
219
- """Clears the result ImageSlider."""
220
  return gr.update(value=None)
221
 
222
  def preload_presets(target_ratio, ui_width, ui_height):
223
- """Updates the width and height sliders based on the selected aspect ratio."""
224
  if target_ratio == "9:16":
225
- changed_width = 720
226
- changed_height = 1280
227
- return changed_width, changed_height, gr.update()
228
  elif target_ratio == "16:9":
229
- changed_width = 1280
230
- changed_height = 720
231
- return changed_width, changed_height, gr.update()
232
  elif target_ratio == "1:1":
233
- changed_width = 1024
234
- changed_height = 1024
235
- return changed_width, changed_height, gr.update()
236
  elif target_ratio == "Custom":
237
  return ui_width, ui_height, gr.update(open=True)
238
 
@@ -250,7 +236,6 @@ def toggle_custom_resize_slider(resize_option):
250
  return gr.update(visible=(resize_option == "Custom"))
251
 
252
  def update_history(new_image, history):
253
- """Updates the history gallery with the new image."""
254
  if history is None:
255
  history = []
256
  history.insert(0, new_image)
@@ -262,68 +247,37 @@ css = """
262
  }
263
  """
264
 
265
- # Define the title HTML string
266
- title = """<h1 align="center">Re-Size Image Outpaint</h1>
267
- """
268
 
269
  with gr.Blocks(theme="soft", css=css) as demo:
270
  with gr.Column():
271
  gr.HTML(title)
272
-
273
  with gr.Row():
274
  with gr.Column():
275
- input_image = gr.Image(
276
- type="pil",
277
- label="Input Image"
278
- )
279
-
280
  with gr.Row():
281
  with gr.Column(scale=2):
282
  prompt_input = gr.Textbox(label="Prompt (Optional)")
283
  with gr.Column(scale=1):
284
  run_button = gr.Button("Generate")
285
-
286
  with gr.Row():
287
  target_ratio = gr.Radio(
288
  label="Expected Ratio",
289
  choices=["9:16", "16:9", "1:1", "Custom"],
290
- value="9:16",
291
- scale=2
292
  )
293
-
294
  alignment_dropdown = gr.Dropdown(
295
  choices=["Middle", "Left", "Right", "Top", "Bottom"],
296
- value="Middle",
297
- label="Alignment"
298
  )
299
-
300
  with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
301
  with gr.Column():
302
  with gr.Row():
303
- width_slider = gr.Slider(
304
- label="Target Width",
305
- minimum=720,
306
- maximum=1536,
307
- step=8,
308
- value=720,
309
- )
310
- height_slider = gr.Slider(
311
- label="Target Height",
312
- minimum=720,
313
- maximum=1536,
314
- step=8,
315
- value=1280,
316
- )
317
-
318
  num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
319
  with gr.Group():
320
- overlap_percentage = gr.Slider(
321
- label="Mask overlap (%)",
322
- minimum=1,
323
- maximum=50,
324
- value=10,
325
- step=1
326
- )
327
  with gr.Row():
328
  overlap_top = gr.Checkbox(label="Overlap Top", value=True)
329
  overlap_right = gr.Checkbox(label="Overlap Right", value=True)
@@ -331,24 +285,11 @@ with gr.Blocks(theme="soft", css=css) as demo:
331
  overlap_left = gr.Checkbox(label="Overlap Left", value=True)
332
  overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
333
  with gr.Row():
334
- resize_option = gr.Radio(
335
- label="Resize input image",
336
- choices=["Full", "50%", "33%", "25%", "Custom"],
337
- value="Full"
338
- )
339
- custom_resize_percentage = gr.Slider(
340
- label="Custom resize (%)",
341
- minimum=1,
342
- maximum=100,
343
- step=1,
344
- value=50,
345
- visible=False
346
- )
347
-
348
  with gr.Column():
349
  preview_button = gr.Button("Preview alignment and mask")
350
-
351
-
352
  gr.Examples(
353
  examples=[
354
  ["./examples/example_2.jpg", 1440, 810, "Left"],
@@ -358,108 +299,106 @@ with gr.Blocks(theme="soft", css=css) as demo:
358
  inputs=[input_image, width_slider, height_slider, alignment_dropdown],
359
  )
360
 
361
-
362
-
363
  with gr.Column():
364
- result = ImageSlider(
365
- interactive=False,
366
- label="Generated Image",
367
- )
368
  use_as_input_button = gr.Button("Use as Input Image", visible=False)
369
-
370
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
371
  preview_image = gr.Image(label="Preview")
372
 
373
-
374
-
375
  def use_output_as_input(output_image):
376
- """Sets the generated output as the new input image."""
377
  return gr.update(value=output_image[1])
378
 
379
- use_as_input_button.click(
380
- fn=use_output_as_input,
381
- inputs=[result],
382
- outputs=[input_image]
383
- )
384
-
385
- target_ratio.change(
386
- fn=preload_presets,
387
- inputs=[target_ratio, width_slider, height_slider],
388
- outputs=[width_slider, height_slider, settings_panel],
389
- queue=False
390
- )
391
-
392
- width_slider.change(
393
- fn=select_the_right_preset,
394
- inputs=[width_slider, height_slider],
395
- outputs=[target_ratio],
396
- queue=False
397
- )
398
-
399
- height_slider.change(
400
- fn=select_the_right_preset,
401
- inputs=[width_slider, height_slider],
402
- outputs=[target_ratio],
403
- queue=False
404
- )
405
-
406
- resize_option.change(
407
- fn=toggle_custom_resize_slider,
408
- inputs=[resize_option],
409
- outputs=[custom_resize_percentage],
410
- queue=False
411
- )
412
-
413
- run_button.click(
414
- fn=clear_result,
415
- inputs=None,
416
- outputs=result,
417
- ).then(
418
- fn=infer,
419
- inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
420
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
421
- overlap_left, overlap_right, overlap_top, overlap_bottom],
422
- outputs=result,
423
- ).then(
424
- # --- FIX APPLIED HERE ---
425
- # Safely update history only if the result (x) is not None.
426
- fn=lambda x, history: update_history(x[1], history) if x else history,
427
- inputs=[result, history_gallery],
428
- outputs=history_gallery,
429
- ).then(
430
- fn=lambda: gr.update(visible=True),
431
- inputs=None,
432
- outputs=use_as_input_button,
433
- )
434
-
435
- prompt_input.submit(
436
- fn=clear_result,
437
- inputs=None,
438
- outputs=result,
439
- ).then(
440
- fn=infer,
441
- inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
442
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
443
- overlap_left, overlap_right, overlap_top, overlap_bottom],
444
- outputs=result,
445
- ).then(
446
- # --- FIX APPLIED HERE ---
447
- # Safely update history only if the result (x) is not None.
448
- fn=lambda x, history: update_history(x[1], history) if x else history,
449
- inputs=[result, history_gallery],
450
- outputs=history_gallery,
451
- ).then(
452
- fn=lambda: gr.update(visible=True),
453
- inputs=None,
454
- outputs=use_as_input_button,
455
- )
456
 
457
  preview_button.click(
458
- fn=preview_image_and_mask,
459
  inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
460
  overlap_left, overlap_right, overlap_top, overlap_bottom],
461
- outputs=preview_image,
462
- queue=False
463
  )
464
 
465
- demo.queue(max_size=12).launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
  import gradio as gr
3
  import spaces
4
  import torch
 
13
  from PIL import Image, ImageDraw
14
  import numpy as np
15
 
16
+ # --- NEW: FastAPI bits for custom REST endpoint ---
17
+ from fastapi import FastAPI, File, UploadFile, Form
18
+ from fastapi.responses import StreamingResponse, JSONResponse
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ # -------------------------------------------------
21
+
22
+ # =========================
23
+ # MODEL / PIPELINE LOAD
24
+ # =========================
25
  config_file = hf_hub_download(
26
  "xinsir/controlnet-union-sdxl-1.0",
27
  filename="config_promax.json",
 
30
  config = ControlNetModel_Union.load_config(config_file)
31
  controlnet_model = ControlNetModel_Union.from_config(config)
32
 
 
33
  model_file = hf_hub_download(
34
  "xinsir/controlnet-union-sdxl-1.0",
35
  filename="diffusion_pytorch_model_promax.safetensors",
36
  )
37
  state_dict = load_state_dict(model_file)
 
 
38
  loaded_keys = list(state_dict.keys())
39
 
 
40
  result = ControlNetModel_Union._load_pretrained_model(
41
  controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0", loaded_keys
42
  )
 
 
43
  model = result[0]
44
  model = model.to(device="cuda", dtype=torch.float16)
45
 
 
46
  vae = AutoencoderKL.from_pretrained(
47
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
48
  ).to("cuda")
 
58
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
59
 
60
 
61
+ # =========================
62
+ # HELPERS
63
+ # =========================
64
  def can_expand(source_width, source_height, target_width, target_height, alignment):
 
65
  if alignment in ("Left", "Right") and source_width >= target_width:
66
  return False
67
  if alignment in ("Top", "Bottom") and source_height >= target_height:
 
71
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
72
  target_size = (width, height)
73
 
74
+ # Fit image into target canvas
75
  scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
76
  new_width = int(image.width * scale_factor)
77
  new_height = int(image.height * scale_factor)
 
 
78
  source = image.resize((new_width, new_height), Image.LANCZOS)
79
 
80
+ # Resize option (%)
81
  if resize_option == "Full":
82
  resize_percentage = 100
83
  elif resize_option == "50%":
 
86
  resize_percentage = 33
87
  elif resize_option == "25%":
88
  resize_percentage = 25
89
+ else:
90
  resize_percentage = custom_resize_percentage
91
 
92
+ resize_factor = max(1, int(resize_percentage)) / 100.0
93
+ new_width = max(int(source.width * resize_factor), 64)
94
+ new_height = max(int(source.height * resize_factor), 64)
 
 
 
 
 
 
 
95
  source = source.resize((new_width, new_height), Image.LANCZOS)
96
 
97
+ overlap_x = max(int(new_width * (overlap_percentage / 100)), 1)
98
+ overlap_y = max(int(new_height * (overlap_percentage / 100)), 1)
 
99
 
 
 
 
 
 
100
  if alignment == "Middle":
101
  margin_x = (target_size[0] - new_width) // 2
102
  margin_y = (target_size[1] - new_height) // 2
103
  elif alignment == "Left":
104
+ margin_x = 0; margin_y = (target_size[1] - new_height) // 2
 
105
  elif alignment == "Right":
106
+ margin_x = target_size[0] - new_width; margin_y = (target_size[1] - new_height) // 2
 
107
  elif alignment == "Top":
108
+ margin_x = (target_size[0] - new_width) // 2; margin_y = 0
 
109
  elif alignment == "Bottom":
110
+ margin_x = (target_size[0] - new_width) // 2; margin_y = target_size[1] - new_height
 
111
 
 
112
  margin_x = max(0, min(margin_x, target_size[0] - new_width))
113
  margin_y = max(0, min(margin_y, target_size[1] - new_height))
114
 
 
115
  background = Image.new('RGB', target_size, (255, 255, 255))
116
  background.paste(source, (margin_x, margin_y))
117
 
 
118
  mask = Image.new('L', target_size, 255)
119
  mask_draw = ImageDraw.Draw(mask)
120
 
 
121
  white_gaps_patch = 2
 
122
  left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
123
  right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
124
  top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
125
  bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
126
+
127
  if alignment == "Left":
128
  left_overlap = margin_x + overlap_x if overlap_left else margin_x
129
  elif alignment == "Right":
 
133
  elif alignment == "Bottom":
134
  bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
135
 
136
+ mask_draw.rectangle([(left_overlap, top_overlap), (right_overlap, bottom_overlap)], fill=0)
 
 
 
 
 
 
137
  return background, mask
138
 
139
+ # --- NEW: single-call synchronous generator for both UI and REST ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  @spaces.GPU(duration=24)
141
+ def run_outpaint_sync(image, width, height, overlap_percentage, num_inference_steps, resize_option,
142
+ custom_resize_percentage, prompt_input, alignment,
143
+ overlap_left, overlap_right, overlap_top, overlap_bottom):
144
+ background, mask = prepare_image_and_mask(
145
+ image, width, height, overlap_percentage, resize_option, custom_resize_percentage,
146
+ alignment, overlap_left, overlap_right, overlap_top, overlap_bottom
147
+ )
148
+
149
  if not can_expand(background.width, background.height, width, height, alignment):
150
  alignment = "Middle"
151
 
 
154
 
155
  final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
156
 
 
157
  with torch.autocast(device_type="cuda", dtype=torch.float16):
158
  (
159
  prompt_embeds,
 
162
  negative_pooled_prompt_embeds,
163
  ) = pipe.encode_prompt(final_prompt, "cuda", True)
164
 
165
+ last_image = None
166
+ for img in pipe(
167
  prompt_embeds=prompt_embeds,
168
  negative_prompt_embeds=negative_prompt_embeds,
169
  pooled_prompt_embeds=pooled_prompt_embeds,
 
171
  image=cnet_image,
172
  num_inference_steps=num_inference_steps
173
  ):
174
+ last_image = img
175
+
176
+ if last_image is None:
177
+ raise RuntimeError("Pipeline did not return an image.")
178
 
179
+ last_image = last_image.convert("RGBA")
180
+ cnet_image.paste(last_image, (0, 0), mask)
181
+ return background, cnet_image
182
 
183
+ # (Original streaming infer for UI remains, unchanged)
184
+ @spaces.GPU(duration=24)
185
+ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
186
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
187
+ if not can_expand(background.width, background.height, width, height, alignment):
188
+ alignment = "Middle"
189
+ cnet_image = background.copy()
190
+ cnet_image.paste(0, (0, 0), mask)
191
+ final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
192
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
193
+ (
194
+ prompt_embeds,
195
+ negative_prompt_embeds,
196
+ pooled_prompt_embeds,
197
+ negative_pooled_prompt_embeds,
198
+ ) = pipe.encode_prompt(final_prompt, "cuda", True)
199
+ for img in pipe(
200
+ prompt_embeds=prompt_embeds,
201
+ negative_prompt_embeds=negative_prompt_embeds,
202
+ pooled_prompt_embeds=pooled_prompt_embeds,
203
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
204
+ image=cnet_image,
205
+ num_inference_steps=num_inference_steps
206
+ ):
207
+ yield cnet_image, img
208
+ img = img.convert("RGBA")
209
+ cnet_image.paste(img, (0, 0), mask)
210
  yield background, cnet_image
211
 
212
  def clear_result():
 
213
  return gr.update(value=None)
214
 
215
  def preload_presets(target_ratio, ui_width, ui_height):
 
216
  if target_ratio == "9:16":
217
+ return 720, 1280, gr.update()
 
 
218
  elif target_ratio == "16:9":
219
+ return 1280, 720, gr.update()
 
 
220
  elif target_ratio == "1:1":
221
+ return 1024, 1024, gr.update()
 
 
222
  elif target_ratio == "Custom":
223
  return ui_width, ui_height, gr.update(open=True)
224
 
 
236
  return gr.update(visible=(resize_option == "Custom"))
237
 
238
  def update_history(new_image, history):
 
239
  if history is None:
240
  history = []
241
  history.insert(0, new_image)
 
247
  }
248
  """
249
 
250
+ title = """<h1 align="center">Re-Size Image Outpaint</h1>"""
 
 
251
 
252
  with gr.Blocks(theme="soft", css=css) as demo:
253
  with gr.Column():
254
  gr.HTML(title)
 
255
  with gr.Row():
256
  with gr.Column():
257
+ input_image = gr.Image(type="pil", label="Input Image")
 
 
 
 
258
  with gr.Row():
259
  with gr.Column(scale=2):
260
  prompt_input = gr.Textbox(label="Prompt (Optional)")
261
  with gr.Column(scale=1):
262
  run_button = gr.Button("Generate")
 
263
  with gr.Row():
264
  target_ratio = gr.Radio(
265
  label="Expected Ratio",
266
  choices=["9:16", "16:9", "1:1", "Custom"],
267
+ value="9:16", scale=2
 
268
  )
 
269
  alignment_dropdown = gr.Dropdown(
270
  choices=["Middle", "Left", "Right", "Top", "Bottom"],
271
+ value="Middle", label="Alignment"
 
272
  )
 
273
  with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
274
  with gr.Column():
275
  with gr.Row():
276
+ width_slider = gr.Slider(label="Target Width", minimum=720, maximum=1536, step=8, value=720)
277
+ height_slider = gr.Slider(label="Target Height", minimum=720, maximum=1536, step=8, value=1280)
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
279
  with gr.Group():
280
+ overlap_percentage = gr.Slider(label="Mask overlap (%)", minimum=1, maximum=50, value=10, step=1)
 
 
 
 
 
 
281
  with gr.Row():
282
  overlap_top = gr.Checkbox(label="Overlap Top", value=True)
283
  overlap_right = gr.Checkbox(label="Overlap Right", value=True)
 
285
  overlap_left = gr.Checkbox(label="Overlap Left", value=True)
286
  overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
287
  with gr.Row():
288
+ resize_option = gr.Radio(label="Resize input image", choices=["Full", "50%", "33%", "25%", "Custom"], value="Full")
289
+ custom_resize_percentage = gr.Slider(label="Custom resize (%)", minimum=1, maximum=100, step=1, value=50, visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
290
  with gr.Column():
291
  preview_button = gr.Button("Preview alignment and mask")
292
+
 
293
  gr.Examples(
294
  examples=[
295
  ["./examples/example_2.jpg", 1440, 810, "Left"],
 
299
  inputs=[input_image, width_slider, height_slider, alignment_dropdown],
300
  )
301
 
 
 
302
  with gr.Column():
303
+ result = ImageSlider(interactive=False, label="Generated Image")
 
 
 
304
  use_as_input_button = gr.Button("Use as Input Image", visible=False)
 
305
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
306
  preview_image = gr.Image(label="Preview")
307
 
 
 
308
  def use_output_as_input(output_image):
 
309
  return gr.update(value=output_image[1])
310
 
311
+ use_as_input_button.click(fn=use_output_as_input, inputs=[result], outputs=[input_image])
312
+
313
+ target_ratio.change(fn=preload_presets, inputs=[target_ratio, width_slider, height_slider], outputs=[width_slider, height_slider, settings_panel], queue=False)
314
+ width_slider.change(fn=select_the_right_preset, inputs=[width_slider, height_slider], outputs=[target_ratio], queue=False)
315
+ height_slider.change(fn=select_the_right_preset, inputs=[width_slider, height_slider], outputs=[target_ratio], queue=False)
316
+ resize_option.change(fn=toggle_custom_resize_slider, inputs=[resize_option], outputs=[custom_resize_percentage], queue=False)
317
+
318
+ run_button.click(fn=clear_result, inputs=None, outputs=result)\
319
+ .then(fn=infer,
320
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
321
+ resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
322
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
323
+ outputs=result)\
324
+ .then(fn=lambda x, history: update_history(x[1], history) if x else history,
325
+ inputs=[result, history_gallery],
326
+ outputs=history_gallery)\
327
+ .then(fn=lambda: gr.update(visible=True), inputs=None, outputs=use_as_input_button)
328
+
329
+ prompt_input.submit(fn=clear_result, inputs=None, outputs=result)\
330
+ .then(fn=infer,
331
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
332
+ resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
333
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
334
+ outputs=result)\
335
+ .then(fn=lambda x, history: update_history(x[1], history) if x else history,
336
+ inputs=[result, history_gallery],
337
+ outputs=history_gallery)\
338
+ .then(fn=lambda: gr.update(visible=True), inputs=None, outputs=use_as_input_button)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
  preview_button.click(
341
+ fn=lambda *args: preview_image_and_mask(*args),
342
  inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
343
  overlap_left, overlap_right, overlap_top, overlap_bottom],
344
+ outputs=preview_image, queue=False
 
345
  )
346
 
347
+ # =========================================
348
+ # FASTAPI APP + CUSTOM REST ENDPOINT
349
+ # =========================================
350
+ app = FastAPI()
351
+ app.add_middleware(
352
+ CORSMiddleware,
353
+ allow_origins=["*"], allow_credentials=True,
354
+ allow_methods=["*"], allow_headers=["*"],
355
+ )
356
+
357
+ @app.post("/rest/infer")
358
+ def rest_infer(
359
+ file: UploadFile = File(...),
360
+ width: int = Form(1024),
361
+ height: int = Form(1024),
362
+ overlap_percentage: float = Form(10),
363
+ num_inference_steps: int = Form(8),
364
+ resize_option: str = Form("Full"),
365
+ custom_resize_percentage: float = Form(50),
366
+ prompt_input: str = Form(""),
367
+ alignment: str = Form("Middle"),
368
+ overlap_left: bool = Form(True),
369
+ overlap_right: bool = Form(True),
370
+ overlap_top: bool = Form(True),
371
+ overlap_bottom: bool = Form(True),
372
+ ):
373
+ try:
374
+ img_bytes = file.file.read()
375
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
376
+ except Exception as e:
377
+ return JSONResponse({"error": f"Invalid image upload: {e}"}, status_code=400)
378
+
379
+ try:
380
+ _, outpainted = run_outpaint_sync(
381
+ image=image,
382
+ width=width,
383
+ height=height,
384
+ overlap_percentage=overlap_percentage,
385
+ num_inference_steps=num_inference_steps,
386
+ resize_option=resize_option,
387
+ custom_resize_percentage=custom_resize_percentage,
388
+ prompt_input=prompt_input,
389
+ alignment=alignment,
390
+ overlap_left=overlap_left,
391
+ overlap_right=overlap_right,
392
+ overlap_top=overlap_top,
393
+ overlap_bottom=overlap_bottom,
394
+ )
395
+ except Exception as e:
396
+ return JSONResponse({"error": str(e)}, status_code=500)
397
+
398
+ buf = io.BytesIO()
399
+ outpainted.save(buf, format="PNG")
400
+ buf.seek(0)
401
+ return StreamingResponse(buf, media_type="image/png")
402
+
403
+ # Mount the Gradio UI at root path
404
+ app = gr.mount_gradio_app(app, demo, path="/")