Gemini899 commited on
Commit
032ce8c
·
verified ·
1 Parent(s): ad12b8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +359 -420
app.py CHANGED
@@ -1,61 +1,47 @@
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
 
3
  import torch
4
- from diffusers import AutoencoderKL, TCDScheduler
5
- from diffusers.models.model_loading_utils import load_state_dict
6
- from gradio_imageslider import ImageSlider
7
- from huggingface_hub import hf_hub_download
8
-
9
- from controlnet_union import ControlNetModel_Union
10
- from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
11
-
12
- from PIL import Image, ImageDraw
13
- import numpy as np
14
-
15
- config_file = hf_hub_download(
16
- "xinsir/controlnet-union-sdxl-1.0",
17
- filename="config_promax.json",
18
- )
19
-
20
- config = ControlNetModel_Union.load_config(config_file)
21
- controlnet_model = ControlNetModel_Union.from_config(config)
22
-
23
- # Load the state dictionary
24
- model_file = hf_hub_download(
25
- "xinsir/controlnet-union-sdxl-1.0",
26
- filename="diffusion_pytorch_model_promax.safetensors",
27
- )
28
- state_dict = load_state_dict(model_file)
29
-
30
- # Extract the keys from the state_dict
31
- loaded_keys = list(state_dict.keys())
32
 
33
- # Call the method and store all returns in a variable
34
- result = ControlNetModel_Union._load_pretrained_model(
35
- controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0", loaded_keys
36
- )
 
 
 
37
 
38
- # Use the first element from the result
39
- model = result[0]
40
- model = model.to(device="cuda", dtype=torch.float16)
41
 
 
 
42
 
43
- vae = AutoencoderKL.from_pretrained(
44
- "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
45
- ).to("cuda")
46
-
47
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
48
- "SG161222/RealVisXL_V5.0_Lightning",
49
- torch_dtype=torch.float16,
50
- vae=vae,
51
- controlnet=model,
52
- variant="fp16",
53
- ).to("cuda")
 
 
54
 
55
- pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
56
 
 
57
 
58
- def can_expand(source_width, source_height, target_width, target_height, alignment):
59
  """Checks if the image can be expanded based on the alignment."""
60
  if alignment in ("Left", "Right") and source_width >= target_width:
61
  return False
@@ -63,403 +49,356 @@ def can_expand(source_width, source_height, target_width, target_height, alignme
63
  return False
64
  return True
65
 
66
- def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
67
- target_size = (width, height)
68
-
69
- # Calculate the scaling factor to fit the image within the target size
70
- scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
71
- new_width = int(image.width * scale_factor)
72
- new_height = int(image.height * scale_factor)
73
-
74
- # Resize the source image to fit within target size
75
- source = image.resize((new_width, new_height), Image.LANCZOS)
76
-
77
- # Apply resize option using percentages
78
  if resize_option == "Full":
79
- resize_percentage = 100
80
- elif resize_option == "50%":
81
- resize_percentage = 50
82
- elif resize_option == "33%":
83
- resize_percentage = 33
84
- elif resize_option == "25%":
85
- resize_percentage = 25
86
- else: # Custom
87
- resize_percentage = custom_resize_percentage
88
-
89
- # Calculate new dimensions based on percentage
90
- resize_factor = resize_percentage / 100
91
- new_width = int(source.width * resize_factor)
92
- new_height = int(source.height * resize_factor)
93
-
94
- # Ensure minimum size of 64 pixels
95
- new_width = max(new_width, 64)
96
- new_height = max(new_height, 64)
97
-
98
- # Resize the image
99
- source = source.resize((new_width, new_height), Image.LANCZOS)
100
-
101
- # Calculate the overlap in pixels based on the percentage
102
- overlap_x = int(new_width * (overlap_percentage / 100))
103
- overlap_y = int(new_height * (overlap_percentage / 100))
104
-
105
- # Ensure minimum overlap of 1 pixel
106
- overlap_x = max(overlap_x, 1)
107
- overlap_y = max(overlap_y, 1)
108
-
109
- # Calculate margins based on alignment
110
- if alignment == "Middle":
111
- margin_x = (target_size[0] - new_width) // 2
112
- margin_y = (target_size[1] - new_height) // 2
113
- elif alignment == "Left":
114
- margin_x = 0
115
- margin_y = (target_size[1] - new_height) // 2
116
  elif alignment == "Right":
117
- margin_x = target_size[0] - new_width
118
- margin_y = (target_size[1] - new_height) // 2
119
  elif alignment == "Top":
120
- margin_x = (target_size[0] - new_width) // 2
121
- margin_y = 0
122
  elif alignment == "Bottom":
123
- margin_x = (target_size[0] - new_width) // 2
124
- margin_y = target_size[1] - new_height
125
-
126
- # Adjust margins to eliminate gaps
127
- margin_x = max(0, min(margin_x, target_size[0] - new_width))
128
- margin_y = max(0, min(margin_y, target_size[1] - new_height))
129
-
130
- # Create a new background image and paste the resized source image
131
- background = Image.new('RGB', target_size, (255, 255, 255))
132
- background.paste(source, (margin_x, margin_y))
133
-
134
- # Create the mask
135
- mask = Image.new('L', target_size, 255)
136
- mask_draw = ImageDraw.Draw(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- # Calculate overlap areas
139
- white_gaps_patch = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
142
- right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
143
- top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
144
- bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
145
-
146
- if alignment == "Left":
147
- left_overlap = margin_x + overlap_x if overlap_left else margin_x
148
- elif alignment == "Right":
149
- right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
150
- elif alignment == "Top":
151
- top_overlap = margin_y + overlap_y if overlap_top else margin_y
152
- elif alignment == "Bottom":
153
- bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
154
-
155
-
156
- # Draw the mask
157
- mask_draw.rectangle([
158
- (left_overlap, top_overlap),
159
- (right_overlap, bottom_overlap)
160
- ], fill=0)
161
-
162
- return background, mask
163
-
164
- def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
165
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
166
-
167
- # Create a preview image showing the mask
168
- preview = background.copy().convert('RGBA')
169
-
170
- # Create a semi-transparent red overlay
171
- red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64)) # Reduced alpha to 64 (25% opacity)
172
-
173
- # Convert black pixels in the mask to semi-transparent red
174
- red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
175
- red_mask.paste(red_overlay, (0, 0), mask)
176
-
177
- # Overlay the red mask on the background
178
- preview = Image.alpha_composite(preview, red_mask)
179
-
180
  return preview
181
 
182
- @spaces.GPU(duration=24)
183
- def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
184
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
185
-
186
- if not can_expand(background.width, background.height, width, height, alignment):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  alignment = "Middle"
188
 
189
- cnet_image = background.copy()
190
- cnet_image.paste(0, (0, 0), mask)
191
-
192
- final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
193
-
194
- # Use with torch.autocast to ensure consistent dtype
195
- with torch.autocast(device_type="cuda", dtype=torch.float16):
196
- (
197
- prompt_embeds,
198
- negative_prompt_embeds,
199
- pooled_prompt_embeds,
200
- negative_pooled_prompt_embeds,
201
- ) = pipe.encode_prompt(final_prompt, "cuda", True)
202
-
203
- for image in pipe(
204
- prompt_embeds=prompt_embeds,
205
- negative_prompt_embeds=negative_prompt_embeds,
206
- pooled_prompt_embeds=pooled_prompt_embeds,
207
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
208
- image=cnet_image,
209
- num_inference_steps=num_inference_steps
210
- ):
211
- yield cnet_image, image
212
-
213
- image = image.convert("RGBA")
214
- cnet_image.paste(image, (0, 0), mask)
215
-
216
- yield background, cnet_image
217
-
218
- def clear_result():
219
- """Clears the result ImageSlider."""
220
- return gr.update(value=None)
221
-
222
- def preload_presets(target_ratio, ui_width, ui_height):
223
- """Updates the width and height sliders based on the selected aspect ratio."""
224
- if target_ratio == "9:16":
225
- changed_width = 720
226
- changed_height = 1280
227
- return changed_width, changed_height, gr.update()
228
- elif target_ratio == "16:9":
229
- changed_width = 1280
230
- changed_height = 720
231
- return changed_width, changed_height, gr.update()
232
- elif target_ratio == "1:1":
233
- changed_width = 1024
234
- changed_height = 1024
235
- return changed_width, changed_height, gr.update()
236
- elif target_ratio == "Custom":
237
- return ui_width, ui_height, gr.update(open=True)
238
-
239
- def select_the_right_preset(user_width, user_height):
240
- if user_width == 720 and user_height == 1280:
241
- return "9:16"
242
- elif user_width == 1280 and user_height == 720:
243
- return "16:9"
244
- elif user_width == 1024 and user_height == 1024:
245
- return "1:1"
246
- else:
247
- return "Custom"
248
-
249
- def toggle_custom_resize_slider(resize_option):
250
- return gr.update(visible=(resize_option == "Custom"))
251
-
252
- def update_history(new_image, history):
253
- """Updates the history gallery with the new image."""
254
- if history is None:
255
- history = []
256
- history.insert(0, new_image)
257
- return history
258
-
259
- css = """
260
- .gradio-container {
261
- width: 1200px !important;
262
- }
263
- """
264
-
265
- # Define the title HTML string
266
- title = """<h1 align="center">Re-Size Image Outpaint</h1>
267
- """
268
-
269
- with gr.Blocks(theme="soft", css=css) as demo:
270
- with gr.Column():
271
- gr.HTML(title)
272
-
273
- with gr.Row():
274
- with gr.Column():
275
- input_image = gr.Image(
276
- type="pil",
277
- label="Input Image"
278
- )
279
-
280
- with gr.Row():
281
- with gr.Column(scale=2):
282
- prompt_input = gr.Textbox(label="Prompt (Optional)")
283
- with gr.Column(scale=1):
284
- run_button = gr.Button("Generate")
285
-
286
- with gr.Row():
287
- target_ratio = gr.Radio(
288
- label="Expected Ratio",
289
- choices=["9:16", "16:9", "1:1", "Custom"],
290
- value="9:16",
291
- scale=2
292
- )
293
-
294
- alignment_dropdown = gr.Dropdown(
295
- choices=["Middle", "Left", "Right", "Top", "Bottom"],
296
- value="Middle",
297
- label="Alignment"
298
- )
299
-
300
- with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
301
- with gr.Column():
302
- with gr.Row():
303
- width_slider = gr.Slider(
304
- label="Target Width",
305
- minimum=720,
306
- maximum=1536,
307
- step=8,
308
- value=720,
309
- )
310
- height_slider = gr.Slider(
311
- label="Target Height",
312
- minimum=720,
313
- maximum=1536,
314
- step=8,
315
- value=1280,
316
- )
317
-
318
- num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
319
- with gr.Group():
320
- overlap_percentage = gr.Slider(
321
- label="Mask overlap (%)",
322
- minimum=1,
323
- maximum=50,
324
- value=10,
325
- step=1
326
- )
327
- with gr.Row():
328
- overlap_top = gr.Checkbox(label="Overlap Top", value=True)
329
- overlap_right = gr.Checkbox(label="Overlap Right", value=True)
330
- with gr.Row():
331
- overlap_left = gr.Checkbox(label="Overlap Left", value=True)
332
- overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
333
- with gr.Row():
334
- resize_option = gr.Radio(
335
- label="Resize input image",
336
- choices=["Full", "50%", "33%", "25%", "Custom"],
337
- value="Full"
338
- )
339
- custom_resize_percentage = gr.Slider(
340
- label="Custom resize (%)",
341
- minimum=1,
342
- maximum=100,
343
- step=1,
344
- value=50,
345
- visible=False
346
- )
347
-
348
- with gr.Column():
349
- preview_button = gr.Button("Preview alignment and mask")
350
-
351
-
352
- gr.Examples(
353
- examples=[
354
- ["./examples/example_2.jpg", 1440, 810, "Left"],
355
- ["./examples/example_3.jpg", 1024, 1024, "Top"],
356
- ["./examples/example_3.jpg", 1024, 1024, "Bottom"],
357
- ],
358
- inputs=[input_image, width_slider, height_slider, alignment_dropdown],
359
- )
360
-
361
-
362
-
363
- with gr.Column():
364
- result = ImageSlider(
365
- interactive=False,
366
- label="Generated Image",
367
- )
368
- use_as_input_button = gr.Button("Use as Input Image", visible=False)
369
-
370
- history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
371
- preview_image = gr.Image(label="Preview")
372
-
373
-
374
-
375
- def use_output_as_input(output_image):
376
- """Sets the generated output as the new input image."""
377
- return gr.update(value=output_image[1])
378
-
379
- use_as_input_button.click(
380
- fn=use_output_as_input,
381
- inputs=[result],
382
- outputs=[input_image]
383
- )
384
-
385
- target_ratio.change(
386
- fn=preload_presets,
387
- inputs=[target_ratio, width_slider, height_slider],
388
- outputs=[width_slider, height_slider, settings_panel],
389
- queue=False
390
- )
391
 
392
- width_slider.change(
393
- fn=select_the_right_preset,
394
- inputs=[width_slider, height_slider],
395
- outputs=[target_ratio],
396
- queue=False
397
  )
398
 
399
- height_slider.change(
400
- fn=select_the_right_preset,
401
- inputs=[width_slider, height_slider],
402
- outputs=[target_ratio],
403
- queue=False
 
 
 
 
 
 
 
 
 
404
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
 
 
 
 
 
 
 
 
 
406
  resize_option.change(
407
  fn=toggle_custom_resize_slider,
408
- inputs=[resize_option],
409
- outputs=[custom_resize_percentage],
410
- queue=False
411
  )
412
-
413
- run_button.click(
414
- fn=clear_result,
415
- inputs=None,
416
- outputs=result,
417
- ).then(
418
- fn=infer,
419
- inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
420
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
421
  overlap_left, overlap_right, overlap_top, overlap_bottom],
422
- outputs=result,
423
- ).then(
424
- # --- FIX APPLIED HERE ---
425
- # Safely update history only if the result (x) is not None.
426
- fn=lambda x, history: update_history(x[1], history) if x else history,
427
- inputs=[result, history_gallery],
428
- outputs=history_gallery,
429
- ).then(
430
- fn=lambda: gr.update(visible=True),
431
- inputs=None,
432
- outputs=use_as_input_button,
433
  )
434
 
435
- prompt_input.submit(
436
- fn=clear_result,
437
- inputs=None,
438
- outputs=result,
439
- ).then(
440
- fn=infer,
 
 
 
 
441
  inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
442
  resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
443
  overlap_left, overlap_right, overlap_top, overlap_bottom],
444
- outputs=result,
445
- ).then(
446
- # --- FIX APPLIED HERE ---
447
- # Safely update history only if the result (x) is not None.
448
- fn=lambda x, history: update_history(x[1], history) if x else history,
449
- inputs=[result, history_gallery],
450
- outputs=history_gallery,
451
- ).then(
452
- fn=lambda: gr.update(visible=True),
453
- inputs=None,
454
- outputs=use_as_input_button,
455
  )
456
 
457
- preview_button.click(
458
- fn=preview_image_and_mask,
459
- inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
460
- overlap_left, overlap_right, overlap_top, overlap_bottom],
461
- outputs=preview_image,
462
- queue=False
 
 
 
 
 
 
 
 
 
 
 
 
463
  )
464
 
465
- demo.queue(max_size=12).launch(share=False)
 
1
+ import os
2
+ import io
3
+ import math
4
+ import tempfile
5
+ from typing import Tuple
6
+
7
  import gradio as gr
8
  import spaces
9
+ from PIL import Image, ImageDraw, ImageOps
10
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # ===== Pipeline setup =====
13
+ # We try to keep quality similar to your current Space by using SDXL Inpainting.
14
+ # If CUDA isn't available, it'll fall back to CPU (slower).
15
+ try:
16
+ from diffusers import StableDiffusionXLInpaintPipeline
17
+ except Exception as e:
18
+ raise RuntimeError("diffusers is required. Please ensure requirements.txt includes diffusers>=0.27.0") from e
19
 
20
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
 
22
 
23
+ # Prefer the official SDXL inpaint checkpoint
24
+ MODEL_ID = os.environ.get("INPAINT_MODEL_ID", "diffusers/stable-diffusion-xl-1.0-inpainting-0.1")
25
 
26
+ def _load_pipe():
27
+ pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
28
+ MODEL_ID, torch_dtype=DTYPE
29
+ )
30
+ if DEVICE == "cuda":
31
+ pipe = pipe.to("cuda")
32
+ try:
33
+ pipe.enable_xformers_memory_efficient_attention()
34
+ except Exception:
35
+ pass
36
+ else:
37
+ pipe = pipe.to("cpu")
38
+ return pipe
39
 
40
+ pipe = _load_pipe()
41
 
42
+ # ===== Helpers =====
43
 
44
+ def can_expand(source_width: int, source_height: int, target_width: int, target_height: int, alignment: str) -> bool:
45
  """Checks if the image can be expanded based on the alignment."""
46
  if alignment in ("Left", "Right") and source_width >= target_width:
47
  return False
 
49
  return False
50
  return True
51
 
52
+ def _resize_input_for_option(img: Image.Image, resize_option: str, custom_resize_percentage: float) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
 
53
  if resize_option == "Full":
54
+ return img
55
+ if resize_option in ("50%", "33%", "25%"):
56
+ pct = {"50%": 50, "33%": 33, "25%": 25}[resize_option]
57
+ elif resize_option == "Custom":
58
+ pct = max(1, min(400, int(custom_resize_percentage)))
59
+ else:
60
+ return img
61
+ w, h = img.size
62
+ nw = max(1, int(w * pct / 100.0))
63
+ nh = max(1, int(h * pct / 100.0))
64
+ return img.resize((nw, nh), Image.LANCZOS)
65
+
66
+ def _place_rect(canvas_w: int, canvas_h: int, img_w: int, img_h: int, alignment: str) -> Tuple[int, int]:
67
+ """Top-left placement for given alignment."""
68
+ if alignment == "Left":
69
+ x = 0
70
+ y = (canvas_h - img_h) // 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  elif alignment == "Right":
72
+ x = canvas_w - img_w
73
+ y = (canvas_h - img_h) // 2
74
  elif alignment == "Top":
75
+ x = (canvas_w - img_w) // 2
76
+ y = 0
77
  elif alignment == "Bottom":
78
+ x = (canvas_w - img_w) // 2
79
+ y = canvas_h - img_h
80
+ else: # Middle
81
+ x = (canvas_w - img_w) // 2
82
+ y = (canvas_h - img_h) // 2
83
+ return x, y
84
+
85
+ def _apply_side_overlaps(x, y, ow, oh, margin, overlap_left, overlap_right, overlap_top, overlap_bottom):
86
+ left = x + (margin if overlap_left else 0)
87
+ top = y + (margin if overlap_top else 0)
88
+ right = x + ow - (margin if overlap_right else 0)
89
+ bottom = y + oh - (margin if overlap_bottom else 0)
90
+ # ensure rectangle is valid
91
+ if right <= left: right = left + 1
92
+ if bottom <= top: bottom = top + 1
93
+ return left, top, right, bottom
94
+
95
+ def prepare_image_and_mask(
96
+ image: Image.Image,
97
+ target_w: int,
98
+ target_h: int,
99
+ overlap_percentage: float,
100
+ resize_option: str,
101
+ custom_resize_percentage: float,
102
+ alignment: str,
103
+ overlap_left: bool,
104
+ overlap_right: bool,
105
+ overlap_top: bool,
106
+ overlap_bottom: bool,
107
+ ):
108
+ """
109
+ Returns (background, mask) for inpainting:
110
+ - background: RGB, input pasted onto a larger canvas
111
+ - mask: L (white = to generate, black = keep)
112
+ """
113
+ if image is None:
114
+ return None, None
115
+
116
+ if image.mode != "RGB":
117
+ image = image.convert("RGB")
118
+
119
+ # Optional initial resize for the input
120
+ image = _resize_input_for_option(image, resize_option, custom_resize_percentage)
121
+
122
+ # Canvas size
123
+ cw, ch = int(target_w), int(target_h)
124
+ iw, ih = image.size
125
+ cw = max(cw, iw)
126
+ ch = max(ch, ih)
127
+
128
+ base = Image.new("RGB", (cw, ch), (0, 0, 0))
129
+ x, y = _place_rect(cw, ch, iw, ih, alignment)
130
+ base.paste(image, (x, y))
131
+
132
+ # Mask creation: white outside the "keep" rect
133
+ mask = Image.new("L", (cw, ch), 255)
134
+ draw = ImageDraw.Draw(mask)
135
+
136
+ margin = int(min(iw, ih) * max(0.0, float(overlap_percentage)) / 100.0)
137
+ margin = max(0, min(margin, min(iw, ih)//3))
138
+
139
+ left, top, right, bottom = _apply_side_overlaps(
140
+ x, y, iw, ih, margin, overlap_left, overlap_right, overlap_top, overlap_bottom
141
+ )
142
+ draw.rectangle([left, top, right, bottom], fill=0)
143
+
144
+ return base, mask
145
+
146
+ # ===== Core inference (UI) =====
147
+
148
+ @spaces.GPU(duration=60)
149
+ def infer(
150
+ image: Image.Image,
151
+ width: int = 720,
152
+ height: int = 1280,
153
+ overlap_percentage: float = 10.0,
154
+ num_inference_steps: int = 8,
155
+ resize_option: str = "Full",
156
+ custom_resize_percentage: float = 50.0,
157
+ prompt_input: str = "",
158
+ alignment: str = "Middle",
159
+ overlap_left: bool = True,
160
+ overlap_right: bool = True,
161
+ overlap_top: bool = True,
162
+ overlap_bottom: bool = True,
163
+ ):
164
+ """
165
+ UI endpoint that returns an ImageSlider-compatible tuple:
166
+ (control_preview_image, generated_image)
167
+ """
168
+ if image is None:
169
+ return None
170
+
171
+ # safety: if alignment can't expand, center instead
172
+ iw, ih = image.size
173
+ if not can_expand(iw, ih, int(width), int(height), alignment):
174
+ alignment = "Middle"
175
 
176
+ background, mask = prepare_image_and_mask(
177
+ image, int(width), int(height), float(overlap_percentage),
178
+ resize_option, float(custom_resize_percentage), alignment,
179
+ overlap_left, overlap_right, overlap_top, overlap_bottom
180
+ )
181
+ if background is None:
182
+ return None
183
+
184
+ # Control preview: show masked area in black overlay
185
+ control_preview = background.copy()
186
+ control_overlay = Image.new("RGB", control_preview.size, (0, 0, 0))
187
+ control_preview.paste(control_overlay, (0, 0), mask)
188
+
189
+ # Seed/generator
190
+ generator = None
191
+ if DEVICE == "cuda":
192
+ generator = torch.Generator(device="cuda")
193
+ if generator is not None:
194
+ generator.manual_seed(torch.seed())
195
+
196
+ # Run inpainting
197
+ result = pipe(
198
+ prompt=prompt_input or "",
199
+ image=background,
200
+ mask_image=mask,
201
+ guidance_scale=3.5,
202
+ num_inference_steps=int(num_inference_steps),
203
+ generator=generator,
204
+ )
205
+ out = result.images[0]
206
+
207
+ # Return slider tuple
208
+ return (control_preview, out)
209
+
210
+ # ===== Preview helper =====
211
+
212
+ def preview_image_and_mask(
213
+ image: Image.Image,
214
+ width: int,
215
+ height: int,
216
+ overlap_percentage: float,
217
+ resize_option: str,
218
+ custom_resize_percentage: float,
219
+ alignment: str,
220
+ overlap_left: bool,
221
+ overlap_right: bool,
222
+ overlap_top: bool,
223
+ overlap_bottom: bool,
224
+ ):
225
+ """
226
+ Return a single preview image for the UI.
227
+ """
228
+ if image is None:
229
+ return None
230
+
231
+ background, mask = prepare_image_and_mask(
232
+ image, int(width), int(height), float(overlap_percentage),
233
+ resize_option, float(custom_resize_percentage), alignment,
234
+ overlap_left, overlap_right, overlap_top, overlap_bottom
235
+ )
236
+ if background is None:
237
+ return None
238
 
239
+ preview = background.copy()
240
+ overlay = Image.new("RGBA", preview.size, (255, 0, 0, 90))
241
+ preview.paste(overlay, (0, 0), mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  return preview
243
 
244
+ # ===== img2img-style API (single image path string) =====
245
+
246
+ @spaces.GPU(duration=60)
247
+ def process_images(
248
+ image: Image.Image,
249
+ prompt: str = "",
250
+ strength: float = 0.75, # kept for client parity; unused by SDXL inpaint
251
+ seed: int = 0,
252
+ inference_step: int = 8,
253
+ width: int = 720,
254
+ height: int = 1280,
255
+ overlap_percentage: float = 10.0,
256
+ alignment: str = "Middle",
257
+ ):
258
+ """
259
+ Adapter endpoint to match your img2img client contract:
260
+ - accepts a single file input
261
+ - returns a single file path (string)
262
+ - internally reuses the same preparation and inpaint call as the UI
263
+ """
264
+ if image is None:
265
+ return None
266
+
267
+ iw, ih = image.size
268
+ if not can_expand(iw, ih, int(width), int(height), alignment):
269
  alignment = "Middle"
270
 
271
+ # Use the same defaults as the UI
272
+ resize_option = "Full"
273
+ custom_resize_percentage = 50.0
274
+ overlap_left = overlap_right = overlap_top = overlap_bottom = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ background, mask = prepare_image_and_mask(
277
+ image, int(width), int(height), float(overlap_percentage),
278
+ resize_option, float(custom_resize_percentage), alignment,
279
+ overlap_left, overlap_right, overlap_top, overlap_bottom
 
280
  )
281
 
282
+ # Seed handling
283
+ if seed is None:
284
+ seed = 0
285
+ generator = torch.Generator(device=DEVICE) if DEVICE == "cuda" else None
286
+ if generator is not None and int(seed) != 0:
287
+ generator.manual_seed(int(seed))
288
+
289
+ result = pipe(
290
+ prompt=prompt or "",
291
+ image=background,
292
+ mask_image=mask,
293
+ guidance_scale=3.5,
294
+ num_inference_steps=int(inference_step),
295
+ generator=generator,
296
  )
297
+ out = result.images[0]
298
+
299
+ # Save to temp file and return PATH
300
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
301
+ out.save(tmp.name)
302
+ return tmp.name
303
+
304
+ # ===== Gradio UI =====
305
+
306
+ with gr.Blocks(css="#wrap {max-width: 1100px; margin: 0 auto;}") as demo:
307
+ gr.Markdown("## ReSize Image Outpainting")
308
+
309
+ with gr.Row(elem_id="wrap"):
310
+ with gr.Column():
311
+ input_image = gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"], height=380)
312
+
313
+ with gr.Row():
314
+ width_slider = gr.Slider(256, 2048, value=720, step=8, label="Target Width")
315
+ height_slider = gr.Slider(256, 2048, value=1280, step=8, label="Target Height")
316
+
317
+ with gr.Row():
318
+ overlap_percentage = gr.Slider(0, 30, value=10, step=1, label="Mask overlap (%)")
319
+ num_inference_steps = gr.Slider(4, 50, value=8, step=1, label="Steps")
320
+
321
+ resize_option = gr.Radio(
322
+ ["Full", "50%", "33%", "25%", "Custom"], value="Full", label="Resize input image"
323
+ )
324
+ custom_resize_percentage = gr.Slider(1, 400, value=50, step=1, label="Custom resize (%)")
325
+
326
+ alignment_dropdown = gr.Dropdown(
327
+ ["Middle", "Left", "Right", "Top", "Bottom"], value="Middle", label="Alignment"
328
+ )
329
+
330
+ with gr.Row():
331
+ overlap_left = gr.Checkbox(value=True, label="Overlap Left")
332
+ overlap_right = gr.Checkbox(value=True, label="Overlap Right")
333
+ overlap_top = gr.Checkbox(value=True, label="Overlap Top")
334
+ overlap_bottom = gr.Checkbox(value=True, label="Overlap Bottom")
335
+
336
+ prompt_input = gr.Textbox(label="Prompt (Optional)", placeholder="extend the scene softly")
337
+
338
+ with gr.Row():
339
+ preview_button = gr.Button("Preview")
340
+ generate_button = gr.Button("Generate")
341
 
342
+ with gr.Column():
343
+ preview_image = gr.Image(label="Preview", height=300)
344
+ slider = gr.Image(label="Generated Image (control vs result)", height=380, show_label=True)
345
+
346
+ # Reactive helpers
347
+ def toggle_custom_resize_slider(resize_option):
348
+ return gr.update(visible=(resize_option == "Custom"))
349
+
350
+ custom_resize_percentage.update(visible=False)
351
  resize_option.change(
352
  fn=toggle_custom_resize_slider,
353
+ inputs=resize_option,
354
+ outputs=custom_resize_percentage
 
355
  )
356
+
357
+ # Hook buttons
358
+ preview_button.click(
359
+ fn=preview_image_and_mask,
360
+ inputs=[input_image, width_slider, height_slider, overlap_percentage,
361
+ resize_option, custom_resize_percentage, alignment_dropdown,
 
 
 
362
  overlap_left, overlap_right, overlap_top, overlap_bottom],
363
+ outputs=preview_image,
364
+ api_name="/preview_image_and_mask"
 
 
 
 
 
 
 
 
 
365
  )
366
 
367
+ def _infer_wrapper(image, width, height, overlap_percentage, num_inference_steps,
368
+ resize_option, custom_resize_percentage, prompt_input, alignment,
369
+ overlap_left, overlap_right, overlap_top, overlap_bottom):
370
+ res = infer(image, width, height, overlap_percentage, num_inference_steps,
371
+ resize_option, custom_resize_percentage, prompt_input, alignment,
372
+ overlap_left, overlap_right, overlap_top, overlap_bottom)
373
+ return res
374
+
375
+ generate_button.click(
376
+ fn=_infer_wrapper,
377
  inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
378
  resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
379
  overlap_left, overlap_right, overlap_top, overlap_bottom],
380
+ outputs=slider,
381
+ api_name="/infer"
 
 
 
 
 
 
 
 
 
382
  )
383
 
384
+ # ===== Hidden API binding for img2img-compatible client =====
385
+ api_output_path = gr.Textbox(visible=False)
386
+ api_trigger = gr.Button(visible=False)
387
+ api_trigger.click(
388
+ fn=process_images,
389
+ inputs=[
390
+ input_image, # image
391
+ prompt_input, # prompt
392
+ gr.Number(value=0.75), # strength (ignored)
393
+ gr.Number(value=0), # seed
394
+ num_inference_steps, # inference_step
395
+ width_slider, # width
396
+ height_slider, # height
397
+ overlap_percentage, # overlap_percentage
398
+ alignment_dropdown # alignment
399
+ ],
400
+ outputs=[api_output_path],
401
+ api_name="/process_images"
402
  )
403
 
404
+ demo.queue(max_size=12).launch(share=False)