DawnC commited on
Commit
991a517
·
verified ·
1 Parent(s): 9b38240

Upload 15 files

Browse files
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import sys
3
  import traceback
4
  import warnings
@@ -6,45 +5,6 @@ warnings.filterwarnings("ignore")
6
 
7
  from ui_manager import UIManager
8
 
9
- def preload_models_to_cache():
10
- """
11
- Pre-download models to HuggingFace cache before GPU allocation.
12
- This runs on CPU and avoids downloading during @spaces.GPU execution.
13
- """
14
- if not os.getenv('SPACE_ID'):
15
- return # Skip if not on Spaces
16
-
17
- print("📦 Pre-downloading models to cache (CPU only, no GPU usage)...")
18
-
19
- try:
20
- from diffusers import ControlNetModel
21
- import torch
22
-
23
- # Pre-download ControlNet models to cache
24
- models_to_cache = [
25
- ("diffusers/controlnet-canny-sdxl-1.0", "Canny ControlNet"),
26
- ("diffusers/controlnet-depth-sdxl-1.0", "Depth ControlNet"),
27
- ]
28
-
29
- for model_id, model_name in models_to_cache:
30
- print(f" ⬇️ Downloading {model_name} ({model_id})...")
31
- try:
32
- _ = ControlNetModel.from_pretrained(
33
- model_id,
34
- torch_dtype=torch.float16,
35
- use_safetensors=True,
36
- local_files_only=False # Allow download
37
- )
38
- print(f" ✅ {model_name} cached")
39
- except Exception as e:
40
- print(f" ⚠️ {model_name} download failed (will retry on-demand): {e}")
41
-
42
- print("✅ Model pre-caching complete")
43
-
44
- except Exception as e:
45
- print(f"⚠️ Model pre-caching failed: {e}")
46
- print(" Models will be downloaded on first use instead.")
47
-
48
  def launch_final_blend_sceneweaver(share: bool = True, debug: bool = False):
49
  """Launch SceneWeaver Application"""
50
 
@@ -52,9 +12,6 @@ def launch_final_blend_sceneweaver(share: bool = True, debug: bool = False):
52
  print("✨ AI-Powered Image Background Generation")
53
 
54
  try:
55
- # Pre-download models on Spaces to avoid downloading during GPU time
56
- preload_models_to_cache()
57
-
58
  # Test imports first
59
  print("🔍 Testing imports...")
60
  try:
@@ -63,13 +20,6 @@ def launch_final_blend_sceneweaver(share: bool = True, debug: bool = False):
63
  ui = UIManager()
64
  print("✅ UIManager instance created successfully")
65
 
66
- # Note: On Hugging Face Spaces, models are pre-cached at startup
67
- if os.getenv('SPACE_ID'):
68
- print("\n🔧 Detected Hugging Face Spaces environment")
69
- print("⚡ Models pre-cached - ready for fast inference")
70
- print(" Expected inference time: ~300-350s (with cached models)")
71
- print()
72
-
73
  # Launch UI
74
  print("🚀 Launching interface...")
75
  interface = ui.launch(share=share, debug=debug)
@@ -128,4 +78,4 @@ def main():
128
  raise
129
 
130
  if __name__ == "__main__":
131
- main()
 
 
1
  import sys
2
  import traceback
3
  import warnings
 
5
 
6
  from ui_manager import UIManager
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def launch_final_blend_sceneweaver(share: bool = True, debug: bool = False):
9
  """Launch SceneWeaver Application"""
10
 
 
12
  print("✨ AI-Powered Image Background Generation")
13
 
14
  try:
 
 
 
15
  # Test imports first
16
  print("🔍 Testing imports...")
17
  try:
 
20
  ui = UIManager()
21
  print("✅ UIManager instance created successfully")
22
 
 
 
 
 
 
 
 
23
  # Launch UI
24
  print("🚀 Launching interface...")
25
  interface = ui.launch(share=share, debug=debug)
 
78
  raise
79
 
80
  if __name__ == "__main__":
81
+ main()
control_image_processor.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image, ImageFilter
8
+
9
+ from transformers import AutoImageProcessor, AutoModelForDepthEstimation
10
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
11
+
12
+ logger = logging.getLogger(__name__)
13
+ logger.setLevel(logging.INFO)
14
+
15
+
16
+ class ControlImageProcessor:
17
+ """
18
+ Generates control images for ControlNet conditioning.
19
+
20
+ Supports Canny edge detection and depth map estimation with
21
+ mask-aware processing for selective structure preservation.
22
+
23
+ Attributes:
24
+ device: Computation device (cuda/mps/cpu)
25
+ canny_low_threshold: Low threshold for Canny edge detection
26
+ canny_high_threshold: High threshold for Canny edge detection
27
+
28
+ Example:
29
+ >>> processor = ControlImageProcessor(device="cuda")
30
+ >>> canny_image = processor.generate_canny_edges(image)
31
+ >>> depth_map = processor.generate_depth_map(image)
32
+ """
33
+
34
+ # Depth model identifiers
35
+ DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf"
36
+ DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas"
37
+
38
+ def __init__(
39
+ self,
40
+ device: str = "cuda",
41
+ canny_low_threshold: int = 100,
42
+ canny_high_threshold: int = 200
43
+ ):
44
+ """
45
+ Initialize the ControlImageProcessor.
46
+
47
+ Parameters
48
+ ----------
49
+ device : str
50
+ Computation device
51
+ canny_low_threshold : int
52
+ Low threshold for Canny edge detection
53
+ canny_high_threshold : int
54
+ High threshold for Canny edge detection
55
+ """
56
+ self.device = device
57
+ self.canny_low_threshold = canny_low_threshold
58
+ self.canny_high_threshold = canny_high_threshold
59
+
60
+ # Depth estimation models (lazy loaded)
61
+ self._depth_estimator = None
62
+ self._depth_processor = None
63
+ self._depth_model_loaded = False
64
+
65
+ logger.info(f"ControlImageProcessor initialized on {device}")
66
+
67
+ def generate_canny_edges(self, image: np.ndarray) -> Image.Image:
68
+ """
69
+ Generate Canny edge detection image.
70
+
71
+ Parameters
72
+ ----------
73
+ image : np.ndarray
74
+ Input image as numpy array (RGB)
75
+
76
+ Returns
77
+ -------
78
+ PIL.Image
79
+ Canny edge image (grayscale)
80
+ """
81
+ # Convert to grayscale
82
+ if len(image.shape) == 3:
83
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
84
+ else:
85
+ gray = image
86
+
87
+ # Apply Gaussian blur to reduce noise
88
+ blurred = cv2.GaussianBlur(gray, (5, 5), 1.4)
89
+
90
+ # Canny edge detection
91
+ edges = cv2.Canny(
92
+ blurred,
93
+ self.canny_low_threshold,
94
+ self.canny_high_threshold
95
+ )
96
+
97
+ # Convert to 3-channel for ControlNet
98
+ edges_3ch = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
99
+
100
+ return Image.fromarray(edges_3ch)
101
+
102
+ def load_depth_estimator(self) -> bool:
103
+ """
104
+ Load depth estimation model.
105
+
106
+ Returns
107
+ -------
108
+ bool
109
+ True if loaded successfully
110
+ """
111
+ if self._depth_model_loaded:
112
+ return True
113
+
114
+ logger.info("Loading depth estimation model...")
115
+
116
+ try:
117
+ # Try primary model first (Depth Anything)
118
+ self._depth_processor = AutoImageProcessor.from_pretrained(
119
+ self.DEPTH_MODEL_PRIMARY
120
+ )
121
+ self._depth_estimator = AutoModelForDepthEstimation.from_pretrained(
122
+ self.DEPTH_MODEL_PRIMARY,
123
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
124
+ )
125
+ self._depth_estimator = self._depth_estimator.to(self.device)
126
+ self._depth_estimator.eval()
127
+ self._depth_model_loaded = True
128
+ logger.info(f"Loaded depth model: {self.DEPTH_MODEL_PRIMARY}")
129
+ return True
130
+
131
+ except Exception as e:
132
+ logger.warning(f"Primary depth model failed: {e}, trying fallback...")
133
+
134
+ try:
135
+ # Fallback to DPT
136
+ self._depth_processor = DPTImageProcessor.from_pretrained(
137
+ self.DEPTH_MODEL_FALLBACK
138
+ )
139
+ self._depth_estimator = DPTForDepthEstimation.from_pretrained(
140
+ self.DEPTH_MODEL_FALLBACK,
141
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
142
+ )
143
+ self._depth_estimator = self._depth_estimator.to(self.device)
144
+ self._depth_estimator.eval()
145
+ self._depth_model_loaded = True
146
+ logger.info(f"Loaded fallback depth model: {self.DEPTH_MODEL_FALLBACK}")
147
+ return True
148
+
149
+ except Exception as e2:
150
+ logger.error(f"All depth models failed: {e2}")
151
+ return False
152
+
153
+ def generate_depth_map(self, image: Image.Image) -> Image.Image:
154
+ """
155
+ Generate depth map using depth estimation model.
156
+
157
+ Parameters
158
+ ----------
159
+ image : PIL.Image
160
+ Input image
161
+
162
+ Returns
163
+ -------
164
+ PIL.Image
165
+ Depth map image (grayscale, normalized)
166
+ """
167
+ if not self._depth_model_loaded:
168
+ if not self.load_depth_estimator():
169
+ # Fallback to simple gradient
170
+ logger.warning("Using fallback gradient depth")
171
+ return self._generate_fallback_depth(image)
172
+
173
+ try:
174
+ # Prepare image for model
175
+ inputs = self._depth_processor(
176
+ images=image,
177
+ return_tensors="pt"
178
+ )
179
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
180
+
181
+ # Run inference
182
+ with torch.no_grad():
183
+ outputs = self._depth_estimator(**inputs)
184
+ predicted_depth = outputs.predicted_depth
185
+
186
+ # Normalize depth map
187
+ depth = predicted_depth.squeeze().cpu().numpy()
188
+ depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
189
+ depth = (depth * 255).astype(np.uint8)
190
+
191
+ # Resize to match input
192
+ depth_image = Image.fromarray(depth)
193
+ depth_image = depth_image.resize(image.size, Image.Resampling.BILINEAR)
194
+
195
+ # Convert to 3-channel for ControlNet
196
+ depth_3ch = np.stack([np.array(depth_image)] * 3, axis=-1)
197
+
198
+ return Image.fromarray(depth_3ch)
199
+
200
+ except Exception as e:
201
+ logger.error(f"Depth estimation failed: {e}")
202
+ return self._generate_fallback_depth(image)
203
+
204
+ def _generate_fallback_depth(self, image: Image.Image) -> Image.Image:
205
+ """
206
+ Generate a simple fallback depth map using gradient.
207
+
208
+ Parameters
209
+ ----------
210
+ image : PIL.Image
211
+ Input image
212
+
213
+ Returns
214
+ -------
215
+ PIL.Image
216
+ Simple gradient depth map
217
+ """
218
+ w, h = image.size
219
+ # Create vertical gradient (top = far, bottom = near)
220
+ gradient = np.linspace(50, 200, h).reshape(-1, 1)
221
+ gradient = np.tile(gradient, (1, w))
222
+ gradient = gradient.astype(np.uint8)
223
+
224
+ # Stack to 3 channels
225
+ depth_3ch = np.stack([gradient] * 3, axis=-1)
226
+ return Image.fromarray(depth_3ch)
227
+
228
+ def prepare_control_image(
229
+ self,
230
+ image: Image.Image,
231
+ mode: str = "canny",
232
+ mask: Optional[Image.Image] = None,
233
+ preserve_structure: bool = False,
234
+ edge_guidance_mode: str = "boundary"
235
+ ) -> Image.Image:
236
+ """
237
+ Generate ControlNet conditioning image.
238
+
239
+ Parameters
240
+ ----------
241
+ image : PIL.Image
242
+ Input image
243
+ mode : str
244
+ Conditioning mode: "canny" or "depth"
245
+ mask : PIL.Image, optional
246
+ If provided, can modify edges based on edge_guidance_mode
247
+ preserve_structure : bool
248
+ If True, keep all edges in masked region (for color change tasks)
249
+ If False, use edge_guidance_mode to determine edge handling
250
+ edge_guidance_mode : str
251
+ How to handle edges when preserve_structure=False:
252
+ - "none": Completely remove edges in masked region (removal tasks)
253
+ - "boundary": Keep only boundary edges of masked region (replacement tasks)
254
+ - "soft": Gradually fade edges from boundary (default for better blending)
255
+
256
+ Returns
257
+ -------
258
+ PIL.Image
259
+ Generated control image
260
+ """
261
+ logger.info(f"Preparing control image: mode={mode}, preserve_structure={preserve_structure}, edge_guidance={edge_guidance_mode}")
262
+
263
+ # Convert to RGB if needed
264
+ if image.mode != 'RGB':
265
+ image = image.convert('RGB')
266
+
267
+ img_array = np.array(image)
268
+
269
+ if mode == "canny":
270
+ control_image = self.generate_canny_edges(img_array)
271
+
272
+ if mask is not None:
273
+ control_array = np.array(control_image)
274
+ mask_array = np.array(mask.convert('L'))
275
+
276
+ if preserve_structure:
277
+ # Keep all edges - no modification needed
278
+ logger.info("Preserving all edges in masked region for color change")
279
+
280
+ elif edge_guidance_mode == "none":
281
+ # Completely suppress edges in masked region (for removal)
282
+ mask_region = mask_array > 128
283
+ control_array[mask_region] = 0
284
+ logger.info("Suppressed all edges in masked region for removal")
285
+
286
+ elif edge_guidance_mode == "mask_outline":
287
+ # For object replacement: clear inside edges, draw clear mask outline
288
+ # Outline guides WHERE and WHAT SIZE the new object should be
289
+ mask_binary = (mask_array > 128).astype(np.uint8) * 255
290
+
291
+ # Step 1: Clear all edges inside the mask
292
+ mask_region = mask_array > 128
293
+ control_array[mask_region] = 0
294
+
295
+ # Step 2: Draw clear mask outline for position/size guidance
296
+ contours, _ = cv2.findContours(
297
+ mask_binary,
298
+ cv2.RETR_EXTERNAL,
299
+ cv2.CHAIN_APPROX_SIMPLE
300
+ )
301
+
302
+ if contours:
303
+ # Draw visible white outline (thickness=2) for clear guidance
304
+ cv2.drawContours(control_array, contours, -1, (255, 255, 255), thickness=2)
305
+ logger.info(f"Drew {len(contours)} mask outline(s) for placement guidance")
306
+
307
+ elif edge_guidance_mode == "boundary":
308
+ # Keep boundary edges to guide object placement and size
309
+ # This helps ControlNet understand WHERE to place the new object
310
+ mask_binary = (mask_array > 128).astype(np.uint8) * 255
311
+
312
+ # Create boundary mask using morphological operations
313
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
314
+ dilated = cv2.dilate(mask_binary, kernel, iterations=1)
315
+ eroded = cv2.erode(mask_binary, kernel, iterations=1)
316
+ boundary = dilated - eroded
317
+
318
+ # Inner region (not boundary) - suppress edges
319
+ inner_region = (mask_array > 128) & (boundary == 0)
320
+ control_array[inner_region] = 0
321
+
322
+ # Keep boundary edges intact for object placement guidance
323
+ logger.info("Keeping boundary edges for object replacement guidance")
324
+
325
+ elif edge_guidance_mode == "soft":
326
+ # Soft fade: gradually reduce edges from boundary to center
327
+ mask_binary = (mask_array > 128).astype(np.uint8) * 255
328
+
329
+ # Calculate distance from boundary
330
+ dist_transform = cv2.distanceTransform(mask_binary, cv2.DIST_L2, 5)
331
+ max_dist = dist_transform.max()
332
+ if max_dist > 0:
333
+ # Normalize and invert: 1 at boundary, 0 at center
334
+ fade_factor = 1 - (dist_transform / max_dist)
335
+ fade_factor = np.clip(fade_factor, 0, 1)
336
+
337
+ # Apply fade to masked region only
338
+ mask_region = mask_array > 128
339
+ for c in range(3):
340
+ control_array[:, :, c][mask_region] = (
341
+ control_array[:, :, c][mask_region] * fade_factor[mask_region]
342
+ ).astype(np.uint8)
343
+
344
+ logger.info("Applied soft edge fading in masked region")
345
+
346
+ control_image = Image.fromarray(control_array)
347
+
348
+ return control_image
349
+
350
+ elif mode == "depth":
351
+ control_image = self.generate_depth_map(image)
352
+
353
+ # For depth mode with replacement, we want to keep depth info for context
354
+ # but allow flexibility in the masked region
355
+ if mask is not None and not preserve_structure:
356
+ control_array = np.array(control_image)
357
+ mask_array = np.array(mask.convert('L'))
358
+
359
+ # Smooth the depth in masked region using surrounding context
360
+ if edge_guidance_mode in ["boundary", "soft"]:
361
+ mask_binary = (mask_array > 128).astype(np.uint8)
362
+
363
+ # Inpaint the depth map in masked region using surrounding values
364
+ depth_gray = control_array[:, :, 0]
365
+ inpainted_depth = cv2.inpaint(
366
+ depth_gray,
367
+ mask_binary,
368
+ inpaintRadius=10,
369
+ flags=cv2.INPAINT_TELEA
370
+ )
371
+ control_array = np.stack([inpainted_depth] * 3, axis=-1)
372
+ logger.info("Inpainted depth map in masked region")
373
+
374
+ control_image = Image.fromarray(control_array)
375
+
376
+ return control_image
377
+
378
+ else:
379
+ raise ValueError(f"Unknown control mode: {mode}")
380
+
381
+ def unload_depth_model(self) -> None:
382
+ """Unload depth estimation model to free memory."""
383
+ if self._depth_estimator is not None:
384
+ del self._depth_estimator
385
+ self._depth_estimator = None
386
+
387
+ if self._depth_processor is not None:
388
+ del self._depth_processor
389
+ self._depth_processor = None
390
+
391
+ self._depth_model_loaded = False
392
+ logger.info("Depth model unloaded")
gpu_handlers.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ from typing import Any, Callable, Dict, Optional, Tuple
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import spaces
8
+ from PIL import Image
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logger.setLevel(logging.INFO)
12
+
13
+
14
+ class GPUHandlers:
15
+ """
16
+ Handles all GPU-intensive generation operations.
17
+
18
+ This class encapsulates the execution logic for both background generation
19
+ and inpainting operations with proper @spaces.GPU decorator for
20
+ HuggingFace Spaces deployment.
21
+
22
+ Supports dual-mode inpainting:
23
+ - Pure Inpainting (use_controlnet=False): For object replacement/removal
24
+ - ControlNet Inpainting (use_controlnet=True): For clothing/color change
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ core: Any,
30
+ inpainting_template_manager: Any
31
+ ):
32
+ """
33
+ Initialize the GPU handlers.
34
+
35
+ Parameters
36
+ ----------
37
+ core : SceneWeaverCore
38
+ Main engine instance
39
+ inpainting_template_manager : InpaintingTemplateManager
40
+ Template manager for inpainting
41
+ """
42
+ self.core = core
43
+ self.inpainting_template_manager = inpainting_template_manager
44
+ logger.info("GPUHandlers initialized")
45
+
46
+ @spaces.GPU(duration=240)
47
+ def background_generate(
48
+ self,
49
+ image: Optional[Image.Image],
50
+ prompt: str,
51
+ negative_prompt: str,
52
+ composition_mode: str,
53
+ focus_mode: str,
54
+ num_steps: int,
55
+ guidance_scale: float,
56
+ progress_callback: Optional[Callable[[str, int], None]] = None
57
+ ) -> Dict[str, Any]:
58
+ """
59
+ Handle background generation request with GPU access.
60
+
61
+ Parameters
62
+ ----------
63
+ image : PIL.Image, optional
64
+ Input image
65
+ prompt : str
66
+ Generation prompt
67
+ negative_prompt : str
68
+ Negative prompt
69
+ composition_mode : str
70
+ Composition mode (center, left_half, etc.)
71
+ focus_mode : str
72
+ Focus mode (person, scene)
73
+ num_steps : int
74
+ Number of inference steps
75
+ guidance_scale : float
76
+ Guidance scale
77
+ progress_callback : callable, optional
78
+ Progress update function(message, percentage)
79
+
80
+ Returns
81
+ -------
82
+ dict
83
+ Result dictionary with success status and images
84
+ """
85
+ if image is None:
86
+ return {"success": False, "error": "Please upload an image first"}
87
+
88
+ if not prompt.strip():
89
+ return {"success": False, "error": "Please enter a prompt"}
90
+
91
+ try:
92
+ logger.info(f"Starting background generation: {prompt[:50]}...")
93
+ start_time = time.time()
94
+
95
+ # Initialize if needed
96
+ if not self.core.is_initialized:
97
+ if progress_callback:
98
+ progress_callback("Loading AI models...", 5)
99
+ self.core.load_models(progress_callback=progress_callback)
100
+
101
+ # Generate and combine
102
+ if progress_callback:
103
+ progress_callback("Generating background...", 20)
104
+
105
+ result = self.core.generate_and_combine(
106
+ original_image=image,
107
+ prompt=prompt,
108
+ combination_mode=composition_mode,
109
+ focus_mode=focus_mode,
110
+ negative_prompt=negative_prompt,
111
+ num_inference_steps=num_steps,
112
+ guidance_scale=guidance_scale,
113
+ progress_callback=progress_callback
114
+ )
115
+
116
+ elapsed = time.time() - start_time
117
+ logger.info(f"Background generation complete in {elapsed:.1f}s")
118
+
119
+ return result
120
+
121
+ except Exception as e:
122
+ error_msg = str(e)
123
+ logger.error(f"Background generation error: {error_msg}")
124
+ return {"success": False, "error": error_msg}
125
+
126
+ @spaces.GPU(duration=420)
127
+ def inpainting_generate(
128
+ self,
129
+ image: Optional[Image.Image],
130
+ mask: Optional[Image.Image],
131
+ prompt: str,
132
+ template_key: Optional[str],
133
+ model_key: str,
134
+ conditioning_type: str,
135
+ conditioning_scale: float,
136
+ feather_radius: int,
137
+ guidance_scale: float,
138
+ num_steps: int,
139
+ seed: int = -1,
140
+ progress_callback: Optional[Callable[[str, int], None]] = None
141
+ ) -> Tuple[Optional[Image.Image], Optional[Image.Image], str, int]:
142
+ """
143
+ Handle inpainting request with GPU access.
144
+
145
+ Supports dual-mode operation based on template:
146
+ - Pure Inpainting: For object_replacement, removal
147
+ - ControlNet: For clothing_change, change_color
148
+
149
+ Parameters
150
+ ----------
151
+ image : PIL.Image
152
+ Original image to inpaint
153
+ mask : PIL.Image
154
+ Inpainting mask (white = area to regenerate)
155
+ prompt : str
156
+ Inpainting prompt
157
+ template_key : str, optional
158
+ Template key if using a template
159
+ model_key : str
160
+ Model key (juggernaut_xl, realvis_xl, sdxl_base, animagine_xl)
161
+ conditioning_type : str
162
+ ControlNet conditioning type (canny/depth) - only for ControlNet mode
163
+ conditioning_scale : float
164
+ ControlNet conditioning scale
165
+ feather_radius : int
166
+ Mask feather radius
167
+ guidance_scale : float
168
+ Generation guidance scale
169
+ num_steps : int
170
+ Number of inference steps
171
+ seed : int
172
+ Random seed (-1 for random)
173
+ progress_callback : callable, optional
174
+ Progress update function
175
+
176
+ Returns
177
+ -------
178
+ tuple
179
+ (result_image, control_image, status_message, used_seed)
180
+ """
181
+ if image is None:
182
+ return None, None, "Please upload an image first", -1
183
+
184
+ if mask is None:
185
+ return None, None, "Please draw a mask on the image", -1
186
+
187
+ try:
188
+ logger.info(f"Starting inpainting: prompt='{prompt[:30]}...', template={template_key}")
189
+ start_time = time.time()
190
+
191
+ # Get template parameters
192
+ built_prompt = prompt
193
+ negative_prompt = ""
194
+ template_params = {}
195
+ use_controlnet = True # Default to ControlNet mode
196
+
197
+ if template_key:
198
+ template = self.inpainting_template_manager.get_template(template_key)
199
+ if template:
200
+ # For removal template, use template prompt directly if user prompt is empty
201
+ if template_key == "removal" and not prompt.strip():
202
+ built_prompt = template.prompt_template
203
+ else:
204
+ built_prompt = self.inpainting_template_manager.build_prompt(template_key, prompt)
205
+ negative_prompt = self.inpainting_template_manager.get_negative_prompt(template_key)
206
+ template_params = self.inpainting_template_manager.get_parameters_for_template(template_key)
207
+ use_controlnet = template_params.get("use_controlnet", True)
208
+ logger.info(f"Template: {template_key}, use_controlnet={use_controlnet}")
209
+
210
+ # Build final parameters
211
+ final_params = {
212
+ # Pipeline mode
213
+ "use_controlnet": use_controlnet,
214
+ "mask_dilation": template_params.get("mask_dilation", 0),
215
+
216
+ # ControlNet parameters (only used if use_controlnet=True)
217
+ "conditioning_type": template_params.get("preferred_conditioning", conditioning_type),
218
+ "controlnet_conditioning_scale": template_params.get("controlnet_conditioning_scale", conditioning_scale),
219
+ "preserve_structure_in_mask": template_params.get("preserve_structure_in_mask", False),
220
+ "edge_guidance_mode": template_params.get("edge_guidance_mode", "boundary"),
221
+
222
+ # Generation parameters
223
+ "feather_radius": template_params.get("feather_radius", feather_radius),
224
+ "guidance_scale": template_params.get("guidance_scale", guidance_scale),
225
+ "num_inference_steps": template_params.get("num_inference_steps", num_steps),
226
+ "strength": template_params.get("strength", 0.99),
227
+ "negative_prompt": negative_prompt,
228
+ "seed": seed,
229
+ }
230
+
231
+ # Execute inpainting through core
232
+ result = self.core.execute_inpainting(
233
+ image=image,
234
+ mask=mask,
235
+ prompt=built_prompt,
236
+ model_key=model_key,
237
+ progress_callback=progress_callback,
238
+ **final_params
239
+ )
240
+
241
+ elapsed = time.time() - start_time
242
+
243
+ if result.get('success'):
244
+ mode_str = "Pure Inpainting" if not use_controlnet else "ControlNet"
245
+ # Get the actual seed used from metadata
246
+ used_seed = result.get('metadata', {}).get('seed', seed)
247
+ status = f"Complete ({mode_str}) in {elapsed:.1f}s | Seed: {used_seed}"
248
+
249
+ return (
250
+ result.get('combined_image'),
251
+ result.get('control_image'),
252
+ status,
253
+ used_seed
254
+ )
255
+ else:
256
+ error_msg = result.get('error', 'Unknown error')
257
+ return None, None, f"Error: {error_msg}", -1
258
+
259
+ except Exception as e:
260
+ error_msg = str(e)
261
+ logger.error(f"Inpainting handler error: {e}")
262
+ return None, None, f"Error: {error_msg}", -1
263
+
264
+
265
+ def extract_mask_from_editor(mask_editor: Dict[str, Any]) -> Optional[Image.Image]:
266
+ """
267
+ Extract mask from Gradio ImageEditor component.
268
+
269
+ Parameters
270
+ ----------
271
+ mask_editor : dict
272
+ ImageEditor output with 'background' and 'layers'
273
+
274
+ Returns
275
+ -------
276
+ PIL.Image or None
277
+ Extracted mask image (L mode)
278
+ """
279
+ if mask_editor is None:
280
+ return None
281
+
282
+ try:
283
+ layers = mask_editor.get("layers", [])
284
+ if not layers:
285
+ return None
286
+
287
+ mask_layer = layers[0]
288
+ if mask_layer is None:
289
+ return None
290
+
291
+ # Convert to numpy array
292
+ if isinstance(mask_layer, Image.Image):
293
+ mask_array = np.array(mask_layer)
294
+ else:
295
+ mask_array = np.array(Image.open(mask_layer))
296
+
297
+ # Handle different formats
298
+ if len(mask_array.shape) == 3:
299
+ if mask_array.shape[2] == 4:
300
+ # RGBA - use alpha channel combined with RGB
301
+ alpha = mask_array[:, :, 3]
302
+ gray = cv2.cvtColor(mask_array[:, :, :3], cv2.COLOR_RGB2GRAY)
303
+ mask_gray = np.maximum(gray, alpha)
304
+ elif mask_array.shape[2] == 3:
305
+ # RGB - convert to grayscale
306
+ mask_gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
307
+ else:
308
+ mask_gray = mask_array[:, :, 0]
309
+ else:
310
+ mask_gray = mask_array
311
+
312
+ return Image.fromarray(mask_gray.astype(np.uint8), mode='L')
313
+
314
+ except Exception as e:
315
+ logger.error(f"Failed to extract mask from editor: {e}")
316
+ return None
image_blender.py CHANGED
@@ -483,7 +483,7 @@ class ImageBlender:
483
  orig_bg_color_lab = cv2.cvtColor(orig_bg_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
484
  logger.info(f"🎨 Detected original background color: RGB{tuple(orig_bg_color_rgb)}")
485
 
486
- # Remove original background color contamination from foreground
487
  orig_array = self._remove_background_color_contamination(
488
  orig_array,
489
  mask_array,
@@ -491,7 +491,7 @@ class ImageBlender:
491
  tolerance=self.BACKGROUND_COLOR_TOLERANCE
492
  )
493
 
494
- # Redefine trimap, optimized for cartoon characters
495
  try:
496
  kernel_3x3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
497
 
@@ -531,7 +531,7 @@ class ImageBlender:
531
 
532
  fg_rep_color_lab = cv2.cvtColor(fg_rep_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
533
 
534
- # Edge band spill suppression and repair
535
  if np.any(ring_zone):
536
  # Convert to Lab space
537
  orig_lab = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB).astype(np.float32)
@@ -625,20 +625,20 @@ class ImageBlender:
625
  delta_a_pass2 = ring_pixels_lab_pass2[:, 1] - orig_bg_color_lab[1]
626
  delta_b_pass2 = ring_pixels_lab_pass2[:, 2] - orig_bg_color_lab[2]
627
  delta_e_pass2 = np.sqrt(delta_l_pass2**2 + delta_a_pass2**2 + delta_b_pass2**2)
628
-
629
  still_contaminated = delta_e_pass2 < (DELTAE_THRESHOLD * 0.8)
630
-
631
  if np.any(still_contaminated):
632
  # Apply stronger correction to remaining contaminated pixels
633
  remaining_pixels = ring_pixels_lab_pass2[still_contaminated]
634
-
635
  # More aggressive chroma neutralization
636
  remaining_chroma = remaining_pixels[:, 1:3]
637
  neutralized_chroma = remaining_chroma * 0.3 + fg_rep_color_lab[1:3] * 0.7
638
-
639
  # Stronger luminance matching
640
  neutralized_l = remaining_pixels[:, 0] * 0.4 + fg_rep_color_lab[0] * 0.6
641
-
642
  ring_pixels_lab_pass2[still_contaminated, 0] = neutralized_l
643
  ring_pixels_lab_pass2[still_contaminated, 1:3] = neutralized_chroma
644
  orig_lab[ring_zone] = ring_pixels_lab_pass2
@@ -691,7 +691,7 @@ class ImageBlender:
691
  orig_linear = srgb_to_linear(orig_array)
692
  bg_linear = srgb_to_linear(bg_array)
693
 
694
- # Cartoon-optimized Alpha calculation
695
  alpha = mask_array.astype(np.float32) / 255.0
696
 
697
  # Core foreground region - fully opaque
@@ -701,13 +701,13 @@ class ImageBlender:
701
  alpha[bg_zone] = 0.0
702
 
703
  # [Key Fix] Force pixels with mask≥160 to α=1.0, avoiding white fill areas being limited to 0.9
704
- high_confidence_pixels = mask_array >= 160
705
  alpha[high_confidence_pixels] = 1.0
706
  logger.info(f"💯 High confidence pixels set to full opacity: {high_confidence_pixels.sum()}")
707
 
708
  # Ring area can be dehaloed, but doesn't affect already set high confidence pixels
709
  ring_without_high_conf = ring_zone & (~high_confidence_pixels)
710
- alpha[ring_without_high_conf] = np.clip(alpha[ring_without_high_conf], 0.2, 0.9)
711
 
712
  # Retain existing black outline/strong edge protection
713
  orig_gray = np.mean(orig_array, axis=2)
@@ -739,10 +739,10 @@ class ImageBlender:
739
  result_srgb = linear_to_srgb(result_linear)
740
  result_array = (result_srgb * 255).astype(np.uint8)
741
 
742
- # Final edge cleanup pass
743
  result_array = self._apply_edge_cleanup(result_array, bg_array, alpha)
744
 
745
- # Protect core foreground from any background influence
746
  # This ensures faces and bodies retain original colors
747
  result_array = self._protect_foreground_core(
748
  result_array,
 
483
  orig_bg_color_lab = cv2.cvtColor(orig_bg_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
484
  logger.info(f"🎨 Detected original background color: RGB{tuple(orig_bg_color_rgb)}")
485
 
486
+ # Remove original background color contamination from foreground
487
  orig_array = self._remove_background_color_contamination(
488
  orig_array,
489
  mask_array,
 
491
  tolerance=self.BACKGROUND_COLOR_TOLERANCE
492
  )
493
 
494
+ # Redefine trimap, optimized for cartoon characters
495
  try:
496
  kernel_3x3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
497
 
 
531
 
532
  fg_rep_color_lab = cv2.cvtColor(fg_rep_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
533
 
534
+ # Edge band spill suppression and repair
535
  if np.any(ring_zone):
536
  # Convert to Lab space
537
  orig_lab = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB).astype(np.float32)
 
625
  delta_a_pass2 = ring_pixels_lab_pass2[:, 1] - orig_bg_color_lab[1]
626
  delta_b_pass2 = ring_pixels_lab_pass2[:, 2] - orig_bg_color_lab[2]
627
  delta_e_pass2 = np.sqrt(delta_l_pass2**2 + delta_a_pass2**2 + delta_b_pass2**2)
628
+
629
  still_contaminated = delta_e_pass2 < (DELTAE_THRESHOLD * 0.8)
630
+
631
  if np.any(still_contaminated):
632
  # Apply stronger correction to remaining contaminated pixels
633
  remaining_pixels = ring_pixels_lab_pass2[still_contaminated]
634
+
635
  # More aggressive chroma neutralization
636
  remaining_chroma = remaining_pixels[:, 1:3]
637
  neutralized_chroma = remaining_chroma * 0.3 + fg_rep_color_lab[1:3] * 0.7
638
+
639
  # Stronger luminance matching
640
  neutralized_l = remaining_pixels[:, 0] * 0.4 + fg_rep_color_lab[0] * 0.6
641
+
642
  ring_pixels_lab_pass2[still_contaminated, 0] = neutralized_l
643
  ring_pixels_lab_pass2[still_contaminated, 1:3] = neutralized_chroma
644
  orig_lab[ring_zone] = ring_pixels_lab_pass2
 
691
  orig_linear = srgb_to_linear(orig_array)
692
  bg_linear = srgb_to_linear(bg_array)
693
 
694
+ # Cartoon-optimized Alpha calculation
695
  alpha = mask_array.astype(np.float32) / 255.0
696
 
697
  # Core foreground region - fully opaque
 
701
  alpha[bg_zone] = 0.0
702
 
703
  # [Key Fix] Force pixels with mask≥160 to α=1.0, avoiding white fill areas being limited to 0.9
704
+ high_confidence_pixels = mask_array >= 160
705
  alpha[high_confidence_pixels] = 1.0
706
  logger.info(f"💯 High confidence pixels set to full opacity: {high_confidence_pixels.sum()}")
707
 
708
  # Ring area can be dehaloed, but doesn't affect already set high confidence pixels
709
  ring_without_high_conf = ring_zone & (~high_confidence_pixels)
710
+ alpha[ring_without_high_conf] = np.clip(alpha[ring_without_high_conf], 0.2, 0.9)
711
 
712
  # Retain existing black outline/strong edge protection
713
  orig_gray = np.mean(orig_array, axis=2)
 
739
  result_srgb = linear_to_srgb(result_linear)
740
  result_array = (result_srgb * 255).astype(np.uint8)
741
 
742
+ # Final edge cleanup pass
743
  result_array = self._apply_edge_cleanup(result_array, bg_array, alpha)
744
 
745
+ # Protect core foreground from any background influence
746
  # This ensures faces and bodies retain original colors
747
  result_array = self._protect_foreground_core(
748
  result_array,
inpainting_blender.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, Optional, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ logger = logging.getLogger(__name__)
9
+ logger.setLevel(logging.INFO)
10
+
11
+
12
+ class InpaintingBlender:
13
+ """
14
+ Handles mask processing, prompt enhancement, and result blending for inpainting.
15
+
16
+ This class encapsulates all pre-processing and post-processing operations
17
+ needed for inpainting, separate from the main generation pipeline.
18
+
19
+ Attributes:
20
+ min_mask_coverage: Minimum mask coverage threshold
21
+ max_mask_coverage: Maximum mask coverage threshold
22
+
23
+ Example:
24
+ >>> blender = InpaintingBlender()
25
+ >>> processed_mask, info = blender.prepare_mask(mask, (512, 512), feather_radius=8)
26
+ >>> enhanced_prompt, negative = blender.enhance_prompt("a flower", image, mask)
27
+ >>> result = blender.blend_result(original, generated, mask)
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ min_mask_coverage: float = 0.01,
33
+ max_mask_coverage: float = 0.95
34
+ ):
35
+ """
36
+ Initialize the InpaintingBlender.
37
+
38
+ Parameters
39
+ ----------
40
+ min_mask_coverage : float
41
+ Minimum mask coverage (default: 1%)
42
+ max_mask_coverage : float
43
+ Maximum mask coverage (default: 95%)
44
+ """
45
+ self.min_mask_coverage = min_mask_coverage
46
+ self.max_mask_coverage = max_mask_coverage
47
+ logger.info("InpaintingBlender initialized")
48
+
49
+ def prepare_mask(
50
+ self,
51
+ mask: Image.Image,
52
+ target_size: Tuple[int, int],
53
+ feather_radius: int = 8
54
+ ) -> Tuple[Image.Image, Dict[str, Any]]:
55
+ """
56
+ Prepare and validate mask for inpainting.
57
+
58
+ Parameters
59
+ ----------
60
+ mask : PIL.Image
61
+ Input mask (white = inpaint area)
62
+ target_size : tuple
63
+ Target (width, height) to match input image
64
+ feather_radius : int
65
+ Feathering radius in pixels
66
+
67
+ Returns
68
+ -------
69
+ tuple
70
+ (processed_mask, validation_info)
71
+
72
+ Raises
73
+ ------
74
+ ValueError
75
+ If mask coverage is outside acceptable range
76
+ """
77
+ # Convert to grayscale
78
+ if mask.mode != 'L':
79
+ mask = mask.convert('L')
80
+
81
+ # Resize to match target
82
+ if mask.size != target_size:
83
+ mask = mask.resize(target_size, Image.LANCZOS)
84
+
85
+ # Convert to array for processing
86
+ mask_array = np.array(mask)
87
+
88
+ # Calculate coverage
89
+ total_pixels = mask_array.size
90
+ white_pixels = np.count_nonzero(mask_array > 127)
91
+ coverage = white_pixels / total_pixels
92
+
93
+ validation_info = {
94
+ "coverage": coverage,
95
+ "white_pixels": white_pixels,
96
+ "total_pixels": total_pixels,
97
+ "feather_radius": feather_radius,
98
+ "valid": True,
99
+ "warning": ""
100
+ }
101
+
102
+ # Validate coverage
103
+ if coverage < self.min_mask_coverage:
104
+ validation_info["valid"] = False
105
+ validation_info["warning"] = (
106
+ f"Mask coverage too low ({coverage:.1%}). "
107
+ f"Please select a larger area to inpaint."
108
+ )
109
+ logger.warning(f"Mask coverage {coverage:.1%} below minimum {self.min_mask_coverage:.1%}")
110
+
111
+ elif coverage > self.max_mask_coverage:
112
+ validation_info["valid"] = False
113
+ validation_info["warning"] = (
114
+ f"Mask coverage too high ({coverage:.1%}). "
115
+ f"Consider using background generation instead."
116
+ )
117
+ logger.warning(f"Mask coverage {coverage:.1%} above maximum {self.max_mask_coverage:.1%}")
118
+
119
+ # Apply feathering
120
+ if feather_radius > 0:
121
+ mask_array = cv2.GaussianBlur(
122
+ mask_array,
123
+ (feather_radius * 2 + 1, feather_radius * 2 + 1),
124
+ feather_radius / 2
125
+ )
126
+ logger.debug(f"Applied {feather_radius}px feathering to mask")
127
+
128
+ processed_mask = Image.fromarray(mask_array, mode='L')
129
+
130
+ return processed_mask, validation_info
131
+
132
+ def enhance_prompt_for_inpainting(
133
+ self,
134
+ prompt: str,
135
+ image: Image.Image,
136
+ mask: Image.Image
137
+ ) -> Tuple[str, str]:
138
+ """
139
+ Enhance prompt based on non-masked region analysis.
140
+
141
+ Analyzes the surrounding context to generate appropriate
142
+ lighting and color descriptors.
143
+
144
+ Parameters
145
+ ----------
146
+ prompt : str
147
+ User-provided prompt
148
+ image : PIL.Image
149
+ Original image
150
+ mask : PIL.Image
151
+ Inpainting mask
152
+
153
+ Returns
154
+ -------
155
+ tuple
156
+ (enhanced_prompt, negative_prompt)
157
+ """
158
+ logger.info("Enhancing prompt for inpainting context...")
159
+
160
+ # Convert to arrays
161
+ img_array = np.array(image.convert('RGB'))
162
+ mask_array = np.array(mask.convert('L'))
163
+
164
+ # Analyze non-masked regions
165
+ non_masked = mask_array < 127
166
+
167
+ if not np.any(non_masked):
168
+ # No context available
169
+ enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic"
170
+ negative_prompt = self._get_inpainting_negative_prompt()
171
+ return enhanced_prompt, negative_prompt
172
+
173
+ # Extract context pixels
174
+ context_pixels = img_array[non_masked]
175
+
176
+ # Convert to Lab for analysis
177
+ context_lab = cv2.cvtColor(
178
+ context_pixels.reshape(-1, 1, 3),
179
+ cv2.COLOR_RGB2LAB
180
+ ).reshape(-1, 3)
181
+
182
+ # Use robust statistics (median) to avoid outlier influence
183
+ median_l = np.median(context_lab[:, 0])
184
+ median_b = np.median(context_lab[:, 2])
185
+
186
+ # Analyze lighting conditions
187
+ lighting_descriptors = []
188
+
189
+ if median_l > 170:
190
+ lighting_descriptors.append("bright")
191
+ elif median_l > 130:
192
+ lighting_descriptors.append("well-lit")
193
+ elif median_l > 80:
194
+ lighting_descriptors.append("moderate lighting")
195
+ else:
196
+ lighting_descriptors.append("dim lighting")
197
+
198
+ # Analyze color temperature (b channel: blue(-) to yellow(+))
199
+ if median_b > 140:
200
+ lighting_descriptors.append("warm golden tones")
201
+ elif median_b > 120:
202
+ lighting_descriptors.append("warm afternoon light")
203
+ elif median_b < 110:
204
+ lighting_descriptors.append("cool neutral tones")
205
+
206
+ # Calculate saturation from context
207
+ hsv = cv2.cvtColor(context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV)
208
+ median_saturation = np.median(hsv[:, :, 1])
209
+
210
+ if median_saturation > 150:
211
+ lighting_descriptors.append("vibrant colors")
212
+ elif median_saturation < 80:
213
+ lighting_descriptors.append("subtle muted colors")
214
+
215
+ # Build enhanced prompt
216
+ lighting_desc = ", ".join(lighting_descriptors) if lighting_descriptors else ""
217
+ quality_suffix = "high quality, detailed, photorealistic, seamless integration"
218
+
219
+ if lighting_desc:
220
+ enhanced_prompt = f"{prompt}, {lighting_desc}, {quality_suffix}"
221
+ else:
222
+ enhanced_prompt = f"{prompt}, {quality_suffix}"
223
+
224
+ negative_prompt = self._get_inpainting_negative_prompt()
225
+
226
+ logger.info(f"Enhanced prompt with context: {lighting_desc}")
227
+
228
+ return enhanced_prompt, negative_prompt
229
+
230
+ def _get_inpainting_negative_prompt(self) -> str:
231
+ """Get standard negative prompt for inpainting."""
232
+ return (
233
+ "inconsistent lighting, wrong perspective, mismatched colors, "
234
+ "visible seams, blending artifacts, color bleeding, "
235
+ "blurry, low quality, distorted, deformed, "
236
+ "harsh edges, unnatural transition"
237
+ )
238
+
239
+ def blend_result(
240
+ self,
241
+ original: Image.Image,
242
+ generated: Image.Image,
243
+ mask: Image.Image
244
+ ) -> Image.Image:
245
+ """
246
+ Blend generated content with original image.
247
+
248
+ Uses color matching and linear color space blending for seamless results.
249
+
250
+ Parameters
251
+ ----------
252
+ original : PIL.Image
253
+ Original image
254
+ generated : PIL.Image
255
+ Generated inpainted image
256
+ mask : PIL.Image
257
+ Blending mask (white = use generated)
258
+
259
+ Returns
260
+ -------
261
+ PIL.Image
262
+ Blended result
263
+ """
264
+ logger.info("Blending inpainting result with color matching...")
265
+
266
+ # Ensure same size
267
+ if generated.size != original.size:
268
+ generated = generated.resize(original.size, Image.LANCZOS)
269
+ if mask.size != original.size:
270
+ mask = mask.resize(original.size, Image.LANCZOS)
271
+
272
+ # Convert to arrays
273
+ orig_array = np.array(original.convert('RGB')).astype(np.float32)
274
+ gen_array = np.array(generated.convert('RGB')).astype(np.float32)
275
+ mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
276
+
277
+ # Apply color matching to generated region (use original mask for accurate boundary detection)
278
+ gen_array = self._match_colors_at_boundary(orig_array, gen_array, mask_array)
279
+
280
+ # Create blend mask: soften edges ONLY for blending (not for generation)
281
+ # This ensures full generation coverage while smooth blending at edges
282
+ blend_mask = self._create_blend_mask(mask_array)
283
+
284
+ # sRGB to linear conversion
285
+ def srgb_to_linear(img: np.ndarray) -> np.ndarray:
286
+ img_norm = img / 255.0
287
+ return np.where(
288
+ img_norm <= 0.04045,
289
+ img_norm / 12.92,
290
+ np.power((img_norm + 0.055) / 1.055, 2.4)
291
+ )
292
+
293
+ def linear_to_srgb(img: np.ndarray) -> np.ndarray:
294
+ img_clipped = np.clip(img, 0, 1)
295
+ return np.where(
296
+ img_clipped <= 0.0031308,
297
+ 12.92 * img_clipped,
298
+ 1.055 * np.power(img_clipped, 1/2.4) - 0.055
299
+ )
300
+
301
+ # Convert to linear space
302
+ orig_linear = srgb_to_linear(orig_array)
303
+ gen_linear = srgb_to_linear(gen_array)
304
+
305
+ # Alpha blending in linear space using the blend mask (with softened edges)
306
+ alpha = blend_mask[:, :, np.newaxis]
307
+ result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
308
+
309
+ # Convert back to sRGB
310
+ result_srgb = linear_to_srgb(result_linear)
311
+ result_array = (result_srgb * 255).astype(np.uint8)
312
+
313
+ logger.debug("Blending completed with color matching")
314
+
315
+ return Image.fromarray(result_array)
316
+
317
+ def _match_colors_at_boundary(
318
+ self,
319
+ original: np.ndarray,
320
+ generated: np.ndarray,
321
+ mask: np.ndarray
322
+ ) -> np.ndarray:
323
+ """
324
+ Match colors of generated content to original at the boundary.
325
+
326
+ Uses histogram matching in Lab color space for natural blending.
327
+
328
+ Parameters
329
+ ----------
330
+ original : np.ndarray
331
+ Original image array (float32, 0-255)
332
+ generated : np.ndarray
333
+ Generated image array (float32, 0-255)
334
+ mask : np.ndarray
335
+ Mask array (float32, 0-1)
336
+
337
+ Returns
338
+ -------
339
+ np.ndarray
340
+ Color-matched generated image
341
+ """
342
+ # Create boundary region mask (dilated mask - eroded mask)
343
+ mask_binary = (mask > 0.5).astype(np.uint8) * 255
344
+
345
+ # Create narrow boundary region for sampling original colors
346
+ kernel_size = 25 # Pixels to sample around boundary
347
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
348
+ dilated = cv2.dilate(mask_binary, kernel, iterations=1)
349
+ eroded = cv2.erode(mask_binary, kernel, iterations=1)
350
+
351
+ # Outer boundary (original side)
352
+ outer_boundary = (dilated > 0) & (mask_binary == 0)
353
+ # Inner boundary (generated side)
354
+ inner_boundary = (mask_binary > 0) & (eroded == 0)
355
+
356
+ if not np.any(outer_boundary) or not np.any(inner_boundary):
357
+ logger.debug("No boundary region found, skipping color matching")
358
+ return generated
359
+
360
+ # Convert to Lab color space
361
+ orig_lab = cv2.cvtColor(original.astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
362
+ gen_lab = cv2.cvtColor(generated.astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
363
+
364
+ # Sample colors from boundary regions
365
+ orig_boundary_pixels = orig_lab[outer_boundary]
366
+ gen_boundary_pixels = gen_lab[inner_boundary]
367
+
368
+ if len(orig_boundary_pixels) < 10 or len(gen_boundary_pixels) < 10:
369
+ logger.debug("Not enough boundary pixels, skipping color matching")
370
+ return generated
371
+
372
+ # Calculate statistics
373
+ orig_mean = np.mean(orig_boundary_pixels, axis=0)
374
+ orig_std = np.std(orig_boundary_pixels, axis=0) + 1e-6
375
+
376
+ gen_mean = np.mean(gen_boundary_pixels, axis=0)
377
+ gen_std = np.std(gen_boundary_pixels, axis=0) + 1e-6
378
+
379
+ # Calculate correction factors
380
+ # Only correct L (lightness) and a,b (color) channels
381
+ l_correction = (orig_mean[0] - gen_mean[0]) * 0.7 # 70% correction for lightness
382
+ a_correction = (orig_mean[1] - gen_mean[1]) * 0.5 # 50% correction for color
383
+ b_correction = (orig_mean[2] - gen_mean[2]) * 0.5
384
+
385
+ logger.debug(f"Color correction: L={l_correction:.1f}, a={a_correction:.1f}, b={b_correction:.1f}")
386
+
387
+ # Apply correction to masked region only
388
+ corrected_lab = gen_lab.copy()
389
+ mask_region = mask > 0.3 # Apply to most of masked region
390
+
391
+ corrected_lab[mask_region, 0] = np.clip(
392
+ corrected_lab[mask_region, 0] + l_correction, 0, 255
393
+ )
394
+ corrected_lab[mask_region, 1] = np.clip(
395
+ corrected_lab[mask_region, 1] + a_correction, 0, 255
396
+ )
397
+ corrected_lab[mask_region, 2] = np.clip(
398
+ corrected_lab[mask_region, 2] + b_correction, 0, 255
399
+ )
400
+
401
+ # Convert back to RGB
402
+ corrected_rgb = cv2.cvtColor(
403
+ corrected_lab.astype(np.uint8),
404
+ cv2.COLOR_LAB2RGB
405
+ ).astype(np.float32)
406
+
407
+ logger.info("Applied boundary color matching")
408
+
409
+ return corrected_rgb
410
+
411
+ def _create_blend_mask(self, mask: np.ndarray) -> np.ndarray:
412
+ """
413
+ Create a blend mask with softened edges for natural compositing.
414
+
415
+ The mask interior stays fully opaque (1.0) while only the edges
416
+ get a smooth transition. This preserves full generated content
417
+ while blending naturally at boundaries.
418
+
419
+ Parameters
420
+ ----------
421
+ mask : np.ndarray
422
+ Original mask array (float32, 0-1)
423
+
424
+ Returns
425
+ -------
426
+ np.ndarray
427
+ Blend mask with soft edges but solid interior
428
+ """
429
+ # Convert to uint8 for morphological operations
430
+ mask_uint8 = (mask * 255).astype(np.uint8)
431
+
432
+ # Create eroded version (solid interior)
433
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
434
+ eroded = cv2.erode(mask_uint8, kernel, iterations=1)
435
+
436
+ # Create smooth transition zone at edges only
437
+ # Blur the original mask for edge softness
438
+ blurred = cv2.GaussianBlur(mask_uint8, (15, 15), 4)
439
+
440
+ # Combine: use eroded (solid) for interior, blurred for edges
441
+ # Where eroded > 0, use full opacity; elsewhere use blurred transition
442
+ result = np.where(eroded > 128, mask_uint8, blurred)
443
+
444
+ # Final light smoothing
445
+ result = cv2.GaussianBlur(result, (5, 5), 1)
446
+
447
+ # Convert back to float
448
+ blend_mask = result.astype(np.float32) / 255.0
449
+
450
+ logger.debug("Created blend mask with soft edges and solid interior")
451
+
452
+ return blend_mask
453
+
454
+ def validate_inputs(
455
+ self,
456
+ image: Image.Image,
457
+ mask: Image.Image
458
+ ) -> Tuple[bool, str]:
459
+ """
460
+ Validate image and mask inputs before processing.
461
+
462
+ Parameters
463
+ ----------
464
+ image : PIL.Image
465
+ Input image
466
+ mask : PIL.Image
467
+ Input mask
468
+
469
+ Returns
470
+ -------
471
+ tuple
472
+ (is_valid, error_message)
473
+ """
474
+ if image is None:
475
+ return False, "No image provided"
476
+
477
+ if mask is None:
478
+ return False, "No mask provided"
479
+
480
+ # Check sizes match
481
+ if image.size != mask.size:
482
+ # Will be resized later, so just log a warning
483
+ logger.warning(f"Image size {image.size} != mask size {mask.size}, will resize")
484
+
485
+ return True, ""
inpainting_models.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Any, Dict, Optional, Tuple
6
+ from diffusers import StableDiffusionXLControlNetInpaintPipeline
7
+ import torch
8
+
9
+ logger = logging.getLogger(__name__)
10
+ logger.setLevel(logging.INFO)
11
+
12
+
13
+ class ImageMode(Enum):
14
+ """Image style modes for model selection."""
15
+ PHOTO = "photo"
16
+ ANIME = "anime"
17
+
18
+
19
+ @dataclass
20
+ class ModelConfig:
21
+ """Configuration for an inpainting model."""
22
+
23
+ model_id: str
24
+ name: str
25
+ description: str
26
+ mode: ImageMode
27
+ requires_variant: bool = True
28
+ variant: str = "fp16"
29
+ recommended_for: str = ""
30
+
31
+ # Model-specific settings
32
+ default_guidance_scale: float = 7.5
33
+ default_num_inference_steps: int = 25
34
+
35
+
36
+ class InpaintingModelManager:
37
+ """
38
+ Manages multiple inpainting models for different image styles.
39
+
40
+ Provides lazy loading and switching between models optimized for
41
+ photorealistic images vs anime/illustration styles.
42
+
43
+ Attributes:
44
+ AVAILABLE_MODELS: Dictionary of all supported models
45
+ current_model: Currently loaded model identifier
46
+
47
+ Example:
48
+ >>> manager = InpaintingModelManager(device="cuda")
49
+ >>> pipeline = manager.get_pipeline(ImageMode.PHOTO)
50
+ >>> # Use pipeline for inpainting
51
+ >>> manager.switch_model(ImageMode.ANIME)
52
+ """
53
+
54
+ # Available models configuration
55
+ AVAILABLE_MODELS: Dict[str, ModelConfig] = {
56
+ # Photo-realistic models
57
+ "juggernaut_xl": ModelConfig(
58
+ model_id="RunDiffusion/Juggernaut-XL-v9",
59
+ name="JuggernautXL v9",
60
+ description="Best for photorealistic images, portraits, and real photos",
61
+ mode=ImageMode.PHOTO,
62
+ requires_variant=True,
63
+ variant="fp16",
64
+ recommended_for="Real photos, portraits, professional photography",
65
+ default_guidance_scale=7.0,
66
+ default_num_inference_steps=25
67
+ ),
68
+ "realvis_xl": ModelConfig(
69
+ model_id="SG161222/RealVisXL_V4.0",
70
+ name="RealVisXL v4",
71
+ description="Excellent for realistic images with fine details",
72
+ mode=ImageMode.PHOTO,
73
+ requires_variant=True,
74
+ variant="fp16",
75
+ recommended_for="Realistic scenes, product photos, nature",
76
+ default_guidance_scale=7.0,
77
+ default_num_inference_steps=25
78
+ ),
79
+ # Anime/Illustration models
80
+ "sdxl_base": ModelConfig(
81
+ model_id="stabilityai/stable-diffusion-xl-base-1.0",
82
+ name="SDXL Base",
83
+ description="Versatile model for general use and illustrations",
84
+ mode=ImageMode.ANIME,
85
+ requires_variant=True,
86
+ variant="fp16",
87
+ recommended_for="General illustrations, digital art, versatile use",
88
+ default_guidance_scale=7.5,
89
+ default_num_inference_steps=25
90
+ ),
91
+ "animagine_xl": ModelConfig(
92
+ model_id="cagliostrolab/animagine-xl-3.1",
93
+ name="Animagine XL 3.1",
94
+ description="Specialized for anime and manga style images",
95
+ mode=ImageMode.ANIME,
96
+ requires_variant=False,
97
+ recommended_for="Anime, manga, cartoon style images",
98
+ default_guidance_scale=7.0,
99
+ default_num_inference_steps=25
100
+ ),
101
+ }
102
+
103
+ # Default model for each mode
104
+ DEFAULT_MODELS = {
105
+ ImageMode.PHOTO: "juggernaut_xl",
106
+ ImageMode.ANIME: "sdxl_base"
107
+ }
108
+
109
+ def __init__(self, device: Optional[str] = None):
110
+ """
111
+ Initialize the model manager.
112
+
113
+ Parameters
114
+ ----------
115
+ device : str, optional
116
+ Device to load models on. Auto-detected if not specified.
117
+ """
118
+ self.device = device or self._detect_device()
119
+ self._current_model_key: Optional[str] = None
120
+ self._pipeline: Optional[Any] = None
121
+ self._controlnet: Optional[Any] = None
122
+ self._controlnet_loaded: bool = False
123
+
124
+ logger.info(f"InpaintingModelManager initialized on device: {self.device}")
125
+
126
+ def _detect_device(self) -> str:
127
+ """Detect the best available device."""
128
+ if torch.cuda.is_available():
129
+ return "cuda"
130
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
131
+ return "mps"
132
+ return "cpu"
133
+
134
+ def get_models_for_mode(self, mode: ImageMode) -> Dict[str, ModelConfig]:
135
+ """
136
+ Get all available models for a specific mode.
137
+
138
+ Parameters
139
+ ----------
140
+ mode : ImageMode
141
+ The image mode (PHOTO or ANIME)
142
+
143
+ Returns
144
+ -------
145
+ dict
146
+ Dictionary of model configs for the mode
147
+ """
148
+ return {
149
+ key: config
150
+ for key, config in self.AVAILABLE_MODELS.items()
151
+ if config.mode == mode
152
+ }
153
+
154
+ def get_model_choices(self) -> Dict[str, list]:
155
+ """
156
+ Get model choices formatted for UI dropdown.
157
+
158
+ Returns
159
+ -------
160
+ dict
161
+ Dictionary with 'photo' and 'anime' lists of (display_name, key) tuples
162
+ """
163
+ choices = {
164
+ "photo": [],
165
+ "anime": []
166
+ }
167
+
168
+ for key, config in self.AVAILABLE_MODELS.items():
169
+ display = f"{config.name} - {config.description}"
170
+ if config.mode == ImageMode.PHOTO:
171
+ choices["photo"].append((display, key))
172
+ else:
173
+ choices["anime"].append((display, key))
174
+
175
+ return choices
176
+
177
+ def get_default_model(self, mode: ImageMode) -> str:
178
+ """Get the default model key for a mode."""
179
+ return self.DEFAULT_MODELS.get(mode, "sdxl_base")
180
+
181
+ def load_controlnet(self) -> Any:
182
+ """
183
+ Load the ControlNet model (shared across all base models).
184
+
185
+ Returns
186
+ -------
187
+ ControlNetModel
188
+ Loaded ControlNet model
189
+ """
190
+ if self._controlnet_loaded and self._controlnet is not None:
191
+ return self._controlnet
192
+
193
+ try:
194
+ from diffusers import ControlNetModel
195
+
196
+ logger.info("Loading ControlNet Canny model...")
197
+ self._controlnet = ControlNetModel.from_pretrained(
198
+ "diffusers/controlnet-canny-sdxl-1.0",
199
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
200
+ use_safetensors=True
201
+ )
202
+ self._controlnet_loaded = True
203
+ logger.info("ControlNet loaded successfully")
204
+ return self._controlnet
205
+
206
+ except Exception as e:
207
+ logger.error(f"Failed to load ControlNet: {e}")
208
+ raise
209
+
210
+ def load_pipeline(
211
+ self,
212
+ model_key: Optional[str] = None,
213
+ mode: Optional[ImageMode] = None
214
+ ) -> Any:
215
+ """
216
+ Load an inpainting pipeline for the specified model.
217
+
218
+ Parameters
219
+ ----------
220
+ model_key : str, optional
221
+ Specific model key to load
222
+ mode : ImageMode, optional
223
+ If model_key not specified, load default for this mode
224
+
225
+ Returns
226
+ -------
227
+ StableDiffusionXLControlNetInpaintPipeline
228
+ Loaded pipeline ready for inference
229
+ """
230
+ # Determine which model to load
231
+ if model_key is None:
232
+ if mode is None:
233
+ mode = ImageMode.PHOTO
234
+ model_key = self.get_default_model(mode)
235
+
236
+ # Check if already loaded
237
+ if self._current_model_key == model_key and self._pipeline is not None:
238
+ logger.info(f"Model {model_key} already loaded")
239
+ return self._pipeline
240
+
241
+ # Unload current model if different
242
+ if self._current_model_key != model_key:
243
+ self.unload_pipeline()
244
+
245
+ # Get model config
246
+ config = self.AVAILABLE_MODELS.get(model_key)
247
+ if config is None:
248
+ raise ValueError(f"Unknown model key: {model_key}")
249
+
250
+ logger.info(f"Loading model: {config.name} ({config.model_id})")
251
+
252
+ try:
253
+ # Ensure ControlNet is loaded
254
+ controlnet = self.load_controlnet()
255
+
256
+ # Load pipeline
257
+ dtype = torch.float16 if self.device == "cuda" else torch.float32
258
+
259
+ load_kwargs = {
260
+ "controlnet": controlnet,
261
+ "torch_dtype": dtype,
262
+ "use_safetensors": True,
263
+ }
264
+
265
+ if config.requires_variant:
266
+ load_kwargs["variant"] = config.variant
267
+
268
+ self._pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
269
+ config.model_id,
270
+ **load_kwargs
271
+ )
272
+
273
+ # Move to device and optimize
274
+ self._pipeline = self._pipeline.to(self.device)
275
+
276
+ if self.device == "cuda":
277
+ self._pipeline.enable_vae_tiling()
278
+ try:
279
+ self._pipeline.enable_xformers_memory_efficient_attention()
280
+ logger.info("xformers enabled")
281
+ except Exception:
282
+ logger.info("xformers not available, using default attention")
283
+
284
+ self._current_model_key = model_key
285
+ logger.info(f"Model {config.name} loaded successfully")
286
+
287
+ return self._pipeline
288
+
289
+ except Exception as e:
290
+ logger.error(f"Failed to load model {model_key}: {e}")
291
+ raise
292
+
293
+ def unload_pipeline(self) -> None:
294
+ """Unload the current pipeline to free memory."""
295
+ if self._pipeline is not None:
296
+ logger.info(f"Unloading model: {self._current_model_key}")
297
+ del self._pipeline
298
+ self._pipeline = None
299
+ self._current_model_key = None
300
+
301
+ if self.device == "cuda":
302
+ torch.cuda.empty_cache()
303
+ gc.collect()
304
+
305
+ def switch_model(self, model_key: str) -> Any:
306
+ """
307
+ Switch to a different model.
308
+
309
+ Parameters
310
+ ----------
311
+ model_key : str
312
+ Model key to switch to
313
+
314
+ Returns
315
+ -------
316
+ Pipeline
317
+ Newly loaded pipeline
318
+ """
319
+ return self.load_pipeline(model_key=model_key)
320
+
321
+ def get_current_model_config(self) -> Optional[ModelConfig]:
322
+ """Get the configuration of the currently loaded model."""
323
+ if self._current_model_key is None:
324
+ return None
325
+ return self.AVAILABLE_MODELS.get(self._current_model_key)
326
+
327
+ def get_pipeline(self) -> Optional[Any]:
328
+ """Get the currently loaded pipeline."""
329
+ return self._pipeline
330
+
331
+ def is_loaded(self) -> bool:
332
+ """Check if a pipeline is currently loaded."""
333
+ return self._pipeline is not None
334
+
335
+ def get_status(self) -> Dict[str, Any]:
336
+ """
337
+ Get current status of the model manager.
338
+
339
+ Returns
340
+ -------
341
+ dict
342
+ Status information
343
+ """
344
+ current_config = self.get_current_model_config()
345
+ return {
346
+ "device": self.device,
347
+ "current_model": self._current_model_key,
348
+ "current_model_name": current_config.name if current_config else None,
349
+ "is_loaded": self.is_loaded(),
350
+ "controlnet_loaded": self._controlnet_loaded,
351
+ "available_models": list(self.AVAILABLE_MODELS.keys())
352
+ }
353
+
354
+
355
+ def get_model_selection_guide() -> str:
356
+ """
357
+ Get HTML guide for model selection to display in UI.
358
+
359
+ Returns
360
+ -------
361
+ str
362
+ HTML formatted guide
363
+ """
364
+ return """
365
+ <div style="background: linear-gradient(135deg, #f5f7fa 0%, #e4e8ec 100%);
366
+ padding: 16px;
367
+ border-radius: 12px;
368
+ margin: 12px 0;
369
+ border: 1px solid #ddd;">
370
+ <h4 style="margin: 0 0 12px 0; color: #333; font-size: 16px;">
371
+ 📸 Model Selection Guide
372
+ </h4>
373
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 12px;">
374
+ <div style="background: white; padding: 12px; border-radius: 8px; border-left: 4px solid #4CAF50;">
375
+ <p style="margin: 0 0 8px 0; font-weight: bold; color: #4CAF50;">
376
+ 🖼️ Photo Mode
377
+ </p>
378
+ <p style="margin: 0; font-size: 13px; color: #555;">
379
+ <strong>Best for:</strong> Real photographs, portraits, product shots, nature photos
380
+ </p>
381
+ <p style="margin: 8px 0 0 0; font-size: 12px; color: #777;">
382
+ Recommended: JuggernautXL for portraits, RealVisXL for scenes
383
+ </p>
384
+ </div>
385
+ <div style="background: white; padding: 12px; border-radius: 8px; border-left: 4px solid #9C27B0;">
386
+ <p style="margin: 0 0 8px 0; font-weight: bold; color: #9C27B0;">
387
+ 🎨 Anime Mode
388
+ </p>
389
+ <p style="margin: 0; font-size: 13px; color: #555;">
390
+ <strong>Best for:</strong> Anime, manga, illustrations, digital art, cartoons
391
+ </p>
392
+ <p style="margin: 8px 0 0 0; font-size: 12px; color: #777;">
393
+ Recommended: Animagine XL for anime, SDXL Base for general art
394
+ </p>
395
+ </div>
396
+ </div>
397
+ </div>
398
+ """
inpainting_module.py CHANGED
@@ -4,55 +4,57 @@ import os
4
  import time
5
  import traceback
6
  from dataclasses import dataclass, field
7
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
 
9
  import cv2
10
  import numpy as np
11
  import torch
12
- from PIL import Image, ImageFilter
13
 
14
- from diffusers import ControlNetModel, DPMSolverMultistepScheduler
 
 
15
  from diffusers import StableDiffusionXLControlNetInpaintPipeline
16
- from diffusers import StableDiffusionXLInpaintPipeline
17
- from transformers import AutoImageProcessor, AutoModelForDepthEstimation
18
- from transformers import DPTImageProcessor, DPTForDepthEstimation
 
 
 
 
19
 
20
  logger = logging.getLogger(__name__)
21
  logger.setLevel(logging.INFO)
22
 
23
 
 
 
 
 
24
  @dataclass
25
  class InpaintingConfig:
26
  """Configuration for inpainting operations."""
27
 
28
- # ControlNet settings
29
  controlnet_conditioning_scale: float = 0.7
30
- conditioning_type: str = "canny" # "canny" or "depth"
31
 
32
  # Canny edge detection parameters
33
  canny_low_threshold: int = 100
34
  canny_high_threshold: int = 200
35
 
36
  # Mask settings
37
- feather_radius: int = 8
38
  min_mask_coverage: float = 0.01
39
  max_mask_coverage: float = 0.95
40
 
41
  # Generation settings
42
  num_inference_steps: int = 25
43
  guidance_scale: float = 7.5
44
- strength: float = 1.0 # Inpainting strength (0.0-1.0), 1.0 = full repaint
45
- preview_steps: int = 15
46
- preview_guidance_scale: float = 8.0
47
-
48
- # Quality settings
49
- enable_auto_optimization: bool = True
50
- max_optimization_retries: int = 3
51
- min_quality_score: float = 70.0
52
 
53
  # Memory settings
54
  enable_vae_tiling: bool = True
55
- enable_attention_slicing: bool = True
56
  max_resolution: int = 1024
57
 
58
 
@@ -66,94 +68,81 @@ class InpaintingResult:
66
  control_image: Optional[Image.Image] = None
67
  blended_image: Optional[Image.Image] = None
68
  quality_score: float = 0.0
69
- quality_details: Dict[str, Any] = field(default_factory=dict)
70
  generation_time: float = 0.0
71
- retries: int = 0
72
  error_message: str = ""
73
  metadata: Dict[str, Any] = field(default_factory=dict)
74
 
75
 
76
  class InpaintingModule:
77
  """
78
- ControlNet-based Inpainting Module for SceneWeaver.
79
 
80
- Implements StableDiffusionXLControlNetInpaintPipeline with support for
81
- Canny edge and depth map conditioning. Features two-stage generation
82
- (preview + full quality) and automatic quality optimization.
 
83
 
84
- Attributes:
85
- device: Computation device (cuda/mps/cpu)
86
- config: InpaintingConfig instance
87
- is_initialized: Whether pipeline is loaded
88
 
89
  Example:
90
  >>> module = InpaintingModule(device="cuda")
91
- >>> module.load_inpainting_pipeline(progress_callback=my_callback)
92
- >>> result = module.execute_inpainting(
93
- ... image=my_image,
94
- ... mask=my_mask,
95
- ... prompt="a beautiful garden"
96
- ... )
97
  """
98
 
99
- # Model identifiers
100
  CONTROLNET_CANNY_MODEL = "diffusers/controlnet-canny-sdxl-1.0"
101
  CONTROLNET_DEPTH_MODEL = "diffusers/controlnet-depth-sdxl-1.0"
102
  DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf"
103
  DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas"
104
- BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
 
 
 
 
 
 
 
105
 
106
  def __init__(
107
  self,
108
  device: str = "auto",
109
  config: Optional[InpaintingConfig] = None
110
  ):
111
- """
112
- Initialize the InpaintingModule.
113
-
114
- Parameters
115
- ----------
116
- device : str, optional
117
- Computation device. "auto" for automatic detection.
118
- config : InpaintingConfig, optional
119
- Configuration object. Uses defaults if not provided.
120
- """
121
  self.device = self._setup_device(device)
122
  self.config = config or InpaintingConfig()
123
 
124
- # Pipeline instances (lazy loaded)
125
- self._inpaint_pipeline = None
126
- self._controlnet_canny = None
127
- self._controlnet_depth = None
 
 
 
 
 
 
 
 
 
 
128
  self._depth_estimator = None
129
  self._depth_processor = None
130
 
131
  # State tracking
132
  self.is_initialized = False
 
133
  self._current_conditioning_type = None
134
- self._last_seed = None
135
- self._cached_latents = None
136
- self._use_controlnet = True # Track if ControlNet is available
137
-
138
- # Reference to model manager (set by SceneWeaverCore)
139
- self._model_manager = None
140
 
141
  logger.info(f"InpaintingModule initialized on {self.device}")
142
 
143
  def _setup_device(self, device: str) -> str:
144
- """
145
- Setup computation device.
146
-
147
- Parameters
148
- ----------
149
- device : str
150
- Device specification or "auto"
151
-
152
- Returns
153
- -------
154
- str
155
- Resolved device name
156
- """
157
  if device == "auto":
158
  if torch.cuda.is_available():
159
  return "cuda"
@@ -162,224 +151,159 @@ class InpaintingModule:
162
  return "cpu"
163
  return device
164
 
165
- def set_model_manager(self, manager: Any) -> None:
166
- """
167
- Set reference to ModelManager for coordinated model lifecycle.
168
-
169
- Parameters
170
- ----------
171
- manager : ModelManager
172
- The global model manager instance
173
- """
174
- self._model_manager = manager
175
- logger.info("ModelManager reference set for InpaintingModule")
176
-
177
  def _memory_cleanup(self, aggressive: bool = False) -> None:
178
- """
179
- Perform memory cleanup.
180
-
181
- Parameters
182
- ----------
183
- aggressive : bool
184
- If True, perform multiple GC rounds and sync CUDA
185
- """
186
- rounds = 5 if aggressive else 2
187
- for _ in range(rounds):
188
  gc.collect()
189
 
190
- # On Hugging Face Spaces, avoid CUDA operations in main process
191
- # CUDA operations must only happen within @spaces.GPU decorated functions
192
  is_spaces = os.getenv('SPACE_ID') is not None
193
-
194
  if not is_spaces and torch.cuda.is_available():
195
  torch.cuda.empty_cache()
196
  if aggressive:
197
  torch.cuda.ipc_collect()
198
- torch.cuda.synchronize()
199
-
200
- logger.debug(f"Memory cleanup completed (aggressive={aggressive}, spaces={is_spaces})")
201
-
202
- def _check_memory_status(self) -> Dict[str, float]:
203
- """
204
- Check current GPU memory status.
205
-
206
- Returns
207
- -------
208
- dict
209
- Memory statistics including allocated, total, and usage ratio
210
- """
211
- # On Spaces, skip CUDA checks in main process
212
- is_spaces = os.getenv('SPACE_ID') is not None
213
-
214
- if is_spaces or not torch.cuda.is_available():
215
- return {"available": True, "usage_ratio": 0.0}
216
 
217
- allocated = torch.cuda.memory_allocated() / 1024**3
218
- total = torch.cuda.get_device_properties(0).total_memory / 1024**3
219
- usage_ratio = allocated / total
220
-
221
- return {
222
- "allocated_gb": round(allocated, 2),
223
- "total_gb": round(total, 2),
224
- "free_gb": round(total - allocated, 2),
225
- "usage_ratio": round(usage_ratio, 3),
226
- "available": usage_ratio < 0.9
227
- }
228
-
229
- def load_inpainting_pipeline(
230
  self,
 
231
  conditioning_type: str = "canny",
 
232
  progress_callback: Optional[Callable[[str, int], None]] = None
233
  ) -> Tuple[bool, str]:
234
  """
235
- Load the ControlNet inpainting pipeline.
236
-
237
- Implements mutual exclusion with background generation pipeline.
238
- Only one pipeline can be loaded at a time.
239
 
240
  Parameters
241
  ----------
 
 
 
242
  conditioning_type : str
243
- Type of ControlNet conditioning: "canny" or "depth"
 
 
244
  progress_callback : callable, optional
245
- Function(message, percentage) for progress updates
246
 
247
  Returns
248
  -------
249
  tuple
250
  (success: bool, error_message: str)
251
  """
252
- if self.is_initialized and self._current_conditioning_type == conditioning_type:
253
- logger.info(f"Inpainting pipeline already loaded with {conditioning_type}")
 
 
 
 
 
 
 
254
  return True, ""
255
 
256
- logger.info(f"Loading inpainting pipeline with {conditioning_type} conditioning...")
257
 
258
  try:
259
  self._memory_cleanup(aggressive=True)
260
 
261
  if progress_callback:
262
- progress_callback("Preparing to load inpainting models...", 5)
 
 
 
263
 
264
- # Unload existing pipeline if different conditioning type
265
- if self._inpaint_pipeline is not None:
266
- self._unload_pipeline()
267
 
268
- # Use ControlNet inpainting by default
269
- use_controlnet_inpaint = True
270
- logger.info("Using StableDiffusionXLControlNetInpaintPipeline")
 
271
 
272
- if progress_callback:
273
- progress_callback("Loading ControlNet model...", 20)
 
 
 
 
 
 
274
 
275
- # Load appropriate ControlNet
276
- dtype = torch.float16 if self.device == "cuda" else torch.float32
277
- controlnet = None
 
 
278
 
279
- if use_controlnet_inpaint:
 
 
 
280
  if conditioning_type == "canny":
281
- controlnet = ControlNetModel.from_pretrained(
282
  self.CONTROLNET_CANNY_MODEL,
283
  torch_dtype=dtype,
284
  use_safetensors=True
285
  )
286
- self._controlnet_canny = controlnet
287
- logger.info("Loaded ControlNet Canny model")
288
-
289
  elif conditioning_type == "depth":
290
- controlnet = ControlNetModel.from_pretrained(
291
  self.CONTROLNET_DEPTH_MODEL,
292
  torch_dtype=dtype,
293
  use_safetensors=True
294
  )
295
- self._controlnet_depth = controlnet
296
-
297
- # Load depth estimator
298
- if progress_callback:
299
- progress_callback("Loading depth estimation model...", 35)
300
  self._load_depth_estimator()
301
- logger.info("Loaded ControlNet Depth model")
302
  else:
303
  raise ValueError(f"Unknown conditioning type: {conditioning_type}")
304
- else:
305
- # Skip ControlNet loading for fallback mode
306
- logger.info(f"Skipping ControlNet loading (fallback mode)")
307
 
308
- if progress_callback:
309
- progress_callback("Loading SDXL Inpainting pipeline...", 50)
 
 
 
 
 
 
 
 
 
 
310
 
311
- # Load the inpainting pipeline
312
- if use_controlnet_inpaint and controlnet is not None:
313
- self._inpaint_pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
314
- self.BASE_MODEL,
315
- controlnet=controlnet,
316
- torch_dtype=dtype,
317
- use_safetensors=True,
318
- variant="fp16" if dtype == torch.float16 else None
319
  )
320
- else:
321
- # Fallback: Use dedicated inpainting model without ControlNet
322
- self._inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
323
- "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
324
- torch_dtype=dtype,
325
- use_safetensors=True,
326
- variant="fp16" if dtype == torch.float16 else None
327
- )
328
- self._use_controlnet = False
329
-
330
- # Track ControlNet usage
331
- self._use_controlnet = use_controlnet_inpaint and controlnet is not None
332
 
333
  if progress_callback:
334
- progress_callback("Configuring scheduler...", 70)
335
 
336
- # Configure scheduler for faster generation
337
- self._inpaint_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
338
- self._inpaint_pipeline.scheduler.config
339
  )
340
 
341
- # Move to device
342
- self._inpaint_pipeline = self._inpaint_pipeline.to(self.device)
343
-
344
- if progress_callback:
345
- progress_callback("Applying optimizations...", 85)
346
-
347
- # Apply memory optimizations
348
- self._apply_pipeline_optimizations()
349
-
350
- # Set eval mode
351
- self._inpaint_pipeline.unet.eval()
352
- if hasattr(self._inpaint_pipeline, 'vae'):
353
- self._inpaint_pipeline.vae.eval()
354
 
355
  self.is_initialized = True
356
- self._current_conditioning_type = conditioning_type if self._use_controlnet else "none"
357
 
358
  if progress_callback:
359
- progress_callback("Inpainting pipeline ready!", 100)
360
-
361
- # Log memory status
362
- mem_status = self._check_memory_status()
363
- logger.info(f"Pipeline loaded. GPU memory: {mem_status.get('allocated_gb', 0):.1f}GB used")
364
 
365
  return True, ""
366
 
367
  except Exception as e:
368
  error_msg = str(e)
369
- logger.error(f"Failed to load inpainting pipeline: {error_msg}")
370
  traceback.print_exc()
371
  self._unload_pipeline()
372
  return False, error_msg
373
 
374
  def _load_depth_estimator(self) -> None:
375
- """
376
- Load depth estimation model with fallback strategy.
377
-
378
- Tries Depth-Anything first, falls back to MiDaS if unavailable.
379
- """
380
  try:
381
- logger.info(f"Attempting to load depth model: {self.DEPTH_MODEL_PRIMARY}")
382
-
383
  self._depth_processor = AutoImageProcessor.from_pretrained(
384
  self.DEPTH_MODEL_PRIMARY
385
  )
@@ -389,70 +313,50 @@ class InpaintingModule:
389
  )
390
  self._depth_estimator.to(self.device)
391
  self._depth_estimator.eval()
392
-
393
- logger.info("Successfully loaded Depth-Anything model")
394
-
395
  except Exception as e:
396
  logger.warning(f"Primary depth model failed: {e}, trying fallback...")
 
 
 
 
 
 
 
 
 
 
397
 
398
- try:
399
- self._depth_processor = DPTImageProcessor.from_pretrained(
400
- self.DEPTH_MODEL_FALLBACK
401
- )
402
- self._depth_estimator = DPTForDepthEstimation.from_pretrained(
403
- self.DEPTH_MODEL_FALLBACK,
404
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
405
- )
406
- self._depth_estimator.to(self.device)
407
- self._depth_estimator.eval()
408
-
409
- logger.info("Successfully loaded MiDaS fallback model")
410
-
411
- except Exception as fallback_e:
412
- logger.error(f"Fallback depth model also failed: {fallback_e}")
413
- raise RuntimeError("Unable to load any depth estimation model")
414
-
415
- def _apply_pipeline_optimizations(self) -> None:
416
- """Apply memory and performance optimizations to the pipeline."""
417
- if self._inpaint_pipeline is None:
418
  return
419
 
420
- # Try xformers first
421
  try:
422
- self._inpaint_pipeline.enable_xformers_memory_efficient_attention()
423
- logger.info("Enabled xformers memory efficient attention")
424
  except Exception:
425
  try:
426
- self._inpaint_pipeline.enable_attention_slicing()
427
  logger.info("Enabled attention slicing")
428
  except Exception:
429
- logger.warning("No attention optimization available")
430
 
431
- # VAE optimizations
432
  if self.config.enable_vae_tiling:
433
- if hasattr(self._inpaint_pipeline, 'enable_vae_tiling'):
434
- self._inpaint_pipeline.enable_vae_tiling()
435
- logger.debug("Enabled VAE tiling")
436
-
437
- if hasattr(self._inpaint_pipeline, 'enable_vae_slicing'):
438
- self._inpaint_pipeline.enable_vae_slicing()
439
- logger.debug("Enabled VAE slicing")
440
 
441
  def _unload_pipeline(self) -> None:
442
- """Unload the inpainting pipeline and free memory."""
443
- logger.info("Unloading inpainting pipeline...")
 
 
444
 
445
- if self._inpaint_pipeline is not None:
446
- del self._inpaint_pipeline
447
- self._inpaint_pipeline = None
448
-
449
- if self._controlnet_canny is not None:
450
- del self._controlnet_canny
451
- self._controlnet_canny = None
452
-
453
- if self._controlnet_depth is not None:
454
- del self._controlnet_depth
455
- self._controlnet_depth = None
456
 
457
  if self._depth_estimator is not None:
458
  del self._depth_estimator
@@ -463,942 +367,300 @@ class InpaintingModule:
463
  self._depth_processor = None
464
 
465
  self.is_initialized = False
 
466
  self._current_conditioning_type = None
467
- self._cached_latents = None
468
 
469
  self._memory_cleanup(aggressive=True)
470
- logger.info("Inpainting pipeline unloaded")
471
-
472
- def prepare_control_image(
473
- self,
474
- image: Image.Image,
475
- mode: str = "canny",
476
- mask: Optional[Image.Image] = None,
477
- preserve_structure: bool = False
478
- ) -> Image.Image:
479
- """
480
- Generate ControlNet conditioning image.
481
-
482
- Parameters
483
- ----------
484
- image : PIL.Image
485
- Input image
486
- mode : str
487
- Conditioning mode: "canny" or "depth"
488
- mask : PIL.Image, optional
489
- If provided, can suppress edges in masked region (when preserve_structure=False).
490
- preserve_structure : bool
491
- If True, keep edges in masked region (for color change tasks).
492
- If False, suppress edges in masked region (for replacement/removal tasks).
493
-
494
- Returns
495
- -------
496
- PIL.Image
497
- Generated control image (edges or depth map)
498
- """
499
- logger.info(f"Preparing control image with mode: {mode}, preserve_structure: {preserve_structure}")
500
-
501
- # Convert to RGB if needed
502
- if image.mode != 'RGB':
503
- image = image.convert('RGB')
504
-
505
- img_array = np.array(image)
506
-
507
- if mode == "canny":
508
- canny_image = self._generate_canny_edges(img_array)
509
-
510
- # Mask-aware processing: suppress edges in masked region ONLY if not preserving structure
511
- if mask is not None and not preserve_structure:
512
- canny_array = np.array(canny_image)
513
- mask_array = np.array(mask.convert('L'))
514
-
515
- # In masked region, completely suppress Canny edges
516
- # This allows complete replacement/removal of the object
517
- mask_region = mask_array > 128 # White = masked area
518
- canny_array[mask_region] = 0
519
-
520
- canny_image = Image.fromarray(canny_array)
521
- logger.info("Suppressed edges in masked region for replacement/removal")
522
- elif preserve_structure:
523
- logger.info("Preserving edges in masked region for color change")
524
-
525
- return canny_image
526
-
527
- elif mode == "depth":
528
- return self._generate_depth_map(image)
529
- else:
530
- raise ValueError(f"Unknown control mode: {mode}")
531
-
532
- def _generate_canny_edges(self, img_array: np.ndarray) -> Image.Image:
533
- """
534
- Generate Canny edge detection image.
535
-
536
- Parameters
537
- ----------
538
- img_array : np.ndarray
539
- Input image as RGB numpy array
540
-
541
- Returns
542
- -------
543
- PIL.Image
544
- Edge detection result as grayscale image
545
- """
546
- # Convert to grayscale
547
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
548
-
549
- # Apply Gaussian blur to reduce noise
550
- blurred = cv2.GaussianBlur(gray, (5, 5), 1.4)
551
-
552
- # Canny edge detection
553
- edges = cv2.Canny(
554
- blurred,
555
- self.config.canny_low_threshold,
556
- self.config.canny_high_threshold
557
- )
558
-
559
- # Convert to 3-channel for ControlNet
560
- edges_3ch = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
561
-
562
- logger.debug(f"Generated Canny edges with thresholds "
563
- f"{self.config.canny_low_threshold}/{self.config.canny_high_threshold}")
564
-
565
- return Image.fromarray(edges_3ch)
566
-
567
- def _generate_depth_map(self, image: Image.Image) -> Image.Image:
568
- """
569
- Generate depth map using depth estimation model.
570
-
571
- Parameters
572
- ----------
573
- image : PIL.Image
574
- Input RGB image
575
-
576
- Returns
577
- -------
578
- PIL.Image
579
- Depth map as grayscale image
580
- """
581
- if self._depth_estimator is None or self._depth_processor is None:
582
- raise RuntimeError("Depth estimator not loaded")
583
-
584
- # Preprocess
585
- inputs = self._depth_processor(images=image, return_tensors="pt")
586
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
587
-
588
- # Inference
589
- with torch.no_grad():
590
- outputs = self._depth_estimator(**inputs)
591
- predicted_depth = outputs.predicted_depth
592
-
593
- # Interpolate to original size
594
- prediction = torch.nn.functional.interpolate(
595
- predicted_depth.unsqueeze(1),
596
- size=image.size[::-1], # (H, W)
597
- mode="bicubic",
598
- align_corners=False
599
- )
600
-
601
- # Normalize to 0-255
602
- depth_array = prediction.squeeze().cpu().numpy()
603
- depth_min = depth_array.min()
604
- depth_max = depth_array.max()
605
-
606
- if depth_max - depth_min > 0:
607
- depth_normalized = ((depth_array - depth_min) / (depth_max - depth_min) * 255)
608
- else:
609
- depth_normalized = np.zeros_like(depth_array)
610
-
611
- depth_normalized = depth_normalized.astype(np.uint8)
612
-
613
- # Convert to 3-channel for ControlNet
614
- depth_3ch = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB)
615
-
616
- logger.debug(f"Generated depth map, range: {depth_min:.2f} - {depth_max:.2f}")
617
-
618
- return Image.fromarray(depth_3ch)
619
-
620
- def prepare_mask(
621
- self,
622
- mask: Image.Image,
623
- target_size: Tuple[int, int],
624
- feather_radius: Optional[int] = None
625
- ) -> Tuple[Image.Image, Dict[str, Any]]:
626
- """
627
- Prepare and validate mask for inpainting.
628
-
629
- Parameters
630
- ----------
631
- mask : PIL.Image
632
- Input mask (white = inpaint area)
633
- target_size : tuple
634
- Target (width, height) to match input image
635
- feather_radius : int, optional
636
- Feathering radius in pixels. Uses config default if None.
637
-
638
- Returns
639
- -------
640
- tuple
641
- (processed_mask, validation_info)
642
-
643
- Raises
644
- ------
645
- ValueError
646
- If mask coverage is outside acceptable range
647
- """
648
- feather = feather_radius if feather_radius is not None else self.config.feather_radius
649
-
650
- # Convert to grayscale
651
- if mask.mode != 'L':
652
- mask = mask.convert('L')
653
-
654
- # Resize to match target
655
- if mask.size != target_size:
656
- mask = mask.resize(target_size, Image.LANCZOS)
657
-
658
- # Convert to array for processing
659
- mask_array = np.array(mask)
660
-
661
- # Calculate coverage
662
- total_pixels = mask_array.size
663
- white_pixels = np.count_nonzero(mask_array > 127)
664
- coverage = white_pixels / total_pixels
665
-
666
- validation_info = {
667
- "coverage": coverage,
668
- "white_pixels": white_pixels,
669
- "total_pixels": total_pixels,
670
- "feather_radius": feather,
671
- "valid": True,
672
- "warning": ""
673
- }
674
-
675
- # Validate coverage
676
- if coverage < self.config.min_mask_coverage:
677
- validation_info["valid"] = False
678
- validation_info["warning"] = (
679
- f"Mask coverage too low ({coverage:.1%}). "
680
- f"Please select a larger area to inpaint."
681
- )
682
- logger.warning(f"Mask coverage {coverage:.1%} below minimum {self.config.min_mask_coverage:.1%}")
683
-
684
- elif coverage > self.config.max_mask_coverage:
685
- validation_info["valid"] = False
686
- validation_info["warning"] = (
687
- f"Mask coverage too high ({coverage:.1%}). "
688
- f"Consider using background generation instead."
689
- )
690
- logger.warning(f"Mask coverage {coverage:.1%} above maximum {self.config.max_mask_coverage:.1%}")
691
-
692
- # Apply feathering
693
- if feather > 0:
694
- mask_array = cv2.GaussianBlur(
695
- mask_array,
696
- (feather * 2 + 1, feather * 2 + 1),
697
- feather / 2
698
- )
699
- logger.debug(f"Applied {feather}px feathering to mask")
700
-
701
- processed_mask = Image.fromarray(mask_array, mode='L')
702
-
703
- return processed_mask, validation_info
704
-
705
- def enhance_prompt_for_inpainting(
706
- self,
707
- prompt: str,
708
- image: Image.Image,
709
- mask: Image.Image
710
- ) -> Tuple[str, str]:
711
- """
712
- Enhance prompt based on non-masked region analysis.
713
-
714
- Analyzes the surrounding context to generate appropriate
715
- lighting and color descriptors.
716
-
717
- Parameters
718
- ----------
719
- prompt : str
720
- User-provided prompt
721
- image : PIL.Image
722
- Original image
723
- mask : PIL.Image
724
- Inpainting mask
725
-
726
- Returns
727
- -------
728
- tuple
729
- (enhanced_prompt, negative_prompt)
730
- """
731
- logger.info("Enhancing prompt for inpainting context...")
732
-
733
- # Convert to arrays
734
- img_array = np.array(image.convert('RGB'))
735
- mask_array = np.array(mask.convert('L'))
736
-
737
- # Analyze non-masked regions
738
- non_masked = mask_array < 127
739
-
740
- if not np.any(non_masked):
741
- # No context available
742
- enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic"
743
- negative_prompt = self._get_inpainting_negative_prompt()
744
- return enhanced_prompt, negative_prompt
745
-
746
- # Extract context pixels
747
- context_pixels = img_array[non_masked]
748
-
749
- # Convert to Lab for analysis
750
- context_lab = cv2.cvtColor(
751
- context_pixels.reshape(-1, 1, 3),
752
- cv2.COLOR_RGB2LAB
753
- ).reshape(-1, 3)
754
-
755
- # Use robust statistics (median) to avoid outlier influence
756
- median_l = np.median(context_lab[:, 0])
757
- median_a = np.median(context_lab[:, 1])
758
- median_b = np.median(context_lab[:, 2])
759
-
760
- # Analyze lighting conditions
761
- lighting_descriptors = []
762
-
763
- if median_l > 170:
764
- lighting_descriptors.append("bright")
765
- elif median_l > 130:
766
- lighting_descriptors.append("well-lit")
767
- elif median_l > 80:
768
- lighting_descriptors.append("moderate lighting")
769
- else:
770
- lighting_descriptors.append("dim lighting")
771
-
772
- # Analyze color temperature (b channel: blue(-) to yellow(+))
773
- if median_b > 140:
774
- lighting_descriptors.append("warm golden tones")
775
- elif median_b > 120:
776
- lighting_descriptors.append("warm afternoon light")
777
- elif median_b < 110:
778
- lighting_descriptors.append("cool neutral tones")
779
-
780
- # Calculate saturation from context
781
- hsv = cv2.cvtColor(context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV)
782
- median_saturation = np.median(hsv[:, :, 1])
783
-
784
- if median_saturation > 150:
785
- lighting_descriptors.append("vibrant colors")
786
- elif median_saturation < 80:
787
- lighting_descriptors.append("subtle muted colors")
788
-
789
- # Build enhanced prompt
790
- lighting_desc = ", ".join(lighting_descriptors) if lighting_descriptors else ""
791
- quality_suffix = "high quality, detailed, photorealistic, seamless integration"
792
-
793
- if lighting_desc:
794
- enhanced_prompt = f"{prompt}, {lighting_desc}, {quality_suffix}"
795
- else:
796
- enhanced_prompt = f"{prompt}, {quality_suffix}"
797
-
798
- negative_prompt = self._get_inpainting_negative_prompt()
799
-
800
- logger.info(f"Enhanced prompt with context: {lighting_desc}")
801
-
802
- return enhanced_prompt, negative_prompt
803
-
804
- def _get_inpainting_negative_prompt(self) -> str:
805
- """Get standard negative prompt for inpainting."""
806
- return (
807
- "inconsistent lighting, wrong perspective, mismatched colors, "
808
- "visible seams, blending artifacts, color bleeding, "
809
- "blurry, low quality, distorted, deformed, "
810
- "harsh edges, unnatural transition"
811
- )
812
 
813
  def execute_inpainting(
814
  self,
815
  image: Image.Image,
816
  mask: Image.Image,
817
  prompt: str,
818
- preview_only: bool = False,
819
- seed: Optional[int] = None,
820
  progress_callback: Optional[Callable[[str, int], None]] = None,
821
  **kwargs
822
  ) -> InpaintingResult:
823
  """
824
- Execute the inpainting operation.
825
-
826
- Implements two-stage generation: fast preview followed by
827
- full quality generation if requested.
828
 
829
  Parameters
830
  ----------
831
  image : PIL.Image
832
- Original image to inpaint
833
  mask : PIL.Image
834
  Inpainting mask (white = area to regenerate)
835
  prompt : str
836
- Text description of desired content
837
- preview_only : bool
838
- If True, only generate preview (faster)
839
- seed : int, optional
840
- Random seed for reproducibility
841
  progress_callback : callable, optional
842
- Progress update function(message, percentage)
843
  **kwargs
844
- Additional parameters:
845
- - controlnet_conditioning_scale: float
846
- - feather_radius: int
847
- - num_inference_steps: int
848
- - guidance_scale: float
849
 
850
  Returns
851
  -------
852
  InpaintingResult
853
- Result container with generated images and metadata
854
  """
855
  start_time = time.time()
856
 
857
  if not self.is_initialized:
858
  return InpaintingResult(
859
  success=False,
860
- error_message="Inpainting pipeline not initialized. Call load_inpainting_pipeline() first."
861
  )
862
 
863
- logger.info(f"Starting inpainting: prompt='{prompt[:50]}...', preview_only={preview_only}")
864
 
865
  try:
866
- # Update config with kwargs
867
- conditioning_scale = kwargs.get(
868
- 'controlnet_conditioning_scale',
869
- self.config.controlnet_conditioning_scale
870
- )
871
- feather_radius = kwargs.get('feather_radius', self.config.feather_radius)
872
- strength = kwargs.get('strength', self.config.strength)
873
- preserve_structure = kwargs.get('preserve_structure_in_mask', False)
874
-
875
  if progress_callback:
876
- progress_callback("Preparing images...", 5)
877
 
878
  # Prepare image
879
  if image.mode != 'RGB':
880
  image = image.convert('RGB')
881
 
882
- # Ensure dimensions are multiple of 8
 
 
 
883
  width, height = image.size
884
  new_width = (width // 8) * 8
885
  new_height = (height // 8) * 8
886
-
887
  if new_width != width or new_height != height:
888
  image = image.resize((new_width, new_height), Image.LANCZOS)
889
 
890
- # Check and potentially reduce resolution for memory
891
  max_res = self.config.max_resolution
892
  if max(new_width, new_height) > max_res:
893
  scale = max_res / max(new_width, new_height)
894
  new_width = int(new_width * scale) // 8 * 8
895
  new_height = int(new_height * scale) // 8 * 8
896
  image = image.resize((new_width, new_height), Image.LANCZOS)
897
- logger.info(f"Reduced resolution to {new_width}x{new_height} for memory")
898
 
899
- # Prepare mask
900
- if progress_callback:
901
- progress_callback("Processing mask...", 10)
902
-
903
- processed_mask, mask_info = self.prepare_mask(
904
  mask,
905
  (new_width, new_height),
906
- feather_radius
907
- )
908
-
909
- if not mask_info["valid"]:
910
- return InpaintingResult(
911
- success=False,
912
- error_message=mask_info["warning"]
913
- )
914
-
915
- # Generate control image
916
- if progress_callback:
917
- progress_callback("Generating control image...", 20)
918
-
919
- control_image = self.prepare_control_image(
920
- image,
921
- self._current_conditioning_type,
922
- mask=processed_mask,
923
- preserve_structure=preserve_structure # True for color change, False for replacement/removal
924
  )
925
 
926
- # Conditional prompt enhancement based on template
927
- # Check if we should enhance the prompt or use it directly
928
- should_enhance = kwargs.get('enhance_prompt', False) # Default: no enhancement
 
 
929
 
930
- if should_enhance:
931
- if progress_callback:
932
- progress_callback("Enhancing prompt...", 25)
933
- enhanced_prompt, negative_prompt = self.enhance_prompt_for_inpainting(
934
- prompt, image, processed_mask
935
- )
936
- logger.info(f"Prompt enhanced with OpenCLIP context")
937
- else:
938
- # Use prompt directly without enhancement
939
- enhanced_prompt = prompt
940
- negative_prompt = self._get_inpainting_negative_prompt()
941
- logger.info("Prompt enhancement disabled for this template")
942
 
943
- # Setup generator for reproducibility
944
- if seed is None:
 
 
945
  seed = int(time.time() * 1000) % (2**32)
946
- self._last_seed = seed
 
947
  generator = torch.Generator(device=self.device).manual_seed(seed)
 
948
 
949
- # Check if running on Hugging Face Spaces
950
- is_spaces = os.getenv('SPACE_ID') is not None
951
-
952
- # Stage 1: Preview generation
953
- # On Spaces, skip preview to save time (300s hard limit)
954
- preview_result = None
955
-
956
- if preview_only or not is_spaces:
957
  if progress_callback:
958
- progress_callback("Generating preview...", 30)
959
-
960
- # Optimize preview steps for Hugging Face Spaces
961
- preview_steps = self.config.preview_steps
962
- if is_spaces:
963
- # On Spaces, use minimal preview steps
964
- preview_steps = min(preview_steps, 8)
965
- logger.debug(f"Spaces environment - using {preview_steps} preview steps")
966
 
967
- preview_result = self._generate_inpaint(
968
  image=image,
969
  mask=processed_mask,
970
- control_image=control_image,
971
- prompt=enhanced_prompt,
972
  negative_prompt=negative_prompt,
973
- num_inference_steps=preview_steps,
974
- guidance_scale=self.config.preview_guidance_scale,
975
- controlnet_conditioning_scale=conditioning_scale,
976
  strength=strength,
977
  generator=generator
978
  )
 
 
979
  else:
980
- logger.debug("Spaces environment - skipping preview to fit 300s limit")
 
 
981
 
982
- if preview_only:
983
- generation_time = time.time() - start_time
 
984
 
985
- return InpaintingResult(
986
- success=True,
987
- preview_image=preview_result,
988
- control_image=control_image,
989
- generation_time=generation_time,
990
- metadata={
991
- "seed": seed,
992
- "prompt": enhanced_prompt,
993
- "conditioning_type": self._current_conditioning_type,
994
- "conditioning_scale": conditioning_scale,
995
- "preview_only": True
996
- }
997
  )
998
 
999
- # Stage 2: Full quality generation
1000
- if progress_callback:
1001
- progress_callback("Generating full quality...", 60)
1002
-
1003
- # Use same seed for reproducibility
1004
- generator = torch.Generator(device=self.device).manual_seed(seed)
1005
-
1006
- num_steps = kwargs.get('num_inference_steps', self.config.num_inference_steps)
1007
- guidance = kwargs.get('guidance_scale', self.config.guidance_scale)
1008
-
1009
- # Optimize for Hugging Face Spaces ZeroGPU (stateless, 300s hard limit)
1010
- if is_spaces:
1011
- # ZeroGPU timing breakdown with model caching (actual measurements):
1012
- # - Model loading from cache: ~60s (cached models, CPU to GPU transfer)
1013
- # - Inference: ~28-29s/step (observed on shared H200)
1014
- # - Blending & overhead: ~35s
1015
- # - Platform limit: 300s hard limit (Pro tier)
1016
- #
1017
- # Strategy with unified 10-step approach:
1018
- # - Skip preview completely (done above)
1019
- # - Use 10 steps for balance of quality and speed
1020
- # - Time budget: 60s (load) + 285s (10 steps) + 35s (blend) = 380s
1021
- # - Note: Still may timeout, but parameter optimization is more important than step count
1022
- # - Quality comes from correct conditioning_scale, not high step count
1023
-
1024
- spaces_max_steps = 10 # Optimized: 10 steps sufficient with proper parameters
1025
-
1026
- if num_steps > spaces_max_steps:
1027
- num_steps = spaces_max_steps
1028
- logger.debug(f"Spaces deployment: using {num_steps} steps (optimized for parameter quality)")
1029
-
1030
- full_result = self._generate_inpaint(
1031
- image=image,
1032
- mask=processed_mask,
1033
- control_image=control_image,
1034
- prompt=enhanced_prompt,
1035
- negative_prompt=negative_prompt,
1036
- num_inference_steps=num_steps,
1037
- guidance_scale=guidance,
1038
- controlnet_conditioning_scale=conditioning_scale,
1039
- strength=strength,
1040
- generator=generator
1041
- )
1042
 
1043
- if progress_callback:
1044
- progress_callback("Blending result...", 90)
 
 
1045
 
1046
- # Blend result
1047
- blended = self.blend_result(image, full_result, processed_mask)
 
 
 
 
 
 
 
 
 
 
1048
 
1049
  generation_time = time.time() - start_time
1050
 
 
 
 
 
 
1051
  if progress_callback:
1052
  progress_callback("Complete!", 100)
1053
 
1054
  return InpaintingResult(
1055
  success=True,
1056
- result_image=full_result,
1057
- preview_image=preview_result,
1058
  control_image=control_image,
1059
- blended_image=blended,
1060
  generation_time=generation_time,
1061
  metadata={
1062
  "seed": seed,
1063
- "prompt": enhanced_prompt,
1064
- "negative_prompt": negative_prompt,
1065
- "conditioning_type": self._current_conditioning_type,
1066
- "conditioning_scale": conditioning_scale,
1067
  "strength": strength,
1068
- "preserve_structure": preserve_structure,
1069
- "num_inference_steps": num_steps,
1070
- "guidance_scale": guidance,
1071
- "feather_radius": feather_radius,
1072
- "mask_coverage": mask_info["coverage"],
1073
- "preview_only": False
1074
  }
1075
  )
1076
 
1077
  except torch.cuda.OutOfMemoryError:
1078
- logger.error("CUDA out of memory during inpainting")
1079
  self._memory_cleanup(aggressive=True)
1080
  return InpaintingResult(
1081
  success=False,
1082
- error_message="GPU memory exhausted. Try reducing image size or closing other applications."
1083
  )
1084
-
1085
  except Exception as e:
1086
  logger.error(f"Inpainting failed: {e}")
1087
- logger.error(traceback.format_exc())
1088
  return InpaintingResult(
1089
  success=False,
1090
- error_message=f"Inpainting failed: {str(e)}"
1091
  )
1092
 
1093
- def _generate_inpaint(
1094
  self,
1095
- image: Image.Image,
1096
  mask: Image.Image,
1097
- control_image: Image.Image,
1098
- prompt: str,
1099
- negative_prompt: str,
1100
- num_inference_steps: int,
1101
- guidance_scale: float,
1102
- controlnet_conditioning_scale: float,
1103
- strength: float,
1104
- generator: torch.Generator
1105
- ) -> Image.Image:
1106
- """
1107
- Internal method to run the inpainting pipeline.
1108
-
1109
- Supports both ControlNet and non-ControlNet pipelines.
1110
-
1111
- Parameters
1112
- ----------
1113
- image : PIL.Image
1114
- Original image
1115
- mask : PIL.Image
1116
- Processed mask
1117
- control_image : PIL.Image
1118
- ControlNet conditioning image (ignored if ControlNet not available)
1119
- prompt : str
1120
- Enhanced prompt
1121
- negative_prompt : str
1122
- Negative prompt
1123
- num_inference_steps : int
1124
- Number of denoising steps
1125
- guidance_scale : float
1126
- Classifier-free guidance scale
1127
- controlnet_conditioning_scale : float
1128
- ControlNet influence strength (ignored if ControlNet not available)
1129
- strength : float
1130
- Inpainting strength (0.0-1.0). 1.0 = fully repaint masked area.
1131
- generator : torch.Generator
1132
- Random generator for reproducibility
1133
-
1134
- Returns
1135
- -------
1136
- PIL.Image
1137
- Generated image
1138
- """
1139
- with torch.inference_mode():
1140
- if self._use_controlnet:
1141
- # Full ControlNet inpainting pipeline
1142
- result = self._inpaint_pipeline(
1143
- prompt=prompt,
1144
- negative_prompt=negative_prompt,
1145
- image=image,
1146
- mask_image=mask,
1147
- control_image=control_image,
1148
- num_inference_steps=num_inference_steps,
1149
- guidance_scale=guidance_scale,
1150
- controlnet_conditioning_scale=controlnet_conditioning_scale,
1151
- strength=strength,
1152
- generator=generator
1153
- )
1154
- else:
1155
- # Fallback: Standard SDXL inpainting without ControlNet
1156
- result = self._inpaint_pipeline(
1157
- prompt=prompt,
1158
- negative_prompt=negative_prompt,
1159
- image=image,
1160
- mask_image=mask,
1161
- num_inference_steps=num_inference_steps,
1162
- guidance_scale=guidance_scale,
1163
- strength=strength,
1164
- generator=generator
1165
- )
1166
-
1167
- return result.images[0]
1168
-
1169
- def blend_result(
1170
- self,
1171
- original: Image.Image,
1172
- generated: Image.Image,
1173
- mask: Image.Image
1174
  ) -> Image.Image:
1175
- """
1176
- Blend generated content with original image.
1177
-
1178
- Uses linear color space blending for accurate results.
 
 
1179
 
1180
- Parameters
1181
- ----------
1182
- original : PIL.Image
1183
- Original image
1184
- generated : PIL.Image
1185
- Generated inpainted image
1186
- mask : PIL.Image
1187
- Blending mask (white = use generated)
1188
 
1189
- Returns
1190
- -------
1191
- PIL.Image
1192
- Blended result
1193
- """
1194
- logger.info("Blending inpainting result...")
1195
-
1196
- # Ensure same size
1197
- if generated.size != original.size:
1198
- generated = generated.resize(original.size, Image.LANCZOS)
1199
- if mask.size != original.size:
1200
- mask = mask.resize(original.size, Image.LANCZOS)
1201
-
1202
- # Convert to arrays
1203
- orig_array = np.array(original.convert('RGB')).astype(np.float32)
1204
- gen_array = np.array(generated.convert('RGB')).astype(np.float32)
1205
- mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
1206
-
1207
- # sRGB to linear conversion
1208
- def srgb_to_linear(img):
1209
- img_norm = img / 255.0
1210
- return np.where(
1211
- img_norm <= 0.04045,
1212
- img_norm / 12.92,
1213
- np.power((img_norm + 0.055) / 1.055, 2.4)
1214
  )
 
 
1215
 
1216
- def linear_to_srgb(img):
1217
- img_clipped = np.clip(img, 0, 1)
1218
- return np.where(
1219
- img_clipped <= 0.0031308,
1220
- 12.92 * img_clipped,
1221
- 1.055 * np.power(img_clipped, 1/2.4) - 0.055
1222
  )
1223
 
1224
- # Convert to linear space
1225
- orig_linear = srgb_to_linear(orig_array)
1226
- gen_linear = srgb_to_linear(gen_array)
1227
-
1228
- # Alpha blending in linear space
1229
- alpha = mask_array[:, :, np.newaxis]
1230
- result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
1231
-
1232
- # Convert back to sRGB
1233
- result_srgb = linear_to_srgb(result_linear)
1234
- result_array = (result_srgb * 255).astype(np.uint8)
1235
 
1236
- logger.debug("Blending completed in linear color space")
1237
-
1238
- return Image.fromarray(result_array)
1239
-
1240
- def execute_with_auto_optimization(
1241
  self,
1242
  image: Image.Image,
1243
  mask: Image.Image,
1244
  prompt: str,
1245
- quality_checker: Any,
1246
- progress_callback: Optional[Callable[[str, int], None]] = None,
1247
- **kwargs
1248
- ) -> InpaintingResult:
1249
- """
1250
- Execute inpainting with automatic quality-based optimization.
1251
-
1252
- Retries with adjusted parameters if quality score is below threshold.
1253
-
1254
- Parameters
1255
- ----------
1256
- image : PIL.Image
1257
- Original image
1258
- mask : PIL.Image
1259
- Inpainting mask
1260
- prompt : str
1261
- Text prompt
1262
- quality_checker : QualityChecker
1263
- Quality assessment instance
1264
- progress_callback : callable, optional
1265
- Progress update function
1266
- **kwargs
1267
- Additional inpainting parameters
1268
-
1269
- Returns
1270
- -------
1271
- InpaintingResult
1272
- Best result achieved (may include retry information)
1273
- """
1274
- if not self.config.enable_auto_optimization:
1275
- return self.execute_inpainting(
1276
- image, mask, prompt,
1277
- progress_callback=progress_callback,
1278
- **kwargs
1279
  )
 
1280
 
1281
- best_result = None
1282
- best_score = 0.0
1283
- retry_count = 0
1284
- prev_score = 0.0
1285
-
1286
- # Mutable parameters for optimization
1287
- current_feather = kwargs.get('feather_radius', self.config.feather_radius)
1288
- current_scale = kwargs.get(
1289
- 'controlnet_conditioning_scale',
1290
- self.config.controlnet_conditioning_scale
1291
- )
1292
- current_guidance = kwargs.get('guidance_scale', self.config.guidance_scale)
1293
- current_prompt = prompt
1294
-
1295
- while retry_count <= self.config.max_optimization_retries:
1296
- if progress_callback and retry_count > 0:
1297
- progress_callback(f"Optimizing (attempt {retry_count + 1})...", 5)
1298
-
1299
- # Execute inpainting
1300
- result = self.execute_inpainting(
1301
- image, mask, current_prompt,
1302
- preview_only=False,
1303
- feather_radius=current_feather,
1304
- controlnet_conditioning_scale=current_scale,
1305
- guidance_scale=current_guidance,
1306
- progress_callback=progress_callback if retry_count == 0 else None,
1307
- **{k: v for k, v in kwargs.items()
1308
- if k not in ['feather_radius', 'controlnet_conditioning_scale',
1309
- 'guidance_scale']}
1310
  )
1311
-
1312
- if not result.success:
1313
- return result
1314
-
1315
- # Evaluate quality
1316
- if result.blended_image is not None:
1317
- quality_results = quality_checker.run_all_checks(
1318
- foreground=image,
1319
- background=result.result_image,
1320
- mask=mask,
1321
- combined=result.blended_image
1322
- )
1323
- quality_score = quality_results.get("overall_score", 0)
1324
- else:
1325
- quality_score = 50.0 # Default if no blended image
1326
-
1327
- result.quality_score = quality_score
1328
- result.quality_details = quality_results if result.blended_image else {}
1329
- result.retries = retry_count
1330
-
1331
- logger.info(f"Quality score: {quality_score:.1f} (attempt {retry_count + 1})")
1332
-
1333
- # Track best result
1334
- if quality_score > best_score:
1335
- best_score = quality_score
1336
- best_result = result
1337
-
1338
- # Check if quality is acceptable
1339
- if quality_score >= self.config.min_quality_score:
1340
- logger.info(f"Quality threshold met: {quality_score:.1f}")
1341
- return best_result
1342
-
1343
- # Check for minimal improvement (early termination)
1344
- if retry_count > 0 and abs(quality_score - prev_score) < 5.0:
1345
- logger.info("Minimal improvement, stopping optimization")
1346
- return best_result
1347
-
1348
- prev_score = quality_score
1349
- retry_count += 1
1350
-
1351
- if retry_count > self.config.max_optimization_retries:
1352
- break
1353
-
1354
- # Adjust parameters based on quality issues
1355
- checks = quality_results.get("checks", {})
1356
-
1357
- edge_score = checks.get("edge_continuity", {}).get("score", 100)
1358
- harmony_score = checks.get("color_harmony", {}).get("score", 100)
1359
-
1360
- if edge_score < 60:
1361
- # Edge issues: increase feathering, decrease control strength
1362
- current_feather = min(20, current_feather + 3)
1363
- current_scale = max(0.5, current_scale - 0.1)
1364
- logger.debug(f"Adjusting for edges: feather={current_feather}, scale={current_scale}")
1365
-
1366
- if harmony_score < 60:
1367
- # Color harmony issues: emphasize consistency in prompt
1368
- if "color consistent" not in current_prompt.lower():
1369
- current_prompt = f"{current_prompt}, color consistent with surroundings, matching lighting"
1370
- current_guidance = min(12.0, current_guidance + 1.0)
1371
- logger.debug(f"Adjusting for harmony: guidance={current_guidance}")
1372
-
1373
- if edge_score < 60 and harmony_score < 60:
1374
- # Both issues: stronger guidance
1375
- current_guidance = min(12.0, current_guidance + 1.5)
1376
-
1377
- logger.info(f"Optimization complete. Best score: {best_score:.1f}")
1378
- return best_result
1379
 
1380
  def get_status(self) -> Dict[str, Any]:
1381
- """
1382
- Get current module status.
1383
-
1384
- Returns
1385
- -------
1386
- dict
1387
- Status information including initialization state and memory usage
1388
- """
1389
- status = {
1390
  "initialized": self.is_initialized,
1391
  "device": self.device,
 
1392
  "conditioning_type": self._current_conditioning_type,
1393
- "last_seed": self._last_seed,
1394
- "config": {
1395
- "controlnet_conditioning_scale": self.config.controlnet_conditioning_scale,
1396
- "feather_radius": self.config.feather_radius,
1397
- "num_inference_steps": self.config.num_inference_steps,
1398
- "guidance_scale": self.config.guidance_scale
1399
- }
1400
  }
1401
-
1402
- status["memory"] = self._check_memory_status()
1403
-
1404
- return status
 
4
  import time
5
  import traceback
6
  from dataclasses import dataclass, field
7
+ from typing import Any, Callable, Dict, Optional, Tuple
8
 
9
  import cv2
10
  import numpy as np
11
  import torch
12
+ from PIL import Image
13
 
14
+ from diffusers import AutoPipelineForInpainting
15
+ from diffusers import ControlNetModel
16
+ from diffusers import DPMSolverMultistepScheduler
17
  from diffusers import StableDiffusionXLControlNetInpaintPipeline
18
+ from transformers import AutoImageProcessor
19
+ from transformers import AutoModelForDepthEstimation
20
+ from transformers import DPTForDepthEstimation
21
+ from transformers import DPTImageProcessor
22
+
23
+ from control_image_processor import ControlImageProcessor
24
+ from inpainting_blender import InpaintingBlender
25
 
26
  logger = logging.getLogger(__name__)
27
  logger.setLevel(logging.INFO)
28
 
29
 
30
+ # Dedicated SDXL Inpainting model - trained specifically for inpainting
31
+ SDXL_INPAINTING_MODEL = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
32
+
33
+
34
  @dataclass
35
  class InpaintingConfig:
36
  """Configuration for inpainting operations."""
37
 
38
+ # ControlNet settings (for ControlNet mode only)
39
  controlnet_conditioning_scale: float = 0.7
40
+ conditioning_type: str = "canny"
41
 
42
  # Canny edge detection parameters
43
  canny_low_threshold: int = 100
44
  canny_high_threshold: int = 200
45
 
46
  # Mask settings
47
+ feather_radius: int = 3
48
  min_mask_coverage: float = 0.01
49
  max_mask_coverage: float = 0.95
50
 
51
  # Generation settings
52
  num_inference_steps: int = 25
53
  guidance_scale: float = 7.5
54
+ strength: float = 0.99 # Use 0.99 to avoid noise issues with 1.0
 
 
 
 
 
 
 
55
 
56
  # Memory settings
57
  enable_vae_tiling: bool = True
 
58
  max_resolution: int = 1024
59
 
60
 
 
68
  control_image: Optional[Image.Image] = None
69
  blended_image: Optional[Image.Image] = None
70
  quality_score: float = 0.0
 
71
  generation_time: float = 0.0
 
72
  error_message: str = ""
73
  metadata: Dict[str, Any] = field(default_factory=dict)
74
 
75
 
76
  class InpaintingModule:
77
  """
78
+ Dual-mode Inpainting Module for SceneWeaver.
79
 
80
+ Supports two modes:
81
+ 1. Pure Inpainting (use_controlnet=False): Uses dedicated SDXL Inpainting model
82
+ - Best for: Object replacement, Object removal
83
+ - More stable, better edge blending
84
 
85
+ 2. ControlNet Inpainting (use_controlnet=True): Uses ControlNet + SDXL
86
+ - Best for: Clothing change (depth), Color change (canny)
87
+ - Preserves structure in masked region
 
88
 
89
  Example:
90
  >>> module = InpaintingModule(device="cuda")
91
+ >>> # For object replacement (no ControlNet)
92
+ >>> module.load_pipeline(use_controlnet=False)
93
+ >>> result = module.execute_inpainting(image, mask, "a vase with flowers")
 
 
 
94
  """
95
 
96
+ # ControlNet model identifiers
97
  CONTROLNET_CANNY_MODEL = "diffusers/controlnet-canny-sdxl-1.0"
98
  CONTROLNET_DEPTH_MODEL = "diffusers/controlnet-depth-sdxl-1.0"
99
  DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf"
100
  DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas"
101
+
102
+ # Base models for ControlNet mode
103
+ SUPPORTED_MODELS = {
104
+ "juggernaut_xl": "RunDiffusion/Juggernaut-XL-v9",
105
+ "realvis_xl": "SG161222/RealVisXL_V4.0",
106
+ "sdxl_base": "stabilityai/stable-diffusion-xl-base-1.0",
107
+ "animagine_xl": "cagliostrolab/animagine-xl-3.1",
108
+ }
109
 
110
  def __init__(
111
  self,
112
  device: str = "auto",
113
  config: Optional[InpaintingConfig] = None
114
  ):
115
+ """Initialize the InpaintingModule."""
 
 
 
 
 
 
 
 
 
116
  self.device = self._setup_device(device)
117
  self.config = config or InpaintingConfig()
118
 
119
+ # Sub-modules
120
+ self._control_processor = ControlImageProcessor(
121
+ device=self.device,
122
+ canny_low_threshold=self.config.canny_low_threshold,
123
+ canny_high_threshold=self.config.canny_high_threshold
124
+ )
125
+ self._blender = InpaintingBlender(
126
+ min_mask_coverage=self.config.min_mask_coverage,
127
+ max_mask_coverage=self.config.max_mask_coverage
128
+ )
129
+
130
+ # Pipeline instances
131
+ self._pipeline = None
132
+ self._controlnet = None
133
  self._depth_estimator = None
134
  self._depth_processor = None
135
 
136
  # State tracking
137
  self.is_initialized = False
138
+ self._current_mode = None # "pure" or "controlnet"
139
  self._current_conditioning_type = None
140
+ self._current_model_key = None
 
 
 
 
 
141
 
142
  logger.info(f"InpaintingModule initialized on {self.device}")
143
 
144
  def _setup_device(self, device: str) -> str:
145
+ """Setup computation device."""
 
 
 
 
 
 
 
 
 
 
 
 
146
  if device == "auto":
147
  if torch.cuda.is_available():
148
  return "cuda"
 
151
  return "cpu"
152
  return device
153
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def _memory_cleanup(self, aggressive: bool = False) -> None:
155
+ """Perform memory cleanup."""
156
+ for _ in range(5 if aggressive else 2):
 
 
 
 
 
 
 
 
157
  gc.collect()
158
 
 
 
159
  is_spaces = os.getenv('SPACE_ID') is not None
 
160
  if not is_spaces and torch.cuda.is_available():
161
  torch.cuda.empty_cache()
162
  if aggressive:
163
  torch.cuda.ipc_collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ def load_pipeline(
 
 
 
 
 
 
 
 
 
 
 
 
166
  self,
167
+ use_controlnet: bool = False,
168
  conditioning_type: str = "canny",
169
+ model_key: str = "sdxl_base",
170
  progress_callback: Optional[Callable[[str, int], None]] = None
171
  ) -> Tuple[bool, str]:
172
  """
173
+ Load the appropriate inpainting pipeline.
 
 
 
174
 
175
  Parameters
176
  ----------
177
+ use_controlnet : bool
178
+ If False, use dedicated SDXL Inpainting model (for replacement/removal)
179
+ If True, use ControlNet pipeline (for clothing/color change)
180
  conditioning_type : str
181
+ ControlNet type: "canny" or "depth" (only used when use_controlnet=True)
182
+ model_key : str
183
+ Base model for ControlNet mode
184
  progress_callback : callable, optional
185
+ Progress update function
186
 
187
  Returns
188
  -------
189
  tuple
190
  (success: bool, error_message: str)
191
  """
192
+ mode = "controlnet" if use_controlnet else "pure"
193
+
194
+ # Check if already loaded with same config
195
+ if (self.is_initialized and
196
+ self._current_mode == mode and
197
+ (not use_controlnet or
198
+ (self._current_conditioning_type == conditioning_type and
199
+ self._current_model_key == model_key))):
200
+ logger.info(f"Pipeline already loaded: mode={mode}")
201
  return True, ""
202
 
203
+ logger.info(f"Loading pipeline: mode={mode}, conditioning={conditioning_type}")
204
 
205
  try:
206
  self._memory_cleanup(aggressive=True)
207
 
208
  if progress_callback:
209
+ progress_callback("Preparing pipeline...", 10)
210
+
211
+ # Unload existing pipeline
212
+ self._unload_pipeline()
213
 
214
+ dtype = torch.float16 if self.device == "cuda" else torch.float32
 
 
215
 
216
+ if not use_controlnet:
217
+ # Mode A: Pure SDXL Inpainting (for replacement/removal)
218
+ if progress_callback:
219
+ progress_callback("Loading SDXL Inpainting model...", 30)
220
 
221
+ self._pipeline = AutoPipelineForInpainting.from_pretrained(
222
+ SDXL_INPAINTING_MODEL,
223
+ torch_dtype=dtype,
224
+ variant="fp16" if dtype == torch.float16 else None,
225
+ )
226
+ self._current_mode = "pure"
227
+ self._current_conditioning_type = None
228
+ logger.info("Loaded pure SDXL Inpainting pipeline")
229
 
230
+ else:
231
+ # Mode B: ControlNet Inpainting (for structure-preserving tasks)
232
+ if model_key not in self.SUPPORTED_MODELS:
233
+ model_key = "sdxl_base"
234
+ base_model_id = self.SUPPORTED_MODELS[model_key]
235
 
236
+ if progress_callback:
237
+ progress_callback("Loading ControlNet model...", 30)
238
+
239
+ # Load ControlNet
240
  if conditioning_type == "canny":
241
+ self._controlnet = ControlNetModel.from_pretrained(
242
  self.CONTROLNET_CANNY_MODEL,
243
  torch_dtype=dtype,
244
  use_safetensors=True
245
  )
 
 
 
246
  elif conditioning_type == "depth":
247
+ self._controlnet = ControlNetModel.from_pretrained(
248
  self.CONTROLNET_DEPTH_MODEL,
249
  torch_dtype=dtype,
250
  use_safetensors=True
251
  )
 
 
 
 
 
252
  self._load_depth_estimator()
 
253
  else:
254
  raise ValueError(f"Unknown conditioning type: {conditioning_type}")
 
 
 
255
 
256
+ if progress_callback:
257
+ progress_callback(f"Loading {model_key}...", 60)
258
+
259
+ # Load pipeline with ControlNet
260
+ use_variant = model_key != "animagine_xl"
261
+ load_kwargs = {
262
+ "controlnet": self._controlnet,
263
+ "torch_dtype": dtype,
264
+ "use_safetensors": True,
265
+ }
266
+ if use_variant and dtype == torch.float16:
267
+ load_kwargs["variant"] = "fp16"
268
 
269
+ self._pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
270
+ base_model_id,
271
+ **load_kwargs
 
 
 
 
 
272
  )
273
+ self._current_mode = "controlnet"
274
+ self._current_conditioning_type = conditioning_type
275
+ self._current_model_key = model_key
276
+ logger.info(f"Loaded ControlNet pipeline: {model_key} + {conditioning_type}")
 
 
 
 
 
 
 
 
277
 
278
  if progress_callback:
279
+ progress_callback("Configuring pipeline...", 80)
280
 
281
+ # Configure scheduler
282
+ self._pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
283
+ self._pipeline.scheduler.config
284
  )
285
 
286
+ # Move to device and optimize
287
+ self._pipeline = self._pipeline.to(self.device)
288
+ self._apply_optimizations()
 
 
 
 
 
 
 
 
 
 
289
 
290
  self.is_initialized = True
 
291
 
292
  if progress_callback:
293
+ progress_callback("Pipeline ready!", 100)
 
 
 
 
294
 
295
  return True, ""
296
 
297
  except Exception as e:
298
  error_msg = str(e)
299
+ logger.error(f"Failed to load pipeline: {error_msg}")
300
  traceback.print_exc()
301
  self._unload_pipeline()
302
  return False, error_msg
303
 
304
  def _load_depth_estimator(self) -> None:
305
+ """Load depth estimation model."""
 
 
 
 
306
  try:
 
 
307
  self._depth_processor = AutoImageProcessor.from_pretrained(
308
  self.DEPTH_MODEL_PRIMARY
309
  )
 
313
  )
314
  self._depth_estimator.to(self.device)
315
  self._depth_estimator.eval()
316
+ logger.info("Loaded Depth-Anything model")
 
 
317
  except Exception as e:
318
  logger.warning(f"Primary depth model failed: {e}, trying fallback...")
319
+ self._depth_processor = DPTImageProcessor.from_pretrained(
320
+ self.DEPTH_MODEL_FALLBACK
321
+ )
322
+ self._depth_estimator = DPTForDepthEstimation.from_pretrained(
323
+ self.DEPTH_MODEL_FALLBACK,
324
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
325
+ )
326
+ self._depth_estimator.to(self.device)
327
+ self._depth_estimator.eval()
328
+ logger.info("Loaded MiDaS fallback model")
329
 
330
+ def _apply_optimizations(self) -> None:
331
+ """Apply memory and performance optimizations."""
332
+ if self._pipeline is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  return
334
 
 
335
  try:
336
+ self._pipeline.enable_xformers_memory_efficient_attention()
337
+ logger.info("Enabled xformers attention")
338
  except Exception:
339
  try:
340
+ self._pipeline.enable_attention_slicing()
341
  logger.info("Enabled attention slicing")
342
  except Exception:
343
+ pass
344
 
 
345
  if self.config.enable_vae_tiling:
346
+ if hasattr(self._pipeline, 'enable_vae_tiling'):
347
+ self._pipeline.enable_vae_tiling()
348
+ if hasattr(self._pipeline, 'enable_vae_slicing'):
349
+ self._pipeline.enable_vae_slicing()
 
 
 
350
 
351
  def _unload_pipeline(self) -> None:
352
+ """Unload pipeline and free memory."""
353
+ if self._pipeline is not None:
354
+ del self._pipeline
355
+ self._pipeline = None
356
 
357
+ if self._controlnet is not None:
358
+ del self._controlnet
359
+ self._controlnet = None
 
 
 
 
 
 
 
 
360
 
361
  if self._depth_estimator is not None:
362
  del self._depth_estimator
 
367
  self._depth_processor = None
368
 
369
  self.is_initialized = False
370
+ self._current_mode = None
371
  self._current_conditioning_type = None
 
372
 
373
  self._memory_cleanup(aggressive=True)
374
+ logger.info("Pipeline unloaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
  def execute_inpainting(
377
  self,
378
  image: Image.Image,
379
  mask: Image.Image,
380
  prompt: str,
 
 
381
  progress_callback: Optional[Callable[[str, int], None]] = None,
382
  **kwargs
383
  ) -> InpaintingResult:
384
  """
385
+ Execute inpainting operation.
 
 
 
386
 
387
  Parameters
388
  ----------
389
  image : PIL.Image
390
+ Original image
391
  mask : PIL.Image
392
  Inpainting mask (white = area to regenerate)
393
  prompt : str
394
+ Text description
 
 
 
 
395
  progress_callback : callable, optional
396
+ Progress update function
397
  **kwargs
398
+ Additional parameters from template
 
 
 
 
399
 
400
  Returns
401
  -------
402
  InpaintingResult
403
+ Result with generated image
404
  """
405
  start_time = time.time()
406
 
407
  if not self.is_initialized:
408
  return InpaintingResult(
409
  success=False,
410
+ error_message="Pipeline not initialized. Call load_pipeline() first."
411
  )
412
 
413
+ logger.info(f"Inpainting: mode={self._current_mode}, prompt='{prompt[:50]}...'")
414
 
415
  try:
 
 
 
 
 
 
 
 
 
416
  if progress_callback:
417
+ progress_callback("Preparing images...", 10)
418
 
419
  # Prepare image
420
  if image.mode != 'RGB':
421
  image = image.convert('RGB')
422
 
423
+ # Store original size for later restoration
424
+ original_size = image.size # (width, height)
425
+
426
+ # Ensure dimensions are multiple of 8 for model compatibility
427
  width, height = image.size
428
  new_width = (width // 8) * 8
429
  new_height = (height // 8) * 8
 
430
  if new_width != width or new_height != height:
431
  image = image.resize((new_width, new_height), Image.LANCZOS)
432
 
433
+ # Limit resolution for memory efficiency
434
  max_res = self.config.max_resolution
435
  if max(new_width, new_height) > max_res:
436
  scale = max_res / max(new_width, new_height)
437
  new_width = int(new_width * scale) // 8 * 8
438
  new_height = int(new_height * scale) // 8 * 8
439
  image = image.resize((new_width, new_height), Image.LANCZOS)
 
440
 
441
+ # Prepare mask with dilation
442
+ mask_dilation = kwargs.get('mask_dilation', 0)
443
+ processed_mask = self._prepare_mask(
 
 
444
  mask,
445
  (new_width, new_height),
446
+ dilation=mask_dilation,
447
+ feather_radius=kwargs.get('feather_radius', self.config.feather_radius)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  )
449
 
450
+ # Get generation parameters
451
+ strength = kwargs.get('strength', self.config.strength)
452
+ guidance_scale = kwargs.get('guidance_scale', self.config.guidance_scale)
453
+ num_steps = kwargs.get('num_inference_steps', self.config.num_inference_steps)
454
+ negative_prompt = kwargs.get('negative_prompt', "")
455
 
456
+ # Optimize for HuggingFace Spaces
457
+ is_spaces = os.getenv('SPACE_ID') is not None
458
+ if is_spaces:
459
+ num_steps = min(num_steps, 15)
 
 
 
 
 
 
 
 
460
 
461
+ # Setup generator with seed
462
+ # If seed is -1 or None, use random seed based on current time
463
+ input_seed = kwargs.get('seed', -1)
464
+ if input_seed is None or input_seed < 0:
465
  seed = int(time.time() * 1000) % (2**32)
466
+ else:
467
+ seed = int(input_seed)
468
  generator = torch.Generator(device=self.device).manual_seed(seed)
469
+ logger.info(f"Using seed: {seed}")
470
 
471
+ # Generate based on mode
472
+ if self._current_mode == "pure":
473
+ # Pure inpainting - no ControlNet
 
 
 
 
 
474
  if progress_callback:
475
+ progress_callback("Generating (Pure Inpainting)...", 40)
 
 
 
 
 
 
 
476
 
477
+ result_image = self._generate_pure_inpaint(
478
  image=image,
479
  mask=processed_mask,
480
+ prompt=prompt,
 
481
  negative_prompt=negative_prompt,
482
+ num_steps=num_steps,
483
+ guidance_scale=guidance_scale,
 
484
  strength=strength,
485
  generator=generator
486
  )
487
+ control_image = None
488
+
489
  else:
490
+ # ControlNet inpainting
491
+ if progress_callback:
492
+ progress_callback("Generating control image...", 30)
493
 
494
+ # Prepare control image
495
+ preserve_structure = kwargs.get('preserve_structure_in_mask', False)
496
+ edge_guidance_mode = kwargs.get('edge_guidance_mode', 'boundary')
497
 
498
+ control_image = self._control_processor.prepare_control_image(
499
+ image=image,
500
+ mode=self._current_conditioning_type,
501
+ mask=processed_mask,
502
+ preserve_structure=preserve_structure,
503
+ edge_guidance_mode=edge_guidance_mode
 
 
 
 
 
 
504
  )
505
 
506
+ if progress_callback:
507
+ progress_callback("Generating (ControlNet)...", 50)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
+ conditioning_scale = kwargs.get(
510
+ 'controlnet_conditioning_scale',
511
+ self.config.controlnet_conditioning_scale
512
+ )
513
 
514
+ result_image = self._generate_controlnet_inpaint(
515
+ image=image,
516
+ mask=processed_mask,
517
+ control_image=control_image,
518
+ prompt=prompt,
519
+ negative_prompt=negative_prompt,
520
+ num_steps=num_steps,
521
+ guidance_scale=guidance_scale,
522
+ conditioning_scale=conditioning_scale,
523
+ strength=strength,
524
+ generator=generator
525
+ )
526
 
527
  generation_time = time.time() - start_time
528
 
529
+ # Restore original size if it was changed
530
+ if result_image.size != original_size:
531
+ result_image = result_image.resize(original_size, Image.LANCZOS)
532
+ logger.info(f"Restored result to original size: {original_size}")
533
+
534
  if progress_callback:
535
  progress_callback("Complete!", 100)
536
 
537
  return InpaintingResult(
538
  success=True,
539
+ result_image=result_image,
540
+ blended_image=result_image, # Pipeline output is already blended
541
  control_image=control_image,
 
542
  generation_time=generation_time,
543
  metadata={
544
  "seed": seed,
545
+ "prompt": prompt,
546
+ "mode": self._current_mode,
547
+ "num_steps": num_steps,
548
+ "guidance_scale": guidance_scale,
549
  "strength": strength,
550
+ "original_size": original_size,
 
 
 
 
 
551
  }
552
  )
553
 
554
  except torch.cuda.OutOfMemoryError:
555
+ logger.error("CUDA out of memory")
556
  self._memory_cleanup(aggressive=True)
557
  return InpaintingResult(
558
  success=False,
559
+ error_message="GPU memory exhausted."
560
  )
 
561
  except Exception as e:
562
  logger.error(f"Inpainting failed: {e}")
563
+ traceback.print_exc()
564
  return InpaintingResult(
565
  success=False,
566
+ error_message=str(e)
567
  )
568
 
569
+ def _prepare_mask(
570
  self,
 
571
  mask: Image.Image,
572
+ target_size: Tuple[int, int],
573
+ dilation: int = 0,
574
+ feather_radius: int = 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  ) -> Image.Image:
576
+ """Prepare mask with optional dilation and feathering."""
577
+ # Convert and resize
578
+ if mask.mode != 'L':
579
+ mask = mask.convert('L')
580
+ if mask.size != target_size:
581
+ mask = mask.resize(target_size, Image.LANCZOS)
582
 
583
+ mask_array = np.array(mask)
 
 
 
 
 
 
 
584
 
585
+ # Apply dilation to expand mask
586
+ if dilation > 0:
587
+ kernel = cv2.getStructuringElement(
588
+ cv2.MORPH_ELLIPSE,
589
+ (dilation * 2 + 1, dilation * 2 + 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
  )
591
+ mask_array = cv2.dilate(mask_array, kernel, iterations=1)
592
+ logger.debug(f"Applied mask dilation: {dilation}px")
593
 
594
+ # Apply feathering
595
+ if feather_radius > 0:
596
+ mask_array = cv2.GaussianBlur(
597
+ mask_array,
598
+ (feather_radius * 2 + 1, feather_radius * 2 + 1),
599
+ feather_radius / 2
600
  )
601
 
602
+ return Image.fromarray(mask_array, mode='L')
 
 
 
 
 
 
 
 
 
 
603
 
604
+ def _generate_pure_inpaint(
 
 
 
 
605
  self,
606
  image: Image.Image,
607
  mask: Image.Image,
608
  prompt: str,
609
+ negative_prompt: str,
610
+ num_steps: int,
611
+ guidance_scale: float,
612
+ strength: float,
613
+ generator: torch.Generator
614
+ ) -> Image.Image:
615
+ """Generate using pure SDXL Inpainting pipeline."""
616
+ with torch.inference_mode():
617
+ result = self._pipeline(
618
+ prompt=prompt,
619
+ negative_prompt=negative_prompt,
620
+ image=image,
621
+ mask_image=mask,
622
+ num_inference_steps=num_steps,
623
+ guidance_scale=guidance_scale,
624
+ strength=strength,
625
+ generator=generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
626
  )
627
+ return result.images[0]
628
 
629
+ def _generate_controlnet_inpaint(
630
+ self,
631
+ image: Image.Image,
632
+ mask: Image.Image,
633
+ control_image: Image.Image,
634
+ prompt: str,
635
+ negative_prompt: str,
636
+ num_steps: int,
637
+ guidance_scale: float,
638
+ conditioning_scale: float,
639
+ strength: float,
640
+ generator: torch.Generator
641
+ ) -> Image.Image:
642
+ """Generate using ControlNet Inpainting pipeline."""
643
+ with torch.inference_mode():
644
+ result = self._pipeline(
645
+ prompt=prompt,
646
+ negative_prompt=negative_prompt,
647
+ image=image,
648
+ mask_image=mask,
649
+ control_image=control_image,
650
+ num_inference_steps=num_steps,
651
+ guidance_scale=guidance_scale,
652
+ controlnet_conditioning_scale=conditioning_scale,
653
+ strength=strength,
654
+ generator=generator
 
 
 
655
  )
656
+ return result.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
 
658
  def get_status(self) -> Dict[str, Any]:
659
+ """Get current module status."""
660
+ return {
 
 
 
 
 
 
 
661
  "initialized": self.is_initialized,
662
  "device": self.device,
663
+ "mode": self._current_mode,
664
  "conditioning_type": self._current_conditioning_type,
665
+ "model_key": self._current_model_key,
 
 
 
 
 
 
666
  }
 
 
 
 
inpainting_templates.py CHANGED
@@ -1,6 +1,6 @@
1
  import logging
2
  from dataclasses import dataclass, field
3
- from typing import Dict, List, Optional
4
 
5
  logger = logging.getLogger(__name__)
6
 
@@ -19,30 +19,31 @@ class InpaintingTemplate:
19
  prompt_template: str
20
  negative_prompt: str
21
 
22
- # Recommended parameters
23
- controlnet_conditioning_scale: float = 0.7
24
- feather_radius: int = 8
25
- guidance_scale: float = 7.5
26
- num_inference_steps: int = 25
27
 
28
- # Inpainting strength (0.0-1.0)
29
- # 1.0 = fully repaint masked area, 0.0 = keep original
30
- strength: float = 1.0
31
-
32
- # Conditioning type preference
33
  preferred_conditioning: str = "canny" # "canny" or "depth"
34
-
35
- # Structure preservation in masked area
36
- # True = keep edges in mask (for color change), False = clear edges (for replacement/removal)
37
  preserve_structure_in_mask: bool = False
 
38
 
39
- # Prompt enhancement control
40
- enhance_prompt: bool = True # Whether to use OpenCLIP prompt enhancement
 
 
 
 
 
41
 
42
- # Difficulty level for UI display
43
- difficulty: str = "medium" # "easy", "medium", "advanced"
44
 
45
- # Tips for users
 
 
 
46
  usage_tips: List[str] = field(default_factory=list)
47
 
48
 
@@ -50,417 +51,338 @@ class InpaintingTemplateManager:
50
  """
51
  Manages inpainting templates for various use cases.
52
 
53
- Provides categorized presets optimized for different inpainting scenarios
54
- including object replacement, removal, style transfer, and enhancement.
55
-
56
- Attributes:
57
- TEMPLATES: Dictionary of all available templates
58
- CATEGORIES: List of category names in display order
59
 
60
  Example:
61
  >>> manager = InpaintingTemplateManager()
62
  >>> template = manager.get_template("object_replacement")
63
- >>> print(template.prompt_template)
 
 
64
  """
65
 
66
  TEMPLATES: Dict[str, InpaintingTemplate] = {
67
- # ========================================
68
- # 4 CORE TEMPLATES - Optimized for Speed & Quality
69
- # ========================================
70
-
71
- # 1. CHANGE COLOR - Pure color transformation
72
- "change_color": InpaintingTemplate(
73
- key="change_color",
74
- name="Change Color",
75
- category="Color",
76
- icon="🎨",
77
- description="Change color ONLY - fills the masked area with a solid, flat color",
78
- prompt_template="{content} color, solid flat {content}, uniform color, no patterns, smooth surface",
79
  negative_prompt=(
80
- "original color, keeping same color, unchanged color, "
81
- "black, dark, keeping black, maintaining black color, "
82
- "black clothing, dark colors, dark fabric, black fabric, "
83
- "patterns, floral, stripes, plaid, checkered, decorative patterns, "
84
- "diamond pattern, grid pattern, geometric patterns, "
85
- "texture, textured, wrinkles, folds, creases, "
86
- "gradients, shading variations, color variations, "
87
- "complex patterns, printed patterns, embroidery"
88
  ),
89
- controlnet_conditioning_scale=0.3, # Low-medium: allow color freedom in masked area
90
- feather_radius=4, # Low: clean color boundaries
91
- guidance_scale=15.0, # Very high: strongly follow color prompt
92
- num_inference_steps=10, # Optimized for speed
93
- strength=1.0, # Full repaint for color change
94
- preferred_conditioning="canny", # Edge-based
95
- preserve_structure_in_mask=False, # KEY: clear edges in mask for pure color fill
96
- enhance_prompt=False, # Disabled: use color prompt directly
 
 
 
 
 
 
 
 
 
 
 
97
  difficulty="easy",
 
 
 
 
 
 
98
  usage_tips=[
99
- "🎯 Purpose: Fill the masked area with a solid, uniform color.",
100
  "",
101
- "📝 Example Prompts:",
102
- " • 'vibrant red' - bold, saturated red",
103
- " • 'soft pastel pink' - gentle, light pink",
104
- " • 'deep navy blue' - rich, dark blue",
105
- " • 'bright yellow' - eye-catching yellow",
106
- " • 'pure white' - clean, solid white",
107
  "",
108
  "💡 Tips:",
109
- " • Describe ONLY the color, not the object",
110
- " • Paint the entire area you want to recolor",
111
- " • Use modifiers: 'bright', 'dark', 'pastel', 'vivid'"
112
  ]
113
  ),
114
 
115
- # 2. CLOTHING CHANGE - Style and garment transformation
116
- "clothing_change": InpaintingTemplate(
117
- key="clothing_change",
118
- name="Clothing Change",
119
- category="Replacement",
120
- icon="👕",
121
- description="Change clothing style, material, or design - can include color change",
122
- prompt_template="{content}, photorealistic, realistic fabric texture, natural fit, high quality",
123
  negative_prompt=(
124
- "wrong body proportions, floating fabric, unrealistic wrinkles, "
125
- "mismatched lighting, visible edges, original clothing style, "
126
- "keeping same color, original color, faded colors, unchanged appearance, partial change, "
127
- "black clothing, dark original color, distorted body, naked, nudity, "
128
- "cartoon, anime, illustration, drawing, painted"
129
  ),
130
- controlnet_conditioning_scale=0.30, # Medium: preserves body structure, allows clothing change
131
- feather_radius=14, # Medium: natural blending with body
132
- guidance_scale=11.5, # Medium-high: accurate clothing generation
133
- num_inference_steps=10, # Optimized for speed
134
- strength=1.0, # Full repaint: completely replace clothing
135
- preferred_conditioning="depth", # Depth: preserves fabric folds and body structure
136
- enhance_prompt=True, # Enabled: enriches clothing details
 
 
 
 
 
 
 
 
 
 
 
 
137
  difficulty="easy",
 
 
138
  usage_tips=[
139
- "🎯 Purpose: Replace clothing with a different style, material, or design.",
140
  "",
141
- "📝 Example Prompts:",
142
- " 'tailored charcoal suit with silk tie and white shirt' - formal business",
143
- " 'navy blazer with gold buttons over light blue oxford shirt' - smart casual",
144
- " 'black tuxedo with bow tie and white dress shirt' - elegant formal",
145
- " • 'white polo shirt with collar' - casual business",
146
- " • 'cozy cream knit sweater' - warm casual style",
147
- " • 'vintage denim jacket' - retro fashion",
148
  "",
149
  "💡 Tips:",
150
- " • Include clothing type + color + details for best results",
151
- " • For suits: mention 'tailored', 'fitted', specific fabric like 'wool' or 'silk'",
152
- " • Body structure is preserved automatically"
153
  ]
154
  ),
155
 
156
- # 3. OBJECT REPLACEMENT - Replace one object with another
157
- "object_replacement": InpaintingTemplate(
158
- key="object_replacement",
159
- name="Object Replacement",
 
160
  category="Replacement",
161
- icon="🔄",
162
- description="Replace objects (one type at a time) - all masked areas become the SAME object",
163
- prompt_template="{content}, photorealistic, natural lighting, seamlessly integrated into scene, high quality",
164
  negative_prompt=(
165
- "inconsistent lighting, wrong perspective, mismatched colors, "
166
- "visible seams, floating objects, unrealistic placement, original object, "
167
- "poorly integrated, disconnected from scene, keeping original, remnants of original, "
168
- "multiple different objects, mixed objects, various items, "
169
- "cartoon, anime, illustration, drawing, painted"
170
  ),
171
- controlnet_conditioning_scale=0.25, # Low-medium: allows complete object replacement
172
- feather_radius=10, # Medium: natural scene integration
173
- guidance_scale=13.0, # Medium-high: accurate object generation
174
- num_inference_steps=10, # Optimized for speed
175
- strength=1.0, # Full repaint: completely replace object
176
- preferred_conditioning="canny", # Edge-based: preserves scene perspective
177
- enhance_prompt=True, # Enabled: enriches object details
 
 
 
 
 
 
 
 
 
 
 
 
178
  difficulty="medium",
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  usage_tips=[
180
- "🎯 Purpose: Replace an object with something completely different.",
181
  "",
182
- "📝 Example Prompts:",
183
- " • 'elegant ceramic vase with fresh roses' - decorative item",
184
- " • 'modern silver laptop on wooden stand' - tech gadget",
185
- " • 'stack of leather-bound vintage books' - classic decoration",
186
- " • 'healthy green potted succulent' - natural element",
187
- " • 'antique brass table lamp with fabric shade' - lighting",
188
  "",
189
  "💡 Tips:",
190
- " • Replace ONE object type at a time",
191
- " • Describe what you want, not what you're removing",
192
- " • Include material and style for realistic results"
193
  ]
194
  ),
195
 
196
- # 4. REMOVAL - Remove objects and fill with background
197
- "removal": InpaintingTemplate(
198
- key="removal",
199
- name="Remove Object",
200
- category="Removal",
201
- icon="🗑️",
202
- description="Remove objects and naturally fill with background - describe the background material",
203
- prompt_template="continue the background with {content}, photorealistic, seamless blending, natural texture continuation, high quality",
204
  negative_prompt=(
205
- "new object appearing, adding items, inserting objects, "
206
- "foreground elements, visible object, thing, item, "
207
- "unnatural filling, visible patches, inconsistent texture, "
208
- "mismatched pattern, color discontinuity, artificial blending, "
209
- "cartoon, anime, illustration, drawing, painted"
210
  ),
211
- controlnet_conditioning_scale=0.20, # Low: allows creative background filling
212
- feather_radius=12, # Medium: smooth background blending
213
- guidance_scale=12.0, # Medium: balanced control and naturalness
214
- num_inference_steps=10, # Optimized for speed
215
- strength=1.0, # Full repaint: completely remove and fill
216
- preferred_conditioning="depth", # Depth: preserves spatial perspective
217
- enhance_prompt=False, # Disabled: avoid generating new objects
218
- difficulty="medium",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  usage_tips=[
220
- "🎯 Purpose: Remove unwanted objects and fill with background.",
221
- "",
222
- "📝 Example Prompts:",
223
- " • 'polished hardwood floor with natural grain' - indoor floors",
224
- " • 'smooth white painted wall' - wall backgrounds",
225
- " • 'lush green grass lawn' - outdoor areas",
226
- " • 'soft beige carpet texture' - carpeted floors",
227
- " • 'clear blue sky with soft clouds' - sky backgrounds",
228
  "",
229
  "💡 Tips:",
230
- " • Describe the BACKGROUND texture, not the object",
231
- " • Leave empty to auto-match surrounding area",
232
- " • Works best with uniform backgrounds"
233
  ]
234
  ),
235
  }
236
 
237
-
238
  # Category display order
239
- CATEGORIES = ["Color", "Replacement", "Removal"] # 4 core templates only
240
 
241
  def __init__(self):
242
  """Initialize the InpaintingTemplateManager."""
243
  logger.info(f"InpaintingTemplateManager initialized with {len(self.TEMPLATES)} templates")
244
 
245
  def get_all_templates(self) -> Dict[str, InpaintingTemplate]:
246
- """
247
- Get all available templates.
248
-
249
- Returns
250
- -------
251
- dict
252
- Dictionary of all templates keyed by template key
253
- """
254
  return self.TEMPLATES
255
 
256
  def get_template(self, key: str) -> Optional[InpaintingTemplate]:
257
- """
258
- Get a specific template by key.
259
-
260
- Parameters
261
- ----------
262
- key : str
263
- Template identifier
264
-
265
- Returns
266
- -------
267
- InpaintingTemplate or None
268
- Template if found, None otherwise
269
- """
270
  return self.TEMPLATES.get(key)
271
 
272
  def get_templates_by_category(self, category: str) -> List[InpaintingTemplate]:
273
- """
274
- Get all templates in a specific category.
275
-
276
- Parameters
277
- ----------
278
- category : str
279
- Category name
280
-
281
- Returns
282
- -------
283
- list
284
- List of templates in the category
285
- """
286
  return [t for t in self.TEMPLATES.values() if t.category == category]
287
 
288
  def get_categories(self) -> List[str]:
289
- """
290
- Get list of all categories in display order.
291
-
292
- Returns
293
- -------
294
- list
295
- Category names
296
- """
297
  return self.CATEGORIES
298
 
299
  def get_template_choices_sorted(self) -> List[str]:
300
- """
301
- Get template choices formatted for Gradio dropdown.
302
-
303
- Returns list of display strings sorted by category then A-Z.
304
- Format: "icon Name"
305
-
306
- Returns
307
- -------
308
- list
309
- Formatted display strings for dropdown
310
- """
311
  display_list = []
312
-
313
  for category in self.CATEGORIES:
314
  templates = self.get_templates_by_category(category)
315
  for template in sorted(templates, key=lambda t: t.name):
316
  display_name = f"{template.icon} {template.name}"
317
  display_list.append(display_name)
318
-
319
  return display_list
320
 
321
  def get_template_key_from_display(self, display_name: str) -> Optional[str]:
322
- """
323
- Get template key from display name.
324
-
325
- Parameters
326
- ----------
327
- display_name : str
328
- Display string like "🔄 Object Replacement"
329
-
330
- Returns
331
- -------
332
- str or None
333
- Template key if found
334
- """
335
  if not display_name:
336
  return None
337
-
338
  for key, template in self.TEMPLATES.items():
339
  if f"{template.icon} {template.name}" == display_name:
340
  return key
341
  return None
342
 
343
- def get_parameters_for_template(self, key: str) -> Dict[str, any]:
344
- """
345
- Get recommended parameters for a template.
346
-
347
- Parameters
348
- ----------
349
- key : str
350
- Template key
351
-
352
- Returns
353
- -------
354
- dict
355
- Dictionary of parameter names and values
356
- """
357
  template = self.get_template(key)
358
  if not template:
359
  return {}
360
 
361
  return {
 
 
362
  "controlnet_conditioning_scale": template.controlnet_conditioning_scale,
363
- "feather_radius": template.feather_radius,
 
 
364
  "guidance_scale": template.guidance_scale,
365
  "num_inference_steps": template.num_inference_steps,
366
  "strength": template.strength,
367
- "preferred_conditioning": template.preferred_conditioning,
368
- "preserve_structure_in_mask": template.preserve_structure_in_mask,
369
- "enhance_prompt": template.enhance_prompt
370
  }
371
 
372
  def build_prompt(self, key: str, content: str) -> str:
373
- """
374
- Build complete prompt from template and user content.
375
-
376
- Parameters
377
- ----------
378
- key : str
379
- Template key
380
- content : str
381
- User-provided content description
382
-
383
- Returns
384
- -------
385
- str
386
- Formatted prompt with content inserted
387
- """
388
  template = self.get_template(key)
389
  if not template:
390
  return content
391
-
392
  return template.prompt_template.format(content=content)
393
 
394
  def get_negative_prompt(self, key: str) -> str:
395
- """
396
- Get negative prompt for a template.
397
-
398
- Parameters
399
- ----------
400
- key : str
401
- Template key
402
-
403
- Returns
404
- -------
405
- str
406
- Negative prompt string
407
- """
408
  template = self.get_template(key)
409
  if not template:
410
  return ""
411
  return template.negative_prompt
412
 
413
  def get_usage_tips(self, key: str) -> List[str]:
414
- """
415
- Get usage tips for a template.
416
-
417
- Parameters
418
- ----------
419
- key : str
420
- Template key
421
-
422
- Returns
423
- -------
424
- list
425
- List of tip strings
426
- """
427
  template = self.get_template(key)
428
  if not template:
429
  return []
430
  return template.usage_tips
431
 
432
- def build_gallery_html(self) -> str:
433
- """
434
- Build HTML for template gallery display.
435
-
436
- Returns
437
- -------
438
- str
439
- HTML string for Gradio display
440
- """
441
- html_parts = ['<div class="inpainting-gallery">']
442
-
443
- for category in self.CATEGORIES:
444
- templates = self.get_templates_by_category(category)
445
- if not templates:
446
- continue
447
 
448
- html_parts.append(f'''
449
- <div class="inpainting-category">
450
- <h4 class="inpainting-category-title">{category}</h4>
451
- <div class="inpainting-grid">
452
- ''')
 
453
 
454
- for template in sorted(templates, key=lambda t: t.name):
455
- html_parts.append(f'''
456
- <div class="inpainting-card" data-template="{template.key}">
457
- <span class="inpainting-icon">{template.icon}</span>
458
- <span class="inpainting-name">{template.name}</span>
459
- <span class="inpainting-desc">{template.description[:50]}...</span>
460
- </div>
461
- ''')
462
-
463
- html_parts.append('</div></div>')
464
-
465
- html_parts.append('</div>')
466
- return ''.join(html_parts)
 
1
  import logging
2
  from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Optional
4
 
5
  logger = logging.getLogger(__name__)
6
 
 
19
  prompt_template: str
20
  negative_prompt: str
21
 
22
+ # Pipeline mode selection
23
+ use_controlnet: bool = True # False = use pure SDXL Inpainting model (more stable)
24
+ mask_dilation: int = 0 # Pixels to expand mask for better edge blending
 
 
25
 
26
+ # ControlNet parameters (only used when use_controlnet=True)
27
+ controlnet_conditioning_scale: float = 0.7
 
 
 
28
  preferred_conditioning: str = "canny" # "canny" or "depth"
 
 
 
29
  preserve_structure_in_mask: bool = False
30
+ edge_guidance_mode: str = "boundary"
31
 
32
+ # Generation parameters
33
+ guidance_scale: float = 7.5
34
+ num_inference_steps: int = 25
35
+ strength: float = 0.99 # Use 0.99 instead of 1.0 to avoid noise issues
36
+
37
+ # Mask parameters
38
+ feather_radius: int = 3 # Minimal feathering, let pipeline handle blending
39
 
40
+ # Prompt enhancement
41
+ enhance_prompt: bool = True
42
 
43
+ # UI metadata
44
+ difficulty: str = "medium"
45
+ recommended_models: List[str] = field(default_factory=lambda: ["sdxl_base"])
46
+ example_prompts: List[str] = field(default_factory=list)
47
  usage_tips: List[str] = field(default_factory=list)
48
 
49
 
 
51
  """
52
  Manages inpainting templates for various use cases.
53
 
54
+ Templates are categorized into two pipeline modes:
55
+ - Pure Inpainting (use_controlnet=False): For replacement/removal tasks
56
+ - ControlNet Inpainting (use_controlnet=True): For structure-preserving tasks
 
 
 
57
 
58
  Example:
59
  >>> manager = InpaintingTemplateManager()
60
  >>> template = manager.get_template("object_replacement")
61
+ >>> if not template.use_controlnet:
62
+ ... # Use pure SDXL Inpainting pipeline
63
+ ... pass
64
  """
65
 
66
  TEMPLATES: Dict[str, InpaintingTemplate] = {
67
+ # 1. OBJECT REPLACEMENT - Replace one object with another
68
+ "object_replacement": InpaintingTemplate(
69
+ key="object_replacement",
70
+ name="Object Replacement",
71
+ category="Replacement",
72
+ icon="🔄",
73
+ description="Replace objects naturally - uses dedicated inpainting model for best results",
74
+ prompt_template="{content}, photorealistic, natural lighting, seamlessly integrated, high quality, detailed",
 
 
 
 
75
  negative_prompt=(
76
+ "blurry, low quality, distorted, deformed, "
77
+ "visible seams, harsh edges, unnatural, "
78
+ "cartoon, anime, illustration, drawing"
 
 
 
 
 
79
  ),
80
+ # Pipeline mode
81
+ use_controlnet=False, # Pure inpainting for stable results
82
+ mask_dilation=5, # Expand mask for seamless blending
83
+
84
+ # Generation parameters
85
+ guidance_scale=8.0,
86
+ num_inference_steps=25,
87
+ strength=0.99,
88
+
89
+ # Mask parameters
90
+ feather_radius=3,
91
+
92
+ # Not used for Pure Inpainting but kept for compatibility
93
+ controlnet_conditioning_scale=0.0,
94
+ preferred_conditioning="canny", # Placeholder, not used in Pure Inpainting mode
95
+ preserve_structure_in_mask=False,
96
+ edge_guidance_mode="none",
97
+
98
+ enhance_prompt=True,
99
  difficulty="easy",
100
+ recommended_models=["realvis_xl", "juggernaut_xl"],
101
+ example_prompts=[
102
+ "elegant ceramic vase with fresh roses",
103
+ "modern minimalist desk lamp, chrome finish",
104
+ "vintage leather-bound book with gold lettering"
105
+ ],
106
  usage_tips=[
107
+ "🎯 Purpose: Replace an object with something completely different.",
108
  "",
109
+ "💡 Example Prompts:",
110
+ " • elegant ceramic vase with fresh roses",
111
+ " • modern minimalist desk lamp, chrome finish",
112
+ " • vintage leather-bound book with gold lettering",
 
 
113
  "",
114
  "💡 Tips:",
115
+ " • Draw mask slightly larger than the object",
116
+ " • Describe the NEW object in detail",
117
+ " • Include material, color, style for better results"
118
  ]
119
  ),
120
 
121
+ # 2. OBJECT REMOVAL - Remove and fill with background (NO PROMPT NEEDED)
122
+ "removal": InpaintingTemplate(
123
+ key="removal",
124
+ name="Remove Object",
125
+ category="Removal",
126
+ icon="🗑️",
127
+ description="Remove unwanted objects - just draw mask, no prompt needed",
128
+ prompt_template="seamless background, natural texture continuation, photorealistic, high quality",
129
  negative_prompt=(
130
+ "object, item, thing, foreground element, new object, "
131
+ "visible patch, inconsistent texture, "
132
+ "blurry, low quality, artificial"
 
 
133
  ),
134
+ # Pipeline mode
135
+ use_controlnet=False, # Pure inpainting for clean removal
136
+ mask_dilation=8, # Larger expansion to cover shadows/reflections
137
+
138
+ # Generation parameters
139
+ guidance_scale=7.0, # Lower guidance for natural fill
140
+ num_inference_steps=20,
141
+ strength=0.99,
142
+
143
+ # Mask parameters
144
+ feather_radius=5, # More feathering for seamless blend
145
+
146
+ # Not used for Pure Inpainting but kept for compatibility
147
+ controlnet_conditioning_scale=0.0,
148
+ preferred_conditioning="canny",
149
+ preserve_structure_in_mask=False,
150
+ edge_guidance_mode="none",
151
+
152
+ enhance_prompt=False, # Do NOT enhance - keep it simple
153
  difficulty="easy",
154
+ recommended_models=["realvis_xl", "juggernaut_xl"],
155
+ example_prompts=[], # No prompts needed for removal
156
  usage_tips=[
157
+ "🎯 Purpose: Remove unwanted objects from image.",
158
  "",
159
+ "📝 No prompt needed! Just:",
160
+ " 1. Draw white mask over the object",
161
+ " 2. Include shadows in your mask",
162
+ " 3. Click Generate",
 
 
 
163
  "",
164
  "💡 Tips:",
165
+ " • Make mask larger than the object",
166
+ " • If artifacts remain, draw a bigger mask and retry"
 
167
  ]
168
  ),
169
 
170
+ # CONTROLNET TEMPLATES (Structure Preserving)
171
+ # 3. CLOTHING CHANGE - Change clothes while keeping body
172
+ "clothing_change": InpaintingTemplate(
173
+ key="clothing_change",
174
+ name="Clothing Change",
175
  category="Replacement",
176
+ icon="👕",
177
+ description="Change clothing style while preserving body structure",
178
+ prompt_template="{content}, photorealistic, realistic fabric, natural fit, high quality",
179
  negative_prompt=(
180
+ "wrong proportions, distorted body, floating fabric, "
181
+ "mismatched lighting, naked, nudity, "
182
+ "cartoon, anime, illustration"
 
 
183
  ),
184
+ # Pipeline mode
185
+ use_controlnet=True, # Need ControlNet to preserve body
186
+ mask_dilation=3, # Small expansion for clothing edges
187
+
188
+ # ControlNet parameters
189
+ controlnet_conditioning_scale=0.4,
190
+ preferred_conditioning="depth", # Depth preserves body structure
191
+ preserve_structure_in_mask=False,
192
+ edge_guidance_mode="soft",
193
+
194
+ # Generation parameters
195
+ guidance_scale=8.0,
196
+ num_inference_steps=25,
197
+ strength=1.0, # Full repaint for clothing
198
+
199
+ # Mask parameters
200
+ feather_radius=5,
201
+
202
+ enhance_prompt=True,
203
  difficulty="medium",
204
+ recommended_models=["juggernaut_xl", "realvis_xl"],
205
+ example_prompts=[
206
+ "tailored charcoal suit with silk tie",
207
+ "navy blazer with gold buttons",
208
+ "elegant black evening dress",
209
+ "casual white t-shirt",
210
+ "cozy cream sweater",
211
+ "leather motorcycle jacket",
212
+ "formal white dress shirt",
213
+ "vintage denim jacket",
214
+ "red cocktail dress",
215
+ "professional grey blazer"
216
+ ],
217
  usage_tips=[
218
+ "🎯 Purpose: Change clothing while keeping body shape.",
219
  "",
220
+ "🤖 Recommended Models:",
221
+ " • JuggernautXL - Best for formal wear",
222
+ " • RealVisXL - Great for casual clothing",
 
 
 
223
  "",
224
  "💡 Tips:",
225
+ " • Mask only the clothing area",
226
+ " • Include fabric type: 'silk', 'cotton', 'wool'",
227
+ " • Body proportions are preserved automatically"
228
  ]
229
  ),
230
 
231
+ # 4. COLOR CHANGE - Change color only, keep structure
232
+ "change_color": InpaintingTemplate(
233
+ key="change_color",
234
+ name="Change Color",
235
+ category="Color",
236
+ icon="🎨",
237
+ description="Change color only - strictly preserves shape and texture",
238
+ prompt_template="{content} color, solid uniform {content}, flat color, smooth surface",
239
  negative_prompt=(
240
+ "different shape, changed structure, new pattern, "
241
+ "texture change, deformed, distorted, "
242
+ "gradient, multiple colors, pattern"
 
 
243
  ),
244
+ # Pipeline mode
245
+ use_controlnet=True, # Need ControlNet to preserve exact shape
246
+ mask_dilation=0, # No expansion - precise color change
247
+
248
+ # ControlNet parameters
249
+ controlnet_conditioning_scale=0.85, # High: strict structure preservation
250
+ preferred_conditioning="canny", # Canny preserves edges exactly
251
+ preserve_structure_in_mask=True, # Keep all edges
252
+ edge_guidance_mode="boundary",
253
+
254
+ # Generation parameters
255
+ guidance_scale=12.0, # High: force the exact color
256
+ num_inference_steps=15,
257
+ strength=1.0,
258
+
259
+ # Mask parameters
260
+ feather_radius=2, # Very small
261
+
262
+ enhance_prompt=False, # Use color prompt directly
263
+ difficulty="easy",
264
+ recommended_models=["juggernaut_xl", "realvis_xl"],
265
+ example_prompts=[
266
+ "vibrant red",
267
+ "deep navy blue",
268
+ "bright yellow",
269
+ "emerald green",
270
+ "soft pink",
271
+ "pure white",
272
+ "charcoal grey",
273
+ "royal purple",
274
+ "coral orange",
275
+ "golden brown"
276
+ ],
277
  usage_tips=[
278
+ "🎯 Purpose: Change color only, shape stays exactly the same.",
 
 
 
 
 
 
 
279
  "",
280
  "💡 Tips:",
281
+ " • Enter ONLY the color name",
282
+ " • Use modifiers: 'bright', 'dark', 'pastel'",
283
+ " • Shape and texture are preserved exactly"
284
  ]
285
  ),
286
  }
287
 
 
288
  # Category display order
289
+ CATEGORIES = ["Color", "Replacement", "Removal"]
290
 
291
  def __init__(self):
292
  """Initialize the InpaintingTemplateManager."""
293
  logger.info(f"InpaintingTemplateManager initialized with {len(self.TEMPLATES)} templates")
294
 
295
  def get_all_templates(self) -> Dict[str, InpaintingTemplate]:
296
+ """Get all available templates."""
 
 
 
 
 
 
 
297
  return self.TEMPLATES
298
 
299
  def get_template(self, key: str) -> Optional[InpaintingTemplate]:
300
+ """Get a specific template by key."""
 
 
 
 
 
 
 
 
 
 
 
 
301
  return self.TEMPLATES.get(key)
302
 
303
  def get_templates_by_category(self, category: str) -> List[InpaintingTemplate]:
304
+ """Get all templates in a specific category."""
 
 
 
 
 
 
 
 
 
 
 
 
305
  return [t for t in self.TEMPLATES.values() if t.category == category]
306
 
307
  def get_categories(self) -> List[str]:
308
+ """Get list of all categories in display order."""
 
 
 
 
 
 
 
309
  return self.CATEGORIES
310
 
311
  def get_template_choices_sorted(self) -> List[str]:
312
+ """Get template choices formatted for Gradio dropdown."""
 
 
 
 
 
 
 
 
 
 
313
  display_list = []
 
314
  for category in self.CATEGORIES:
315
  templates = self.get_templates_by_category(category)
316
  for template in sorted(templates, key=lambda t: t.name):
317
  display_name = f"{template.icon} {template.name}"
318
  display_list.append(display_name)
 
319
  return display_list
320
 
321
  def get_template_key_from_display(self, display_name: str) -> Optional[str]:
322
+ """Get template key from display name."""
 
 
 
 
 
 
 
 
 
 
 
 
323
  if not display_name:
324
  return None
 
325
  for key, template in self.TEMPLATES.items():
326
  if f"{template.icon} {template.name}" == display_name:
327
  return key
328
  return None
329
 
330
+ def get_parameters_for_template(self, key: str) -> Dict[str, Any]:
331
+ """Get recommended parameters for a template."""
 
 
 
 
 
 
 
 
 
 
 
 
332
  template = self.get_template(key)
333
  if not template:
334
  return {}
335
 
336
  return {
337
+ "use_controlnet": template.use_controlnet,
338
+ "mask_dilation": template.mask_dilation,
339
  "controlnet_conditioning_scale": template.controlnet_conditioning_scale,
340
+ "preferred_conditioning": template.preferred_conditioning,
341
+ "preserve_structure_in_mask": template.preserve_structure_in_mask,
342
+ "edge_guidance_mode": template.edge_guidance_mode,
343
  "guidance_scale": template.guidance_scale,
344
  "num_inference_steps": template.num_inference_steps,
345
  "strength": template.strength,
346
+ "feather_radius": template.feather_radius,
347
+ "enhance_prompt": template.enhance_prompt,
 
348
  }
349
 
350
  def build_prompt(self, key: str, content: str) -> str:
351
+ """Build complete prompt from template and user content."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  template = self.get_template(key)
353
  if not template:
354
  return content
 
355
  return template.prompt_template.format(content=content)
356
 
357
  def get_negative_prompt(self, key: str) -> str:
358
+ """Get negative prompt for a template."""
 
 
 
 
 
 
 
 
 
 
 
 
359
  template = self.get_template(key)
360
  if not template:
361
  return ""
362
  return template.negative_prompt
363
 
364
  def get_usage_tips(self, key: str) -> List[str]:
365
+ """Get usage tips for a template."""
 
 
 
 
 
 
 
 
 
 
 
 
366
  template = self.get_template(key)
367
  if not template:
368
  return []
369
  return template.usage_tips
370
 
371
+ def get_recommended_models(self, key: str) -> List[str]:
372
+ """Get recommended models for a template."""
373
+ template = self.get_template(key)
374
+ if not template:
375
+ return ["sdxl_base"]
376
+ return template.recommended_models
 
 
 
 
 
 
 
 
 
377
 
378
+ def get_example_prompts(self, key: str) -> List[str]:
379
+ """Get example prompts for a template."""
380
+ template = self.get_template(key)
381
+ if not template:
382
+ return []
383
+ return template.example_prompts
384
 
385
+ def get_primary_recommended_model(self, key: str) -> str:
386
+ """Get the primary recommended model for a template."""
387
+ models = self.get_recommended_models(key)
388
+ return models[0] if models else "sdxl_base"
 
 
 
 
 
 
 
 
 
mask_generator.py CHANGED
@@ -298,7 +298,7 @@ class MaskGenerator:
298
  # High confidence areas - keep at full opacity
299
  final_alpha[high_confidence] = 255
300
 
301
- # Medium confidence - boost significantly
302
  final_alpha[medium_confidence] = np.clip(alpha_stretched[medium_confidence] * 1.8, 200, 255)
303
 
304
  # Low confidence - moderate boost (catches faint extremities)
 
298
  # High confidence areas - keep at full opacity
299
  final_alpha[high_confidence] = 255
300
 
301
+ # Medium confidence - boost significantly
302
  final_alpha[medium_confidence] = np.clip(alpha_stretched[medium_confidence] * 1.8, 200, 255)
303
 
304
  # Low confidence - moderate boost (catches faint extremities)
scene_templates.py CHANGED
@@ -24,7 +24,7 @@ class SceneTemplateManager:
24
 
25
  # Scene template definitions
26
  TEMPLATES: Dict[str, SceneTemplate] = {
27
- # Professional Category
28
  "office_modern": SceneTemplate(
29
  key="office_modern",
30
  name="Modern Office",
@@ -71,7 +71,7 @@ class SceneTemplateManager:
71
  guidance_scale=7.5
72
  ),
73
 
74
- # Nature Category
75
  "beach_sunset": SceneTemplate(
76
  key="beach_sunset",
77
  name="Sunset Beach",
@@ -127,7 +127,7 @@ class SceneTemplateManager:
127
  guidance_scale=7.0
128
  ),
129
 
130
- # Urban Category
131
  "city_skyline": SceneTemplate(
132
  key="city_skyline",
133
  name="City Skyline",
@@ -174,7 +174,7 @@ class SceneTemplateManager:
174
  guidance_scale=7.5
175
  ),
176
 
177
- # Artistic Category
178
  "gradient_soft": SceneTemplate(
179
  key="gradient_soft",
180
  name="Soft Gradient",
@@ -212,7 +212,7 @@ class SceneTemplateManager:
212
  guidance_scale=6.5
213
  ),
214
 
215
- # Seasonal Category
216
  "autumn_foliage": SceneTemplate(
217
  key="autumn_foliage",
218
  name="Autumn Foliage",
@@ -425,4 +425,4 @@ class SceneTemplateManager:
425
  grid-template-columns: repeat(3, 1fr);
426
  }
427
  }
428
- """
 
24
 
25
  # Scene template definitions
26
  TEMPLATES: Dict[str, SceneTemplate] = {
27
+ # Professional Category
28
  "office_modern": SceneTemplate(
29
  key="office_modern",
30
  name="Modern Office",
 
71
  guidance_scale=7.5
72
  ),
73
 
74
+ # Nature Category
75
  "beach_sunset": SceneTemplate(
76
  key="beach_sunset",
77
  name="Sunset Beach",
 
127
  guidance_scale=7.0
128
  ),
129
 
130
+ # Urban Category
131
  "city_skyline": SceneTemplate(
132
  key="city_skyline",
133
  name="City Skyline",
 
174
  guidance_scale=7.5
175
  ),
176
 
177
+ # Artistic Category
178
  "gradient_soft": SceneTemplate(
179
  key="gradient_soft",
180
  name="Soft Gradient",
 
212
  guidance_scale=6.5
213
  ),
214
 
215
+ # Seasonal Category
216
  "autumn_foliage": SceneTemplate(
217
  key="autumn_foliage",
218
  name="Autumn Foliage",
 
425
  grid-template-columns: repeat(3, 1fr);
426
  }
427
  }
428
+ """
scene_weaver_core.py CHANGED
@@ -321,7 +321,7 @@ class SceneWeaverCore:
321
  # Analyze image characteristics
322
  img_array = np.array(foreground_image.convert('RGB'))
323
 
324
- # Analyze color temperature
325
  # Convert to LAB to analyze color temperature
326
  lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
327
  avg_a = np.mean(lab[:, :, 1]) # a channel: green(-) to red(+)
@@ -330,12 +330,12 @@ class SceneWeaverCore:
330
  # Determine warm/cool tone
331
  is_warm = avg_b > 128 # b > 128 means more yellow/warm
332
 
333
- # Analyze brightness
334
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
335
  avg_brightness = np.mean(gray)
336
  is_bright = avg_brightness > 127
337
 
338
- # Get subject type from CLIP
339
  clip_analysis = self.analyze_image_with_clip(foreground_image)
340
  subject_type = "unknown"
341
 
@@ -369,7 +369,7 @@ class SceneWeaverCore:
369
 
370
  quality_modifiers = "high quality, detailed, sharp focus, photorealistic"
371
 
372
- # Select appropriate fragments
373
  # Lighting based on color temperature and brightness
374
  if is_warm and is_bright:
375
  lighting = lighting_options["warm_bright"]
@@ -383,7 +383,7 @@ class SceneWeaverCore:
383
  # Atmosphere based on subject type
384
  atmosphere = atmosphere_options.get(subject_type, atmosphere_options["unknown"])
385
 
386
- # Check for conflicts in user prompt
387
  user_prompt_lower = user_prompt.lower()
388
 
389
  # Avoid adding conflicting descriptions
@@ -392,7 +392,7 @@ class SceneWeaverCore:
392
  if "dark" in user_prompt_lower or "night" in user_prompt_lower:
393
  lighting = lighting.replace("bright", "").replace("daylight", "")
394
 
395
- # Combine enhanced prompt
396
  fragments = [user_prompt]
397
 
398
  if lighting:
@@ -864,25 +864,33 @@ class SceneWeaverCore:
864
  """
865
  if self._inpainting_module is None:
866
  self._inpainting_module = InpaintingModule(device=self.device)
867
- self._inpainting_module.set_model_manager(self._model_manager)
868
  logger.info("InpaintingModule created (lazy load)")
869
 
870
  return self._inpainting_module
871
 
872
  def switch_to_inpainting_mode(
873
  self,
 
874
  conditioning_type: str = "canny",
 
875
  progress_callback: Optional[Callable[[str, int], None]] = None
876
  ) -> bool:
877
  """
878
  Switch to inpainting mode, unloading background pipeline.
879
 
880
- Implements mutual exclusion between pipelines to conserve memory.
 
 
881
 
882
  Parameters
883
  ----------
 
 
 
884
  conditioning_type : str
885
- ControlNet conditioning type: "canny" or "depth"
 
 
886
  progress_callback : callable, optional
887
  Progress update function(message, percentage)
888
 
@@ -891,7 +899,8 @@ class SceneWeaverCore:
891
  bool
892
  True if switch was successful
893
  """
894
- logger.info(f"Switching to inpainting mode (conditioning: {conditioning_type})")
 
895
 
896
  try:
897
  # Unload background pipeline first
@@ -912,12 +921,14 @@ class SceneWeaverCore:
912
 
913
  def inpaint_progress(msg, pct):
914
  if progress_callback:
915
- # Map inpainting progress (0-100) to (20-90)
916
  mapped_pct = 20 + int(pct * 0.7)
917
  progress_callback(msg, mapped_pct)
918
 
919
- success, error_msg = inpaint_module.load_inpainting_pipeline(
 
 
920
  conditioning_type=conditioning_type,
 
921
  progress_callback=inpaint_progress
922
  )
923
 
@@ -997,6 +1008,7 @@ class SceneWeaverCore:
997
  prompt: str,
998
  preview_only: bool = False,
999
  template_key: Optional[str] = None,
 
1000
  progress_callback: Optional[Callable[[str, int], None]] = None,
1001
  **kwargs
1002
  ) -> Dict[str, Any]:
@@ -1017,6 +1029,8 @@ class SceneWeaverCore:
1017
  If True, generate quick preview only
1018
  template_key : str, optional
1019
  Inpainting template key to use
 
 
1020
  progress_callback : callable, optional
1021
  Progress update function
1022
  **kwargs
@@ -1027,10 +1041,30 @@ class SceneWeaverCore:
1027
  dict
1028
  Result dictionary with images and metadata
1029
  """
1030
- # Ensure inpainting mode is active
1031
- if self._current_mode != "inpainting" or not self._inpainting_initialized:
1032
- conditioning = kwargs.get('conditioning_type', 'canny')
1033
- if not self.switch_to_inpainting_mode(conditioning, progress_callback):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1034
  error_detail = getattr(self, '_last_inpainting_error', 'Unknown error')
1035
  return {
1036
  "success": False,
@@ -1038,33 +1072,11 @@ class SceneWeaverCore:
1038
  }
1039
 
1040
  inpaint_module = self.get_inpainting_module()
1041
-
1042
- # Apply template if specified
1043
- if template_key:
1044
- template_mgr = InpaintingTemplateManager()
1045
- template = template_mgr.get_template(template_key)
1046
-
1047
- if template:
1048
- # Build prompt from template
1049
- prompt = template_mgr.build_prompt(template_key, prompt)
1050
- # Apply template parameters as defaults
1051
- params = template_mgr.get_parameters_for_template(template_key)
1052
- for key, value in params.items():
1053
- if key not in kwargs:
1054
- kwargs[key] = value
1055
-
1056
- # Pass enhance_prompt flag to inpainting module
1057
- if 'enhance_prompt' not in kwargs:
1058
- kwargs['enhance_prompt'] = template.enhance_prompt
1059
-
1060
- # Execute inpainting
1061
  result = inpaint_module.execute_inpainting(
1062
  image=image,
1063
  mask=mask,
1064
  prompt=prompt,
1065
- preview_only=preview_only,
1066
  progress_callback=progress_callback,
1067
- template_key=template_key, # Pass template_key for conditional prompt enhancement
1068
  **kwargs
1069
  )
1070
 
@@ -1191,4 +1203,4 @@ class SceneWeaverCore:
1191
 
1192
  status = self._inpainting_module.get_status()
1193
  status["mode"] = self._current_mode
1194
- return status
 
321
  # Analyze image characteristics
322
  img_array = np.array(foreground_image.convert('RGB'))
323
 
324
+ # Analyze color temperature
325
  # Convert to LAB to analyze color temperature
326
  lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
327
  avg_a = np.mean(lab[:, :, 1]) # a channel: green(-) to red(+)
 
330
  # Determine warm/cool tone
331
  is_warm = avg_b > 128 # b > 128 means more yellow/warm
332
 
333
+ # Analyze brightness
334
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
335
  avg_brightness = np.mean(gray)
336
  is_bright = avg_brightness > 127
337
 
338
+ # Get subject type from CLIP
339
  clip_analysis = self.analyze_image_with_clip(foreground_image)
340
  subject_type = "unknown"
341
 
 
369
 
370
  quality_modifiers = "high quality, detailed, sharp focus, photorealistic"
371
 
372
+ # Select appropriate fragments
373
  # Lighting based on color temperature and brightness
374
  if is_warm and is_bright:
375
  lighting = lighting_options["warm_bright"]
 
383
  # Atmosphere based on subject type
384
  atmosphere = atmosphere_options.get(subject_type, atmosphere_options["unknown"])
385
 
386
+ # Check for conflicts in user prompt
387
  user_prompt_lower = user_prompt.lower()
388
 
389
  # Avoid adding conflicting descriptions
 
392
  if "dark" in user_prompt_lower or "night" in user_prompt_lower:
393
  lighting = lighting.replace("bright", "").replace("daylight", "")
394
 
395
+ # Combine enhanced prompt
396
  fragments = [user_prompt]
397
 
398
  if lighting:
 
864
  """
865
  if self._inpainting_module is None:
866
  self._inpainting_module = InpaintingModule(device=self.device)
 
867
  logger.info("InpaintingModule created (lazy load)")
868
 
869
  return self._inpainting_module
870
 
871
  def switch_to_inpainting_mode(
872
  self,
873
+ use_controlnet: bool = True,
874
  conditioning_type: str = "canny",
875
+ model_key: str = "sdxl_base",
876
  progress_callback: Optional[Callable[[str, int], None]] = None
877
  ) -> bool:
878
  """
879
  Switch to inpainting mode, unloading background pipeline.
880
 
881
+ Supports dual-mode inpainting:
882
+ - Pure Inpainting (use_controlnet=False): For object replacement/removal
883
+ - ControlNet Inpainting (use_controlnet=True): For clothing/color change
884
 
885
  Parameters
886
  ----------
887
+ use_controlnet : bool
888
+ If False, use dedicated SDXL Inpainting model
889
+ If True, use ControlNet + SDXL model
890
  conditioning_type : str
891
+ ControlNet conditioning type: "canny" or "depth" (only for ControlNet mode)
892
+ model_key : str
893
+ Model key for ControlNet mode base model
894
  progress_callback : callable, optional
895
  Progress update function(message, percentage)
896
 
 
899
  bool
900
  True if switch was successful
901
  """
902
+ mode_str = "ControlNet" if use_controlnet else "Pure Inpainting"
903
+ logger.info(f"Switching to inpainting mode: {mode_str} (model: {model_key})")
904
 
905
  try:
906
  # Unload background pipeline first
 
921
 
922
  def inpaint_progress(msg, pct):
923
  if progress_callback:
 
924
  mapped_pct = 20 + int(pct * 0.7)
925
  progress_callback(msg, mapped_pct)
926
 
927
+ # Use the new load_pipeline method with dual-mode support
928
+ success, error_msg = inpaint_module.load_pipeline(
929
+ use_controlnet=use_controlnet,
930
  conditioning_type=conditioning_type,
931
+ model_key=model_key,
932
  progress_callback=inpaint_progress
933
  )
934
 
 
1008
  prompt: str,
1009
  preview_only: bool = False,
1010
  template_key: Optional[str] = None,
1011
+ model_key: str = "sdxl_base",
1012
  progress_callback: Optional[Callable[[str, int], None]] = None,
1013
  **kwargs
1014
  ) -> Dict[str, Any]:
 
1029
  If True, generate quick preview only
1030
  template_key : str, optional
1031
  Inpainting template key to use
1032
+ model_key : str
1033
+ Model key for the base model (juggernaut_xl, realvis_xl, sdxl_base, animagine_xl)
1034
  progress_callback : callable, optional
1035
  Progress update function
1036
  **kwargs
 
1041
  dict
1042
  Result dictionary with images and metadata
1043
  """
1044
+ # Get pipeline mode from kwargs
1045
+ use_controlnet = kwargs.get('use_controlnet', True)
1046
+ conditioning_type = kwargs.get('conditioning_type', 'canny')
1047
+
1048
+ # Check if we need to reinitialize
1049
+ inpaint_module = self.get_inpainting_module()
1050
+ current_mode = getattr(inpaint_module, '_current_mode', None)
1051
+ current_model = getattr(inpaint_module, '_current_model_key', None)
1052
+
1053
+ expected_mode = "controlnet" if use_controlnet else "pure"
1054
+ needs_reinit = (
1055
+ self._current_mode != "inpainting" or
1056
+ not self._inpainting_initialized or
1057
+ current_mode != expected_mode or
1058
+ (use_controlnet and current_model != model_key)
1059
+ )
1060
+
1061
+ if needs_reinit:
1062
+ if not self.switch_to_inpainting_mode(
1063
+ use_controlnet=use_controlnet,
1064
+ conditioning_type=conditioning_type,
1065
+ model_key=model_key,
1066
+ progress_callback=progress_callback
1067
+ ):
1068
  error_detail = getattr(self, '_last_inpainting_error', 'Unknown error')
1069
  return {
1070
  "success": False,
 
1072
  }
1073
 
1074
  inpaint_module = self.get_inpainting_module()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1075
  result = inpaint_module.execute_inpainting(
1076
  image=image,
1077
  mask=mask,
1078
  prompt=prompt,
 
1079
  progress_callback=progress_callback,
 
1080
  **kwargs
1081
  )
1082
 
 
1203
 
1204
  status = self._inpainting_module.get_status()
1205
  status["mode"] = self._current_mode
1206
+ return status
ui_manager.py CHANGED
@@ -3,16 +3,17 @@ import time
3
  import traceback
4
  from pathlib import Path
5
  from typing import Optional, Tuple, Dict, Any, List
6
- from PIL import Image
7
- import numpy as np
8
  import cv2
9
  import gradio as gr
10
- import spaces
 
11
 
12
- from scene_weaver_core import SceneWeaverCore
13
  from css_styles import CSSStyles
14
  from scene_templates import SceneTemplateManager
15
  from inpainting_templates import InpaintingTemplateManager
 
 
16
 
17
  logger = logging.getLogger(__name__)
18
  logger.setLevel(logging.INFO)
@@ -29,16 +30,20 @@ class UIManager:
29
  Gradio UI Manager with support for background generation and inpainting.
30
 
31
  Provides a professional interface with mode switching, template selection,
32
- and advanced parameter controls.
33
 
34
  Attributes:
35
- sceneweaver: SceneWeaverCore instance
36
  template_manager: Scene template manager
37
  inpainting_template_manager: Inpainting template manager
38
  """
39
 
40
  def __init__(self):
41
  self.sceneweaver = SceneWeaverCore()
 
 
 
 
42
  self.template_manager = SceneTemplateManager()
43
  self.inpainting_template_manager = InpaintingTemplateManager()
44
  self.generation_history = []
@@ -173,7 +178,6 @@ class UIManager:
173
  if len(self.generation_history) > max_history:
174
  self.generation_history = self.generation_history[-max_history:]
175
 
176
- @spaces.GPU(duration=240)
177
  def generate_handler(
178
  self,
179
  uploaded_image: Optional[Image.Image],
@@ -185,8 +189,33 @@ class UIManager:
185
  guidance: float,
186
  progress=gr.Progress()
187
  ):
188
- """Enhanced generation handler with memory management and ZeroGPU support"""
 
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  if uploaded_image is None:
191
  return None, None, None, "Please upload an image to get started!", gr.update(visible=False)
192
 
@@ -194,44 +223,19 @@ class UIManager:
194
  return None, None, None, "Please describe the background scene you'd like!", gr.update(visible=False)
195
 
196
  try:
197
- if not self.sceneweaver.is_initialized:
198
- progress(0.05, desc="Loading AI models (first time may take 2-3 minutes)...")
199
-
200
- def init_progress(msg, pct):
201
- if pct < 30:
202
- desc = "Loading image analysis models..."
203
- elif pct < 60:
204
- desc = "Loading Stable Diffusion XL..."
205
- elif pct < 90:
206
- desc = "Applying memory optimizations..."
207
- else:
208
- desc = "Almost ready..."
209
- progress(0.05 + (pct/100) * 0.2, desc=desc)
210
-
211
- self.sceneweaver.load_models(progress_callback=init_progress)
212
-
213
- def gen_progress(msg, pct):
214
- if pct < 20:
215
- desc = "Analyzing your image..."
216
- elif pct < 50:
217
- desc = "Generating background scene..."
218
- elif pct < 80:
219
- desc = "Blending foreground and background..."
220
- elif pct < 95:
221
- desc = "Applying final touches..."
222
- else:
223
- desc = "Complete!"
224
- progress(0.25 + (pct/100) * 0.75, desc=desc)
225
 
226
- result = self.sceneweaver.generate_and_combine(
227
- original_image=uploaded_image,
 
228
  prompt=prompt,
229
- combination_mode=combination_mode,
230
- focus_mode=focus_mode,
231
  negative_prompt=negative_prompt,
232
- num_inference_steps=int(steps),
 
 
233
  guidance_scale=float(guidance),
234
- progress_callback=gen_progress
235
  )
236
 
237
  if result["success"]:
@@ -547,7 +551,7 @@ class UIManager:
547
  self,
548
  display_name: str,
549
  current_prompt: str
550
- ) -> Tuple[str, float, int, str]:
551
  """
552
  Apply an inpainting template to the UI fields.
553
 
@@ -561,26 +565,76 @@ class UIManager:
561
  Returns
562
  -------
563
  tuple
564
- (prompt, conditioning_scale, feather_radius, conditioning_type)
 
565
  """
 
 
 
 
 
 
 
 
 
 
 
566
  if not display_name:
567
- return current_prompt, 0.7, 8, "canny"
568
 
569
  template_key = self.inpainting_template_manager.get_template_key_from_display(display_name)
570
  if not template_key:
571
- return current_prompt, 0.7, 8, "canny"
572
 
573
  template = self.inpainting_template_manager.get_template(template_key)
574
  if template:
575
  params = self.inpainting_template_manager.get_parameters_for_template(template_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  return (
577
  current_prompt,
578
  params.get('controlnet_conditioning_scale', 0.7),
579
  params.get('feather_radius', 8),
580
- params.get('preferred_conditioning', 'canny')
 
 
 
581
  )
582
 
583
- return current_prompt, 0.7, 8, "canny"
584
 
585
  def extract_mask_from_editor(self, editor_output: Dict[str, Any]) -> Optional[Image.Image]:
586
  """
@@ -664,22 +718,23 @@ class UIManager:
664
  logger.error(f"Failed to extract mask from editor: {e}")
665
  return None
666
 
667
- @spaces.GPU(duration=420)
668
  def inpainting_handler(
669
  self,
670
  image: Optional[Image.Image],
671
  mask_editor: Dict[str, Any],
672
  prompt: str,
673
  template_dropdown: str,
 
674
  conditioning_type: str,
675
  conditioning_scale: float,
676
  feather_radius: int,
677
  guidance_scale: float,
678
  num_steps: int,
 
679
  progress: gr.Progress = gr.Progress()
680
- ) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], str]:
681
  """
682
- Handle inpainting generation request.
683
 
684
  Parameters
685
  ----------
@@ -691,6 +746,8 @@ class UIManager:
691
  Text description of desired content
692
  template_dropdown : str
693
  Selected template (optional)
 
 
694
  conditioning_type : str
695
  ControlNet conditioning type
696
  conditioning_scale : float
@@ -701,36 +758,36 @@ class UIManager:
701
  Guidance scale for generation
702
  num_steps : int
703
  Number of inference steps
 
 
704
  progress : gr.Progress
705
  Progress callback
706
 
707
  Returns
708
  -------
709
  tuple
710
- (result_image, control_image, status_message)
711
  """
712
  if image is None:
713
- return None, None, "⚠️ Please upload an image first"
714
 
715
  # Extract mask
716
  mask = self.extract_mask_from_editor(mask_editor)
717
  if mask is None:
718
- return None, None, "⚠️ Please draw a mask on the image"
719
 
720
  # Validate mask
721
  mask_array = np.array(mask)
722
  coverage = np.count_nonzero(mask_array > 127) / mask_array.size
723
  if coverage < 0.01:
724
- return None, None, "⚠️ Mask too small - please select a larger area"
725
  if coverage > 0.95:
726
- return None, None, "⚠️ Mask too large - consider using background generation instead"
727
 
728
  def progress_callback(msg: str, pct: int):
729
  progress(pct / 100, desc=msg)
730
 
731
  try:
732
- start_time = time.time()
733
-
734
  # Get template key if selected
735
  template_key = None
736
  if template_dropdown:
@@ -738,53 +795,39 @@ class UIManager:
738
  template_dropdown
739
  )
740
 
741
- # Execute inpainting through SceneWeaverCore facade
742
- result = self.sceneweaver.execute_inpainting(
743
  image=image,
744
  mask=mask,
745
  prompt=prompt,
746
- preview_only=False,
747
  template_key=template_key,
 
748
  conditioning_type=conditioning_type,
749
- controlnet_conditioning_scale=conditioning_scale,
750
  feather_radius=feather_radius,
751
  guidance_scale=guidance_scale,
752
- num_inference_steps=num_steps,
 
753
  progress_callback=progress_callback
754
  )
755
 
756
- elapsed = time.time() - start_time
757
-
758
- if result.get('success'):
759
- # Store in history
760
  self.inpainting_history.append({
761
- 'result': result.get('combined_image'),
762
  'prompt': prompt,
763
- 'time': elapsed
 
764
  })
765
  if len(self.inpainting_history) > 3:
766
  self.inpainting_history.pop(0)
767
 
768
- quality_score = result.get('quality_score', 0)
769
-
770
- # Clean, simple status message
771
- status = f"✅ Inpainting complete in {elapsed:.1f}s"
772
- if quality_score > 0:
773
- status += f" | Quality: {quality_score:.0f}/100"
774
-
775
- return (
776
- result.get('combined_image'),
777
- result.get('control_image'),
778
- status
779
- )
780
- else:
781
- error_msg = result.get('error', 'Unknown error')
782
- return None, None, f"❌ Inpainting failed: {error_msg}"
783
 
784
  except Exception as e:
785
  logger.error(f"Inpainting handler error: {e}")
786
  logger.error(traceback.format_exc())
787
- return None, None, f"❌ Error: {str(e)}"
788
 
789
  def create_inpainting_tab(self) -> gr.Tab:
790
  """
@@ -812,17 +855,44 @@ class UIManager:
812
  </span>
813
  </h3>
814
  <p style="color: #666; margin-bottom: 12px;">Draw a mask to select the area you want to regenerate</p>
815
- <div style="background: linear-gradient(to right, #FFF4E6, #FFE8CC);
816
- border-left: 4px solid #FF9500;
817
- padding: 12px 15px;
818
- border-radius: 6px;
819
- margin-top: 10px;
820
- box-shadow: 0 2px 4px rgba(255, 149, 0, 0.1);">
821
- <p style="color: #8B4513; font-size: 0.9em; margin: 0; line-height: 1.5;">
822
- <strong>⚠️ Beta Feature - Continuously Optimizing</strong><br>
823
- Results may vary depending on complexity. Use templates and detailed prompts for best results.
824
- Advanced features (like Add Accessories) may require multiple attempts.
825
- </p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826
  </div>
827
  </div>
828
  """)
@@ -859,6 +929,9 @@ class UIManager:
859
  )
860
  template_tips = gr.Markdown("")
861
 
 
 
 
862
  # Prompt
863
  inpaint_prompt = gr.Textbox(
864
  label="Prompt",
@@ -868,28 +941,49 @@ class UIManager:
868
 
869
  # Right column - Settings and Output
870
  with gr.Column(scale=1):
871
- # Settings
872
- with gr.Accordion("Generation Settings", open=True):
873
- conditioning_type = gr.Radio(
874
- choices=["canny", "depth"],
875
- value="canny",
876
- label="ControlNet Mode"
877
- )
 
 
 
 
 
 
 
878
 
879
- conditioning_scale = gr.Slider(
880
- minimum=0.05,
881
- maximum=1.0,
882
- value=0.7,
883
- step=0.05,
884
- label="ControlNet Strength"
885
- )
 
 
 
 
 
 
 
 
 
 
 
886
 
 
 
887
  feather_radius = gr.Slider(
888
  minimum=0,
889
  maximum=20,
890
  value=8,
891
  step=1,
892
- label="Feather Radius (px)"
 
893
  )
894
 
895
  with gr.Accordion("Advanced Settings", open=False):
@@ -909,6 +1003,14 @@ class UIManager:
909
  label="Inference Steps"
910
  )
911
 
 
 
 
 
 
 
 
 
912
  # Generate button
913
  inpaint_btn = gr.Button(
914
  "Generate Inpainting",
@@ -925,9 +1027,9 @@ class UIManager:
925
  border-radius: 8px;
926
  margin: 12px 0;">
927
  <p style="margin: 0; color: #5d4037; font-size: 14px;">
928
- ⏳ <strong>Please be patient!</strong> Inpainting typically takes <strong>5-7 minutes</strong>
929
- depending on GPU availability and image complexity.
930
- Please don't refresh the page while processing.
931
  </p>
932
  </div>
933
  <div style="background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%);
@@ -943,13 +1045,27 @@ class UIManager:
943
  """
944
  )
945
 
946
- # Status
947
  inpaint_status = gr.Textbox(
948
  label="Status",
949
  value="Ready for inpainting",
950
  interactive=False
951
  )
952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
953
  # Output row
954
  with gr.Row():
955
  with gr.Column(scale=1):
@@ -971,7 +1087,15 @@ class UIManager:
971
  inpaint_template.change(
972
  fn=self.apply_inpainting_template,
973
  inputs=[inpaint_template, inpaint_prompt],
974
- outputs=[inpaint_prompt, conditioning_scale, feather_radius, conditioning_type]
 
 
 
 
 
 
 
 
975
  )
976
 
977
  inpaint_template.change(
@@ -980,9 +1104,16 @@ class UIManager:
980
  outputs=[template_tips]
981
  )
982
 
983
- # Copy uploaded image to mask editor
 
 
 
 
 
 
 
984
  inpaint_image.change(
985
- fn=lambda x: x,
986
  inputs=[inpaint_image],
987
  outputs=[mask_editor]
988
  )
@@ -994,19 +1125,29 @@ class UIManager:
994
  mask_editor,
995
  inpaint_prompt,
996
  inpaint_template,
 
997
  conditioning_type,
998
  conditioning_scale,
999
  feather_radius,
1000
  inpaint_guidance,
1001
- inpaint_steps
 
1002
  ],
1003
  outputs=[
1004
  inpaint_result,
1005
  inpaint_control,
1006
- inpaint_status
 
1007
  ]
1008
  )
1009
 
 
 
 
 
 
 
 
1010
  return tab
1011
 
1012
  def _get_template_tips(self, display_name: str) -> str:
@@ -1021,4 +1162,4 @@ class UIManager:
1021
  tips = self.inpainting_template_manager.get_usage_tips(template_key)
1022
  if tips:
1023
  return "**Tips:**\n" + "\n".join(f"- {tip}" for tip in tips)
1024
- return ""
 
3
  import traceback
4
  from pathlib import Path
5
  from typing import Optional, Tuple, Dict, Any, List
6
+
 
7
  import cv2
8
  import gradio as gr
9
+ import numpy as np
10
+ from PIL import Image
11
 
 
12
  from css_styles import CSSStyles
13
  from scene_templates import SceneTemplateManager
14
  from inpainting_templates import InpaintingTemplateManager
15
+ from scene_weaver_core import SceneWeaverCore
16
+ from gpu_handlers import GPUHandlers
17
 
18
  logger = logging.getLogger(__name__)
19
  logger.setLevel(logging.INFO)
 
30
  Gradio UI Manager with support for background generation and inpainting.
31
 
32
  Provides a professional interface with mode switching, template selection,
33
+ and advanced parameter controls. GPU operations are delegated to GPUHandlers.
34
 
35
  Attributes:
36
+ gpu_handlers: GPUHandlers instance for GPU operations
37
  template_manager: Scene template manager
38
  inpainting_template_manager: Inpainting template manager
39
  """
40
 
41
  def __init__(self):
42
  self.sceneweaver = SceneWeaverCore()
43
+ self.gpu_handlers = GPUHandlers(
44
+ core=self.sceneweaver,
45
+ inpainting_template_manager=InpaintingTemplateManager()
46
+ )
47
  self.template_manager = SceneTemplateManager()
48
  self.inpainting_template_manager = InpaintingTemplateManager()
49
  self.generation_history = []
 
178
  if len(self.generation_history) > max_history:
179
  self.generation_history = self.generation_history[-max_history:]
180
 
 
181
  def generate_handler(
182
  self,
183
  uploaded_image: Optional[Image.Image],
 
189
  guidance: float,
190
  progress=gr.Progress()
191
  ):
192
+ """
193
+ Generation handler - delegates GPU work to GPUHandlers.
194
 
195
+ Parameters
196
+ ----------
197
+ uploaded_image : PIL.Image
198
+ Input image
199
+ prompt : str
200
+ Background description
201
+ combination_mode : str
202
+ Composition mode
203
+ focus_mode : str
204
+ Focus mode
205
+ negative_prompt : str
206
+ Negative prompt
207
+ steps : int
208
+ Inference steps
209
+ guidance : float
210
+ Guidance scale
211
+ progress : gr.Progress
212
+ Progress callback
213
+
214
+ Returns
215
+ -------
216
+ tuple
217
+ (combined, generated, original, status, download_btn_update)
218
+ """
219
  if uploaded_image is None:
220
  return None, None, None, "Please upload an image to get started!", gr.update(visible=False)
221
 
 
223
  return None, None, None, "Please describe the background scene you'd like!", gr.update(visible=False)
224
 
225
  try:
226
+ def progress_callback(msg: str, pct: int):
227
+ progress(pct / 100, desc=msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ # Delegate to GPUHandlers
230
+ result = self.gpu_handlers.background_generate(
231
+ image=uploaded_image,
232
  prompt=prompt,
 
 
233
  negative_prompt=negative_prompt,
234
+ composition_mode=combination_mode,
235
+ focus_mode=focus_mode,
236
+ num_steps=int(steps),
237
  guidance_scale=float(guidance),
238
+ progress_callback=progress_callback
239
  )
240
 
241
  if result["success"]:
 
551
  self,
552
  display_name: str,
553
  current_prompt: str
554
+ ) -> Tuple[str, float, int, str, Any, Any, Any]:
555
  """
556
  Apply an inpainting template to the UI fields.
557
 
 
565
  Returns
566
  -------
567
  tuple
568
+ (prompt, conditioning_scale, feather_radius, conditioning_type,
569
+ controlnet_settings_visibility, mode_info_html, model_selection_visibility)
570
  """
571
+ # Default returns for no template selected
572
+ default_return = (
573
+ current_prompt,
574
+ 0.7,
575
+ 8,
576
+ "canny",
577
+ gr.update(visible=True), # Show ControlNet settings by default
578
+ "", # No mode info
579
+ gr.update(visible=True) # Show model selection by default
580
+ )
581
+
582
  if not display_name:
583
+ return default_return
584
 
585
  template_key = self.inpainting_template_manager.get_template_key_from_display(display_name)
586
  if not template_key:
587
+ return default_return
588
 
589
  template = self.inpainting_template_manager.get_template(template_key)
590
  if template:
591
  params = self.inpainting_template_manager.get_parameters_for_template(template_key)
592
+ use_controlnet = params.get('use_controlnet', True)
593
+
594
+ # Determine visibility and info based on mode
595
+ if use_controlnet:
596
+ controlnet_visibility = gr.update(visible=True)
597
+ model_visibility = gr.update(visible=True)
598
+ mode_info = """
599
+ <div style="background: linear-gradient(135deg, #e8f5e9 0%, #c8e6c9 100%);
600
+ border-left: 4px solid #4CAF50;
601
+ padding: 10px 14px;
602
+ border-radius: 8px;
603
+ margin: 8px 0;">
604
+ <p style="margin: 0; color: #2e7d32; font-size: 13px;">
605
+ 🎛️ <strong>ControlNet Mode</strong> - Structure will be preserved using edge/depth guidance.
606
+ You can adjust ControlNet settings and select model below.
607
+ </p>
608
+ </div>
609
+ """
610
+ else:
611
+ # Pure Inpainting mode - hide both ControlNet and Model Selection
612
+ controlnet_visibility = gr.update(visible=False)
613
+ model_visibility = gr.update(visible=False)
614
+ mode_info = """
615
+ <div style="background: linear-gradient(135deg, #fff3e0 0%, #ffe0b2 100%);
616
+ border-left: 4px solid #ff9800;
617
+ padding: 10px 14px;
618
+ border-radius: 8px;
619
+ margin: 8px 0;">
620
+ <p style="margin: 0; color: #e65100; font-size: 13px;">
621
+ 🚀 <strong>Pure Inpainting Mode</strong> - Using dedicated SDXL Inpainting model.<br>
622
+ Model and ControlNet settings are automatically configured for best results.
623
+ </p>
624
+ </div>
625
+ """
626
+
627
  return (
628
  current_prompt,
629
  params.get('controlnet_conditioning_scale', 0.7),
630
  params.get('feather_radius', 8),
631
+ params.get('preferred_conditioning', 'canny'),
632
+ controlnet_visibility,
633
+ mode_info,
634
+ model_visibility
635
  )
636
 
637
+ return default_return
638
 
639
  def extract_mask_from_editor(self, editor_output: Dict[str, Any]) -> Optional[Image.Image]:
640
  """
 
718
  logger.error(f"Failed to extract mask from editor: {e}")
719
  return None
720
 
 
721
  def inpainting_handler(
722
  self,
723
  image: Optional[Image.Image],
724
  mask_editor: Dict[str, Any],
725
  prompt: str,
726
  template_dropdown: str,
727
+ model_choice: str,
728
  conditioning_type: str,
729
  conditioning_scale: float,
730
  feather_radius: int,
731
  guidance_scale: float,
732
  num_steps: int,
733
+ seed: int,
734
  progress: gr.Progress = gr.Progress()
735
+ ) -> Tuple[Optional[Image.Image], Optional[Image.Image], str, int]:
736
  """
737
+ Handle inpainting generation request - delegates GPU work to GPUHandlers.
738
 
739
  Parameters
740
  ----------
 
746
  Text description of desired content
747
  template_dropdown : str
748
  Selected template (optional)
749
+ model_choice : str
750
+ Model key to use (juggernaut_xl, realvis_xl, sdxl_base, animagine_xl)
751
  conditioning_type : str
752
  ControlNet conditioning type
753
  conditioning_scale : float
 
758
  Guidance scale for generation
759
  num_steps : int
760
  Number of inference steps
761
+ seed : int
762
+ Random seed (-1 for random)
763
  progress : gr.Progress
764
  Progress callback
765
 
766
  Returns
767
  -------
768
  tuple
769
+ (result_image, control_image, status_message, used_seed)
770
  """
771
  if image is None:
772
+ return None, None, "⚠️ Please upload an image first", -1
773
 
774
  # Extract mask
775
  mask = self.extract_mask_from_editor(mask_editor)
776
  if mask is None:
777
+ return None, None, "⚠️ Please draw a mask on the image", -1
778
 
779
  # Validate mask
780
  mask_array = np.array(mask)
781
  coverage = np.count_nonzero(mask_array > 127) / mask_array.size
782
  if coverage < 0.01:
783
+ return None, None, "⚠️ Mask too small - please select a larger area", -1
784
  if coverage > 0.95:
785
+ return None, None, "⚠️ Mask too large - consider using background generation instead", -1
786
 
787
  def progress_callback(msg: str, pct: int):
788
  progress(pct / 100, desc=msg)
789
 
790
  try:
 
 
791
  # Get template key if selected
792
  template_key = None
793
  if template_dropdown:
 
795
  template_dropdown
796
  )
797
 
798
+ # Delegate to GPUHandlers
799
+ result_image, control_image, status, used_seed = self.gpu_handlers.inpainting_generate(
800
  image=image,
801
  mask=mask,
802
  prompt=prompt,
 
803
  template_key=template_key,
804
+ model_key=model_choice,
805
  conditioning_type=conditioning_type,
806
+ conditioning_scale=conditioning_scale,
807
  feather_radius=feather_radius,
808
  guidance_scale=guidance_scale,
809
+ num_steps=num_steps,
810
+ seed=int(seed) if seed is not None else -1,
811
  progress_callback=progress_callback
812
  )
813
 
814
+ # Store in history if successful
815
+ if result_image is not None:
 
 
816
  self.inpainting_history.append({
817
+ 'result': result_image,
818
  'prompt': prompt,
819
+ 'seed': used_seed,
820
+ 'time': time.time()
821
  })
822
  if len(self.inpainting_history) > 3:
823
  self.inpainting_history.pop(0)
824
 
825
+ return result_image, control_image, status, used_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826
 
827
  except Exception as e:
828
  logger.error(f"Inpainting handler error: {e}")
829
  logger.error(traceback.format_exc())
830
+ return None, None, f"❌ Error: {str(e)}", -1
831
 
832
  def create_inpainting_tab(self) -> gr.Tab:
833
  """
 
855
  </span>
856
  </h3>
857
  <p style="color: #666; margin-bottom: 12px;">Draw a mask to select the area you want to regenerate</p>
858
+ </div>
859
+ """)
860
+
861
+ # Model Selection Guide
862
+ gr.HTML("""
863
+ <div style="background: linear-gradient(135deg, #f5f7fa 0%, #e4e8ec 100%);
864
+ padding: 16px;
865
+ border-radius: 12px;
866
+ margin: 12px 0;
867
+ border: 1px solid #ddd;">
868
+ <h4 style="margin: 0 0 12px 0; color: #333; font-size: 16px;">
869
+ 📸 Model Selection Guide
870
+ </h4>
871
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 12px;">
872
+ <div style="background: white; padding: 12px; border-radius: 8px; border-left: 4px solid #4CAF50;">
873
+ <p style="margin: 0 0 8px 0; font-weight: bold; color: #4CAF50;">
874
+ 🖼️ Photo Mode (Real Photos)
875
+ </p>
876
+ <p style="margin: 0; font-size: 13px; color: #555;">
877
+ <strong>Best for:</strong> Photographs, portraits, product shots, nature photos
878
+ </p>
879
+ <p style="margin: 8px 0 0 0; font-size: 12px; color: #777;">
880
+ • <strong>JuggernautXL</strong> - Best for portraits and people<br>
881
+ • <strong>RealVisXL</strong> - Best for scenes and objects
882
+ </p>
883
+ </div>
884
+ <div style="background: white; padding: 12px; border-radius: 8px; border-left: 4px solid #9C27B0;">
885
+ <p style="margin: 0 0 8px 0; font-weight: bold; color: #9C27B0;">
886
+ 🎨 Anime Mode (Illustrations)
887
+ </p>
888
+ <p style="margin: 0; font-size: 13px; color: #555;">
889
+ <strong>Best for:</strong> Anime, manga, illustrations, digital art, cartoons
890
+ </p>
891
+ <p style="margin: 8px 0 0 0; font-size: 12px; color: #777;">
892
+ • <strong>Animagine XL</strong> - Best for anime/manga style<br>
893
+ • <strong>SDXL Base</strong> - Versatile for general art
894
+ </p>
895
+ </div>
896
  </div>
897
  </div>
898
  """)
 
929
  )
930
  template_tips = gr.Markdown("")
931
 
932
+ # Mode info (dynamically updated based on template)
933
+ mode_info_html = gr.HTML("")
934
+
935
  # Prompt
936
  inpaint_prompt = gr.Textbox(
937
  label="Prompt",
 
941
 
942
  # Right column - Settings and Output
943
  with gr.Column(scale=1):
944
+ # Model Selection (hidden for Pure Inpainting templates)
945
+ with gr.Group(visible=True) as model_selection_group:
946
+ with gr.Accordion("Model Selection", open=True):
947
+ model_choice = gr.Dropdown(
948
+ choices=[
949
+ ("🖼️ JuggernautXL v9 - Best for portraits & real photos", "juggernaut_xl"),
950
+ ("🖼️ RealVisXL v4 - Best for realistic scenes", "realvis_xl"),
951
+ ("🎨 SDXL Base - Versatile for general art", "sdxl_base"),
952
+ ("🎨 Animagine XL 3.1 - Best for anime/manga", "animagine_xl"),
953
+ ],
954
+ value="juggernaut_xl",
955
+ label="Select Model",
956
+ info="Choose based on your image type (photo vs illustration)"
957
+ )
958
 
959
+ # ControlNet Settings (hidden for Pure Inpainting templates)
960
+ with gr.Group(visible=True) as controlnet_settings_group:
961
+ with gr.Accordion("ControlNet Settings", open=True):
962
+ conditioning_type = gr.Radio(
963
+ choices=["canny", "depth"],
964
+ value="canny",
965
+ label="ControlNet Mode",
966
+ info="Canny: preserves edges | Depth: preserves 3D structure"
967
+ )
968
+
969
+ conditioning_scale = gr.Slider(
970
+ minimum=0.05,
971
+ maximum=1.0,
972
+ value=0.7,
973
+ step=0.05,
974
+ label="ControlNet Strength",
975
+ info="Higher = more structure preservation"
976
+ )
977
 
978
+ # General Settings (always visible)
979
+ with gr.Accordion("General Settings", open=True):
980
  feather_radius = gr.Slider(
981
  minimum=0,
982
  maximum=20,
983
  value=8,
984
  step=1,
985
+ label="Feather Radius (px)",
986
+ info="Edge blending softness"
987
  )
988
 
989
  with gr.Accordion("Advanced Settings", open=False):
 
1003
  label="Inference Steps"
1004
  )
1005
 
1006
+ # Seed control for reproducibility
1007
+ seed_input = gr.Number(
1008
+ label="Seed",
1009
+ value=-1,
1010
+ precision=0,
1011
+ info="-1 = random seed, or enter a specific number to reproduce results"
1012
+ )
1013
+
1014
  # Generate button
1015
  inpaint_btn = gr.Button(
1016
  "Generate Inpainting",
 
1027
  border-radius: 8px;
1028
  margin: 12px 0;">
1029
  <p style="margin: 0; color: #5d4037; font-size: 14px;">
1030
+ ⏳ <strong>Please be patient!</strong><br>
1031
+ <strong>First run:</strong> 5-7 minutes (model initialization)<br>
1032
+ <strong>Subsequent runs:</strong> 2-3 minutes (model cached)
1033
  </p>
1034
  </div>
1035
  <div style="background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%);
 
1045
  """
1046
  )
1047
 
1048
+ # Status and Seed display
1049
  inpaint_status = gr.Textbox(
1050
  label="Status",
1051
  value="Ready for inpainting",
1052
  interactive=False
1053
  )
1054
 
1055
+ # Display used seed for reproducibility
1056
+ with gr.Row():
1057
+ used_seed_display = gr.Number(
1058
+ label="Used Seed (copy this to reproduce)",
1059
+ value=-1,
1060
+ precision=0,
1061
+ interactive=False
1062
+ )
1063
+ copy_seed_btn = gr.Button(
1064
+ "📋 Use This Seed",
1065
+ size="sm",
1066
+ scale=0
1067
+ )
1068
+
1069
  # Output row
1070
  with gr.Row():
1071
  with gr.Column(scale=1):
 
1087
  inpaint_template.change(
1088
  fn=self.apply_inpainting_template,
1089
  inputs=[inpaint_template, inpaint_prompt],
1090
+ outputs=[
1091
+ inpaint_prompt,
1092
+ conditioning_scale,
1093
+ feather_radius,
1094
+ conditioning_type,
1095
+ controlnet_settings_group,
1096
+ mode_info_html,
1097
+ model_selection_group
1098
+ ]
1099
  )
1100
 
1101
  inpaint_template.change(
 
1104
  outputs=[template_tips]
1105
  )
1106
 
1107
+ # Copy uploaded image to mask editor (as background)
1108
+ def set_mask_editor_background(image):
1109
+ """Set uploaded image as mask editor background."""
1110
+ if image is None:
1111
+ return None
1112
+ # Return dict format for ImageEditor with background
1113
+ return {"background": image, "layers": [], "composite": None}
1114
+
1115
  inpaint_image.change(
1116
+ fn=set_mask_editor_background,
1117
  inputs=[inpaint_image],
1118
  outputs=[mask_editor]
1119
  )
 
1125
  mask_editor,
1126
  inpaint_prompt,
1127
  inpaint_template,
1128
+ model_choice,
1129
  conditioning_type,
1130
  conditioning_scale,
1131
  feather_radius,
1132
  inpaint_guidance,
1133
+ inpaint_steps,
1134
+ seed_input
1135
  ],
1136
  outputs=[
1137
  inpaint_result,
1138
  inpaint_control,
1139
+ inpaint_status,
1140
+ used_seed_display
1141
  ]
1142
  )
1143
 
1144
+ # Copy seed button - copies used seed to input
1145
+ copy_seed_btn.click(
1146
+ fn=lambda x: x,
1147
+ inputs=[used_seed_display],
1148
+ outputs=[seed_input]
1149
+ )
1150
+
1151
  return tab
1152
 
1153
  def _get_template_tips(self, display_name: str) -> str:
 
1162
  tips = self.inpainting_template_manager.get_usage_tips(template_key)
1163
  if tips:
1164
  return "**Tips:**\n" + "\n".join(f"- {tip}" for tip in tips)
1165
+ return ""