Gemini899 commited on
Commit
820711f
·
verified ·
1 Parent(s): 032ce8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +465 -309
app.py CHANGED
@@ -1,47 +1,70 @@
1
- import os
2
- import io
3
- import math
4
- import tempfile
5
- from typing import Tuple
6
-
7
  import gradio as gr
8
  import spaces
9
- from PIL import Image, ImageDraw, ImageOps
10
  import torch
 
 
 
 
11
 
12
- # ===== Pipeline setup =====
13
- # We try to keep quality similar to your current Space by using SDXL Inpainting.
14
- # If CUDA isn't available, it'll fall back to CPU (slower).
15
- try:
16
- from diffusers import StableDiffusionXLInpaintPipeline
17
- except Exception as e:
18
- raise RuntimeError("diffusers is required. Please ensure requirements.txt includes diffusers>=0.27.0") from e
19
 
20
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
- DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
 
22
 
23
- # Prefer the official SDXL inpaint checkpoint
24
- MODEL_ID = os.environ.get("INPAINT_MODEL_ID", "diffusers/stable-diffusion-xl-1.0-inpainting-0.1")
25
 
26
- def _load_pipe():
27
- pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
28
- MODEL_ID, torch_dtype=DTYPE
29
- )
30
- if DEVICE == "cuda":
31
- pipe = pipe.to("cuda")
32
- try:
33
- pipe.enable_xformers_memory_efficient_attention()
34
- except Exception:
35
- pass
36
- else:
37
- pipe = pipe.to("cpu")
38
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- pipe = _load_pipe()
 
 
41
 
42
- # ===== Helpers =====
 
 
43
 
44
- def can_expand(source_width: int, source_height: int, target_width: int, target_height: int, alignment: str) -> bool:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  """Checks if the image can be expanded based on the alignment."""
46
  if alignment in ("Left", "Right") and source_width >= target_width:
47
  return False
@@ -49,228 +72,192 @@ def can_expand(source_width: int, source_height: int, target_width: int, target_
49
  return False
50
  return True
51
 
52
- def _resize_input_for_option(img: Image.Image, resize_option: str, custom_resize_percentage: float) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if resize_option == "Full":
54
- return img
55
- if resize_option in ("50%", "33%", "25%"):
56
- pct = {"50%": 50, "33%": 33, "25%": 25}[resize_option]
 
 
 
 
57
  elif resize_option == "Custom":
58
- pct = max(1, min(400, int(custom_resize_percentage)))
59
  else:
60
- return img
61
- w, h = img.size
62
- nw = max(1, int(w * pct / 100.0))
63
- nh = max(1, int(h * pct / 100.0))
64
- return img.resize((nw, nh), Image.LANCZOS)
65
-
66
- def _place_rect(canvas_w: int, canvas_h: int, img_w: int, img_h: int, alignment: str) -> Tuple[int, int]:
67
- """Top-left placement for given alignment."""
68
- if alignment == "Left":
69
- x = 0
70
- y = (canvas_h - img_h) // 2
 
 
 
 
 
 
 
 
71
  elif alignment == "Right":
72
- x = canvas_w - img_w
73
- y = (canvas_h - img_h) // 2
74
  elif alignment == "Top":
75
- x = (canvas_w - img_w) // 2
76
- y = 0
77
  elif alignment == "Bottom":
78
- x = (canvas_w - img_w) // 2
79
- y = canvas_h - img_h
80
- else: # Middle
81
- x = (canvas_w - img_w) // 2
82
- y = (canvas_h - img_h) // 2
83
- return x, y
84
-
85
- def _apply_side_overlaps(x, y, ow, oh, margin, overlap_left, overlap_right, overlap_top, overlap_bottom):
86
- left = x + (margin if overlap_left else 0)
87
- top = y + (margin if overlap_top else 0)
88
- right = x + ow - (margin if overlap_right else 0)
89
- bottom = y + oh - (margin if overlap_bottom else 0)
90
- # ensure rectangle is valid
91
- if right <= left: right = left + 1
92
- if bottom <= top: bottom = top + 1
93
- return left, top, right, bottom
94
-
95
- def prepare_image_and_mask(
96
- image: Image.Image,
97
- target_w: int,
98
- target_h: int,
99
- overlap_percentage: float,
100
- resize_option: str,
101
- custom_resize_percentage: float,
102
- alignment: str,
103
- overlap_left: bool,
104
- overlap_right: bool,
105
- overlap_top: bool,
106
- overlap_bottom: bool,
107
- ):
108
- """
109
- Returns (background, mask) for inpainting:
110
- - background: RGB, input pasted onto a larger canvas
111
- - mask: L (white = to generate, black = keep)
112
- """
113
- if image is None:
114
- return None, None
115
 
116
- if image.mode != "RGB":
117
- image = image.convert("RGB")
 
118
 
119
- # Optional initial resize for the input
120
- image = _resize_input_for_option(image, resize_option, custom_resize_percentage)
 
121
 
122
- # Canvas size
123
- cw, ch = int(target_w), int(target_h)
124
- iw, ih = image.size
125
- cw = max(cw, iw)
126
- ch = max(ch, ih)
127
 
128
- base = Image.new("RGB", (cw, ch), (0, 0, 0))
129
- x, y = _place_rect(cw, ch, iw, ih, alignment)
130
- base.paste(image, (x, y))
131
 
132
- # Mask creation: white outside the "keep" rect
133
- mask = Image.new("L", (cw, ch), 255)
134
- draw = ImageDraw.Draw(mask)
 
135
 
136
- margin = int(min(iw, ih) * max(0.0, float(overlap_percentage)) / 100.0)
137
- margin = max(0, min(margin, min(iw, ih)//3))
 
 
 
 
 
 
 
138
 
139
- left, top, right, bottom = _apply_side_overlaps(
140
- x, y, iw, ih, margin, overlap_left, overlap_right, overlap_top, overlap_bottom
141
- )
142
- draw.rectangle([left, top, right, bottom], fill=0)
143
-
144
- return base, mask
145
-
146
- # ===== Core inference (UI) =====
147
-
148
- @spaces.GPU(duration=60)
149
- def infer(
150
- image: Image.Image,
151
- width: int = 720,
152
- height: int = 1280,
153
- overlap_percentage: float = 10.0,
154
- num_inference_steps: int = 8,
155
- resize_option: str = "Full",
156
- custom_resize_percentage: float = 50.0,
157
- prompt_input: str = "",
158
- alignment: str = "Middle",
159
- overlap_left: bool = True,
160
- overlap_right: bool = True,
161
- overlap_top: bool = True,
162
- overlap_bottom: bool = True,
163
- ):
164
- """
165
- UI endpoint that returns an ImageSlider-compatible tuple:
166
- (control_preview_image, generated_image)
167
- """
168
- if image is None:
169
- return None
170
 
171
- # safety: if alignment can't expand, center instead
172
- iw, ih = image.size
173
- if not can_expand(iw, ih, int(width), int(height), alignment):
174
- alignment = "Middle"
175
 
 
 
 
176
  background, mask = prepare_image_and_mask(
177
- image, int(width), int(height), float(overlap_percentage),
178
- resize_option, float(custom_resize_percentage), alignment,
179
  overlap_left, overlap_right, overlap_top, overlap_bottom
180
  )
181
- if background is None:
182
- return None
183
 
184
- # Control preview: show masked area in black overlay
185
- control_preview = background.copy()
186
- control_overlay = Image.new("RGB", control_preview.size, (0, 0, 0))
187
- control_preview.paste(control_overlay, (0, 0), mask)
188
-
189
- # Seed/generator
190
- generator = None
191
- if DEVICE == "cuda":
192
- generator = torch.Generator(device="cuda")
193
- if generator is not None:
194
- generator.manual_seed(torch.seed())
195
-
196
- # Run inpainting
197
- result = pipe(
198
- prompt=prompt_input or "",
199
- image=background,
200
- mask_image=mask,
201
- guidance_scale=3.5,
202
- num_inference_steps=int(num_inference_steps),
203
- generator=generator,
204
- )
205
- out = result.images[0]
206
-
207
- # Return slider tuple
208
- return (control_preview, out)
209
-
210
- # ===== Preview helper =====
211
-
212
- def preview_image_and_mask(
213
- image: Image.Image,
214
- width: int,
215
- height: int,
216
- overlap_percentage: float,
217
- resize_option: str,
218
- custom_resize_percentage: float,
219
- alignment: str,
220
- overlap_left: bool,
221
- overlap_right: bool,
222
- overlap_top: bool,
223
- overlap_bottom: bool,
224
- ):
225
- """
226
- Return a single preview image for the UI.
227
- """
228
- if image is None:
229
- return None
230
 
231
  background, mask = prepare_image_and_mask(
232
- image, int(width), int(height), float(overlap_percentage),
233
- resize_option, float(custom_resize_percentage), alignment,
234
  overlap_left, overlap_right, overlap_top, overlap_bottom
235
  )
236
- if background is None:
237
- return None
238
 
239
- preview = background.copy()
240
- overlay = Image.new("RGBA", preview.size, (255, 0, 0, 90))
241
- preview.paste(overlay, (0, 0), mask)
242
- return preview
243
-
244
- # ===== img2img-style API (single image path string) =====
245
 
246
- @spaces.GPU(duration=60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def process_images(
248
- image: Image.Image,
249
- prompt: str = "",
250
- strength: float = 0.75, # kept for client parity; unused by SDXL inpaint
251
- seed: int = 0,
252
- inference_step: int = 8,
253
- width: int = 720,
254
- height: int = 1280,
255
- overlap_percentage: float = 10.0,
256
- alignment: str = "Middle",
257
  ):
258
- """
259
- Adapter endpoint to match your img2img client contract:
260
- - accepts a single file input
261
- - returns a single file path (string)
262
- - internally reuses the same preparation and inpaint call as the UI
263
- """
264
  if image is None:
265
  return None
266
 
267
- iw, ih = image.size
268
- if not can_expand(iw, ih, int(width), int(height), alignment):
269
- alignment = "Middle"
270
-
271
- # Use the same defaults as the UI
272
  resize_option = "Full"
273
- custom_resize_percentage = 50.0
274
  overlap_left = overlap_right = overlap_top = overlap_bottom = True
275
 
276
  background, mask = prepare_image_and_mask(
@@ -279,109 +266,278 @@ def process_images(
279
  overlap_left, overlap_right, overlap_top, overlap_bottom
280
  )
281
 
282
- # Seed handling
283
- if seed is None:
284
- seed = 0
285
- generator = torch.Generator(device=DEVICE) if DEVICE == "cuda" else None
286
- if generator is not None and int(seed) != 0:
287
- generator.manual_seed(int(seed))
288
-
289
- result = pipe(
290
- prompt=prompt or "",
291
- image=background,
292
- mask_image=mask,
293
- guidance_scale=3.5,
294
- num_inference_steps=int(inference_step),
295
- generator=generator,
296
- )
297
- out = result.images[0]
 
 
 
 
 
 
 
 
 
 
298
 
299
- # Save to temp file and return PATH
300
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
301
- out.save(tmp.name)
302
  return tmp.name
303
 
304
- # ===== Gradio UI =====
305
-
306
- with gr.Blocks(css="#wrap {max-width: 1100px; margin: 0 auto;}") as demo:
307
- gr.Markdown("## ReSize Image Outpainting")
308
-
309
- with gr.Row(elem_id="wrap"):
310
- with gr.Column():
311
- input_image = gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"], height=380)
312
-
313
- with gr.Row():
314
- width_slider = gr.Slider(256, 2048, value=720, step=8, label="Target Width")
315
- height_slider = gr.Slider(256, 2048, value=1280, step=8, label="Target Height")
316
-
317
- with gr.Row():
318
- overlap_percentage = gr.Slider(0, 30, value=10, step=1, label="Mask overlap (%)")
319
- num_inference_steps = gr.Slider(4, 50, value=8, step=1, label="Steps")
320
-
321
- resize_option = gr.Radio(
322
- ["Full", "50%", "33%", "25%", "Custom"], value="Full", label="Resize input image"
323
- )
324
- custom_resize_percentage = gr.Slider(1, 400, value=50, step=1, label="Custom resize (%)")
325
 
326
- alignment_dropdown = gr.Dropdown(
327
- ["Middle", "Left", "Right", "Top", "Bottom"], value="Middle", label="Alignment"
328
- )
329
-
330
- with gr.Row():
331
- overlap_left = gr.Checkbox(value=True, label="Overlap Left")
332
- overlap_right = gr.Checkbox(value=True, label="Overlap Right")
333
- overlap_top = gr.Checkbox(value=True, label="Overlap Top")
334
- overlap_bottom = gr.Checkbox(value=True, label="Overlap Bottom")
335
-
336
- prompt_input = gr.Textbox(label="Prompt (Optional)", placeholder="extend the scene softly")
337
-
338
- with gr.Row():
339
- preview_button = gr.Button("Preview")
340
- generate_button = gr.Button("Generate")
341
-
342
- with gr.Column():
343
- preview_image = gr.Image(label="Preview", height=300)
344
- slider = gr.Image(label="Generated Image (control vs result)", height=380, show_label=True)
345
-
346
- # Reactive helpers
347
- def toggle_custom_resize_slider(resize_option):
348
- return gr.update(visible=(resize_option == "Custom"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
- custom_resize_percentage.update(visible=False)
351
  resize_option.change(
352
  fn=toggle_custom_resize_slider,
353
- inputs=resize_option,
354
- outputs=custom_resize_percentage
 
355
  )
356
 
357
- # Hook buttons
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  preview_button.click(
359
  fn=preview_image_and_mask,
360
  inputs=[input_image, width_slider, height_slider, overlap_percentage,
361
- resize_option, custom_resize_percentage, alignment_dropdown,
362
  overlap_left, overlap_right, overlap_top, overlap_bottom],
363
  outputs=preview_image,
364
- api_name="/preview_image_and_mask"
365
  )
366
 
367
- def _infer_wrapper(image, width, height, overlap_percentage, num_inference_steps,
368
- resize_option, custom_resize_percentage, prompt_input, alignment,
369
- overlap_left, overlap_right, overlap_top, overlap_bottom):
370
- res = infer(image, width, height, overlap_percentage, num_inference_steps,
371
- resize_option, custom_resize_percentage, prompt_input, alignment,
372
- overlap_left, overlap_right, overlap_top, overlap_bottom)
373
- return res
374
-
375
- generate_button.click(
376
- fn=_infer_wrapper,
377
- inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
378
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
379
- overlap_left, overlap_right, overlap_top, overlap_bottom],
380
- outputs=slider,
381
- api_name="/infer"
382
  )
383
 
384
  # ===== Hidden API binding for img2img-compatible client =====
 
385
  api_output_path = gr.Textbox(visible=False)
386
  api_trigger = gr.Button(visible=False)
387
  api_trigger.click(
@@ -395,7 +551,7 @@ with gr.Blocks(css="#wrap {max-width: 1100px; margin: 0 auto;}") as demo:
395
  width_slider, # width
396
  height_slider, # height
397
  overlap_percentage, # overlap_percentage
398
- alignment_dropdown # alignment
399
  ],
400
  outputs=[api_output_path],
401
  api_name="/process_images"
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
 
3
  import torch
4
+ from diffusers import AutoencoderKL, TCDScheduler
5
+ from diffusers.models.model_loading_utils import load_state_dict
6
+ from gradio_imageslider import ImageSlider
7
+ from huggingface_hub import hf_hub_download
8
 
9
+ from controlnet_union import ControlNetModel_Union
10
+ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
 
 
 
 
 
11
 
12
+ from PIL import Image, ImageDraw
13
+ import numpy as np
14
+ import tempfile
15
 
 
 
16
 
17
+ # ---------------------------
18
+ # Load ControlNet-Union + VAE + SDXL Fill pipeline (same as your Space)
19
+ # ---------------------------
20
+
21
+ config_file = hf_hub_download(
22
+ "xinsir/controlnet-union-sdxl-1.0",
23
+ filename="config_promax.json",
24
+ )
25
+
26
+ config = ControlNetModel_Union.load_config(config_file)
27
+ controlnet_model = ControlNetModel_Union.from_config(config)
28
+
29
+ # Load the state dictionary
30
+ model_file = hf_hub_download(
31
+ "xinsir/controlnet-union-sdxl-1.0",
32
+ filename="diffusion_pytorch_model_promax.safetensors",
33
+ )
34
+ state_dict = load_state_dict(model_file)
35
+
36
+ # Extract the keys from the state_dict
37
+ loaded_keys = list(state_dict.keys())
38
+
39
+ # Call the method and store all returns in a variable
40
+ result = ControlNetModel_Union._load_pretrained_model(
41
+ controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0", loaded_keys
42
+ )
43
 
44
+ # Use the first element from the result
45
+ model = result[0]
46
+ model = model.to(device="cuda", dtype=torch.float16)
47
 
48
+ vae = AutoencoderKL.from_pretrained(
49
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
50
+ ).to("cuda")
51
 
52
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
53
+ "SG161222/RealVisXL_V5.0_Lightning",
54
+ torch_dtype=torch.float16,
55
+ vae=vae,
56
+ controlnet=model,
57
+ variant="fp16",
58
+ ).to("cuda")
59
+
60
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
61
+
62
+
63
+ # ---------------------------
64
+ # Helpers (unchanged behavior)
65
+ # ---------------------------
66
+
67
+ def can_expand(source_width, source_height, target_width, target_height, alignment):
68
  """Checks if the image can be expanded based on the alignment."""
69
  if alignment in ("Left", "Right") and source_width >= target_width:
70
  return False
 
72
  return False
73
  return True
74
 
75
+
76
+ def prepare_image_and_mask(image, width, height, overlap_percentage,
77
+ resize_option, custom_resize_percentage, alignment,
78
+ overlap_left, overlap_right, overlap_top, overlap_bottom):
79
+ target_size = (int(width), int(height))
80
+
81
+ # Calculate the scaling factor to fit the image within the target size
82
+ scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
83
+ new_width = int(image.width * scale_factor)
84
+ new_height = int(image.height * scale_factor)
85
+
86
+ # Resize the source image to fit within target size
87
+ source = image.resize((new_width, new_height), Image.LANCZOS)
88
+
89
+ # Apply resize option using percentages
90
  if resize_option == "Full":
91
+ resize_percentage = 100
92
+ elif resize_option == "50%":
93
+ resize_percentage = 50
94
+ elif resize_option == "33%":
95
+ resize_percentage = 33
96
+ elif resize_option == "25%":
97
+ resize_percentage = 25
98
  elif resize_option == "Custom":
99
+ resize_percentage = max(1, min(400, int(custom_resize_percentage)))
100
  else:
101
+ resize_percentage = 100
102
+
103
+ # Apply the resize percentage to the already fitted source
104
+ resize_factor = resize_percentage / 100.0
105
+ new_width = max(64, int(source.width * resize_factor))
106
+ new_height = max(64, int(source.height * resize_factor))
107
+ source = source.resize((new_width, new_height), Image.LANCZOS)
108
+
109
+ # Calculate the overlap in pixels based on the percentage
110
+ overlap_x = max(1, int(new_width * (float(overlap_percentage) / 100.0)))
111
+ overlap_y = max(1, int(new_height * (float(overlap_percentage) / 100.0)))
112
+
113
+ # Calculate margins based on alignment
114
+ if alignment == "Middle":
115
+ margin_x = (target_size[0] - new_width) // 2
116
+ margin_y = (target_size[1] - new_height) // 2
117
+ elif alignment == "Left":
118
+ margin_x = 0
119
+ margin_y = (target_size[1] - new_height) // 2
120
  elif alignment == "Right":
121
+ margin_x = target_size[0] - new_width
122
+ margin_y = (target_size[1] - new_height) // 2
123
  elif alignment == "Top":
124
+ margin_x = (target_size[0] - new_width) // 2
125
+ margin_y = 0
126
  elif alignment == "Bottom":
127
+ margin_x = (target_size[0] - new_width) // 2
128
+ margin_y = target_size[1] - new_height
129
+ else:
130
+ margin_x = (target_size[0] - new_width) // 2
131
+ margin_y = (target_size[1] - new_height) // 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # Adjust margins to eliminate gaps
134
+ margin_x = max(0, min(margin_x, target_size[0] - new_width))
135
+ margin_y = max(0, min(margin_y, target_size[1] - new_height))
136
 
137
+ # Create a new background image and paste the resized source image
138
+ background = Image.new('RGB', target_size, (255, 255, 255))
139
+ background.paste(source, (margin_x, margin_y))
140
 
141
+ # Create the mask
142
+ mask = Image.new('L', target_size, 255)
143
+ mask_draw = ImageDraw.Draw(mask)
 
 
144
 
145
+ # Calculate overlap areas
146
+ white_gaps_patch = 2
 
147
 
148
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
149
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
150
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
151
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
152
 
153
+ # Tighten edges further depending on chosen alignment
154
+ if alignment == "Left":
155
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x
156
+ elif alignment == "Right":
157
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
158
+ elif alignment == "Top":
159
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y
160
+ elif alignment == "Bottom":
161
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
162
 
163
+ # Draw the mask (black = keep, white = generate)
164
+ mask_draw.rectangle([
165
+ (left_overlap, top_overlap),
166
+ (right_overlap, bottom_overlap)
167
+ ], fill=0)
168
+
169
+ return background, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
 
171
 
172
+ def preview_image_and_mask(image, width, height, overlap_percentage,
173
+ resize_option, custom_resize_percentage, alignment,
174
+ overlap_left, overlap_right, overlap_top, overlap_bottom):
175
  background, mask = prepare_image_and_mask(
176
+ image, width, height, overlap_percentage,
177
+ resize_option, custom_resize_percentage, alignment,
178
  overlap_left, overlap_right, overlap_top, overlap_bottom
179
  )
 
 
180
 
181
+ # Create a preview image showing the mask overlay
182
+ preview = background.copy().convert('RGBA')
183
+ red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64))
184
+ red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
185
+ red_mask.paste(red_overlay, (0, 0), mask)
186
+ preview = Image.alpha_composite(preview, red_mask)
187
+ return preview
188
+
189
+
190
+ # ---------------------------
191
+ # Main UI inference (returns ImageSlider tuple)
192
+ # ---------------------------
193
+
194
+ @spaces.GPU(duration=24)
195
+ def infer(image, width, height, overlap_percentage, num_inference_steps,
196
+ resize_option, custom_resize_percentage, prompt_input, alignment,
197
+ overlap_left, overlap_right, overlap_top, overlap_bottom):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  background, mask = prepare_image_and_mask(
200
+ image, width, height, overlap_percentage,
201
+ resize_option, custom_resize_percentage, alignment,
202
  overlap_left, overlap_right, overlap_top, overlap_bottom
203
  )
 
 
204
 
205
+ if not can_expand(background.width, background.height, width, height, alignment):
206
+ alignment = "Middle"
 
 
 
 
207
 
208
+ cnet_image = background.copy()
209
+ cnet_image.paste(0, (0, 0), mask)
210
+
211
+ final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
212
+
213
+ # Encode prompt + run pipeline yielding previews then final
214
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
215
+ (
216
+ prompt_embeds,
217
+ negative_prompt_embeds,
218
+ pooled_prompt_embeds,
219
+ negative_pooled_prompt_embeds,
220
+ ) = pipe.encode_prompt(final_prompt, "cuda", True)
221
+
222
+ for image in pipe(
223
+ prompt_embeds=prompt_embeds,
224
+ negative_prompt_embeds=negative_prompt_embeds,
225
+ pooled_prompt_embeds=pooled_prompt_embeds,
226
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
227
+ image=cnet_image,
228
+ num_inference_steps=num_inference_steps
229
+ ):
230
+ # Streaming preview to slider (left = control, right = preview)
231
+ yield cnet_image, image
232
+
233
+ # Final composite (place the original inside the masked area)
234
+ image = image.convert("RGBA")
235
+ cnet_image.paste(image, (0, 0), mask)
236
+ yield background, cnet_image
237
+
238
+
239
+ # ---------------------------
240
+ # img2img-style API: /process_images (single file path string)
241
+ # ---------------------------
242
+
243
+ @spaces.GPU(duration=24)
244
  def process_images(
245
+ image, # PIL image from handle_file
246
+ prompt="", # str
247
+ strength=0.75, # kept for client parity; unused
248
+ seed=0, # int
249
+ inference_step=8, # int
250
+ width=720, # int
251
+ height=1280, # int
252
+ overlap_percentage=10, # float
253
+ alignment="Middle", # str
254
  ):
 
 
 
 
 
 
255
  if image is None:
256
  return None
257
 
258
+ # Use same prep as UI
 
 
 
 
259
  resize_option = "Full"
260
+ custom_resize_percentage = 50
261
  overlap_left = overlap_right = overlap_top = overlap_bottom = True
262
 
263
  background, mask = prepare_image_and_mask(
 
266
  overlap_left, overlap_right, overlap_top, overlap_bottom
267
  )
268
 
269
+ cnet_image = background.copy()
270
+ cnet_image.paste(0, (0, 0), mask)
271
+
272
+ final_prompt = f"{prompt} , high quality, 4k" if prompt else "high quality, 4k"
273
+
274
+ last_img = None
275
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
276
+ (
277
+ prompt_embeds,
278
+ negative_prompt_embeds,
279
+ pooled_prompt_embeds,
280
+ negative_pooled_prompt_embeds,
281
+ ) = pipe.encode_prompt(final_prompt, "cuda", True)
282
+
283
+ for gen_img in pipe(
284
+ prompt_embeds=prompt_embeds,
285
+ negative_prompt_embeds=negative_prompt_embeds,
286
+ pooled_prompt_embeds=pooled_prompt_embeds,
287
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
288
+ image=cnet_image,
289
+ num_inference_steps=int(inference_step)
290
+ ):
291
+ last_img = gen_img
292
+
293
+ if last_img is None:
294
+ return None
295
 
 
296
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
297
+ last_img.save(tmp.name)
298
  return tmp.name
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
+ # ---------------------------
302
+ # Misc helpers & UI wiring
303
+ # ---------------------------
304
+
305
+ def clear_result():
306
+ """Clears the result ImageSlider."""
307
+ return gr.update(value=None)
308
+
309
+ def preload_presets(target_ratio, ui_width, ui_height):
310
+ """Updates the width and height sliders based on the selected aspect ratio."""
311
+ if target_ratio == "9:16":
312
+ changed_width = 720
313
+ changed_height = 1280
314
+ return changed_width, changed_height, gr.update()
315
+ elif target_ratio == "16:9":
316
+ changed_width = 1280
317
+ changed_height = 720
318
+ return changed_width, changed_height, gr.update()
319
+ elif target_ratio == "1:1":
320
+ changed_width = 1024
321
+ changed_height = 1024
322
+ return changed_width, changed_height, gr.update()
323
+ else:
324
+ return ui_width, ui_height, gr.update()
325
+
326
+ def select_the_right_preset(user_width, user_height):
327
+ """Chooses the closest preset by ratio (for display)."""
328
+ ratio = user_width / max(1, user_height)
329
+ if abs(ratio - (9/16)) < 0.05:
330
+ return "9:16"
331
+ if abs(ratio - (16/9)) < 0.05:
332
+ return "16:9"
333
+ if abs(ratio - 1.0) < 0.05:
334
+ return "1:1"
335
+ return "Custom"
336
+
337
+ def toggle_custom_resize_slider(resize_option):
338
+ """Controls visibility of the custom resize slider."""
339
+ return gr.update(visible=(resize_option == "Custom"))
340
+
341
+ def use_output_as_input(x):
342
+ """API bridge for ImageSlider -> Image. Returns right-hand image as next input."""
343
+ if not x:
344
+ return None
345
+ if isinstance(x, (list, tuple)) and len(x) >= 2:
346
+ # return the generated (right) image
347
+ return x[1]
348
+ return None
349
+
350
+ def update_history(new_image, history):
351
+ """Updates the history gallery with the new image."""
352
+ if history is None:
353
+ history = []
354
+ history.insert(0, new_image)
355
+ return history
356
+
357
+
358
+ css = """
359
+ .gradio-container {
360
+ width: 1200px !important;
361
+ }
362
+ """
363
+
364
+ title = """<h1 align="center">Re-Size Image Outpaint</h1>
365
+ <p align="center">Extend images with ControlNet-Union SDXL fill — with an ImageSlider preview.</p>
366
+ """
367
+
368
+ with gr.Blocks(theme="soft", css=css) as demo:
369
+ with gr.Column():
370
+ gr.HTML(title)
371
+
372
+ with gr.Row():
373
+ with gr.Column():
374
+ input_image = gr.Image(
375
+ type="pil",
376
+ label="Input Image"
377
+ )
378
+
379
+ with gr.Row():
380
+ with gr.Column(scale=2):
381
+ prompt_input = gr.Textbox(label="Prompt (Optional)")
382
+
383
+ with gr.Row():
384
+ with gr.Column(scale=2):
385
+ target_ratio = gr.Radio(
386
+ ["9:16", "16:9", "1:1", "Custom"], value="9:16", label="Expected Ratio"
387
+ )
388
+ with gr.Row():
389
+ width_slider = gr.Slider(
390
+ label="Target Width",
391
+ minimum=512,
392
+ maximum=1536,
393
+ step=8,
394
+ value=720,
395
+ )
396
+ height_slider = gr.Slider(
397
+ label="Target Height",
398
+ minimum=720,
399
+ maximum=1536,
400
+ step=8,
401
+ value=1280,
402
+ )
403
+
404
+ num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
405
+ with gr.Group():
406
+ overlap_percentage = gr.Slider(
407
+ label="Mask overlap (%)",
408
+ minimum=1,
409
+ maximum=50,
410
+ value=10,
411
+ step=1
412
+ )
413
+ with gr.Row():
414
+ overlap_top = gr.Checkbox(label="Overlap Top", value=True)
415
+ overlap_right = gr.Checkbox(label="Overlap Right", value=True)
416
+ with gr.Row():
417
+ overlap_left = gr.Checkbox(label="Overlap Left", value=True)
418
+ overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
419
+
420
+ with gr.Column(scale=1):
421
+ with gr.Group():
422
+ resize_option = gr.Radio(
423
+ label="Resize input image",
424
+ choices=["Full", "50%", "33%", "25%", "Custom"],
425
+ value="Full"
426
+ )
427
+ # FIX: set visibility here, do NOT call .update() on a component
428
+ custom_resize_percentage = gr.Slider(
429
+ label="Custom resize (%)",
430
+ minimum=1,
431
+ maximum=100,
432
+ step=1,
433
+ value=50,
434
+ visible=False,
435
+ )
436
+
437
+ with gr.Column():
438
+ preview_button = gr.Button("Preview alignment and mask")
439
+
440
+ gr.Examples(
441
+ examples=[
442
+ ["./examples/example_2.jpg", 1440, 810, "Left"],
443
+ ["./examples/example_3.jpg", 1024, 1024, "Top"],
444
+ ["./examples/example_3.jpg", 1024, 1024, "Bottom"],
445
+ ],
446
+ inputs=[input_image, width_slider, height_slider, target_ratio],
447
+ label="Quick examples",
448
+ )
449
+
450
+ with gr.Column():
451
+ preview_image = gr.Image(label="Preview", height=300)
452
+ result = ImageSlider(
453
+ label="Generated Image",
454
+ elem_id="gen_slider",
455
+ show_label=True,
456
+ interactive=False,
457
+ )
458
+ run_button = gr.Button("Generate", variant="primary")
459
+ use_as_input_button = gr.Button("Use output as input", visible=False)
460
+ history_gallery = gr.Gallery(label="History", columns=4, height=220)
461
+
462
+ # Radio preset to width/height
463
+ target_ratio.change(
464
+ fn=preload_presets,
465
+ inputs=[target_ratio, width_slider, height_slider],
466
+ outputs=[width_slider, height_slider, gr.State()],
467
+ queue=False
468
+ )
469
 
470
+ # Toggle custom resize slider visibility
471
  resize_option.change(
472
  fn=toggle_custom_resize_slider,
473
+ inputs=[resize_option],
474
+ outputs=[custom_resize_percentage],
475
+ queue=False
476
  )
477
 
478
+ # Generate flow: clear slider -> stream infer -> update history -> show "use as input"
479
+ run_button.click(
480
+ fn=clear_result,
481
+ inputs=None,
482
+ outputs=result,
483
+ ).then(
484
+ fn=infer,
485
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
486
+ resize_option, custom_resize_percentage, prompt_input, target_ratio,
487
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
488
+ outputs=result,
489
+ ).then(
490
+ # Safely update history only if the result is not None
491
+ fn=lambda x, history: update_history(x[1], history) if x else history,
492
+ inputs=[result, history_gallery],
493
+ outputs=history_gallery,
494
+ ).then(
495
+ fn=lambda: gr.update(visible=True),
496
+ inputs=None,
497
+ outputs=use_as_input_button,
498
+ )
499
+
500
+ # Enter in prompt also triggers generate flow
501
+ prompt_input.submit(
502
+ fn=clear_result,
503
+ inputs=None,
504
+ outputs=result,
505
+ ).then(
506
+ fn=infer,
507
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
508
+ resize_option, custom_resize_percentage, prompt_input, target_ratio,
509
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
510
+ outputs=result,
511
+ ).then(
512
+ fn=lambda x, history: update_history(x[1], history) if x else history,
513
+ inputs=[result, history_gallery],
514
+ outputs=history_gallery,
515
+ ).then(
516
+ fn=lambda: gr.update(visible=True),
517
+ inputs=None,
518
+ outputs=use_as_input_button,
519
+ )
520
+
521
+ # Preview button
522
  preview_button.click(
523
  fn=preview_image_and_mask,
524
  inputs=[input_image, width_slider, height_slider, overlap_percentage,
525
+ resize_option, custom_resize_percentage, target_ratio,
526
  overlap_left, overlap_right, overlap_top, overlap_bottom],
527
  outputs=preview_image,
528
+ queue=False
529
  )
530
 
531
+ # Use output as next input (ImageSlider -> Image)
532
+ use_as_input_button.click(
533
+ fn=use_output_as_input,
534
+ inputs=[result],
535
+ outputs=[input_image],
536
+ queue=False
 
 
 
 
 
 
 
 
 
537
  )
538
 
539
  # ===== Hidden API binding for img2img-compatible client =====
540
+ # Returns a single PATH string (so your client can copy/handle it exactly like img2img)
541
  api_output_path = gr.Textbox(visible=False)
542
  api_trigger = gr.Button(visible=False)
543
  api_trigger.click(
 
551
  width_slider, # width
552
  height_slider, # height
553
  overlap_percentage, # overlap_percentage
554
+ target_ratio # alignment (reusing same dropdown in this UI)
555
  ],
556
  outputs=[api_output_path],
557
  api_name="/process_images"