Spaces:
Running
on
Zero
Running
on
Zero
Upload 15 files
Browse files- app.py +1 -51
- control_image_processor.py +392 -0
- gpu_handlers.py +316 -0
- image_blender.py +13 -13
- inpainting_blender.py +485 -0
- inpainting_models.py +398 -0
- inpainting_module.py +335 -1073
- inpainting_templates.py +242 -320
- mask_generator.py +1 -1
- scene_templates.py +6 -6
- scene_weaver_core.py +51 -39
- ui_manager.py +262 -121
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import os
|
| 2 |
import sys
|
| 3 |
import traceback
|
| 4 |
import warnings
|
|
@@ -6,45 +5,6 @@ warnings.filterwarnings("ignore")
|
|
| 6 |
|
| 7 |
from ui_manager import UIManager
|
| 8 |
|
| 9 |
-
def preload_models_to_cache():
|
| 10 |
-
"""
|
| 11 |
-
Pre-download models to HuggingFace cache before GPU allocation.
|
| 12 |
-
This runs on CPU and avoids downloading during @spaces.GPU execution.
|
| 13 |
-
"""
|
| 14 |
-
if not os.getenv('SPACE_ID'):
|
| 15 |
-
return # Skip if not on Spaces
|
| 16 |
-
|
| 17 |
-
print("📦 Pre-downloading models to cache (CPU only, no GPU usage)...")
|
| 18 |
-
|
| 19 |
-
try:
|
| 20 |
-
from diffusers import ControlNetModel
|
| 21 |
-
import torch
|
| 22 |
-
|
| 23 |
-
# Pre-download ControlNet models to cache
|
| 24 |
-
models_to_cache = [
|
| 25 |
-
("diffusers/controlnet-canny-sdxl-1.0", "Canny ControlNet"),
|
| 26 |
-
("diffusers/controlnet-depth-sdxl-1.0", "Depth ControlNet"),
|
| 27 |
-
]
|
| 28 |
-
|
| 29 |
-
for model_id, model_name in models_to_cache:
|
| 30 |
-
print(f" ⬇️ Downloading {model_name} ({model_id})...")
|
| 31 |
-
try:
|
| 32 |
-
_ = ControlNetModel.from_pretrained(
|
| 33 |
-
model_id,
|
| 34 |
-
torch_dtype=torch.float16,
|
| 35 |
-
use_safetensors=True,
|
| 36 |
-
local_files_only=False # Allow download
|
| 37 |
-
)
|
| 38 |
-
print(f" ✅ {model_name} cached")
|
| 39 |
-
except Exception as e:
|
| 40 |
-
print(f" ⚠️ {model_name} download failed (will retry on-demand): {e}")
|
| 41 |
-
|
| 42 |
-
print("✅ Model pre-caching complete")
|
| 43 |
-
|
| 44 |
-
except Exception as e:
|
| 45 |
-
print(f"⚠️ Model pre-caching failed: {e}")
|
| 46 |
-
print(" Models will be downloaded on first use instead.")
|
| 47 |
-
|
| 48 |
def launch_final_blend_sceneweaver(share: bool = True, debug: bool = False):
|
| 49 |
"""Launch SceneWeaver Application"""
|
| 50 |
|
|
@@ -52,9 +12,6 @@ def launch_final_blend_sceneweaver(share: bool = True, debug: bool = False):
|
|
| 52 |
print("✨ AI-Powered Image Background Generation")
|
| 53 |
|
| 54 |
try:
|
| 55 |
-
# Pre-download models on Spaces to avoid downloading during GPU time
|
| 56 |
-
preload_models_to_cache()
|
| 57 |
-
|
| 58 |
# Test imports first
|
| 59 |
print("🔍 Testing imports...")
|
| 60 |
try:
|
|
@@ -63,13 +20,6 @@ def launch_final_blend_sceneweaver(share: bool = True, debug: bool = False):
|
|
| 63 |
ui = UIManager()
|
| 64 |
print("✅ UIManager instance created successfully")
|
| 65 |
|
| 66 |
-
# Note: On Hugging Face Spaces, models are pre-cached at startup
|
| 67 |
-
if os.getenv('SPACE_ID'):
|
| 68 |
-
print("\n🔧 Detected Hugging Face Spaces environment")
|
| 69 |
-
print("⚡ Models pre-cached - ready for fast inference")
|
| 70 |
-
print(" Expected inference time: ~300-350s (with cached models)")
|
| 71 |
-
print()
|
| 72 |
-
|
| 73 |
# Launch UI
|
| 74 |
print("🚀 Launching interface...")
|
| 75 |
interface = ui.launch(share=share, debug=debug)
|
|
@@ -128,4 +78,4 @@ def main():
|
|
| 128 |
raise
|
| 129 |
|
| 130 |
if __name__ == "__main__":
|
| 131 |
-
main()
|
|
|
|
|
|
|
| 1 |
import sys
|
| 2 |
import traceback
|
| 3 |
import warnings
|
|
|
|
| 5 |
|
| 6 |
from ui_manager import UIManager
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
def launch_final_blend_sceneweaver(share: bool = True, debug: bool = False):
|
| 9 |
"""Launch SceneWeaver Application"""
|
| 10 |
|
|
|
|
| 12 |
print("✨ AI-Powered Image Background Generation")
|
| 13 |
|
| 14 |
try:
|
|
|
|
|
|
|
|
|
|
| 15 |
# Test imports first
|
| 16 |
print("🔍 Testing imports...")
|
| 17 |
try:
|
|
|
|
| 20 |
ui = UIManager()
|
| 21 |
print("✅ UIManager instance created successfully")
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# Launch UI
|
| 24 |
print("🚀 Launching interface...")
|
| 25 |
interface = ui.launch(share=share, debug=debug)
|
|
|
|
| 78 |
raise
|
| 79 |
|
| 80 |
if __name__ == "__main__":
|
| 81 |
+
main()
|
control_image_processor.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image, ImageFilter
|
| 8 |
+
|
| 9 |
+
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
|
| 10 |
+
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
logger.setLevel(logging.INFO)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ControlImageProcessor:
|
| 17 |
+
"""
|
| 18 |
+
Generates control images for ControlNet conditioning.
|
| 19 |
+
|
| 20 |
+
Supports Canny edge detection and depth map estimation with
|
| 21 |
+
mask-aware processing for selective structure preservation.
|
| 22 |
+
|
| 23 |
+
Attributes:
|
| 24 |
+
device: Computation device (cuda/mps/cpu)
|
| 25 |
+
canny_low_threshold: Low threshold for Canny edge detection
|
| 26 |
+
canny_high_threshold: High threshold for Canny edge detection
|
| 27 |
+
|
| 28 |
+
Example:
|
| 29 |
+
>>> processor = ControlImageProcessor(device="cuda")
|
| 30 |
+
>>> canny_image = processor.generate_canny_edges(image)
|
| 31 |
+
>>> depth_map = processor.generate_depth_map(image)
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
# Depth model identifiers
|
| 35 |
+
DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf"
|
| 36 |
+
DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas"
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
device: str = "cuda",
|
| 41 |
+
canny_low_threshold: int = 100,
|
| 42 |
+
canny_high_threshold: int = 200
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Initialize the ControlImageProcessor.
|
| 46 |
+
|
| 47 |
+
Parameters
|
| 48 |
+
----------
|
| 49 |
+
device : str
|
| 50 |
+
Computation device
|
| 51 |
+
canny_low_threshold : int
|
| 52 |
+
Low threshold for Canny edge detection
|
| 53 |
+
canny_high_threshold : int
|
| 54 |
+
High threshold for Canny edge detection
|
| 55 |
+
"""
|
| 56 |
+
self.device = device
|
| 57 |
+
self.canny_low_threshold = canny_low_threshold
|
| 58 |
+
self.canny_high_threshold = canny_high_threshold
|
| 59 |
+
|
| 60 |
+
# Depth estimation models (lazy loaded)
|
| 61 |
+
self._depth_estimator = None
|
| 62 |
+
self._depth_processor = None
|
| 63 |
+
self._depth_model_loaded = False
|
| 64 |
+
|
| 65 |
+
logger.info(f"ControlImageProcessor initialized on {device}")
|
| 66 |
+
|
| 67 |
+
def generate_canny_edges(self, image: np.ndarray) -> Image.Image:
|
| 68 |
+
"""
|
| 69 |
+
Generate Canny edge detection image.
|
| 70 |
+
|
| 71 |
+
Parameters
|
| 72 |
+
----------
|
| 73 |
+
image : np.ndarray
|
| 74 |
+
Input image as numpy array (RGB)
|
| 75 |
+
|
| 76 |
+
Returns
|
| 77 |
+
-------
|
| 78 |
+
PIL.Image
|
| 79 |
+
Canny edge image (grayscale)
|
| 80 |
+
"""
|
| 81 |
+
# Convert to grayscale
|
| 82 |
+
if len(image.shape) == 3:
|
| 83 |
+
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 84 |
+
else:
|
| 85 |
+
gray = image
|
| 86 |
+
|
| 87 |
+
# Apply Gaussian blur to reduce noise
|
| 88 |
+
blurred = cv2.GaussianBlur(gray, (5, 5), 1.4)
|
| 89 |
+
|
| 90 |
+
# Canny edge detection
|
| 91 |
+
edges = cv2.Canny(
|
| 92 |
+
blurred,
|
| 93 |
+
self.canny_low_threshold,
|
| 94 |
+
self.canny_high_threshold
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Convert to 3-channel for ControlNet
|
| 98 |
+
edges_3ch = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
|
| 99 |
+
|
| 100 |
+
return Image.fromarray(edges_3ch)
|
| 101 |
+
|
| 102 |
+
def load_depth_estimator(self) -> bool:
|
| 103 |
+
"""
|
| 104 |
+
Load depth estimation model.
|
| 105 |
+
|
| 106 |
+
Returns
|
| 107 |
+
-------
|
| 108 |
+
bool
|
| 109 |
+
True if loaded successfully
|
| 110 |
+
"""
|
| 111 |
+
if self._depth_model_loaded:
|
| 112 |
+
return True
|
| 113 |
+
|
| 114 |
+
logger.info("Loading depth estimation model...")
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# Try primary model first (Depth Anything)
|
| 118 |
+
self._depth_processor = AutoImageProcessor.from_pretrained(
|
| 119 |
+
self.DEPTH_MODEL_PRIMARY
|
| 120 |
+
)
|
| 121 |
+
self._depth_estimator = AutoModelForDepthEstimation.from_pretrained(
|
| 122 |
+
self.DEPTH_MODEL_PRIMARY,
|
| 123 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
| 124 |
+
)
|
| 125 |
+
self._depth_estimator = self._depth_estimator.to(self.device)
|
| 126 |
+
self._depth_estimator.eval()
|
| 127 |
+
self._depth_model_loaded = True
|
| 128 |
+
logger.info(f"Loaded depth model: {self.DEPTH_MODEL_PRIMARY}")
|
| 129 |
+
return True
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.warning(f"Primary depth model failed: {e}, trying fallback...")
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
# Fallback to DPT
|
| 136 |
+
self._depth_processor = DPTImageProcessor.from_pretrained(
|
| 137 |
+
self.DEPTH_MODEL_FALLBACK
|
| 138 |
+
)
|
| 139 |
+
self._depth_estimator = DPTForDepthEstimation.from_pretrained(
|
| 140 |
+
self.DEPTH_MODEL_FALLBACK,
|
| 141 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
| 142 |
+
)
|
| 143 |
+
self._depth_estimator = self._depth_estimator.to(self.device)
|
| 144 |
+
self._depth_estimator.eval()
|
| 145 |
+
self._depth_model_loaded = True
|
| 146 |
+
logger.info(f"Loaded fallback depth model: {self.DEPTH_MODEL_FALLBACK}")
|
| 147 |
+
return True
|
| 148 |
+
|
| 149 |
+
except Exception as e2:
|
| 150 |
+
logger.error(f"All depth models failed: {e2}")
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
def generate_depth_map(self, image: Image.Image) -> Image.Image:
|
| 154 |
+
"""
|
| 155 |
+
Generate depth map using depth estimation model.
|
| 156 |
+
|
| 157 |
+
Parameters
|
| 158 |
+
----------
|
| 159 |
+
image : PIL.Image
|
| 160 |
+
Input image
|
| 161 |
+
|
| 162 |
+
Returns
|
| 163 |
+
-------
|
| 164 |
+
PIL.Image
|
| 165 |
+
Depth map image (grayscale, normalized)
|
| 166 |
+
"""
|
| 167 |
+
if not self._depth_model_loaded:
|
| 168 |
+
if not self.load_depth_estimator():
|
| 169 |
+
# Fallback to simple gradient
|
| 170 |
+
logger.warning("Using fallback gradient depth")
|
| 171 |
+
return self._generate_fallback_depth(image)
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
# Prepare image for model
|
| 175 |
+
inputs = self._depth_processor(
|
| 176 |
+
images=image,
|
| 177 |
+
return_tensors="pt"
|
| 178 |
+
)
|
| 179 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 180 |
+
|
| 181 |
+
# Run inference
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
outputs = self._depth_estimator(**inputs)
|
| 184 |
+
predicted_depth = outputs.predicted_depth
|
| 185 |
+
|
| 186 |
+
# Normalize depth map
|
| 187 |
+
depth = predicted_depth.squeeze().cpu().numpy()
|
| 188 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
|
| 189 |
+
depth = (depth * 255).astype(np.uint8)
|
| 190 |
+
|
| 191 |
+
# Resize to match input
|
| 192 |
+
depth_image = Image.fromarray(depth)
|
| 193 |
+
depth_image = depth_image.resize(image.size, Image.Resampling.BILINEAR)
|
| 194 |
+
|
| 195 |
+
# Convert to 3-channel for ControlNet
|
| 196 |
+
depth_3ch = np.stack([np.array(depth_image)] * 3, axis=-1)
|
| 197 |
+
|
| 198 |
+
return Image.fromarray(depth_3ch)
|
| 199 |
+
|
| 200 |
+
except Exception as e:
|
| 201 |
+
logger.error(f"Depth estimation failed: {e}")
|
| 202 |
+
return self._generate_fallback_depth(image)
|
| 203 |
+
|
| 204 |
+
def _generate_fallback_depth(self, image: Image.Image) -> Image.Image:
|
| 205 |
+
"""
|
| 206 |
+
Generate a simple fallback depth map using gradient.
|
| 207 |
+
|
| 208 |
+
Parameters
|
| 209 |
+
----------
|
| 210 |
+
image : PIL.Image
|
| 211 |
+
Input image
|
| 212 |
+
|
| 213 |
+
Returns
|
| 214 |
+
-------
|
| 215 |
+
PIL.Image
|
| 216 |
+
Simple gradient depth map
|
| 217 |
+
"""
|
| 218 |
+
w, h = image.size
|
| 219 |
+
# Create vertical gradient (top = far, bottom = near)
|
| 220 |
+
gradient = np.linspace(50, 200, h).reshape(-1, 1)
|
| 221 |
+
gradient = np.tile(gradient, (1, w))
|
| 222 |
+
gradient = gradient.astype(np.uint8)
|
| 223 |
+
|
| 224 |
+
# Stack to 3 channels
|
| 225 |
+
depth_3ch = np.stack([gradient] * 3, axis=-1)
|
| 226 |
+
return Image.fromarray(depth_3ch)
|
| 227 |
+
|
| 228 |
+
def prepare_control_image(
|
| 229 |
+
self,
|
| 230 |
+
image: Image.Image,
|
| 231 |
+
mode: str = "canny",
|
| 232 |
+
mask: Optional[Image.Image] = None,
|
| 233 |
+
preserve_structure: bool = False,
|
| 234 |
+
edge_guidance_mode: str = "boundary"
|
| 235 |
+
) -> Image.Image:
|
| 236 |
+
"""
|
| 237 |
+
Generate ControlNet conditioning image.
|
| 238 |
+
|
| 239 |
+
Parameters
|
| 240 |
+
----------
|
| 241 |
+
image : PIL.Image
|
| 242 |
+
Input image
|
| 243 |
+
mode : str
|
| 244 |
+
Conditioning mode: "canny" or "depth"
|
| 245 |
+
mask : PIL.Image, optional
|
| 246 |
+
If provided, can modify edges based on edge_guidance_mode
|
| 247 |
+
preserve_structure : bool
|
| 248 |
+
If True, keep all edges in masked region (for color change tasks)
|
| 249 |
+
If False, use edge_guidance_mode to determine edge handling
|
| 250 |
+
edge_guidance_mode : str
|
| 251 |
+
How to handle edges when preserve_structure=False:
|
| 252 |
+
- "none": Completely remove edges in masked region (removal tasks)
|
| 253 |
+
- "boundary": Keep only boundary edges of masked region (replacement tasks)
|
| 254 |
+
- "soft": Gradually fade edges from boundary (default for better blending)
|
| 255 |
+
|
| 256 |
+
Returns
|
| 257 |
+
-------
|
| 258 |
+
PIL.Image
|
| 259 |
+
Generated control image
|
| 260 |
+
"""
|
| 261 |
+
logger.info(f"Preparing control image: mode={mode}, preserve_structure={preserve_structure}, edge_guidance={edge_guidance_mode}")
|
| 262 |
+
|
| 263 |
+
# Convert to RGB if needed
|
| 264 |
+
if image.mode != 'RGB':
|
| 265 |
+
image = image.convert('RGB')
|
| 266 |
+
|
| 267 |
+
img_array = np.array(image)
|
| 268 |
+
|
| 269 |
+
if mode == "canny":
|
| 270 |
+
control_image = self.generate_canny_edges(img_array)
|
| 271 |
+
|
| 272 |
+
if mask is not None:
|
| 273 |
+
control_array = np.array(control_image)
|
| 274 |
+
mask_array = np.array(mask.convert('L'))
|
| 275 |
+
|
| 276 |
+
if preserve_structure:
|
| 277 |
+
# Keep all edges - no modification needed
|
| 278 |
+
logger.info("Preserving all edges in masked region for color change")
|
| 279 |
+
|
| 280 |
+
elif edge_guidance_mode == "none":
|
| 281 |
+
# Completely suppress edges in masked region (for removal)
|
| 282 |
+
mask_region = mask_array > 128
|
| 283 |
+
control_array[mask_region] = 0
|
| 284 |
+
logger.info("Suppressed all edges in masked region for removal")
|
| 285 |
+
|
| 286 |
+
elif edge_guidance_mode == "mask_outline":
|
| 287 |
+
# For object replacement: clear inside edges, draw clear mask outline
|
| 288 |
+
# Outline guides WHERE and WHAT SIZE the new object should be
|
| 289 |
+
mask_binary = (mask_array > 128).astype(np.uint8) * 255
|
| 290 |
+
|
| 291 |
+
# Step 1: Clear all edges inside the mask
|
| 292 |
+
mask_region = mask_array > 128
|
| 293 |
+
control_array[mask_region] = 0
|
| 294 |
+
|
| 295 |
+
# Step 2: Draw clear mask outline for position/size guidance
|
| 296 |
+
contours, _ = cv2.findContours(
|
| 297 |
+
mask_binary,
|
| 298 |
+
cv2.RETR_EXTERNAL,
|
| 299 |
+
cv2.CHAIN_APPROX_SIMPLE
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
if contours:
|
| 303 |
+
# Draw visible white outline (thickness=2) for clear guidance
|
| 304 |
+
cv2.drawContours(control_array, contours, -1, (255, 255, 255), thickness=2)
|
| 305 |
+
logger.info(f"Drew {len(contours)} mask outline(s) for placement guidance")
|
| 306 |
+
|
| 307 |
+
elif edge_guidance_mode == "boundary":
|
| 308 |
+
# Keep boundary edges to guide object placement and size
|
| 309 |
+
# This helps ControlNet understand WHERE to place the new object
|
| 310 |
+
mask_binary = (mask_array > 128).astype(np.uint8) * 255
|
| 311 |
+
|
| 312 |
+
# Create boundary mask using morphological operations
|
| 313 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
|
| 314 |
+
dilated = cv2.dilate(mask_binary, kernel, iterations=1)
|
| 315 |
+
eroded = cv2.erode(mask_binary, kernel, iterations=1)
|
| 316 |
+
boundary = dilated - eroded
|
| 317 |
+
|
| 318 |
+
# Inner region (not boundary) - suppress edges
|
| 319 |
+
inner_region = (mask_array > 128) & (boundary == 0)
|
| 320 |
+
control_array[inner_region] = 0
|
| 321 |
+
|
| 322 |
+
# Keep boundary edges intact for object placement guidance
|
| 323 |
+
logger.info("Keeping boundary edges for object replacement guidance")
|
| 324 |
+
|
| 325 |
+
elif edge_guidance_mode == "soft":
|
| 326 |
+
# Soft fade: gradually reduce edges from boundary to center
|
| 327 |
+
mask_binary = (mask_array > 128).astype(np.uint8) * 255
|
| 328 |
+
|
| 329 |
+
# Calculate distance from boundary
|
| 330 |
+
dist_transform = cv2.distanceTransform(mask_binary, cv2.DIST_L2, 5)
|
| 331 |
+
max_dist = dist_transform.max()
|
| 332 |
+
if max_dist > 0:
|
| 333 |
+
# Normalize and invert: 1 at boundary, 0 at center
|
| 334 |
+
fade_factor = 1 - (dist_transform / max_dist)
|
| 335 |
+
fade_factor = np.clip(fade_factor, 0, 1)
|
| 336 |
+
|
| 337 |
+
# Apply fade to masked region only
|
| 338 |
+
mask_region = mask_array > 128
|
| 339 |
+
for c in range(3):
|
| 340 |
+
control_array[:, :, c][mask_region] = (
|
| 341 |
+
control_array[:, :, c][mask_region] * fade_factor[mask_region]
|
| 342 |
+
).astype(np.uint8)
|
| 343 |
+
|
| 344 |
+
logger.info("Applied soft edge fading in masked region")
|
| 345 |
+
|
| 346 |
+
control_image = Image.fromarray(control_array)
|
| 347 |
+
|
| 348 |
+
return control_image
|
| 349 |
+
|
| 350 |
+
elif mode == "depth":
|
| 351 |
+
control_image = self.generate_depth_map(image)
|
| 352 |
+
|
| 353 |
+
# For depth mode with replacement, we want to keep depth info for context
|
| 354 |
+
# but allow flexibility in the masked region
|
| 355 |
+
if mask is not None and not preserve_structure:
|
| 356 |
+
control_array = np.array(control_image)
|
| 357 |
+
mask_array = np.array(mask.convert('L'))
|
| 358 |
+
|
| 359 |
+
# Smooth the depth in masked region using surrounding context
|
| 360 |
+
if edge_guidance_mode in ["boundary", "soft"]:
|
| 361 |
+
mask_binary = (mask_array > 128).astype(np.uint8)
|
| 362 |
+
|
| 363 |
+
# Inpaint the depth map in masked region using surrounding values
|
| 364 |
+
depth_gray = control_array[:, :, 0]
|
| 365 |
+
inpainted_depth = cv2.inpaint(
|
| 366 |
+
depth_gray,
|
| 367 |
+
mask_binary,
|
| 368 |
+
inpaintRadius=10,
|
| 369 |
+
flags=cv2.INPAINT_TELEA
|
| 370 |
+
)
|
| 371 |
+
control_array = np.stack([inpainted_depth] * 3, axis=-1)
|
| 372 |
+
logger.info("Inpainted depth map in masked region")
|
| 373 |
+
|
| 374 |
+
control_image = Image.fromarray(control_array)
|
| 375 |
+
|
| 376 |
+
return control_image
|
| 377 |
+
|
| 378 |
+
else:
|
| 379 |
+
raise ValueError(f"Unknown control mode: {mode}")
|
| 380 |
+
|
| 381 |
+
def unload_depth_model(self) -> None:
|
| 382 |
+
"""Unload depth estimation model to free memory."""
|
| 383 |
+
if self._depth_estimator is not None:
|
| 384 |
+
del self._depth_estimator
|
| 385 |
+
self._depth_estimator = None
|
| 386 |
+
|
| 387 |
+
if self._depth_processor is not None:
|
| 388 |
+
del self._depth_processor
|
| 389 |
+
self._depth_processor = None
|
| 390 |
+
|
| 391 |
+
self._depth_model_loaded = False
|
| 392 |
+
logger.info("Depth model unloaded")
|
gpu_handlers.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
from typing import Any, Callable, Dict, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import spaces
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
logger.setLevel(logging.INFO)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GPUHandlers:
|
| 15 |
+
"""
|
| 16 |
+
Handles all GPU-intensive generation operations.
|
| 17 |
+
|
| 18 |
+
This class encapsulates the execution logic for both background generation
|
| 19 |
+
and inpainting operations with proper @spaces.GPU decorator for
|
| 20 |
+
HuggingFace Spaces deployment.
|
| 21 |
+
|
| 22 |
+
Supports dual-mode inpainting:
|
| 23 |
+
- Pure Inpainting (use_controlnet=False): For object replacement/removal
|
| 24 |
+
- ControlNet Inpainting (use_controlnet=True): For clothing/color change
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
core: Any,
|
| 30 |
+
inpainting_template_manager: Any
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Initialize the GPU handlers.
|
| 34 |
+
|
| 35 |
+
Parameters
|
| 36 |
+
----------
|
| 37 |
+
core : SceneWeaverCore
|
| 38 |
+
Main engine instance
|
| 39 |
+
inpainting_template_manager : InpaintingTemplateManager
|
| 40 |
+
Template manager for inpainting
|
| 41 |
+
"""
|
| 42 |
+
self.core = core
|
| 43 |
+
self.inpainting_template_manager = inpainting_template_manager
|
| 44 |
+
logger.info("GPUHandlers initialized")
|
| 45 |
+
|
| 46 |
+
@spaces.GPU(duration=240)
|
| 47 |
+
def background_generate(
|
| 48 |
+
self,
|
| 49 |
+
image: Optional[Image.Image],
|
| 50 |
+
prompt: str,
|
| 51 |
+
negative_prompt: str,
|
| 52 |
+
composition_mode: str,
|
| 53 |
+
focus_mode: str,
|
| 54 |
+
num_steps: int,
|
| 55 |
+
guidance_scale: float,
|
| 56 |
+
progress_callback: Optional[Callable[[str, int], None]] = None
|
| 57 |
+
) -> Dict[str, Any]:
|
| 58 |
+
"""
|
| 59 |
+
Handle background generation request with GPU access.
|
| 60 |
+
|
| 61 |
+
Parameters
|
| 62 |
+
----------
|
| 63 |
+
image : PIL.Image, optional
|
| 64 |
+
Input image
|
| 65 |
+
prompt : str
|
| 66 |
+
Generation prompt
|
| 67 |
+
negative_prompt : str
|
| 68 |
+
Negative prompt
|
| 69 |
+
composition_mode : str
|
| 70 |
+
Composition mode (center, left_half, etc.)
|
| 71 |
+
focus_mode : str
|
| 72 |
+
Focus mode (person, scene)
|
| 73 |
+
num_steps : int
|
| 74 |
+
Number of inference steps
|
| 75 |
+
guidance_scale : float
|
| 76 |
+
Guidance scale
|
| 77 |
+
progress_callback : callable, optional
|
| 78 |
+
Progress update function(message, percentage)
|
| 79 |
+
|
| 80 |
+
Returns
|
| 81 |
+
-------
|
| 82 |
+
dict
|
| 83 |
+
Result dictionary with success status and images
|
| 84 |
+
"""
|
| 85 |
+
if image is None:
|
| 86 |
+
return {"success": False, "error": "Please upload an image first"}
|
| 87 |
+
|
| 88 |
+
if not prompt.strip():
|
| 89 |
+
return {"success": False, "error": "Please enter a prompt"}
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
logger.info(f"Starting background generation: {prompt[:50]}...")
|
| 93 |
+
start_time = time.time()
|
| 94 |
+
|
| 95 |
+
# Initialize if needed
|
| 96 |
+
if not self.core.is_initialized:
|
| 97 |
+
if progress_callback:
|
| 98 |
+
progress_callback("Loading AI models...", 5)
|
| 99 |
+
self.core.load_models(progress_callback=progress_callback)
|
| 100 |
+
|
| 101 |
+
# Generate and combine
|
| 102 |
+
if progress_callback:
|
| 103 |
+
progress_callback("Generating background...", 20)
|
| 104 |
+
|
| 105 |
+
result = self.core.generate_and_combine(
|
| 106 |
+
original_image=image,
|
| 107 |
+
prompt=prompt,
|
| 108 |
+
combination_mode=composition_mode,
|
| 109 |
+
focus_mode=focus_mode,
|
| 110 |
+
negative_prompt=negative_prompt,
|
| 111 |
+
num_inference_steps=num_steps,
|
| 112 |
+
guidance_scale=guidance_scale,
|
| 113 |
+
progress_callback=progress_callback
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
elapsed = time.time() - start_time
|
| 117 |
+
logger.info(f"Background generation complete in {elapsed:.1f}s")
|
| 118 |
+
|
| 119 |
+
return result
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
error_msg = str(e)
|
| 123 |
+
logger.error(f"Background generation error: {error_msg}")
|
| 124 |
+
return {"success": False, "error": error_msg}
|
| 125 |
+
|
| 126 |
+
@spaces.GPU(duration=420)
|
| 127 |
+
def inpainting_generate(
|
| 128 |
+
self,
|
| 129 |
+
image: Optional[Image.Image],
|
| 130 |
+
mask: Optional[Image.Image],
|
| 131 |
+
prompt: str,
|
| 132 |
+
template_key: Optional[str],
|
| 133 |
+
model_key: str,
|
| 134 |
+
conditioning_type: str,
|
| 135 |
+
conditioning_scale: float,
|
| 136 |
+
feather_radius: int,
|
| 137 |
+
guidance_scale: float,
|
| 138 |
+
num_steps: int,
|
| 139 |
+
seed: int = -1,
|
| 140 |
+
progress_callback: Optional[Callable[[str, int], None]] = None
|
| 141 |
+
) -> Tuple[Optional[Image.Image], Optional[Image.Image], str, int]:
|
| 142 |
+
"""
|
| 143 |
+
Handle inpainting request with GPU access.
|
| 144 |
+
|
| 145 |
+
Supports dual-mode operation based on template:
|
| 146 |
+
- Pure Inpainting: For object_replacement, removal
|
| 147 |
+
- ControlNet: For clothing_change, change_color
|
| 148 |
+
|
| 149 |
+
Parameters
|
| 150 |
+
----------
|
| 151 |
+
image : PIL.Image
|
| 152 |
+
Original image to inpaint
|
| 153 |
+
mask : PIL.Image
|
| 154 |
+
Inpainting mask (white = area to regenerate)
|
| 155 |
+
prompt : str
|
| 156 |
+
Inpainting prompt
|
| 157 |
+
template_key : str, optional
|
| 158 |
+
Template key if using a template
|
| 159 |
+
model_key : str
|
| 160 |
+
Model key (juggernaut_xl, realvis_xl, sdxl_base, animagine_xl)
|
| 161 |
+
conditioning_type : str
|
| 162 |
+
ControlNet conditioning type (canny/depth) - only for ControlNet mode
|
| 163 |
+
conditioning_scale : float
|
| 164 |
+
ControlNet conditioning scale
|
| 165 |
+
feather_radius : int
|
| 166 |
+
Mask feather radius
|
| 167 |
+
guidance_scale : float
|
| 168 |
+
Generation guidance scale
|
| 169 |
+
num_steps : int
|
| 170 |
+
Number of inference steps
|
| 171 |
+
seed : int
|
| 172 |
+
Random seed (-1 for random)
|
| 173 |
+
progress_callback : callable, optional
|
| 174 |
+
Progress update function
|
| 175 |
+
|
| 176 |
+
Returns
|
| 177 |
+
-------
|
| 178 |
+
tuple
|
| 179 |
+
(result_image, control_image, status_message, used_seed)
|
| 180 |
+
"""
|
| 181 |
+
if image is None:
|
| 182 |
+
return None, None, "Please upload an image first", -1
|
| 183 |
+
|
| 184 |
+
if mask is None:
|
| 185 |
+
return None, None, "Please draw a mask on the image", -1
|
| 186 |
+
|
| 187 |
+
try:
|
| 188 |
+
logger.info(f"Starting inpainting: prompt='{prompt[:30]}...', template={template_key}")
|
| 189 |
+
start_time = time.time()
|
| 190 |
+
|
| 191 |
+
# Get template parameters
|
| 192 |
+
built_prompt = prompt
|
| 193 |
+
negative_prompt = ""
|
| 194 |
+
template_params = {}
|
| 195 |
+
use_controlnet = True # Default to ControlNet mode
|
| 196 |
+
|
| 197 |
+
if template_key:
|
| 198 |
+
template = self.inpainting_template_manager.get_template(template_key)
|
| 199 |
+
if template:
|
| 200 |
+
# For removal template, use template prompt directly if user prompt is empty
|
| 201 |
+
if template_key == "removal" and not prompt.strip():
|
| 202 |
+
built_prompt = template.prompt_template
|
| 203 |
+
else:
|
| 204 |
+
built_prompt = self.inpainting_template_manager.build_prompt(template_key, prompt)
|
| 205 |
+
negative_prompt = self.inpainting_template_manager.get_negative_prompt(template_key)
|
| 206 |
+
template_params = self.inpainting_template_manager.get_parameters_for_template(template_key)
|
| 207 |
+
use_controlnet = template_params.get("use_controlnet", True)
|
| 208 |
+
logger.info(f"Template: {template_key}, use_controlnet={use_controlnet}")
|
| 209 |
+
|
| 210 |
+
# Build final parameters
|
| 211 |
+
final_params = {
|
| 212 |
+
# Pipeline mode
|
| 213 |
+
"use_controlnet": use_controlnet,
|
| 214 |
+
"mask_dilation": template_params.get("mask_dilation", 0),
|
| 215 |
+
|
| 216 |
+
# ControlNet parameters (only used if use_controlnet=True)
|
| 217 |
+
"conditioning_type": template_params.get("preferred_conditioning", conditioning_type),
|
| 218 |
+
"controlnet_conditioning_scale": template_params.get("controlnet_conditioning_scale", conditioning_scale),
|
| 219 |
+
"preserve_structure_in_mask": template_params.get("preserve_structure_in_mask", False),
|
| 220 |
+
"edge_guidance_mode": template_params.get("edge_guidance_mode", "boundary"),
|
| 221 |
+
|
| 222 |
+
# Generation parameters
|
| 223 |
+
"feather_radius": template_params.get("feather_radius", feather_radius),
|
| 224 |
+
"guidance_scale": template_params.get("guidance_scale", guidance_scale),
|
| 225 |
+
"num_inference_steps": template_params.get("num_inference_steps", num_steps),
|
| 226 |
+
"strength": template_params.get("strength", 0.99),
|
| 227 |
+
"negative_prompt": negative_prompt,
|
| 228 |
+
"seed": seed,
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
# Execute inpainting through core
|
| 232 |
+
result = self.core.execute_inpainting(
|
| 233 |
+
image=image,
|
| 234 |
+
mask=mask,
|
| 235 |
+
prompt=built_prompt,
|
| 236 |
+
model_key=model_key,
|
| 237 |
+
progress_callback=progress_callback,
|
| 238 |
+
**final_params
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
elapsed = time.time() - start_time
|
| 242 |
+
|
| 243 |
+
if result.get('success'):
|
| 244 |
+
mode_str = "Pure Inpainting" if not use_controlnet else "ControlNet"
|
| 245 |
+
# Get the actual seed used from metadata
|
| 246 |
+
used_seed = result.get('metadata', {}).get('seed', seed)
|
| 247 |
+
status = f"Complete ({mode_str}) in {elapsed:.1f}s | Seed: {used_seed}"
|
| 248 |
+
|
| 249 |
+
return (
|
| 250 |
+
result.get('combined_image'),
|
| 251 |
+
result.get('control_image'),
|
| 252 |
+
status,
|
| 253 |
+
used_seed
|
| 254 |
+
)
|
| 255 |
+
else:
|
| 256 |
+
error_msg = result.get('error', 'Unknown error')
|
| 257 |
+
return None, None, f"Error: {error_msg}", -1
|
| 258 |
+
|
| 259 |
+
except Exception as e:
|
| 260 |
+
error_msg = str(e)
|
| 261 |
+
logger.error(f"Inpainting handler error: {e}")
|
| 262 |
+
return None, None, f"Error: {error_msg}", -1
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def extract_mask_from_editor(mask_editor: Dict[str, Any]) -> Optional[Image.Image]:
|
| 266 |
+
"""
|
| 267 |
+
Extract mask from Gradio ImageEditor component.
|
| 268 |
+
|
| 269 |
+
Parameters
|
| 270 |
+
----------
|
| 271 |
+
mask_editor : dict
|
| 272 |
+
ImageEditor output with 'background' and 'layers'
|
| 273 |
+
|
| 274 |
+
Returns
|
| 275 |
+
-------
|
| 276 |
+
PIL.Image or None
|
| 277 |
+
Extracted mask image (L mode)
|
| 278 |
+
"""
|
| 279 |
+
if mask_editor is None:
|
| 280 |
+
return None
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
layers = mask_editor.get("layers", [])
|
| 284 |
+
if not layers:
|
| 285 |
+
return None
|
| 286 |
+
|
| 287 |
+
mask_layer = layers[0]
|
| 288 |
+
if mask_layer is None:
|
| 289 |
+
return None
|
| 290 |
+
|
| 291 |
+
# Convert to numpy array
|
| 292 |
+
if isinstance(mask_layer, Image.Image):
|
| 293 |
+
mask_array = np.array(mask_layer)
|
| 294 |
+
else:
|
| 295 |
+
mask_array = np.array(Image.open(mask_layer))
|
| 296 |
+
|
| 297 |
+
# Handle different formats
|
| 298 |
+
if len(mask_array.shape) == 3:
|
| 299 |
+
if mask_array.shape[2] == 4:
|
| 300 |
+
# RGBA - use alpha channel combined with RGB
|
| 301 |
+
alpha = mask_array[:, :, 3]
|
| 302 |
+
gray = cv2.cvtColor(mask_array[:, :, :3], cv2.COLOR_RGB2GRAY)
|
| 303 |
+
mask_gray = np.maximum(gray, alpha)
|
| 304 |
+
elif mask_array.shape[2] == 3:
|
| 305 |
+
# RGB - convert to grayscale
|
| 306 |
+
mask_gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
|
| 307 |
+
else:
|
| 308 |
+
mask_gray = mask_array[:, :, 0]
|
| 309 |
+
else:
|
| 310 |
+
mask_gray = mask_array
|
| 311 |
+
|
| 312 |
+
return Image.fromarray(mask_gray.astype(np.uint8), mode='L')
|
| 313 |
+
|
| 314 |
+
except Exception as e:
|
| 315 |
+
logger.error(f"Failed to extract mask from editor: {e}")
|
| 316 |
+
return None
|
image_blender.py
CHANGED
|
@@ -483,7 +483,7 @@ class ImageBlender:
|
|
| 483 |
orig_bg_color_lab = cv2.cvtColor(orig_bg_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
|
| 484 |
logger.info(f"🎨 Detected original background color: RGB{tuple(orig_bg_color_rgb)}")
|
| 485 |
|
| 486 |
-
# Remove original background color contamination from foreground
|
| 487 |
orig_array = self._remove_background_color_contamination(
|
| 488 |
orig_array,
|
| 489 |
mask_array,
|
|
@@ -491,7 +491,7 @@ class ImageBlender:
|
|
| 491 |
tolerance=self.BACKGROUND_COLOR_TOLERANCE
|
| 492 |
)
|
| 493 |
|
| 494 |
-
# Redefine trimap, optimized for cartoon characters
|
| 495 |
try:
|
| 496 |
kernel_3x3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 497 |
|
|
@@ -531,7 +531,7 @@ class ImageBlender:
|
|
| 531 |
|
| 532 |
fg_rep_color_lab = cv2.cvtColor(fg_rep_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
|
| 533 |
|
| 534 |
-
# Edge band spill suppression and repair
|
| 535 |
if np.any(ring_zone):
|
| 536 |
# Convert to Lab space
|
| 537 |
orig_lab = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB).astype(np.float32)
|
|
@@ -625,20 +625,20 @@ class ImageBlender:
|
|
| 625 |
delta_a_pass2 = ring_pixels_lab_pass2[:, 1] - orig_bg_color_lab[1]
|
| 626 |
delta_b_pass2 = ring_pixels_lab_pass2[:, 2] - orig_bg_color_lab[2]
|
| 627 |
delta_e_pass2 = np.sqrt(delta_l_pass2**2 + delta_a_pass2**2 + delta_b_pass2**2)
|
| 628 |
-
|
| 629 |
still_contaminated = delta_e_pass2 < (DELTAE_THRESHOLD * 0.8)
|
| 630 |
-
|
| 631 |
if np.any(still_contaminated):
|
| 632 |
# Apply stronger correction to remaining contaminated pixels
|
| 633 |
remaining_pixels = ring_pixels_lab_pass2[still_contaminated]
|
| 634 |
-
|
| 635 |
# More aggressive chroma neutralization
|
| 636 |
remaining_chroma = remaining_pixels[:, 1:3]
|
| 637 |
neutralized_chroma = remaining_chroma * 0.3 + fg_rep_color_lab[1:3] * 0.7
|
| 638 |
-
|
| 639 |
# Stronger luminance matching
|
| 640 |
neutralized_l = remaining_pixels[:, 0] * 0.4 + fg_rep_color_lab[0] * 0.6
|
| 641 |
-
|
| 642 |
ring_pixels_lab_pass2[still_contaminated, 0] = neutralized_l
|
| 643 |
ring_pixels_lab_pass2[still_contaminated, 1:3] = neutralized_chroma
|
| 644 |
orig_lab[ring_zone] = ring_pixels_lab_pass2
|
|
@@ -691,7 +691,7 @@ class ImageBlender:
|
|
| 691 |
orig_linear = srgb_to_linear(orig_array)
|
| 692 |
bg_linear = srgb_to_linear(bg_array)
|
| 693 |
|
| 694 |
-
# Cartoon-optimized Alpha calculation
|
| 695 |
alpha = mask_array.astype(np.float32) / 255.0
|
| 696 |
|
| 697 |
# Core foreground region - fully opaque
|
|
@@ -701,13 +701,13 @@ class ImageBlender:
|
|
| 701 |
alpha[bg_zone] = 0.0
|
| 702 |
|
| 703 |
# [Key Fix] Force pixels with mask≥160 to α=1.0, avoiding white fill areas being limited to 0.9
|
| 704 |
-
high_confidence_pixels = mask_array >= 160
|
| 705 |
alpha[high_confidence_pixels] = 1.0
|
| 706 |
logger.info(f"💯 High confidence pixels set to full opacity: {high_confidence_pixels.sum()}")
|
| 707 |
|
| 708 |
# Ring area can be dehaloed, but doesn't affect already set high confidence pixels
|
| 709 |
ring_without_high_conf = ring_zone & (~high_confidence_pixels)
|
| 710 |
-
alpha[ring_without_high_conf] = np.clip(alpha[ring_without_high_conf], 0.2, 0.9)
|
| 711 |
|
| 712 |
# Retain existing black outline/strong edge protection
|
| 713 |
orig_gray = np.mean(orig_array, axis=2)
|
|
@@ -739,10 +739,10 @@ class ImageBlender:
|
|
| 739 |
result_srgb = linear_to_srgb(result_linear)
|
| 740 |
result_array = (result_srgb * 255).astype(np.uint8)
|
| 741 |
|
| 742 |
-
# Final edge cleanup pass
|
| 743 |
result_array = self._apply_edge_cleanup(result_array, bg_array, alpha)
|
| 744 |
|
| 745 |
-
# Protect core foreground from any background influence
|
| 746 |
# This ensures faces and bodies retain original colors
|
| 747 |
result_array = self._protect_foreground_core(
|
| 748 |
result_array,
|
|
|
|
| 483 |
orig_bg_color_lab = cv2.cvtColor(orig_bg_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
|
| 484 |
logger.info(f"🎨 Detected original background color: RGB{tuple(orig_bg_color_rgb)}")
|
| 485 |
|
| 486 |
+
# Remove original background color contamination from foreground
|
| 487 |
orig_array = self._remove_background_color_contamination(
|
| 488 |
orig_array,
|
| 489 |
mask_array,
|
|
|
|
| 491 |
tolerance=self.BACKGROUND_COLOR_TOLERANCE
|
| 492 |
)
|
| 493 |
|
| 494 |
+
# Redefine trimap, optimized for cartoon characters
|
| 495 |
try:
|
| 496 |
kernel_3x3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 497 |
|
|
|
|
| 531 |
|
| 532 |
fg_rep_color_lab = cv2.cvtColor(fg_rep_color_rgb.reshape(1,1,3), cv2.COLOR_RGB2LAB)[0,0].astype(np.float32)
|
| 533 |
|
| 534 |
+
# Edge band spill suppression and repair
|
| 535 |
if np.any(ring_zone):
|
| 536 |
# Convert to Lab space
|
| 537 |
orig_lab = cv2.cvtColor(orig_array, cv2.COLOR_RGB2LAB).astype(np.float32)
|
|
|
|
| 625 |
delta_a_pass2 = ring_pixels_lab_pass2[:, 1] - orig_bg_color_lab[1]
|
| 626 |
delta_b_pass2 = ring_pixels_lab_pass2[:, 2] - orig_bg_color_lab[2]
|
| 627 |
delta_e_pass2 = np.sqrt(delta_l_pass2**2 + delta_a_pass2**2 + delta_b_pass2**2)
|
| 628 |
+
|
| 629 |
still_contaminated = delta_e_pass2 < (DELTAE_THRESHOLD * 0.8)
|
| 630 |
+
|
| 631 |
if np.any(still_contaminated):
|
| 632 |
# Apply stronger correction to remaining contaminated pixels
|
| 633 |
remaining_pixels = ring_pixels_lab_pass2[still_contaminated]
|
| 634 |
+
|
| 635 |
# More aggressive chroma neutralization
|
| 636 |
remaining_chroma = remaining_pixels[:, 1:3]
|
| 637 |
neutralized_chroma = remaining_chroma * 0.3 + fg_rep_color_lab[1:3] * 0.7
|
| 638 |
+
|
| 639 |
# Stronger luminance matching
|
| 640 |
neutralized_l = remaining_pixels[:, 0] * 0.4 + fg_rep_color_lab[0] * 0.6
|
| 641 |
+
|
| 642 |
ring_pixels_lab_pass2[still_contaminated, 0] = neutralized_l
|
| 643 |
ring_pixels_lab_pass2[still_contaminated, 1:3] = neutralized_chroma
|
| 644 |
orig_lab[ring_zone] = ring_pixels_lab_pass2
|
|
|
|
| 691 |
orig_linear = srgb_to_linear(orig_array)
|
| 692 |
bg_linear = srgb_to_linear(bg_array)
|
| 693 |
|
| 694 |
+
# Cartoon-optimized Alpha calculation
|
| 695 |
alpha = mask_array.astype(np.float32) / 255.0
|
| 696 |
|
| 697 |
# Core foreground region - fully opaque
|
|
|
|
| 701 |
alpha[bg_zone] = 0.0
|
| 702 |
|
| 703 |
# [Key Fix] Force pixels with mask≥160 to α=1.0, avoiding white fill areas being limited to 0.9
|
| 704 |
+
high_confidence_pixels = mask_array >= 160
|
| 705 |
alpha[high_confidence_pixels] = 1.0
|
| 706 |
logger.info(f"💯 High confidence pixels set to full opacity: {high_confidence_pixels.sum()}")
|
| 707 |
|
| 708 |
# Ring area can be dehaloed, but doesn't affect already set high confidence pixels
|
| 709 |
ring_without_high_conf = ring_zone & (~high_confidence_pixels)
|
| 710 |
+
alpha[ring_without_high_conf] = np.clip(alpha[ring_without_high_conf], 0.2, 0.9)
|
| 711 |
|
| 712 |
# Retain existing black outline/strong edge protection
|
| 713 |
orig_gray = np.mean(orig_array, axis=2)
|
|
|
|
| 739 |
result_srgb = linear_to_srgb(result_linear)
|
| 740 |
result_array = (result_srgb * 255).astype(np.uint8)
|
| 741 |
|
| 742 |
+
# Final edge cleanup pass
|
| 743 |
result_array = self._apply_edge_cleanup(result_array, bg_array, alpha)
|
| 744 |
|
| 745 |
+
# Protect core foreground from any background influence
|
| 746 |
# This ensures faces and bodies retain original colors
|
| 747 |
result_array = self._protect_foreground_core(
|
| 748 |
result_array,
|
inpainting_blender.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Dict, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
logger.setLevel(logging.INFO)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class InpaintingBlender:
|
| 13 |
+
"""
|
| 14 |
+
Handles mask processing, prompt enhancement, and result blending for inpainting.
|
| 15 |
+
|
| 16 |
+
This class encapsulates all pre-processing and post-processing operations
|
| 17 |
+
needed for inpainting, separate from the main generation pipeline.
|
| 18 |
+
|
| 19 |
+
Attributes:
|
| 20 |
+
min_mask_coverage: Minimum mask coverage threshold
|
| 21 |
+
max_mask_coverage: Maximum mask coverage threshold
|
| 22 |
+
|
| 23 |
+
Example:
|
| 24 |
+
>>> blender = InpaintingBlender()
|
| 25 |
+
>>> processed_mask, info = blender.prepare_mask(mask, (512, 512), feather_radius=8)
|
| 26 |
+
>>> enhanced_prompt, negative = blender.enhance_prompt("a flower", image, mask)
|
| 27 |
+
>>> result = blender.blend_result(original, generated, mask)
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
min_mask_coverage: float = 0.01,
|
| 33 |
+
max_mask_coverage: float = 0.95
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Initialize the InpaintingBlender.
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
min_mask_coverage : float
|
| 41 |
+
Minimum mask coverage (default: 1%)
|
| 42 |
+
max_mask_coverage : float
|
| 43 |
+
Maximum mask coverage (default: 95%)
|
| 44 |
+
"""
|
| 45 |
+
self.min_mask_coverage = min_mask_coverage
|
| 46 |
+
self.max_mask_coverage = max_mask_coverage
|
| 47 |
+
logger.info("InpaintingBlender initialized")
|
| 48 |
+
|
| 49 |
+
def prepare_mask(
|
| 50 |
+
self,
|
| 51 |
+
mask: Image.Image,
|
| 52 |
+
target_size: Tuple[int, int],
|
| 53 |
+
feather_radius: int = 8
|
| 54 |
+
) -> Tuple[Image.Image, Dict[str, Any]]:
|
| 55 |
+
"""
|
| 56 |
+
Prepare and validate mask for inpainting.
|
| 57 |
+
|
| 58 |
+
Parameters
|
| 59 |
+
----------
|
| 60 |
+
mask : PIL.Image
|
| 61 |
+
Input mask (white = inpaint area)
|
| 62 |
+
target_size : tuple
|
| 63 |
+
Target (width, height) to match input image
|
| 64 |
+
feather_radius : int
|
| 65 |
+
Feathering radius in pixels
|
| 66 |
+
|
| 67 |
+
Returns
|
| 68 |
+
-------
|
| 69 |
+
tuple
|
| 70 |
+
(processed_mask, validation_info)
|
| 71 |
+
|
| 72 |
+
Raises
|
| 73 |
+
------
|
| 74 |
+
ValueError
|
| 75 |
+
If mask coverage is outside acceptable range
|
| 76 |
+
"""
|
| 77 |
+
# Convert to grayscale
|
| 78 |
+
if mask.mode != 'L':
|
| 79 |
+
mask = mask.convert('L')
|
| 80 |
+
|
| 81 |
+
# Resize to match target
|
| 82 |
+
if mask.size != target_size:
|
| 83 |
+
mask = mask.resize(target_size, Image.LANCZOS)
|
| 84 |
+
|
| 85 |
+
# Convert to array for processing
|
| 86 |
+
mask_array = np.array(mask)
|
| 87 |
+
|
| 88 |
+
# Calculate coverage
|
| 89 |
+
total_pixels = mask_array.size
|
| 90 |
+
white_pixels = np.count_nonzero(mask_array > 127)
|
| 91 |
+
coverage = white_pixels / total_pixels
|
| 92 |
+
|
| 93 |
+
validation_info = {
|
| 94 |
+
"coverage": coverage,
|
| 95 |
+
"white_pixels": white_pixels,
|
| 96 |
+
"total_pixels": total_pixels,
|
| 97 |
+
"feather_radius": feather_radius,
|
| 98 |
+
"valid": True,
|
| 99 |
+
"warning": ""
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# Validate coverage
|
| 103 |
+
if coverage < self.min_mask_coverage:
|
| 104 |
+
validation_info["valid"] = False
|
| 105 |
+
validation_info["warning"] = (
|
| 106 |
+
f"Mask coverage too low ({coverage:.1%}). "
|
| 107 |
+
f"Please select a larger area to inpaint."
|
| 108 |
+
)
|
| 109 |
+
logger.warning(f"Mask coverage {coverage:.1%} below minimum {self.min_mask_coverage:.1%}")
|
| 110 |
+
|
| 111 |
+
elif coverage > self.max_mask_coverage:
|
| 112 |
+
validation_info["valid"] = False
|
| 113 |
+
validation_info["warning"] = (
|
| 114 |
+
f"Mask coverage too high ({coverage:.1%}). "
|
| 115 |
+
f"Consider using background generation instead."
|
| 116 |
+
)
|
| 117 |
+
logger.warning(f"Mask coverage {coverage:.1%} above maximum {self.max_mask_coverage:.1%}")
|
| 118 |
+
|
| 119 |
+
# Apply feathering
|
| 120 |
+
if feather_radius > 0:
|
| 121 |
+
mask_array = cv2.GaussianBlur(
|
| 122 |
+
mask_array,
|
| 123 |
+
(feather_radius * 2 + 1, feather_radius * 2 + 1),
|
| 124 |
+
feather_radius / 2
|
| 125 |
+
)
|
| 126 |
+
logger.debug(f"Applied {feather_radius}px feathering to mask")
|
| 127 |
+
|
| 128 |
+
processed_mask = Image.fromarray(mask_array, mode='L')
|
| 129 |
+
|
| 130 |
+
return processed_mask, validation_info
|
| 131 |
+
|
| 132 |
+
def enhance_prompt_for_inpainting(
|
| 133 |
+
self,
|
| 134 |
+
prompt: str,
|
| 135 |
+
image: Image.Image,
|
| 136 |
+
mask: Image.Image
|
| 137 |
+
) -> Tuple[str, str]:
|
| 138 |
+
"""
|
| 139 |
+
Enhance prompt based on non-masked region analysis.
|
| 140 |
+
|
| 141 |
+
Analyzes the surrounding context to generate appropriate
|
| 142 |
+
lighting and color descriptors.
|
| 143 |
+
|
| 144 |
+
Parameters
|
| 145 |
+
----------
|
| 146 |
+
prompt : str
|
| 147 |
+
User-provided prompt
|
| 148 |
+
image : PIL.Image
|
| 149 |
+
Original image
|
| 150 |
+
mask : PIL.Image
|
| 151 |
+
Inpainting mask
|
| 152 |
+
|
| 153 |
+
Returns
|
| 154 |
+
-------
|
| 155 |
+
tuple
|
| 156 |
+
(enhanced_prompt, negative_prompt)
|
| 157 |
+
"""
|
| 158 |
+
logger.info("Enhancing prompt for inpainting context...")
|
| 159 |
+
|
| 160 |
+
# Convert to arrays
|
| 161 |
+
img_array = np.array(image.convert('RGB'))
|
| 162 |
+
mask_array = np.array(mask.convert('L'))
|
| 163 |
+
|
| 164 |
+
# Analyze non-masked regions
|
| 165 |
+
non_masked = mask_array < 127
|
| 166 |
+
|
| 167 |
+
if not np.any(non_masked):
|
| 168 |
+
# No context available
|
| 169 |
+
enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic"
|
| 170 |
+
negative_prompt = self._get_inpainting_negative_prompt()
|
| 171 |
+
return enhanced_prompt, negative_prompt
|
| 172 |
+
|
| 173 |
+
# Extract context pixels
|
| 174 |
+
context_pixels = img_array[non_masked]
|
| 175 |
+
|
| 176 |
+
# Convert to Lab for analysis
|
| 177 |
+
context_lab = cv2.cvtColor(
|
| 178 |
+
context_pixels.reshape(-1, 1, 3),
|
| 179 |
+
cv2.COLOR_RGB2LAB
|
| 180 |
+
).reshape(-1, 3)
|
| 181 |
+
|
| 182 |
+
# Use robust statistics (median) to avoid outlier influence
|
| 183 |
+
median_l = np.median(context_lab[:, 0])
|
| 184 |
+
median_b = np.median(context_lab[:, 2])
|
| 185 |
+
|
| 186 |
+
# Analyze lighting conditions
|
| 187 |
+
lighting_descriptors = []
|
| 188 |
+
|
| 189 |
+
if median_l > 170:
|
| 190 |
+
lighting_descriptors.append("bright")
|
| 191 |
+
elif median_l > 130:
|
| 192 |
+
lighting_descriptors.append("well-lit")
|
| 193 |
+
elif median_l > 80:
|
| 194 |
+
lighting_descriptors.append("moderate lighting")
|
| 195 |
+
else:
|
| 196 |
+
lighting_descriptors.append("dim lighting")
|
| 197 |
+
|
| 198 |
+
# Analyze color temperature (b channel: blue(-) to yellow(+))
|
| 199 |
+
if median_b > 140:
|
| 200 |
+
lighting_descriptors.append("warm golden tones")
|
| 201 |
+
elif median_b > 120:
|
| 202 |
+
lighting_descriptors.append("warm afternoon light")
|
| 203 |
+
elif median_b < 110:
|
| 204 |
+
lighting_descriptors.append("cool neutral tones")
|
| 205 |
+
|
| 206 |
+
# Calculate saturation from context
|
| 207 |
+
hsv = cv2.cvtColor(context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV)
|
| 208 |
+
median_saturation = np.median(hsv[:, :, 1])
|
| 209 |
+
|
| 210 |
+
if median_saturation > 150:
|
| 211 |
+
lighting_descriptors.append("vibrant colors")
|
| 212 |
+
elif median_saturation < 80:
|
| 213 |
+
lighting_descriptors.append("subtle muted colors")
|
| 214 |
+
|
| 215 |
+
# Build enhanced prompt
|
| 216 |
+
lighting_desc = ", ".join(lighting_descriptors) if lighting_descriptors else ""
|
| 217 |
+
quality_suffix = "high quality, detailed, photorealistic, seamless integration"
|
| 218 |
+
|
| 219 |
+
if lighting_desc:
|
| 220 |
+
enhanced_prompt = f"{prompt}, {lighting_desc}, {quality_suffix}"
|
| 221 |
+
else:
|
| 222 |
+
enhanced_prompt = f"{prompt}, {quality_suffix}"
|
| 223 |
+
|
| 224 |
+
negative_prompt = self._get_inpainting_negative_prompt()
|
| 225 |
+
|
| 226 |
+
logger.info(f"Enhanced prompt with context: {lighting_desc}")
|
| 227 |
+
|
| 228 |
+
return enhanced_prompt, negative_prompt
|
| 229 |
+
|
| 230 |
+
def _get_inpainting_negative_prompt(self) -> str:
|
| 231 |
+
"""Get standard negative prompt for inpainting."""
|
| 232 |
+
return (
|
| 233 |
+
"inconsistent lighting, wrong perspective, mismatched colors, "
|
| 234 |
+
"visible seams, blending artifacts, color bleeding, "
|
| 235 |
+
"blurry, low quality, distorted, deformed, "
|
| 236 |
+
"harsh edges, unnatural transition"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
def blend_result(
|
| 240 |
+
self,
|
| 241 |
+
original: Image.Image,
|
| 242 |
+
generated: Image.Image,
|
| 243 |
+
mask: Image.Image
|
| 244 |
+
) -> Image.Image:
|
| 245 |
+
"""
|
| 246 |
+
Blend generated content with original image.
|
| 247 |
+
|
| 248 |
+
Uses color matching and linear color space blending for seamless results.
|
| 249 |
+
|
| 250 |
+
Parameters
|
| 251 |
+
----------
|
| 252 |
+
original : PIL.Image
|
| 253 |
+
Original image
|
| 254 |
+
generated : PIL.Image
|
| 255 |
+
Generated inpainted image
|
| 256 |
+
mask : PIL.Image
|
| 257 |
+
Blending mask (white = use generated)
|
| 258 |
+
|
| 259 |
+
Returns
|
| 260 |
+
-------
|
| 261 |
+
PIL.Image
|
| 262 |
+
Blended result
|
| 263 |
+
"""
|
| 264 |
+
logger.info("Blending inpainting result with color matching...")
|
| 265 |
+
|
| 266 |
+
# Ensure same size
|
| 267 |
+
if generated.size != original.size:
|
| 268 |
+
generated = generated.resize(original.size, Image.LANCZOS)
|
| 269 |
+
if mask.size != original.size:
|
| 270 |
+
mask = mask.resize(original.size, Image.LANCZOS)
|
| 271 |
+
|
| 272 |
+
# Convert to arrays
|
| 273 |
+
orig_array = np.array(original.convert('RGB')).astype(np.float32)
|
| 274 |
+
gen_array = np.array(generated.convert('RGB')).astype(np.float32)
|
| 275 |
+
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
|
| 276 |
+
|
| 277 |
+
# Apply color matching to generated region (use original mask for accurate boundary detection)
|
| 278 |
+
gen_array = self._match_colors_at_boundary(orig_array, gen_array, mask_array)
|
| 279 |
+
|
| 280 |
+
# Create blend mask: soften edges ONLY for blending (not for generation)
|
| 281 |
+
# This ensures full generation coverage while smooth blending at edges
|
| 282 |
+
blend_mask = self._create_blend_mask(mask_array)
|
| 283 |
+
|
| 284 |
+
# sRGB to linear conversion
|
| 285 |
+
def srgb_to_linear(img: np.ndarray) -> np.ndarray:
|
| 286 |
+
img_norm = img / 255.0
|
| 287 |
+
return np.where(
|
| 288 |
+
img_norm <= 0.04045,
|
| 289 |
+
img_norm / 12.92,
|
| 290 |
+
np.power((img_norm + 0.055) / 1.055, 2.4)
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
def linear_to_srgb(img: np.ndarray) -> np.ndarray:
|
| 294 |
+
img_clipped = np.clip(img, 0, 1)
|
| 295 |
+
return np.where(
|
| 296 |
+
img_clipped <= 0.0031308,
|
| 297 |
+
12.92 * img_clipped,
|
| 298 |
+
1.055 * np.power(img_clipped, 1/2.4) - 0.055
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Convert to linear space
|
| 302 |
+
orig_linear = srgb_to_linear(orig_array)
|
| 303 |
+
gen_linear = srgb_to_linear(gen_array)
|
| 304 |
+
|
| 305 |
+
# Alpha blending in linear space using the blend mask (with softened edges)
|
| 306 |
+
alpha = blend_mask[:, :, np.newaxis]
|
| 307 |
+
result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
|
| 308 |
+
|
| 309 |
+
# Convert back to sRGB
|
| 310 |
+
result_srgb = linear_to_srgb(result_linear)
|
| 311 |
+
result_array = (result_srgb * 255).astype(np.uint8)
|
| 312 |
+
|
| 313 |
+
logger.debug("Blending completed with color matching")
|
| 314 |
+
|
| 315 |
+
return Image.fromarray(result_array)
|
| 316 |
+
|
| 317 |
+
def _match_colors_at_boundary(
|
| 318 |
+
self,
|
| 319 |
+
original: np.ndarray,
|
| 320 |
+
generated: np.ndarray,
|
| 321 |
+
mask: np.ndarray
|
| 322 |
+
) -> np.ndarray:
|
| 323 |
+
"""
|
| 324 |
+
Match colors of generated content to original at the boundary.
|
| 325 |
+
|
| 326 |
+
Uses histogram matching in Lab color space for natural blending.
|
| 327 |
+
|
| 328 |
+
Parameters
|
| 329 |
+
----------
|
| 330 |
+
original : np.ndarray
|
| 331 |
+
Original image array (float32, 0-255)
|
| 332 |
+
generated : np.ndarray
|
| 333 |
+
Generated image array (float32, 0-255)
|
| 334 |
+
mask : np.ndarray
|
| 335 |
+
Mask array (float32, 0-1)
|
| 336 |
+
|
| 337 |
+
Returns
|
| 338 |
+
-------
|
| 339 |
+
np.ndarray
|
| 340 |
+
Color-matched generated image
|
| 341 |
+
"""
|
| 342 |
+
# Create boundary region mask (dilated mask - eroded mask)
|
| 343 |
+
mask_binary = (mask > 0.5).astype(np.uint8) * 255
|
| 344 |
+
|
| 345 |
+
# Create narrow boundary region for sampling original colors
|
| 346 |
+
kernel_size = 25 # Pixels to sample around boundary
|
| 347 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
| 348 |
+
dilated = cv2.dilate(mask_binary, kernel, iterations=1)
|
| 349 |
+
eroded = cv2.erode(mask_binary, kernel, iterations=1)
|
| 350 |
+
|
| 351 |
+
# Outer boundary (original side)
|
| 352 |
+
outer_boundary = (dilated > 0) & (mask_binary == 0)
|
| 353 |
+
# Inner boundary (generated side)
|
| 354 |
+
inner_boundary = (mask_binary > 0) & (eroded == 0)
|
| 355 |
+
|
| 356 |
+
if not np.any(outer_boundary) or not np.any(inner_boundary):
|
| 357 |
+
logger.debug("No boundary region found, skipping color matching")
|
| 358 |
+
return generated
|
| 359 |
+
|
| 360 |
+
# Convert to Lab color space
|
| 361 |
+
orig_lab = cv2.cvtColor(original.astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 362 |
+
gen_lab = cv2.cvtColor(generated.astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 363 |
+
|
| 364 |
+
# Sample colors from boundary regions
|
| 365 |
+
orig_boundary_pixels = orig_lab[outer_boundary]
|
| 366 |
+
gen_boundary_pixels = gen_lab[inner_boundary]
|
| 367 |
+
|
| 368 |
+
if len(orig_boundary_pixels) < 10 or len(gen_boundary_pixels) < 10:
|
| 369 |
+
logger.debug("Not enough boundary pixels, skipping color matching")
|
| 370 |
+
return generated
|
| 371 |
+
|
| 372 |
+
# Calculate statistics
|
| 373 |
+
orig_mean = np.mean(orig_boundary_pixels, axis=0)
|
| 374 |
+
orig_std = np.std(orig_boundary_pixels, axis=0) + 1e-6
|
| 375 |
+
|
| 376 |
+
gen_mean = np.mean(gen_boundary_pixels, axis=0)
|
| 377 |
+
gen_std = np.std(gen_boundary_pixels, axis=0) + 1e-6
|
| 378 |
+
|
| 379 |
+
# Calculate correction factors
|
| 380 |
+
# Only correct L (lightness) and a,b (color) channels
|
| 381 |
+
l_correction = (orig_mean[0] - gen_mean[0]) * 0.7 # 70% correction for lightness
|
| 382 |
+
a_correction = (orig_mean[1] - gen_mean[1]) * 0.5 # 50% correction for color
|
| 383 |
+
b_correction = (orig_mean[2] - gen_mean[2]) * 0.5
|
| 384 |
+
|
| 385 |
+
logger.debug(f"Color correction: L={l_correction:.1f}, a={a_correction:.1f}, b={b_correction:.1f}")
|
| 386 |
+
|
| 387 |
+
# Apply correction to masked region only
|
| 388 |
+
corrected_lab = gen_lab.copy()
|
| 389 |
+
mask_region = mask > 0.3 # Apply to most of masked region
|
| 390 |
+
|
| 391 |
+
corrected_lab[mask_region, 0] = np.clip(
|
| 392 |
+
corrected_lab[mask_region, 0] + l_correction, 0, 255
|
| 393 |
+
)
|
| 394 |
+
corrected_lab[mask_region, 1] = np.clip(
|
| 395 |
+
corrected_lab[mask_region, 1] + a_correction, 0, 255
|
| 396 |
+
)
|
| 397 |
+
corrected_lab[mask_region, 2] = np.clip(
|
| 398 |
+
corrected_lab[mask_region, 2] + b_correction, 0, 255
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
# Convert back to RGB
|
| 402 |
+
corrected_rgb = cv2.cvtColor(
|
| 403 |
+
corrected_lab.astype(np.uint8),
|
| 404 |
+
cv2.COLOR_LAB2RGB
|
| 405 |
+
).astype(np.float32)
|
| 406 |
+
|
| 407 |
+
logger.info("Applied boundary color matching")
|
| 408 |
+
|
| 409 |
+
return corrected_rgb
|
| 410 |
+
|
| 411 |
+
def _create_blend_mask(self, mask: np.ndarray) -> np.ndarray:
|
| 412 |
+
"""
|
| 413 |
+
Create a blend mask with softened edges for natural compositing.
|
| 414 |
+
|
| 415 |
+
The mask interior stays fully opaque (1.0) while only the edges
|
| 416 |
+
get a smooth transition. This preserves full generated content
|
| 417 |
+
while blending naturally at boundaries.
|
| 418 |
+
|
| 419 |
+
Parameters
|
| 420 |
+
----------
|
| 421 |
+
mask : np.ndarray
|
| 422 |
+
Original mask array (float32, 0-1)
|
| 423 |
+
|
| 424 |
+
Returns
|
| 425 |
+
-------
|
| 426 |
+
np.ndarray
|
| 427 |
+
Blend mask with soft edges but solid interior
|
| 428 |
+
"""
|
| 429 |
+
# Convert to uint8 for morphological operations
|
| 430 |
+
mask_uint8 = (mask * 255).astype(np.uint8)
|
| 431 |
+
|
| 432 |
+
# Create eroded version (solid interior)
|
| 433 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
|
| 434 |
+
eroded = cv2.erode(mask_uint8, kernel, iterations=1)
|
| 435 |
+
|
| 436 |
+
# Create smooth transition zone at edges only
|
| 437 |
+
# Blur the original mask for edge softness
|
| 438 |
+
blurred = cv2.GaussianBlur(mask_uint8, (15, 15), 4)
|
| 439 |
+
|
| 440 |
+
# Combine: use eroded (solid) for interior, blurred for edges
|
| 441 |
+
# Where eroded > 0, use full opacity; elsewhere use blurred transition
|
| 442 |
+
result = np.where(eroded > 128, mask_uint8, blurred)
|
| 443 |
+
|
| 444 |
+
# Final light smoothing
|
| 445 |
+
result = cv2.GaussianBlur(result, (5, 5), 1)
|
| 446 |
+
|
| 447 |
+
# Convert back to float
|
| 448 |
+
blend_mask = result.astype(np.float32) / 255.0
|
| 449 |
+
|
| 450 |
+
logger.debug("Created blend mask with soft edges and solid interior")
|
| 451 |
+
|
| 452 |
+
return blend_mask
|
| 453 |
+
|
| 454 |
+
def validate_inputs(
|
| 455 |
+
self,
|
| 456 |
+
image: Image.Image,
|
| 457 |
+
mask: Image.Image
|
| 458 |
+
) -> Tuple[bool, str]:
|
| 459 |
+
"""
|
| 460 |
+
Validate image and mask inputs before processing.
|
| 461 |
+
|
| 462 |
+
Parameters
|
| 463 |
+
----------
|
| 464 |
+
image : PIL.Image
|
| 465 |
+
Input image
|
| 466 |
+
mask : PIL.Image
|
| 467 |
+
Input mask
|
| 468 |
+
|
| 469 |
+
Returns
|
| 470 |
+
-------
|
| 471 |
+
tuple
|
| 472 |
+
(is_valid, error_message)
|
| 473 |
+
"""
|
| 474 |
+
if image is None:
|
| 475 |
+
return False, "No image provided"
|
| 476 |
+
|
| 477 |
+
if mask is None:
|
| 478 |
+
return False, "No mask provided"
|
| 479 |
+
|
| 480 |
+
# Check sizes match
|
| 481 |
+
if image.size != mask.size:
|
| 482 |
+
# Will be resized later, so just log a warning
|
| 483 |
+
logger.warning(f"Image size {image.size} != mask size {mask.size}, will resize")
|
| 484 |
+
|
| 485 |
+
return True, ""
|
inpainting_models.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import logging
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import Any, Dict, Optional, Tuple
|
| 6 |
+
from diffusers import StableDiffusionXLControlNetInpaintPipeline
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
logger.setLevel(logging.INFO)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ImageMode(Enum):
|
| 14 |
+
"""Image style modes for model selection."""
|
| 15 |
+
PHOTO = "photo"
|
| 16 |
+
ANIME = "anime"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class ModelConfig:
|
| 21 |
+
"""Configuration for an inpainting model."""
|
| 22 |
+
|
| 23 |
+
model_id: str
|
| 24 |
+
name: str
|
| 25 |
+
description: str
|
| 26 |
+
mode: ImageMode
|
| 27 |
+
requires_variant: bool = True
|
| 28 |
+
variant: str = "fp16"
|
| 29 |
+
recommended_for: str = ""
|
| 30 |
+
|
| 31 |
+
# Model-specific settings
|
| 32 |
+
default_guidance_scale: float = 7.5
|
| 33 |
+
default_num_inference_steps: int = 25
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class InpaintingModelManager:
|
| 37 |
+
"""
|
| 38 |
+
Manages multiple inpainting models for different image styles.
|
| 39 |
+
|
| 40 |
+
Provides lazy loading and switching between models optimized for
|
| 41 |
+
photorealistic images vs anime/illustration styles.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
AVAILABLE_MODELS: Dictionary of all supported models
|
| 45 |
+
current_model: Currently loaded model identifier
|
| 46 |
+
|
| 47 |
+
Example:
|
| 48 |
+
>>> manager = InpaintingModelManager(device="cuda")
|
| 49 |
+
>>> pipeline = manager.get_pipeline(ImageMode.PHOTO)
|
| 50 |
+
>>> # Use pipeline for inpainting
|
| 51 |
+
>>> manager.switch_model(ImageMode.ANIME)
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
# Available models configuration
|
| 55 |
+
AVAILABLE_MODELS: Dict[str, ModelConfig] = {
|
| 56 |
+
# Photo-realistic models
|
| 57 |
+
"juggernaut_xl": ModelConfig(
|
| 58 |
+
model_id="RunDiffusion/Juggernaut-XL-v9",
|
| 59 |
+
name="JuggernautXL v9",
|
| 60 |
+
description="Best for photorealistic images, portraits, and real photos",
|
| 61 |
+
mode=ImageMode.PHOTO,
|
| 62 |
+
requires_variant=True,
|
| 63 |
+
variant="fp16",
|
| 64 |
+
recommended_for="Real photos, portraits, professional photography",
|
| 65 |
+
default_guidance_scale=7.0,
|
| 66 |
+
default_num_inference_steps=25
|
| 67 |
+
),
|
| 68 |
+
"realvis_xl": ModelConfig(
|
| 69 |
+
model_id="SG161222/RealVisXL_V4.0",
|
| 70 |
+
name="RealVisXL v4",
|
| 71 |
+
description="Excellent for realistic images with fine details",
|
| 72 |
+
mode=ImageMode.PHOTO,
|
| 73 |
+
requires_variant=True,
|
| 74 |
+
variant="fp16",
|
| 75 |
+
recommended_for="Realistic scenes, product photos, nature",
|
| 76 |
+
default_guidance_scale=7.0,
|
| 77 |
+
default_num_inference_steps=25
|
| 78 |
+
),
|
| 79 |
+
# Anime/Illustration models
|
| 80 |
+
"sdxl_base": ModelConfig(
|
| 81 |
+
model_id="stabilityai/stable-diffusion-xl-base-1.0",
|
| 82 |
+
name="SDXL Base",
|
| 83 |
+
description="Versatile model for general use and illustrations",
|
| 84 |
+
mode=ImageMode.ANIME,
|
| 85 |
+
requires_variant=True,
|
| 86 |
+
variant="fp16",
|
| 87 |
+
recommended_for="General illustrations, digital art, versatile use",
|
| 88 |
+
default_guidance_scale=7.5,
|
| 89 |
+
default_num_inference_steps=25
|
| 90 |
+
),
|
| 91 |
+
"animagine_xl": ModelConfig(
|
| 92 |
+
model_id="cagliostrolab/animagine-xl-3.1",
|
| 93 |
+
name="Animagine XL 3.1",
|
| 94 |
+
description="Specialized for anime and manga style images",
|
| 95 |
+
mode=ImageMode.ANIME,
|
| 96 |
+
requires_variant=False,
|
| 97 |
+
recommended_for="Anime, manga, cartoon style images",
|
| 98 |
+
default_guidance_scale=7.0,
|
| 99 |
+
default_num_inference_steps=25
|
| 100 |
+
),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Default model for each mode
|
| 104 |
+
DEFAULT_MODELS = {
|
| 105 |
+
ImageMode.PHOTO: "juggernaut_xl",
|
| 106 |
+
ImageMode.ANIME: "sdxl_base"
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
def __init__(self, device: Optional[str] = None):
|
| 110 |
+
"""
|
| 111 |
+
Initialize the model manager.
|
| 112 |
+
|
| 113 |
+
Parameters
|
| 114 |
+
----------
|
| 115 |
+
device : str, optional
|
| 116 |
+
Device to load models on. Auto-detected if not specified.
|
| 117 |
+
"""
|
| 118 |
+
self.device = device or self._detect_device()
|
| 119 |
+
self._current_model_key: Optional[str] = None
|
| 120 |
+
self._pipeline: Optional[Any] = None
|
| 121 |
+
self._controlnet: Optional[Any] = None
|
| 122 |
+
self._controlnet_loaded: bool = False
|
| 123 |
+
|
| 124 |
+
logger.info(f"InpaintingModelManager initialized on device: {self.device}")
|
| 125 |
+
|
| 126 |
+
def _detect_device(self) -> str:
|
| 127 |
+
"""Detect the best available device."""
|
| 128 |
+
if torch.cuda.is_available():
|
| 129 |
+
return "cuda"
|
| 130 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 131 |
+
return "mps"
|
| 132 |
+
return "cpu"
|
| 133 |
+
|
| 134 |
+
def get_models_for_mode(self, mode: ImageMode) -> Dict[str, ModelConfig]:
|
| 135 |
+
"""
|
| 136 |
+
Get all available models for a specific mode.
|
| 137 |
+
|
| 138 |
+
Parameters
|
| 139 |
+
----------
|
| 140 |
+
mode : ImageMode
|
| 141 |
+
The image mode (PHOTO or ANIME)
|
| 142 |
+
|
| 143 |
+
Returns
|
| 144 |
+
-------
|
| 145 |
+
dict
|
| 146 |
+
Dictionary of model configs for the mode
|
| 147 |
+
"""
|
| 148 |
+
return {
|
| 149 |
+
key: config
|
| 150 |
+
for key, config in self.AVAILABLE_MODELS.items()
|
| 151 |
+
if config.mode == mode
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
def get_model_choices(self) -> Dict[str, list]:
|
| 155 |
+
"""
|
| 156 |
+
Get model choices formatted for UI dropdown.
|
| 157 |
+
|
| 158 |
+
Returns
|
| 159 |
+
-------
|
| 160 |
+
dict
|
| 161 |
+
Dictionary with 'photo' and 'anime' lists of (display_name, key) tuples
|
| 162 |
+
"""
|
| 163 |
+
choices = {
|
| 164 |
+
"photo": [],
|
| 165 |
+
"anime": []
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
for key, config in self.AVAILABLE_MODELS.items():
|
| 169 |
+
display = f"{config.name} - {config.description}"
|
| 170 |
+
if config.mode == ImageMode.PHOTO:
|
| 171 |
+
choices["photo"].append((display, key))
|
| 172 |
+
else:
|
| 173 |
+
choices["anime"].append((display, key))
|
| 174 |
+
|
| 175 |
+
return choices
|
| 176 |
+
|
| 177 |
+
def get_default_model(self, mode: ImageMode) -> str:
|
| 178 |
+
"""Get the default model key for a mode."""
|
| 179 |
+
return self.DEFAULT_MODELS.get(mode, "sdxl_base")
|
| 180 |
+
|
| 181 |
+
def load_controlnet(self) -> Any:
|
| 182 |
+
"""
|
| 183 |
+
Load the ControlNet model (shared across all base models).
|
| 184 |
+
|
| 185 |
+
Returns
|
| 186 |
+
-------
|
| 187 |
+
ControlNetModel
|
| 188 |
+
Loaded ControlNet model
|
| 189 |
+
"""
|
| 190 |
+
if self._controlnet_loaded and self._controlnet is not None:
|
| 191 |
+
return self._controlnet
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
from diffusers import ControlNetModel
|
| 195 |
+
|
| 196 |
+
logger.info("Loading ControlNet Canny model...")
|
| 197 |
+
self._controlnet = ControlNetModel.from_pretrained(
|
| 198 |
+
"diffusers/controlnet-canny-sdxl-1.0",
|
| 199 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 200 |
+
use_safetensors=True
|
| 201 |
+
)
|
| 202 |
+
self._controlnet_loaded = True
|
| 203 |
+
logger.info("ControlNet loaded successfully")
|
| 204 |
+
return self._controlnet
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.error(f"Failed to load ControlNet: {e}")
|
| 208 |
+
raise
|
| 209 |
+
|
| 210 |
+
def load_pipeline(
|
| 211 |
+
self,
|
| 212 |
+
model_key: Optional[str] = None,
|
| 213 |
+
mode: Optional[ImageMode] = None
|
| 214 |
+
) -> Any:
|
| 215 |
+
"""
|
| 216 |
+
Load an inpainting pipeline for the specified model.
|
| 217 |
+
|
| 218 |
+
Parameters
|
| 219 |
+
----------
|
| 220 |
+
model_key : str, optional
|
| 221 |
+
Specific model key to load
|
| 222 |
+
mode : ImageMode, optional
|
| 223 |
+
If model_key not specified, load default for this mode
|
| 224 |
+
|
| 225 |
+
Returns
|
| 226 |
+
-------
|
| 227 |
+
StableDiffusionXLControlNetInpaintPipeline
|
| 228 |
+
Loaded pipeline ready for inference
|
| 229 |
+
"""
|
| 230 |
+
# Determine which model to load
|
| 231 |
+
if model_key is None:
|
| 232 |
+
if mode is None:
|
| 233 |
+
mode = ImageMode.PHOTO
|
| 234 |
+
model_key = self.get_default_model(mode)
|
| 235 |
+
|
| 236 |
+
# Check if already loaded
|
| 237 |
+
if self._current_model_key == model_key and self._pipeline is not None:
|
| 238 |
+
logger.info(f"Model {model_key} already loaded")
|
| 239 |
+
return self._pipeline
|
| 240 |
+
|
| 241 |
+
# Unload current model if different
|
| 242 |
+
if self._current_model_key != model_key:
|
| 243 |
+
self.unload_pipeline()
|
| 244 |
+
|
| 245 |
+
# Get model config
|
| 246 |
+
config = self.AVAILABLE_MODELS.get(model_key)
|
| 247 |
+
if config is None:
|
| 248 |
+
raise ValueError(f"Unknown model key: {model_key}")
|
| 249 |
+
|
| 250 |
+
logger.info(f"Loading model: {config.name} ({config.model_id})")
|
| 251 |
+
|
| 252 |
+
try:
|
| 253 |
+
# Ensure ControlNet is loaded
|
| 254 |
+
controlnet = self.load_controlnet()
|
| 255 |
+
|
| 256 |
+
# Load pipeline
|
| 257 |
+
dtype = torch.float16 if self.device == "cuda" else torch.float32
|
| 258 |
+
|
| 259 |
+
load_kwargs = {
|
| 260 |
+
"controlnet": controlnet,
|
| 261 |
+
"torch_dtype": dtype,
|
| 262 |
+
"use_safetensors": True,
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
if config.requires_variant:
|
| 266 |
+
load_kwargs["variant"] = config.variant
|
| 267 |
+
|
| 268 |
+
self._pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
|
| 269 |
+
config.model_id,
|
| 270 |
+
**load_kwargs
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# Move to device and optimize
|
| 274 |
+
self._pipeline = self._pipeline.to(self.device)
|
| 275 |
+
|
| 276 |
+
if self.device == "cuda":
|
| 277 |
+
self._pipeline.enable_vae_tiling()
|
| 278 |
+
try:
|
| 279 |
+
self._pipeline.enable_xformers_memory_efficient_attention()
|
| 280 |
+
logger.info("xformers enabled")
|
| 281 |
+
except Exception:
|
| 282 |
+
logger.info("xformers not available, using default attention")
|
| 283 |
+
|
| 284 |
+
self._current_model_key = model_key
|
| 285 |
+
logger.info(f"Model {config.name} loaded successfully")
|
| 286 |
+
|
| 287 |
+
return self._pipeline
|
| 288 |
+
|
| 289 |
+
except Exception as e:
|
| 290 |
+
logger.error(f"Failed to load model {model_key}: {e}")
|
| 291 |
+
raise
|
| 292 |
+
|
| 293 |
+
def unload_pipeline(self) -> None:
|
| 294 |
+
"""Unload the current pipeline to free memory."""
|
| 295 |
+
if self._pipeline is not None:
|
| 296 |
+
logger.info(f"Unloading model: {self._current_model_key}")
|
| 297 |
+
del self._pipeline
|
| 298 |
+
self._pipeline = None
|
| 299 |
+
self._current_model_key = None
|
| 300 |
+
|
| 301 |
+
if self.device == "cuda":
|
| 302 |
+
torch.cuda.empty_cache()
|
| 303 |
+
gc.collect()
|
| 304 |
+
|
| 305 |
+
def switch_model(self, model_key: str) -> Any:
|
| 306 |
+
"""
|
| 307 |
+
Switch to a different model.
|
| 308 |
+
|
| 309 |
+
Parameters
|
| 310 |
+
----------
|
| 311 |
+
model_key : str
|
| 312 |
+
Model key to switch to
|
| 313 |
+
|
| 314 |
+
Returns
|
| 315 |
+
-------
|
| 316 |
+
Pipeline
|
| 317 |
+
Newly loaded pipeline
|
| 318 |
+
"""
|
| 319 |
+
return self.load_pipeline(model_key=model_key)
|
| 320 |
+
|
| 321 |
+
def get_current_model_config(self) -> Optional[ModelConfig]:
|
| 322 |
+
"""Get the configuration of the currently loaded model."""
|
| 323 |
+
if self._current_model_key is None:
|
| 324 |
+
return None
|
| 325 |
+
return self.AVAILABLE_MODELS.get(self._current_model_key)
|
| 326 |
+
|
| 327 |
+
def get_pipeline(self) -> Optional[Any]:
|
| 328 |
+
"""Get the currently loaded pipeline."""
|
| 329 |
+
return self._pipeline
|
| 330 |
+
|
| 331 |
+
def is_loaded(self) -> bool:
|
| 332 |
+
"""Check if a pipeline is currently loaded."""
|
| 333 |
+
return self._pipeline is not None
|
| 334 |
+
|
| 335 |
+
def get_status(self) -> Dict[str, Any]:
|
| 336 |
+
"""
|
| 337 |
+
Get current status of the model manager.
|
| 338 |
+
|
| 339 |
+
Returns
|
| 340 |
+
-------
|
| 341 |
+
dict
|
| 342 |
+
Status information
|
| 343 |
+
"""
|
| 344 |
+
current_config = self.get_current_model_config()
|
| 345 |
+
return {
|
| 346 |
+
"device": self.device,
|
| 347 |
+
"current_model": self._current_model_key,
|
| 348 |
+
"current_model_name": current_config.name if current_config else None,
|
| 349 |
+
"is_loaded": self.is_loaded(),
|
| 350 |
+
"controlnet_loaded": self._controlnet_loaded,
|
| 351 |
+
"available_models": list(self.AVAILABLE_MODELS.keys())
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def get_model_selection_guide() -> str:
|
| 356 |
+
"""
|
| 357 |
+
Get HTML guide for model selection to display in UI.
|
| 358 |
+
|
| 359 |
+
Returns
|
| 360 |
+
-------
|
| 361 |
+
str
|
| 362 |
+
HTML formatted guide
|
| 363 |
+
"""
|
| 364 |
+
return """
|
| 365 |
+
<div style="background: linear-gradient(135deg, #f5f7fa 0%, #e4e8ec 100%);
|
| 366 |
+
padding: 16px;
|
| 367 |
+
border-radius: 12px;
|
| 368 |
+
margin: 12px 0;
|
| 369 |
+
border: 1px solid #ddd;">
|
| 370 |
+
<h4 style="margin: 0 0 12px 0; color: #333; font-size: 16px;">
|
| 371 |
+
📸 Model Selection Guide
|
| 372 |
+
</h4>
|
| 373 |
+
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 12px;">
|
| 374 |
+
<div style="background: white; padding: 12px; border-radius: 8px; border-left: 4px solid #4CAF50;">
|
| 375 |
+
<p style="margin: 0 0 8px 0; font-weight: bold; color: #4CAF50;">
|
| 376 |
+
🖼️ Photo Mode
|
| 377 |
+
</p>
|
| 378 |
+
<p style="margin: 0; font-size: 13px; color: #555;">
|
| 379 |
+
<strong>Best for:</strong> Real photographs, portraits, product shots, nature photos
|
| 380 |
+
</p>
|
| 381 |
+
<p style="margin: 8px 0 0 0; font-size: 12px; color: #777;">
|
| 382 |
+
Recommended: JuggernautXL for portraits, RealVisXL for scenes
|
| 383 |
+
</p>
|
| 384 |
+
</div>
|
| 385 |
+
<div style="background: white; padding: 12px; border-radius: 8px; border-left: 4px solid #9C27B0;">
|
| 386 |
+
<p style="margin: 0 0 8px 0; font-weight: bold; color: #9C27B0;">
|
| 387 |
+
🎨 Anime Mode
|
| 388 |
+
</p>
|
| 389 |
+
<p style="margin: 0; font-size: 13px; color: #555;">
|
| 390 |
+
<strong>Best for:</strong> Anime, manga, illustrations, digital art, cartoons
|
| 391 |
+
</p>
|
| 392 |
+
<p style="margin: 8px 0 0 0; font-size: 12px; color: #777;">
|
| 393 |
+
Recommended: Animagine XL for anime, SDXL Base for general art
|
| 394 |
+
</p>
|
| 395 |
+
</div>
|
| 396 |
+
</div>
|
| 397 |
+
</div>
|
| 398 |
+
"""
|
inpainting_module.py
CHANGED
|
@@ -4,55 +4,57 @@ import os
|
|
| 4 |
import time
|
| 5 |
import traceback
|
| 6 |
from dataclasses import dataclass, field
|
| 7 |
-
from typing import Any, Callable, Dict,
|
| 8 |
|
| 9 |
import cv2
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
-
from PIL import Image
|
| 13 |
|
| 14 |
-
from diffusers import
|
|
|
|
|
|
|
| 15 |
from diffusers import StableDiffusionXLControlNetInpaintPipeline
|
| 16 |
-
from
|
| 17 |
-
from transformers import
|
| 18 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
logger.setLevel(logging.INFO)
|
| 22 |
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
@dataclass
|
| 25 |
class InpaintingConfig:
|
| 26 |
"""Configuration for inpainting operations."""
|
| 27 |
|
| 28 |
-
# ControlNet settings
|
| 29 |
controlnet_conditioning_scale: float = 0.7
|
| 30 |
-
conditioning_type: str = "canny"
|
| 31 |
|
| 32 |
# Canny edge detection parameters
|
| 33 |
canny_low_threshold: int = 100
|
| 34 |
canny_high_threshold: int = 200
|
| 35 |
|
| 36 |
# Mask settings
|
| 37 |
-
feather_radius: int =
|
| 38 |
min_mask_coverage: float = 0.01
|
| 39 |
max_mask_coverage: float = 0.95
|
| 40 |
|
| 41 |
# Generation settings
|
| 42 |
num_inference_steps: int = 25
|
| 43 |
guidance_scale: float = 7.5
|
| 44 |
-
strength: float =
|
| 45 |
-
preview_steps: int = 15
|
| 46 |
-
preview_guidance_scale: float = 8.0
|
| 47 |
-
|
| 48 |
-
# Quality settings
|
| 49 |
-
enable_auto_optimization: bool = True
|
| 50 |
-
max_optimization_retries: int = 3
|
| 51 |
-
min_quality_score: float = 70.0
|
| 52 |
|
| 53 |
# Memory settings
|
| 54 |
enable_vae_tiling: bool = True
|
| 55 |
-
enable_attention_slicing: bool = True
|
| 56 |
max_resolution: int = 1024
|
| 57 |
|
| 58 |
|
|
@@ -66,94 +68,81 @@ class InpaintingResult:
|
|
| 66 |
control_image: Optional[Image.Image] = None
|
| 67 |
blended_image: Optional[Image.Image] = None
|
| 68 |
quality_score: float = 0.0
|
| 69 |
-
quality_details: Dict[str, Any] = field(default_factory=dict)
|
| 70 |
generation_time: float = 0.0
|
| 71 |
-
retries: int = 0
|
| 72 |
error_message: str = ""
|
| 73 |
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 74 |
|
| 75 |
|
| 76 |
class InpaintingModule:
|
| 77 |
"""
|
| 78 |
-
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
is_initialized: Whether pipeline is loaded
|
| 88 |
|
| 89 |
Example:
|
| 90 |
>>> module = InpaintingModule(device="cuda")
|
| 91 |
-
>>>
|
| 92 |
-
>>>
|
| 93 |
-
|
| 94 |
-
... mask=my_mask,
|
| 95 |
-
... prompt="a beautiful garden"
|
| 96 |
-
... )
|
| 97 |
"""
|
| 98 |
|
| 99 |
-
#
|
| 100 |
CONTROLNET_CANNY_MODEL = "diffusers/controlnet-canny-sdxl-1.0"
|
| 101 |
CONTROLNET_DEPTH_MODEL = "diffusers/controlnet-depth-sdxl-1.0"
|
| 102 |
DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf"
|
| 103 |
DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas"
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
def __init__(
|
| 107 |
self,
|
| 108 |
device: str = "auto",
|
| 109 |
config: Optional[InpaintingConfig] = None
|
| 110 |
):
|
| 111 |
-
"""
|
| 112 |
-
Initialize the InpaintingModule.
|
| 113 |
-
|
| 114 |
-
Parameters
|
| 115 |
-
----------
|
| 116 |
-
device : str, optional
|
| 117 |
-
Computation device. "auto" for automatic detection.
|
| 118 |
-
config : InpaintingConfig, optional
|
| 119 |
-
Configuration object. Uses defaults if not provided.
|
| 120 |
-
"""
|
| 121 |
self.device = self._setup_device(device)
|
| 122 |
self.config = config or InpaintingConfig()
|
| 123 |
|
| 124 |
-
#
|
| 125 |
-
self.
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
self._depth_estimator = None
|
| 129 |
self._depth_processor = None
|
| 130 |
|
| 131 |
# State tracking
|
| 132 |
self.is_initialized = False
|
|
|
|
| 133 |
self._current_conditioning_type = None
|
| 134 |
-
self.
|
| 135 |
-
self._cached_latents = None
|
| 136 |
-
self._use_controlnet = True # Track if ControlNet is available
|
| 137 |
-
|
| 138 |
-
# Reference to model manager (set by SceneWeaverCore)
|
| 139 |
-
self._model_manager = None
|
| 140 |
|
| 141 |
logger.info(f"InpaintingModule initialized on {self.device}")
|
| 142 |
|
| 143 |
def _setup_device(self, device: str) -> str:
|
| 144 |
-
"""
|
| 145 |
-
Setup computation device.
|
| 146 |
-
|
| 147 |
-
Parameters
|
| 148 |
-
----------
|
| 149 |
-
device : str
|
| 150 |
-
Device specification or "auto"
|
| 151 |
-
|
| 152 |
-
Returns
|
| 153 |
-
-------
|
| 154 |
-
str
|
| 155 |
-
Resolved device name
|
| 156 |
-
"""
|
| 157 |
if device == "auto":
|
| 158 |
if torch.cuda.is_available():
|
| 159 |
return "cuda"
|
|
@@ -162,224 +151,159 @@ class InpaintingModule:
|
|
| 162 |
return "cpu"
|
| 163 |
return device
|
| 164 |
|
| 165 |
-
def set_model_manager(self, manager: Any) -> None:
|
| 166 |
-
"""
|
| 167 |
-
Set reference to ModelManager for coordinated model lifecycle.
|
| 168 |
-
|
| 169 |
-
Parameters
|
| 170 |
-
----------
|
| 171 |
-
manager : ModelManager
|
| 172 |
-
The global model manager instance
|
| 173 |
-
"""
|
| 174 |
-
self._model_manager = manager
|
| 175 |
-
logger.info("ModelManager reference set for InpaintingModule")
|
| 176 |
-
|
| 177 |
def _memory_cleanup(self, aggressive: bool = False) -> None:
|
| 178 |
-
"""
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
Parameters
|
| 182 |
-
----------
|
| 183 |
-
aggressive : bool
|
| 184 |
-
If True, perform multiple GC rounds and sync CUDA
|
| 185 |
-
"""
|
| 186 |
-
rounds = 5 if aggressive else 2
|
| 187 |
-
for _ in range(rounds):
|
| 188 |
gc.collect()
|
| 189 |
|
| 190 |
-
# On Hugging Face Spaces, avoid CUDA operations in main process
|
| 191 |
-
# CUDA operations must only happen within @spaces.GPU decorated functions
|
| 192 |
is_spaces = os.getenv('SPACE_ID') is not None
|
| 193 |
-
|
| 194 |
if not is_spaces and torch.cuda.is_available():
|
| 195 |
torch.cuda.empty_cache()
|
| 196 |
if aggressive:
|
| 197 |
torch.cuda.ipc_collect()
|
| 198 |
-
torch.cuda.synchronize()
|
| 199 |
-
|
| 200 |
-
logger.debug(f"Memory cleanup completed (aggressive={aggressive}, spaces={is_spaces})")
|
| 201 |
-
|
| 202 |
-
def _check_memory_status(self) -> Dict[str, float]:
|
| 203 |
-
"""
|
| 204 |
-
Check current GPU memory status.
|
| 205 |
-
|
| 206 |
-
Returns
|
| 207 |
-
-------
|
| 208 |
-
dict
|
| 209 |
-
Memory statistics including allocated, total, and usage ratio
|
| 210 |
-
"""
|
| 211 |
-
# On Spaces, skip CUDA checks in main process
|
| 212 |
-
is_spaces = os.getenv('SPACE_ID') is not None
|
| 213 |
-
|
| 214 |
-
if is_spaces or not torch.cuda.is_available():
|
| 215 |
-
return {"available": True, "usage_ratio": 0.0}
|
| 216 |
|
| 217 |
-
|
| 218 |
-
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 219 |
-
usage_ratio = allocated / total
|
| 220 |
-
|
| 221 |
-
return {
|
| 222 |
-
"allocated_gb": round(allocated, 2),
|
| 223 |
-
"total_gb": round(total, 2),
|
| 224 |
-
"free_gb": round(total - allocated, 2),
|
| 225 |
-
"usage_ratio": round(usage_ratio, 3),
|
| 226 |
-
"available": usage_ratio < 0.9
|
| 227 |
-
}
|
| 228 |
-
|
| 229 |
-
def load_inpainting_pipeline(
|
| 230 |
self,
|
|
|
|
| 231 |
conditioning_type: str = "canny",
|
|
|
|
| 232 |
progress_callback: Optional[Callable[[str, int], None]] = None
|
| 233 |
) -> Tuple[bool, str]:
|
| 234 |
"""
|
| 235 |
-
Load the
|
| 236 |
-
|
| 237 |
-
Implements mutual exclusion with background generation pipeline.
|
| 238 |
-
Only one pipeline can be loaded at a time.
|
| 239 |
|
| 240 |
Parameters
|
| 241 |
----------
|
|
|
|
|
|
|
|
|
|
| 242 |
conditioning_type : str
|
| 243 |
-
|
|
|
|
|
|
|
| 244 |
progress_callback : callable, optional
|
| 245 |
-
|
| 246 |
|
| 247 |
Returns
|
| 248 |
-------
|
| 249 |
tuple
|
| 250 |
(success: bool, error_message: str)
|
| 251 |
"""
|
| 252 |
-
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
return True, ""
|
| 255 |
|
| 256 |
-
logger.info(f"Loading
|
| 257 |
|
| 258 |
try:
|
| 259 |
self._memory_cleanup(aggressive=True)
|
| 260 |
|
| 261 |
if progress_callback:
|
| 262 |
-
progress_callback("Preparing
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
-
|
| 265 |
-
if self._inpaint_pipeline is not None:
|
| 266 |
-
self._unload_pipeline()
|
| 267 |
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
|
|
|
| 271 |
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
|
|
|
|
|
|
| 278 |
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
| 280 |
if conditioning_type == "canny":
|
| 281 |
-
|
| 282 |
self.CONTROLNET_CANNY_MODEL,
|
| 283 |
torch_dtype=dtype,
|
| 284 |
use_safetensors=True
|
| 285 |
)
|
| 286 |
-
self._controlnet_canny = controlnet
|
| 287 |
-
logger.info("Loaded ControlNet Canny model")
|
| 288 |
-
|
| 289 |
elif conditioning_type == "depth":
|
| 290 |
-
|
| 291 |
self.CONTROLNET_DEPTH_MODEL,
|
| 292 |
torch_dtype=dtype,
|
| 293 |
use_safetensors=True
|
| 294 |
)
|
| 295 |
-
self._controlnet_depth = controlnet
|
| 296 |
-
|
| 297 |
-
# Load depth estimator
|
| 298 |
-
if progress_callback:
|
| 299 |
-
progress_callback("Loading depth estimation model...", 35)
|
| 300 |
self._load_depth_estimator()
|
| 301 |
-
logger.info("Loaded ControlNet Depth model")
|
| 302 |
else:
|
| 303 |
raise ValueError(f"Unknown conditioning type: {conditioning_type}")
|
| 304 |
-
else:
|
| 305 |
-
# Skip ControlNet loading for fallback mode
|
| 306 |
-
logger.info(f"Skipping ControlNet loading (fallback mode)")
|
| 307 |
|
| 308 |
-
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
self.BASE_MODEL,
|
| 315 |
-
controlnet=controlnet,
|
| 316 |
-
torch_dtype=dtype,
|
| 317 |
-
use_safetensors=True,
|
| 318 |
-
variant="fp16" if dtype == torch.float16 else None
|
| 319 |
)
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
self.
|
| 323 |
-
|
| 324 |
-
torch_dtype=dtype,
|
| 325 |
-
use_safetensors=True,
|
| 326 |
-
variant="fp16" if dtype == torch.float16 else None
|
| 327 |
-
)
|
| 328 |
-
self._use_controlnet = False
|
| 329 |
-
|
| 330 |
-
# Track ControlNet usage
|
| 331 |
-
self._use_controlnet = use_controlnet_inpaint and controlnet is not None
|
| 332 |
|
| 333 |
if progress_callback:
|
| 334 |
-
progress_callback("Configuring
|
| 335 |
|
| 336 |
-
# Configure scheduler
|
| 337 |
-
self.
|
| 338 |
-
self.
|
| 339 |
)
|
| 340 |
|
| 341 |
-
# Move to device
|
| 342 |
-
self.
|
| 343 |
-
|
| 344 |
-
if progress_callback:
|
| 345 |
-
progress_callback("Applying optimizations...", 85)
|
| 346 |
-
|
| 347 |
-
# Apply memory optimizations
|
| 348 |
-
self._apply_pipeline_optimizations()
|
| 349 |
-
|
| 350 |
-
# Set eval mode
|
| 351 |
-
self._inpaint_pipeline.unet.eval()
|
| 352 |
-
if hasattr(self._inpaint_pipeline, 'vae'):
|
| 353 |
-
self._inpaint_pipeline.vae.eval()
|
| 354 |
|
| 355 |
self.is_initialized = True
|
| 356 |
-
self._current_conditioning_type = conditioning_type if self._use_controlnet else "none"
|
| 357 |
|
| 358 |
if progress_callback:
|
| 359 |
-
progress_callback("
|
| 360 |
-
|
| 361 |
-
# Log memory status
|
| 362 |
-
mem_status = self._check_memory_status()
|
| 363 |
-
logger.info(f"Pipeline loaded. GPU memory: {mem_status.get('allocated_gb', 0):.1f}GB used")
|
| 364 |
|
| 365 |
return True, ""
|
| 366 |
|
| 367 |
except Exception as e:
|
| 368 |
error_msg = str(e)
|
| 369 |
-
logger.error(f"Failed to load
|
| 370 |
traceback.print_exc()
|
| 371 |
self._unload_pipeline()
|
| 372 |
return False, error_msg
|
| 373 |
|
| 374 |
def _load_depth_estimator(self) -> None:
|
| 375 |
-
"""
|
| 376 |
-
Load depth estimation model with fallback strategy.
|
| 377 |
-
|
| 378 |
-
Tries Depth-Anything first, falls back to MiDaS if unavailable.
|
| 379 |
-
"""
|
| 380 |
try:
|
| 381 |
-
logger.info(f"Attempting to load depth model: {self.DEPTH_MODEL_PRIMARY}")
|
| 382 |
-
|
| 383 |
self._depth_processor = AutoImageProcessor.from_pretrained(
|
| 384 |
self.DEPTH_MODEL_PRIMARY
|
| 385 |
)
|
|
@@ -389,70 +313,50 @@ class InpaintingModule:
|
|
| 389 |
)
|
| 390 |
self._depth_estimator.to(self.device)
|
| 391 |
self._depth_estimator.eval()
|
| 392 |
-
|
| 393 |
-
logger.info("Successfully loaded Depth-Anything model")
|
| 394 |
-
|
| 395 |
except Exception as e:
|
| 396 |
logger.warning(f"Primary depth model failed: {e}, trying fallback...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
)
|
| 402 |
-
self._depth_estimator = DPTForDepthEstimation.from_pretrained(
|
| 403 |
-
self.DEPTH_MODEL_FALLBACK,
|
| 404 |
-
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
| 405 |
-
)
|
| 406 |
-
self._depth_estimator.to(self.device)
|
| 407 |
-
self._depth_estimator.eval()
|
| 408 |
-
|
| 409 |
-
logger.info("Successfully loaded MiDaS fallback model")
|
| 410 |
-
|
| 411 |
-
except Exception as fallback_e:
|
| 412 |
-
logger.error(f"Fallback depth model also failed: {fallback_e}")
|
| 413 |
-
raise RuntimeError("Unable to load any depth estimation model")
|
| 414 |
-
|
| 415 |
-
def _apply_pipeline_optimizations(self) -> None:
|
| 416 |
-
"""Apply memory and performance optimizations to the pipeline."""
|
| 417 |
-
if self._inpaint_pipeline is None:
|
| 418 |
return
|
| 419 |
|
| 420 |
-
# Try xformers first
|
| 421 |
try:
|
| 422 |
-
self.
|
| 423 |
-
logger.info("Enabled xformers
|
| 424 |
except Exception:
|
| 425 |
try:
|
| 426 |
-
self.
|
| 427 |
logger.info("Enabled attention slicing")
|
| 428 |
except Exception:
|
| 429 |
-
|
| 430 |
|
| 431 |
-
# VAE optimizations
|
| 432 |
if self.config.enable_vae_tiling:
|
| 433 |
-
if hasattr(self.
|
| 434 |
-
self.
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
if hasattr(self._inpaint_pipeline, 'enable_vae_slicing'):
|
| 438 |
-
self._inpaint_pipeline.enable_vae_slicing()
|
| 439 |
-
logger.debug("Enabled VAE slicing")
|
| 440 |
|
| 441 |
def _unload_pipeline(self) -> None:
|
| 442 |
-
"""Unload
|
| 443 |
-
|
|
|
|
|
|
|
| 444 |
|
| 445 |
-
if self.
|
| 446 |
-
del self.
|
| 447 |
-
self.
|
| 448 |
-
|
| 449 |
-
if self._controlnet_canny is not None:
|
| 450 |
-
del self._controlnet_canny
|
| 451 |
-
self._controlnet_canny = None
|
| 452 |
-
|
| 453 |
-
if self._controlnet_depth is not None:
|
| 454 |
-
del self._controlnet_depth
|
| 455 |
-
self._controlnet_depth = None
|
| 456 |
|
| 457 |
if self._depth_estimator is not None:
|
| 458 |
del self._depth_estimator
|
|
@@ -463,942 +367,300 @@ class InpaintingModule:
|
|
| 463 |
self._depth_processor = None
|
| 464 |
|
| 465 |
self.is_initialized = False
|
|
|
|
| 466 |
self._current_conditioning_type = None
|
| 467 |
-
self._cached_latents = None
|
| 468 |
|
| 469 |
self._memory_cleanup(aggressive=True)
|
| 470 |
-
logger.info("
|
| 471 |
-
|
| 472 |
-
def prepare_control_image(
|
| 473 |
-
self,
|
| 474 |
-
image: Image.Image,
|
| 475 |
-
mode: str = "canny",
|
| 476 |
-
mask: Optional[Image.Image] = None,
|
| 477 |
-
preserve_structure: bool = False
|
| 478 |
-
) -> Image.Image:
|
| 479 |
-
"""
|
| 480 |
-
Generate ControlNet conditioning image.
|
| 481 |
-
|
| 482 |
-
Parameters
|
| 483 |
-
----------
|
| 484 |
-
image : PIL.Image
|
| 485 |
-
Input image
|
| 486 |
-
mode : str
|
| 487 |
-
Conditioning mode: "canny" or "depth"
|
| 488 |
-
mask : PIL.Image, optional
|
| 489 |
-
If provided, can suppress edges in masked region (when preserve_structure=False).
|
| 490 |
-
preserve_structure : bool
|
| 491 |
-
If True, keep edges in masked region (for color change tasks).
|
| 492 |
-
If False, suppress edges in masked region (for replacement/removal tasks).
|
| 493 |
-
|
| 494 |
-
Returns
|
| 495 |
-
-------
|
| 496 |
-
PIL.Image
|
| 497 |
-
Generated control image (edges or depth map)
|
| 498 |
-
"""
|
| 499 |
-
logger.info(f"Preparing control image with mode: {mode}, preserve_structure: {preserve_structure}")
|
| 500 |
-
|
| 501 |
-
# Convert to RGB if needed
|
| 502 |
-
if image.mode != 'RGB':
|
| 503 |
-
image = image.convert('RGB')
|
| 504 |
-
|
| 505 |
-
img_array = np.array(image)
|
| 506 |
-
|
| 507 |
-
if mode == "canny":
|
| 508 |
-
canny_image = self._generate_canny_edges(img_array)
|
| 509 |
-
|
| 510 |
-
# Mask-aware processing: suppress edges in masked region ONLY if not preserving structure
|
| 511 |
-
if mask is not None and not preserve_structure:
|
| 512 |
-
canny_array = np.array(canny_image)
|
| 513 |
-
mask_array = np.array(mask.convert('L'))
|
| 514 |
-
|
| 515 |
-
# In masked region, completely suppress Canny edges
|
| 516 |
-
# This allows complete replacement/removal of the object
|
| 517 |
-
mask_region = mask_array > 128 # White = masked area
|
| 518 |
-
canny_array[mask_region] = 0
|
| 519 |
-
|
| 520 |
-
canny_image = Image.fromarray(canny_array)
|
| 521 |
-
logger.info("Suppressed edges in masked region for replacement/removal")
|
| 522 |
-
elif preserve_structure:
|
| 523 |
-
logger.info("Preserving edges in masked region for color change")
|
| 524 |
-
|
| 525 |
-
return canny_image
|
| 526 |
-
|
| 527 |
-
elif mode == "depth":
|
| 528 |
-
return self._generate_depth_map(image)
|
| 529 |
-
else:
|
| 530 |
-
raise ValueError(f"Unknown control mode: {mode}")
|
| 531 |
-
|
| 532 |
-
def _generate_canny_edges(self, img_array: np.ndarray) -> Image.Image:
|
| 533 |
-
"""
|
| 534 |
-
Generate Canny edge detection image.
|
| 535 |
-
|
| 536 |
-
Parameters
|
| 537 |
-
----------
|
| 538 |
-
img_array : np.ndarray
|
| 539 |
-
Input image as RGB numpy array
|
| 540 |
-
|
| 541 |
-
Returns
|
| 542 |
-
-------
|
| 543 |
-
PIL.Image
|
| 544 |
-
Edge detection result as grayscale image
|
| 545 |
-
"""
|
| 546 |
-
# Convert to grayscale
|
| 547 |
-
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 548 |
-
|
| 549 |
-
# Apply Gaussian blur to reduce noise
|
| 550 |
-
blurred = cv2.GaussianBlur(gray, (5, 5), 1.4)
|
| 551 |
-
|
| 552 |
-
# Canny edge detection
|
| 553 |
-
edges = cv2.Canny(
|
| 554 |
-
blurred,
|
| 555 |
-
self.config.canny_low_threshold,
|
| 556 |
-
self.config.canny_high_threshold
|
| 557 |
-
)
|
| 558 |
-
|
| 559 |
-
# Convert to 3-channel for ControlNet
|
| 560 |
-
edges_3ch = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
|
| 561 |
-
|
| 562 |
-
logger.debug(f"Generated Canny edges with thresholds "
|
| 563 |
-
f"{self.config.canny_low_threshold}/{self.config.canny_high_threshold}")
|
| 564 |
-
|
| 565 |
-
return Image.fromarray(edges_3ch)
|
| 566 |
-
|
| 567 |
-
def _generate_depth_map(self, image: Image.Image) -> Image.Image:
|
| 568 |
-
"""
|
| 569 |
-
Generate depth map using depth estimation model.
|
| 570 |
-
|
| 571 |
-
Parameters
|
| 572 |
-
----------
|
| 573 |
-
image : PIL.Image
|
| 574 |
-
Input RGB image
|
| 575 |
-
|
| 576 |
-
Returns
|
| 577 |
-
-------
|
| 578 |
-
PIL.Image
|
| 579 |
-
Depth map as grayscale image
|
| 580 |
-
"""
|
| 581 |
-
if self._depth_estimator is None or self._depth_processor is None:
|
| 582 |
-
raise RuntimeError("Depth estimator not loaded")
|
| 583 |
-
|
| 584 |
-
# Preprocess
|
| 585 |
-
inputs = self._depth_processor(images=image, return_tensors="pt")
|
| 586 |
-
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 587 |
-
|
| 588 |
-
# Inference
|
| 589 |
-
with torch.no_grad():
|
| 590 |
-
outputs = self._depth_estimator(**inputs)
|
| 591 |
-
predicted_depth = outputs.predicted_depth
|
| 592 |
-
|
| 593 |
-
# Interpolate to original size
|
| 594 |
-
prediction = torch.nn.functional.interpolate(
|
| 595 |
-
predicted_depth.unsqueeze(1),
|
| 596 |
-
size=image.size[::-1], # (H, W)
|
| 597 |
-
mode="bicubic",
|
| 598 |
-
align_corners=False
|
| 599 |
-
)
|
| 600 |
-
|
| 601 |
-
# Normalize to 0-255
|
| 602 |
-
depth_array = prediction.squeeze().cpu().numpy()
|
| 603 |
-
depth_min = depth_array.min()
|
| 604 |
-
depth_max = depth_array.max()
|
| 605 |
-
|
| 606 |
-
if depth_max - depth_min > 0:
|
| 607 |
-
depth_normalized = ((depth_array - depth_min) / (depth_max - depth_min) * 255)
|
| 608 |
-
else:
|
| 609 |
-
depth_normalized = np.zeros_like(depth_array)
|
| 610 |
-
|
| 611 |
-
depth_normalized = depth_normalized.astype(np.uint8)
|
| 612 |
-
|
| 613 |
-
# Convert to 3-channel for ControlNet
|
| 614 |
-
depth_3ch = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB)
|
| 615 |
-
|
| 616 |
-
logger.debug(f"Generated depth map, range: {depth_min:.2f} - {depth_max:.2f}")
|
| 617 |
-
|
| 618 |
-
return Image.fromarray(depth_3ch)
|
| 619 |
-
|
| 620 |
-
def prepare_mask(
|
| 621 |
-
self,
|
| 622 |
-
mask: Image.Image,
|
| 623 |
-
target_size: Tuple[int, int],
|
| 624 |
-
feather_radius: Optional[int] = None
|
| 625 |
-
) -> Tuple[Image.Image, Dict[str, Any]]:
|
| 626 |
-
"""
|
| 627 |
-
Prepare and validate mask for inpainting.
|
| 628 |
-
|
| 629 |
-
Parameters
|
| 630 |
-
----------
|
| 631 |
-
mask : PIL.Image
|
| 632 |
-
Input mask (white = inpaint area)
|
| 633 |
-
target_size : tuple
|
| 634 |
-
Target (width, height) to match input image
|
| 635 |
-
feather_radius : int, optional
|
| 636 |
-
Feathering radius in pixels. Uses config default if None.
|
| 637 |
-
|
| 638 |
-
Returns
|
| 639 |
-
-------
|
| 640 |
-
tuple
|
| 641 |
-
(processed_mask, validation_info)
|
| 642 |
-
|
| 643 |
-
Raises
|
| 644 |
-
------
|
| 645 |
-
ValueError
|
| 646 |
-
If mask coverage is outside acceptable range
|
| 647 |
-
"""
|
| 648 |
-
feather = feather_radius if feather_radius is not None else self.config.feather_radius
|
| 649 |
-
|
| 650 |
-
# Convert to grayscale
|
| 651 |
-
if mask.mode != 'L':
|
| 652 |
-
mask = mask.convert('L')
|
| 653 |
-
|
| 654 |
-
# Resize to match target
|
| 655 |
-
if mask.size != target_size:
|
| 656 |
-
mask = mask.resize(target_size, Image.LANCZOS)
|
| 657 |
-
|
| 658 |
-
# Convert to array for processing
|
| 659 |
-
mask_array = np.array(mask)
|
| 660 |
-
|
| 661 |
-
# Calculate coverage
|
| 662 |
-
total_pixels = mask_array.size
|
| 663 |
-
white_pixels = np.count_nonzero(mask_array > 127)
|
| 664 |
-
coverage = white_pixels / total_pixels
|
| 665 |
-
|
| 666 |
-
validation_info = {
|
| 667 |
-
"coverage": coverage,
|
| 668 |
-
"white_pixels": white_pixels,
|
| 669 |
-
"total_pixels": total_pixels,
|
| 670 |
-
"feather_radius": feather,
|
| 671 |
-
"valid": True,
|
| 672 |
-
"warning": ""
|
| 673 |
-
}
|
| 674 |
-
|
| 675 |
-
# Validate coverage
|
| 676 |
-
if coverage < self.config.min_mask_coverage:
|
| 677 |
-
validation_info["valid"] = False
|
| 678 |
-
validation_info["warning"] = (
|
| 679 |
-
f"Mask coverage too low ({coverage:.1%}). "
|
| 680 |
-
f"Please select a larger area to inpaint."
|
| 681 |
-
)
|
| 682 |
-
logger.warning(f"Mask coverage {coverage:.1%} below minimum {self.config.min_mask_coverage:.1%}")
|
| 683 |
-
|
| 684 |
-
elif coverage > self.config.max_mask_coverage:
|
| 685 |
-
validation_info["valid"] = False
|
| 686 |
-
validation_info["warning"] = (
|
| 687 |
-
f"Mask coverage too high ({coverage:.1%}). "
|
| 688 |
-
f"Consider using background generation instead."
|
| 689 |
-
)
|
| 690 |
-
logger.warning(f"Mask coverage {coverage:.1%} above maximum {self.config.max_mask_coverage:.1%}")
|
| 691 |
-
|
| 692 |
-
# Apply feathering
|
| 693 |
-
if feather > 0:
|
| 694 |
-
mask_array = cv2.GaussianBlur(
|
| 695 |
-
mask_array,
|
| 696 |
-
(feather * 2 + 1, feather * 2 + 1),
|
| 697 |
-
feather / 2
|
| 698 |
-
)
|
| 699 |
-
logger.debug(f"Applied {feather}px feathering to mask")
|
| 700 |
-
|
| 701 |
-
processed_mask = Image.fromarray(mask_array, mode='L')
|
| 702 |
-
|
| 703 |
-
return processed_mask, validation_info
|
| 704 |
-
|
| 705 |
-
def enhance_prompt_for_inpainting(
|
| 706 |
-
self,
|
| 707 |
-
prompt: str,
|
| 708 |
-
image: Image.Image,
|
| 709 |
-
mask: Image.Image
|
| 710 |
-
) -> Tuple[str, str]:
|
| 711 |
-
"""
|
| 712 |
-
Enhance prompt based on non-masked region analysis.
|
| 713 |
-
|
| 714 |
-
Analyzes the surrounding context to generate appropriate
|
| 715 |
-
lighting and color descriptors.
|
| 716 |
-
|
| 717 |
-
Parameters
|
| 718 |
-
----------
|
| 719 |
-
prompt : str
|
| 720 |
-
User-provided prompt
|
| 721 |
-
image : PIL.Image
|
| 722 |
-
Original image
|
| 723 |
-
mask : PIL.Image
|
| 724 |
-
Inpainting mask
|
| 725 |
-
|
| 726 |
-
Returns
|
| 727 |
-
-------
|
| 728 |
-
tuple
|
| 729 |
-
(enhanced_prompt, negative_prompt)
|
| 730 |
-
"""
|
| 731 |
-
logger.info("Enhancing prompt for inpainting context...")
|
| 732 |
-
|
| 733 |
-
# Convert to arrays
|
| 734 |
-
img_array = np.array(image.convert('RGB'))
|
| 735 |
-
mask_array = np.array(mask.convert('L'))
|
| 736 |
-
|
| 737 |
-
# Analyze non-masked regions
|
| 738 |
-
non_masked = mask_array < 127
|
| 739 |
-
|
| 740 |
-
if not np.any(non_masked):
|
| 741 |
-
# No context available
|
| 742 |
-
enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic"
|
| 743 |
-
negative_prompt = self._get_inpainting_negative_prompt()
|
| 744 |
-
return enhanced_prompt, negative_prompt
|
| 745 |
-
|
| 746 |
-
# Extract context pixels
|
| 747 |
-
context_pixels = img_array[non_masked]
|
| 748 |
-
|
| 749 |
-
# Convert to Lab for analysis
|
| 750 |
-
context_lab = cv2.cvtColor(
|
| 751 |
-
context_pixels.reshape(-1, 1, 3),
|
| 752 |
-
cv2.COLOR_RGB2LAB
|
| 753 |
-
).reshape(-1, 3)
|
| 754 |
-
|
| 755 |
-
# Use robust statistics (median) to avoid outlier influence
|
| 756 |
-
median_l = np.median(context_lab[:, 0])
|
| 757 |
-
median_a = np.median(context_lab[:, 1])
|
| 758 |
-
median_b = np.median(context_lab[:, 2])
|
| 759 |
-
|
| 760 |
-
# Analyze lighting conditions
|
| 761 |
-
lighting_descriptors = []
|
| 762 |
-
|
| 763 |
-
if median_l > 170:
|
| 764 |
-
lighting_descriptors.append("bright")
|
| 765 |
-
elif median_l > 130:
|
| 766 |
-
lighting_descriptors.append("well-lit")
|
| 767 |
-
elif median_l > 80:
|
| 768 |
-
lighting_descriptors.append("moderate lighting")
|
| 769 |
-
else:
|
| 770 |
-
lighting_descriptors.append("dim lighting")
|
| 771 |
-
|
| 772 |
-
# Analyze color temperature (b channel: blue(-) to yellow(+))
|
| 773 |
-
if median_b > 140:
|
| 774 |
-
lighting_descriptors.append("warm golden tones")
|
| 775 |
-
elif median_b > 120:
|
| 776 |
-
lighting_descriptors.append("warm afternoon light")
|
| 777 |
-
elif median_b < 110:
|
| 778 |
-
lighting_descriptors.append("cool neutral tones")
|
| 779 |
-
|
| 780 |
-
# Calculate saturation from context
|
| 781 |
-
hsv = cv2.cvtColor(context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV)
|
| 782 |
-
median_saturation = np.median(hsv[:, :, 1])
|
| 783 |
-
|
| 784 |
-
if median_saturation > 150:
|
| 785 |
-
lighting_descriptors.append("vibrant colors")
|
| 786 |
-
elif median_saturation < 80:
|
| 787 |
-
lighting_descriptors.append("subtle muted colors")
|
| 788 |
-
|
| 789 |
-
# Build enhanced prompt
|
| 790 |
-
lighting_desc = ", ".join(lighting_descriptors) if lighting_descriptors else ""
|
| 791 |
-
quality_suffix = "high quality, detailed, photorealistic, seamless integration"
|
| 792 |
-
|
| 793 |
-
if lighting_desc:
|
| 794 |
-
enhanced_prompt = f"{prompt}, {lighting_desc}, {quality_suffix}"
|
| 795 |
-
else:
|
| 796 |
-
enhanced_prompt = f"{prompt}, {quality_suffix}"
|
| 797 |
-
|
| 798 |
-
negative_prompt = self._get_inpainting_negative_prompt()
|
| 799 |
-
|
| 800 |
-
logger.info(f"Enhanced prompt with context: {lighting_desc}")
|
| 801 |
-
|
| 802 |
-
return enhanced_prompt, negative_prompt
|
| 803 |
-
|
| 804 |
-
def _get_inpainting_negative_prompt(self) -> str:
|
| 805 |
-
"""Get standard negative prompt for inpainting."""
|
| 806 |
-
return (
|
| 807 |
-
"inconsistent lighting, wrong perspective, mismatched colors, "
|
| 808 |
-
"visible seams, blending artifacts, color bleeding, "
|
| 809 |
-
"blurry, low quality, distorted, deformed, "
|
| 810 |
-
"harsh edges, unnatural transition"
|
| 811 |
-
)
|
| 812 |
|
| 813 |
def execute_inpainting(
|
| 814 |
self,
|
| 815 |
image: Image.Image,
|
| 816 |
mask: Image.Image,
|
| 817 |
prompt: str,
|
| 818 |
-
preview_only: bool = False,
|
| 819 |
-
seed: Optional[int] = None,
|
| 820 |
progress_callback: Optional[Callable[[str, int], None]] = None,
|
| 821 |
**kwargs
|
| 822 |
) -> InpaintingResult:
|
| 823 |
"""
|
| 824 |
-
Execute
|
| 825 |
-
|
| 826 |
-
Implements two-stage generation: fast preview followed by
|
| 827 |
-
full quality generation if requested.
|
| 828 |
|
| 829 |
Parameters
|
| 830 |
----------
|
| 831 |
image : PIL.Image
|
| 832 |
-
Original image
|
| 833 |
mask : PIL.Image
|
| 834 |
Inpainting mask (white = area to regenerate)
|
| 835 |
prompt : str
|
| 836 |
-
Text description
|
| 837 |
-
preview_only : bool
|
| 838 |
-
If True, only generate preview (faster)
|
| 839 |
-
seed : int, optional
|
| 840 |
-
Random seed for reproducibility
|
| 841 |
progress_callback : callable, optional
|
| 842 |
-
Progress update function
|
| 843 |
**kwargs
|
| 844 |
-
Additional parameters
|
| 845 |
-
- controlnet_conditioning_scale: float
|
| 846 |
-
- feather_radius: int
|
| 847 |
-
- num_inference_steps: int
|
| 848 |
-
- guidance_scale: float
|
| 849 |
|
| 850 |
Returns
|
| 851 |
-------
|
| 852 |
InpaintingResult
|
| 853 |
-
Result
|
| 854 |
"""
|
| 855 |
start_time = time.time()
|
| 856 |
|
| 857 |
if not self.is_initialized:
|
| 858 |
return InpaintingResult(
|
| 859 |
success=False,
|
| 860 |
-
error_message="
|
| 861 |
)
|
| 862 |
|
| 863 |
-
logger.info(f"
|
| 864 |
|
| 865 |
try:
|
| 866 |
-
# Update config with kwargs
|
| 867 |
-
conditioning_scale = kwargs.get(
|
| 868 |
-
'controlnet_conditioning_scale',
|
| 869 |
-
self.config.controlnet_conditioning_scale
|
| 870 |
-
)
|
| 871 |
-
feather_radius = kwargs.get('feather_radius', self.config.feather_radius)
|
| 872 |
-
strength = kwargs.get('strength', self.config.strength)
|
| 873 |
-
preserve_structure = kwargs.get('preserve_structure_in_mask', False)
|
| 874 |
-
|
| 875 |
if progress_callback:
|
| 876 |
-
progress_callback("Preparing images...",
|
| 877 |
|
| 878 |
# Prepare image
|
| 879 |
if image.mode != 'RGB':
|
| 880 |
image = image.convert('RGB')
|
| 881 |
|
| 882 |
-
#
|
|
|
|
|
|
|
|
|
|
| 883 |
width, height = image.size
|
| 884 |
new_width = (width // 8) * 8
|
| 885 |
new_height = (height // 8) * 8
|
| 886 |
-
|
| 887 |
if new_width != width or new_height != height:
|
| 888 |
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 889 |
|
| 890 |
-
#
|
| 891 |
max_res = self.config.max_resolution
|
| 892 |
if max(new_width, new_height) > max_res:
|
| 893 |
scale = max_res / max(new_width, new_height)
|
| 894 |
new_width = int(new_width * scale) // 8 * 8
|
| 895 |
new_height = int(new_height * scale) // 8 * 8
|
| 896 |
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 897 |
-
logger.info(f"Reduced resolution to {new_width}x{new_height} for memory")
|
| 898 |
|
| 899 |
-
# Prepare mask
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
processed_mask, mask_info = self.prepare_mask(
|
| 904 |
mask,
|
| 905 |
(new_width, new_height),
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
if not mask_info["valid"]:
|
| 910 |
-
return InpaintingResult(
|
| 911 |
-
success=False,
|
| 912 |
-
error_message=mask_info["warning"]
|
| 913 |
-
)
|
| 914 |
-
|
| 915 |
-
# Generate control image
|
| 916 |
-
if progress_callback:
|
| 917 |
-
progress_callback("Generating control image...", 20)
|
| 918 |
-
|
| 919 |
-
control_image = self.prepare_control_image(
|
| 920 |
-
image,
|
| 921 |
-
self._current_conditioning_type,
|
| 922 |
-
mask=processed_mask,
|
| 923 |
-
preserve_structure=preserve_structure # True for color change, False for replacement/removal
|
| 924 |
)
|
| 925 |
|
| 926 |
-
#
|
| 927 |
-
|
| 928 |
-
|
|
|
|
|
|
|
| 929 |
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
prompt, image, processed_mask
|
| 935 |
-
)
|
| 936 |
-
logger.info(f"Prompt enhanced with OpenCLIP context")
|
| 937 |
-
else:
|
| 938 |
-
# Use prompt directly without enhancement
|
| 939 |
-
enhanced_prompt = prompt
|
| 940 |
-
negative_prompt = self._get_inpainting_negative_prompt()
|
| 941 |
-
logger.info("Prompt enhancement disabled for this template")
|
| 942 |
|
| 943 |
-
# Setup generator
|
| 944 |
-
|
|
|
|
|
|
|
| 945 |
seed = int(time.time() * 1000) % (2**32)
|
| 946 |
-
|
|
|
|
| 947 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
|
|
|
| 948 |
|
| 949 |
-
#
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
# Stage 1: Preview generation
|
| 953 |
-
# On Spaces, skip preview to save time (300s hard limit)
|
| 954 |
-
preview_result = None
|
| 955 |
-
|
| 956 |
-
if preview_only or not is_spaces:
|
| 957 |
if progress_callback:
|
| 958 |
-
progress_callback("Generating
|
| 959 |
-
|
| 960 |
-
# Optimize preview steps for Hugging Face Spaces
|
| 961 |
-
preview_steps = self.config.preview_steps
|
| 962 |
-
if is_spaces:
|
| 963 |
-
# On Spaces, use minimal preview steps
|
| 964 |
-
preview_steps = min(preview_steps, 8)
|
| 965 |
-
logger.debug(f"Spaces environment - using {preview_steps} preview steps")
|
| 966 |
|
| 967 |
-
|
| 968 |
image=image,
|
| 969 |
mask=processed_mask,
|
| 970 |
-
|
| 971 |
-
prompt=enhanced_prompt,
|
| 972 |
negative_prompt=negative_prompt,
|
| 973 |
-
|
| 974 |
-
guidance_scale=
|
| 975 |
-
controlnet_conditioning_scale=conditioning_scale,
|
| 976 |
strength=strength,
|
| 977 |
generator=generator
|
| 978 |
)
|
|
|
|
|
|
|
| 979 |
else:
|
| 980 |
-
|
|
|
|
|
|
|
| 981 |
|
| 982 |
-
|
| 983 |
-
|
|
|
|
| 984 |
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
"seed": seed,
|
| 992 |
-
"prompt": enhanced_prompt,
|
| 993 |
-
"conditioning_type": self._current_conditioning_type,
|
| 994 |
-
"conditioning_scale": conditioning_scale,
|
| 995 |
-
"preview_only": True
|
| 996 |
-
}
|
| 997 |
)
|
| 998 |
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
progress_callback("Generating full quality...", 60)
|
| 1002 |
-
|
| 1003 |
-
# Use same seed for reproducibility
|
| 1004 |
-
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 1005 |
-
|
| 1006 |
-
num_steps = kwargs.get('num_inference_steps', self.config.num_inference_steps)
|
| 1007 |
-
guidance = kwargs.get('guidance_scale', self.config.guidance_scale)
|
| 1008 |
-
|
| 1009 |
-
# Optimize for Hugging Face Spaces ZeroGPU (stateless, 300s hard limit)
|
| 1010 |
-
if is_spaces:
|
| 1011 |
-
# ZeroGPU timing breakdown with model caching (actual measurements):
|
| 1012 |
-
# - Model loading from cache: ~60s (cached models, CPU to GPU transfer)
|
| 1013 |
-
# - Inference: ~28-29s/step (observed on shared H200)
|
| 1014 |
-
# - Blending & overhead: ~35s
|
| 1015 |
-
# - Platform limit: 300s hard limit (Pro tier)
|
| 1016 |
-
#
|
| 1017 |
-
# Strategy with unified 10-step approach:
|
| 1018 |
-
# - Skip preview completely (done above)
|
| 1019 |
-
# - Use 10 steps for balance of quality and speed
|
| 1020 |
-
# - Time budget: 60s (load) + 285s (10 steps) + 35s (blend) = 380s
|
| 1021 |
-
# - Note: Still may timeout, but parameter optimization is more important than step count
|
| 1022 |
-
# - Quality comes from correct conditioning_scale, not high step count
|
| 1023 |
-
|
| 1024 |
-
spaces_max_steps = 10 # Optimized: 10 steps sufficient with proper parameters
|
| 1025 |
-
|
| 1026 |
-
if num_steps > spaces_max_steps:
|
| 1027 |
-
num_steps = spaces_max_steps
|
| 1028 |
-
logger.debug(f"Spaces deployment: using {num_steps} steps (optimized for parameter quality)")
|
| 1029 |
-
|
| 1030 |
-
full_result = self._generate_inpaint(
|
| 1031 |
-
image=image,
|
| 1032 |
-
mask=processed_mask,
|
| 1033 |
-
control_image=control_image,
|
| 1034 |
-
prompt=enhanced_prompt,
|
| 1035 |
-
negative_prompt=negative_prompt,
|
| 1036 |
-
num_inference_steps=num_steps,
|
| 1037 |
-
guidance_scale=guidance,
|
| 1038 |
-
controlnet_conditioning_scale=conditioning_scale,
|
| 1039 |
-
strength=strength,
|
| 1040 |
-
generator=generator
|
| 1041 |
-
)
|
| 1042 |
|
| 1043 |
-
|
| 1044 |
-
|
|
|
|
|
|
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1048 |
|
| 1049 |
generation_time = time.time() - start_time
|
| 1050 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1051 |
if progress_callback:
|
| 1052 |
progress_callback("Complete!", 100)
|
| 1053 |
|
| 1054 |
return InpaintingResult(
|
| 1055 |
success=True,
|
| 1056 |
-
result_image=
|
| 1057 |
-
|
| 1058 |
control_image=control_image,
|
| 1059 |
-
blended_image=blended,
|
| 1060 |
generation_time=generation_time,
|
| 1061 |
metadata={
|
| 1062 |
"seed": seed,
|
| 1063 |
-
"prompt":
|
| 1064 |
-
"
|
| 1065 |
-
"
|
| 1066 |
-
"
|
| 1067 |
"strength": strength,
|
| 1068 |
-
"
|
| 1069 |
-
"num_inference_steps": num_steps,
|
| 1070 |
-
"guidance_scale": guidance,
|
| 1071 |
-
"feather_radius": feather_radius,
|
| 1072 |
-
"mask_coverage": mask_info["coverage"],
|
| 1073 |
-
"preview_only": False
|
| 1074 |
}
|
| 1075 |
)
|
| 1076 |
|
| 1077 |
except torch.cuda.OutOfMemoryError:
|
| 1078 |
-
logger.error("CUDA out of memory
|
| 1079 |
self._memory_cleanup(aggressive=True)
|
| 1080 |
return InpaintingResult(
|
| 1081 |
success=False,
|
| 1082 |
-
error_message="GPU memory exhausted.
|
| 1083 |
)
|
| 1084 |
-
|
| 1085 |
except Exception as e:
|
| 1086 |
logger.error(f"Inpainting failed: {e}")
|
| 1087 |
-
|
| 1088 |
return InpaintingResult(
|
| 1089 |
success=False,
|
| 1090 |
-
error_message=
|
| 1091 |
)
|
| 1092 |
|
| 1093 |
-
def
|
| 1094 |
self,
|
| 1095 |
-
image: Image.Image,
|
| 1096 |
mask: Image.Image,
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
num_inference_steps: int,
|
| 1101 |
-
guidance_scale: float,
|
| 1102 |
-
controlnet_conditioning_scale: float,
|
| 1103 |
-
strength: float,
|
| 1104 |
-
generator: torch.Generator
|
| 1105 |
-
) -> Image.Image:
|
| 1106 |
-
"""
|
| 1107 |
-
Internal method to run the inpainting pipeline.
|
| 1108 |
-
|
| 1109 |
-
Supports both ControlNet and non-ControlNet pipelines.
|
| 1110 |
-
|
| 1111 |
-
Parameters
|
| 1112 |
-
----------
|
| 1113 |
-
image : PIL.Image
|
| 1114 |
-
Original image
|
| 1115 |
-
mask : PIL.Image
|
| 1116 |
-
Processed mask
|
| 1117 |
-
control_image : PIL.Image
|
| 1118 |
-
ControlNet conditioning image (ignored if ControlNet not available)
|
| 1119 |
-
prompt : str
|
| 1120 |
-
Enhanced prompt
|
| 1121 |
-
negative_prompt : str
|
| 1122 |
-
Negative prompt
|
| 1123 |
-
num_inference_steps : int
|
| 1124 |
-
Number of denoising steps
|
| 1125 |
-
guidance_scale : float
|
| 1126 |
-
Classifier-free guidance scale
|
| 1127 |
-
controlnet_conditioning_scale : float
|
| 1128 |
-
ControlNet influence strength (ignored if ControlNet not available)
|
| 1129 |
-
strength : float
|
| 1130 |
-
Inpainting strength (0.0-1.0). 1.0 = fully repaint masked area.
|
| 1131 |
-
generator : torch.Generator
|
| 1132 |
-
Random generator for reproducibility
|
| 1133 |
-
|
| 1134 |
-
Returns
|
| 1135 |
-
-------
|
| 1136 |
-
PIL.Image
|
| 1137 |
-
Generated image
|
| 1138 |
-
"""
|
| 1139 |
-
with torch.inference_mode():
|
| 1140 |
-
if self._use_controlnet:
|
| 1141 |
-
# Full ControlNet inpainting pipeline
|
| 1142 |
-
result = self._inpaint_pipeline(
|
| 1143 |
-
prompt=prompt,
|
| 1144 |
-
negative_prompt=negative_prompt,
|
| 1145 |
-
image=image,
|
| 1146 |
-
mask_image=mask,
|
| 1147 |
-
control_image=control_image,
|
| 1148 |
-
num_inference_steps=num_inference_steps,
|
| 1149 |
-
guidance_scale=guidance_scale,
|
| 1150 |
-
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
| 1151 |
-
strength=strength,
|
| 1152 |
-
generator=generator
|
| 1153 |
-
)
|
| 1154 |
-
else:
|
| 1155 |
-
# Fallback: Standard SDXL inpainting without ControlNet
|
| 1156 |
-
result = self._inpaint_pipeline(
|
| 1157 |
-
prompt=prompt,
|
| 1158 |
-
negative_prompt=negative_prompt,
|
| 1159 |
-
image=image,
|
| 1160 |
-
mask_image=mask,
|
| 1161 |
-
num_inference_steps=num_inference_steps,
|
| 1162 |
-
guidance_scale=guidance_scale,
|
| 1163 |
-
strength=strength,
|
| 1164 |
-
generator=generator
|
| 1165 |
-
)
|
| 1166 |
-
|
| 1167 |
-
return result.images[0]
|
| 1168 |
-
|
| 1169 |
-
def blend_result(
|
| 1170 |
-
self,
|
| 1171 |
-
original: Image.Image,
|
| 1172 |
-
generated: Image.Image,
|
| 1173 |
-
mask: Image.Image
|
| 1174 |
) -> Image.Image:
|
| 1175 |
-
"""
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
|
|
|
|
|
|
|
| 1179 |
|
| 1180 |
-
|
| 1181 |
-
----------
|
| 1182 |
-
original : PIL.Image
|
| 1183 |
-
Original image
|
| 1184 |
-
generated : PIL.Image
|
| 1185 |
-
Generated inpainted image
|
| 1186 |
-
mask : PIL.Image
|
| 1187 |
-
Blending mask (white = use generated)
|
| 1188 |
|
| 1189 |
-
|
| 1190 |
-
|
| 1191 |
-
|
| 1192 |
-
|
| 1193 |
-
|
| 1194 |
-
logger.info("Blending inpainting result...")
|
| 1195 |
-
|
| 1196 |
-
# Ensure same size
|
| 1197 |
-
if generated.size != original.size:
|
| 1198 |
-
generated = generated.resize(original.size, Image.LANCZOS)
|
| 1199 |
-
if mask.size != original.size:
|
| 1200 |
-
mask = mask.resize(original.size, Image.LANCZOS)
|
| 1201 |
-
|
| 1202 |
-
# Convert to arrays
|
| 1203 |
-
orig_array = np.array(original.convert('RGB')).astype(np.float32)
|
| 1204 |
-
gen_array = np.array(generated.convert('RGB')).astype(np.float32)
|
| 1205 |
-
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
|
| 1206 |
-
|
| 1207 |
-
# sRGB to linear conversion
|
| 1208 |
-
def srgb_to_linear(img):
|
| 1209 |
-
img_norm = img / 255.0
|
| 1210 |
-
return np.where(
|
| 1211 |
-
img_norm <= 0.04045,
|
| 1212 |
-
img_norm / 12.92,
|
| 1213 |
-
np.power((img_norm + 0.055) / 1.055, 2.4)
|
| 1214 |
)
|
|
|
|
|
|
|
| 1215 |
|
| 1216 |
-
|
| 1217 |
-
|
| 1218 |
-
|
| 1219 |
-
|
| 1220 |
-
|
| 1221 |
-
|
| 1222 |
)
|
| 1223 |
|
| 1224 |
-
|
| 1225 |
-
orig_linear = srgb_to_linear(orig_array)
|
| 1226 |
-
gen_linear = srgb_to_linear(gen_array)
|
| 1227 |
-
|
| 1228 |
-
# Alpha blending in linear space
|
| 1229 |
-
alpha = mask_array[:, :, np.newaxis]
|
| 1230 |
-
result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
|
| 1231 |
-
|
| 1232 |
-
# Convert back to sRGB
|
| 1233 |
-
result_srgb = linear_to_srgb(result_linear)
|
| 1234 |
-
result_array = (result_srgb * 255).astype(np.uint8)
|
| 1235 |
|
| 1236 |
-
|
| 1237 |
-
|
| 1238 |
-
return Image.fromarray(result_array)
|
| 1239 |
-
|
| 1240 |
-
def execute_with_auto_optimization(
|
| 1241 |
self,
|
| 1242 |
image: Image.Image,
|
| 1243 |
mask: Image.Image,
|
| 1244 |
prompt: str,
|
| 1245 |
-
|
| 1246 |
-
|
| 1247 |
-
|
| 1248 |
-
|
| 1249 |
-
|
| 1250 |
-
|
| 1251 |
-
|
| 1252 |
-
|
| 1253 |
-
|
| 1254 |
-
|
| 1255 |
-
|
| 1256 |
-
|
| 1257 |
-
|
| 1258 |
-
|
| 1259 |
-
|
| 1260 |
-
|
| 1261 |
-
|
| 1262 |
-
quality_checker : QualityChecker
|
| 1263 |
-
Quality assessment instance
|
| 1264 |
-
progress_callback : callable, optional
|
| 1265 |
-
Progress update function
|
| 1266 |
-
**kwargs
|
| 1267 |
-
Additional inpainting parameters
|
| 1268 |
-
|
| 1269 |
-
Returns
|
| 1270 |
-
-------
|
| 1271 |
-
InpaintingResult
|
| 1272 |
-
Best result achieved (may include retry information)
|
| 1273 |
-
"""
|
| 1274 |
-
if not self.config.enable_auto_optimization:
|
| 1275 |
-
return self.execute_inpainting(
|
| 1276 |
-
image, mask, prompt,
|
| 1277 |
-
progress_callback=progress_callback,
|
| 1278 |
-
**kwargs
|
| 1279 |
)
|
|
|
|
| 1280 |
|
| 1281 |
-
|
| 1282 |
-
|
| 1283 |
-
|
| 1284 |
-
|
| 1285 |
-
|
| 1286 |
-
|
| 1287 |
-
|
| 1288 |
-
|
| 1289 |
-
|
| 1290 |
-
|
| 1291 |
-
|
| 1292 |
-
|
| 1293 |
-
|
| 1294 |
-
|
| 1295 |
-
|
| 1296 |
-
|
| 1297 |
-
|
| 1298 |
-
|
| 1299 |
-
|
| 1300 |
-
|
| 1301 |
-
|
| 1302 |
-
|
| 1303 |
-
|
| 1304 |
-
controlnet_conditioning_scale=
|
| 1305 |
-
|
| 1306 |
-
|
| 1307 |
-
**{k: v for k, v in kwargs.items()
|
| 1308 |
-
if k not in ['feather_radius', 'controlnet_conditioning_scale',
|
| 1309 |
-
'guidance_scale']}
|
| 1310 |
)
|
| 1311 |
-
|
| 1312 |
-
if not result.success:
|
| 1313 |
-
return result
|
| 1314 |
-
|
| 1315 |
-
# Evaluate quality
|
| 1316 |
-
if result.blended_image is not None:
|
| 1317 |
-
quality_results = quality_checker.run_all_checks(
|
| 1318 |
-
foreground=image,
|
| 1319 |
-
background=result.result_image,
|
| 1320 |
-
mask=mask,
|
| 1321 |
-
combined=result.blended_image
|
| 1322 |
-
)
|
| 1323 |
-
quality_score = quality_results.get("overall_score", 0)
|
| 1324 |
-
else:
|
| 1325 |
-
quality_score = 50.0 # Default if no blended image
|
| 1326 |
-
|
| 1327 |
-
result.quality_score = quality_score
|
| 1328 |
-
result.quality_details = quality_results if result.blended_image else {}
|
| 1329 |
-
result.retries = retry_count
|
| 1330 |
-
|
| 1331 |
-
logger.info(f"Quality score: {quality_score:.1f} (attempt {retry_count + 1})")
|
| 1332 |
-
|
| 1333 |
-
# Track best result
|
| 1334 |
-
if quality_score > best_score:
|
| 1335 |
-
best_score = quality_score
|
| 1336 |
-
best_result = result
|
| 1337 |
-
|
| 1338 |
-
# Check if quality is acceptable
|
| 1339 |
-
if quality_score >= self.config.min_quality_score:
|
| 1340 |
-
logger.info(f"Quality threshold met: {quality_score:.1f}")
|
| 1341 |
-
return best_result
|
| 1342 |
-
|
| 1343 |
-
# Check for minimal improvement (early termination)
|
| 1344 |
-
if retry_count > 0 and abs(quality_score - prev_score) < 5.0:
|
| 1345 |
-
logger.info("Minimal improvement, stopping optimization")
|
| 1346 |
-
return best_result
|
| 1347 |
-
|
| 1348 |
-
prev_score = quality_score
|
| 1349 |
-
retry_count += 1
|
| 1350 |
-
|
| 1351 |
-
if retry_count > self.config.max_optimization_retries:
|
| 1352 |
-
break
|
| 1353 |
-
|
| 1354 |
-
# Adjust parameters based on quality issues
|
| 1355 |
-
checks = quality_results.get("checks", {})
|
| 1356 |
-
|
| 1357 |
-
edge_score = checks.get("edge_continuity", {}).get("score", 100)
|
| 1358 |
-
harmony_score = checks.get("color_harmony", {}).get("score", 100)
|
| 1359 |
-
|
| 1360 |
-
if edge_score < 60:
|
| 1361 |
-
# Edge issues: increase feathering, decrease control strength
|
| 1362 |
-
current_feather = min(20, current_feather + 3)
|
| 1363 |
-
current_scale = max(0.5, current_scale - 0.1)
|
| 1364 |
-
logger.debug(f"Adjusting for edges: feather={current_feather}, scale={current_scale}")
|
| 1365 |
-
|
| 1366 |
-
if harmony_score < 60:
|
| 1367 |
-
# Color harmony issues: emphasize consistency in prompt
|
| 1368 |
-
if "color consistent" not in current_prompt.lower():
|
| 1369 |
-
current_prompt = f"{current_prompt}, color consistent with surroundings, matching lighting"
|
| 1370 |
-
current_guidance = min(12.0, current_guidance + 1.0)
|
| 1371 |
-
logger.debug(f"Adjusting for harmony: guidance={current_guidance}")
|
| 1372 |
-
|
| 1373 |
-
if edge_score < 60 and harmony_score < 60:
|
| 1374 |
-
# Both issues: stronger guidance
|
| 1375 |
-
current_guidance = min(12.0, current_guidance + 1.5)
|
| 1376 |
-
|
| 1377 |
-
logger.info(f"Optimization complete. Best score: {best_score:.1f}")
|
| 1378 |
-
return best_result
|
| 1379 |
|
| 1380 |
def get_status(self) -> Dict[str, Any]:
|
| 1381 |
-
"""
|
| 1382 |
-
|
| 1383 |
-
|
| 1384 |
-
Returns
|
| 1385 |
-
-------
|
| 1386 |
-
dict
|
| 1387 |
-
Status information including initialization state and memory usage
|
| 1388 |
-
"""
|
| 1389 |
-
status = {
|
| 1390 |
"initialized": self.is_initialized,
|
| 1391 |
"device": self.device,
|
|
|
|
| 1392 |
"conditioning_type": self._current_conditioning_type,
|
| 1393 |
-
"
|
| 1394 |
-
"config": {
|
| 1395 |
-
"controlnet_conditioning_scale": self.config.controlnet_conditioning_scale,
|
| 1396 |
-
"feather_radius": self.config.feather_radius,
|
| 1397 |
-
"num_inference_steps": self.config.num_inference_steps,
|
| 1398 |
-
"guidance_scale": self.config.guidance_scale
|
| 1399 |
-
}
|
| 1400 |
}
|
| 1401 |
-
|
| 1402 |
-
status["memory"] = self._check_memory_status()
|
| 1403 |
-
|
| 1404 |
-
return status
|
|
|
|
| 4 |
import time
|
| 5 |
import traceback
|
| 6 |
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Any, Callable, Dict, Optional, Tuple
|
| 8 |
|
| 9 |
import cv2
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
|
| 14 |
+
from diffusers import AutoPipelineForInpainting
|
| 15 |
+
from diffusers import ControlNetModel
|
| 16 |
+
from diffusers import DPMSolverMultistepScheduler
|
| 17 |
from diffusers import StableDiffusionXLControlNetInpaintPipeline
|
| 18 |
+
from transformers import AutoImageProcessor
|
| 19 |
+
from transformers import AutoModelForDepthEstimation
|
| 20 |
+
from transformers import DPTForDepthEstimation
|
| 21 |
+
from transformers import DPTImageProcessor
|
| 22 |
+
|
| 23 |
+
from control_image_processor import ControlImageProcessor
|
| 24 |
+
from inpainting_blender import InpaintingBlender
|
| 25 |
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
logger.setLevel(logging.INFO)
|
| 28 |
|
| 29 |
|
| 30 |
+
# Dedicated SDXL Inpainting model - trained specifically for inpainting
|
| 31 |
+
SDXL_INPAINTING_MODEL = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
@dataclass
|
| 35 |
class InpaintingConfig:
|
| 36 |
"""Configuration for inpainting operations."""
|
| 37 |
|
| 38 |
+
# ControlNet settings (for ControlNet mode only)
|
| 39 |
controlnet_conditioning_scale: float = 0.7
|
| 40 |
+
conditioning_type: str = "canny"
|
| 41 |
|
| 42 |
# Canny edge detection parameters
|
| 43 |
canny_low_threshold: int = 100
|
| 44 |
canny_high_threshold: int = 200
|
| 45 |
|
| 46 |
# Mask settings
|
| 47 |
+
feather_radius: int = 3
|
| 48 |
min_mask_coverage: float = 0.01
|
| 49 |
max_mask_coverage: float = 0.95
|
| 50 |
|
| 51 |
# Generation settings
|
| 52 |
num_inference_steps: int = 25
|
| 53 |
guidance_scale: float = 7.5
|
| 54 |
+
strength: float = 0.99 # Use 0.99 to avoid noise issues with 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
# Memory settings
|
| 57 |
enable_vae_tiling: bool = True
|
|
|
|
| 58 |
max_resolution: int = 1024
|
| 59 |
|
| 60 |
|
|
|
|
| 68 |
control_image: Optional[Image.Image] = None
|
| 69 |
blended_image: Optional[Image.Image] = None
|
| 70 |
quality_score: float = 0.0
|
|
|
|
| 71 |
generation_time: float = 0.0
|
|
|
|
| 72 |
error_message: str = ""
|
| 73 |
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 74 |
|
| 75 |
|
| 76 |
class InpaintingModule:
|
| 77 |
"""
|
| 78 |
+
Dual-mode Inpainting Module for SceneWeaver.
|
| 79 |
|
| 80 |
+
Supports two modes:
|
| 81 |
+
1. Pure Inpainting (use_controlnet=False): Uses dedicated SDXL Inpainting model
|
| 82 |
+
- Best for: Object replacement, Object removal
|
| 83 |
+
- More stable, better edge blending
|
| 84 |
|
| 85 |
+
2. ControlNet Inpainting (use_controlnet=True): Uses ControlNet + SDXL
|
| 86 |
+
- Best for: Clothing change (depth), Color change (canny)
|
| 87 |
+
- Preserves structure in masked region
|
|
|
|
| 88 |
|
| 89 |
Example:
|
| 90 |
>>> module = InpaintingModule(device="cuda")
|
| 91 |
+
>>> # For object replacement (no ControlNet)
|
| 92 |
+
>>> module.load_pipeline(use_controlnet=False)
|
| 93 |
+
>>> result = module.execute_inpainting(image, mask, "a vase with flowers")
|
|
|
|
|
|
|
|
|
|
| 94 |
"""
|
| 95 |
|
| 96 |
+
# ControlNet model identifiers
|
| 97 |
CONTROLNET_CANNY_MODEL = "diffusers/controlnet-canny-sdxl-1.0"
|
| 98 |
CONTROLNET_DEPTH_MODEL = "diffusers/controlnet-depth-sdxl-1.0"
|
| 99 |
DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf"
|
| 100 |
DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas"
|
| 101 |
+
|
| 102 |
+
# Base models for ControlNet mode
|
| 103 |
+
SUPPORTED_MODELS = {
|
| 104 |
+
"juggernaut_xl": "RunDiffusion/Juggernaut-XL-v9",
|
| 105 |
+
"realvis_xl": "SG161222/RealVisXL_V4.0",
|
| 106 |
+
"sdxl_base": "stabilityai/stable-diffusion-xl-base-1.0",
|
| 107 |
+
"animagine_xl": "cagliostrolab/animagine-xl-3.1",
|
| 108 |
+
}
|
| 109 |
|
| 110 |
def __init__(
|
| 111 |
self,
|
| 112 |
device: str = "auto",
|
| 113 |
config: Optional[InpaintingConfig] = None
|
| 114 |
):
|
| 115 |
+
"""Initialize the InpaintingModule."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
self.device = self._setup_device(device)
|
| 117 |
self.config = config or InpaintingConfig()
|
| 118 |
|
| 119 |
+
# Sub-modules
|
| 120 |
+
self._control_processor = ControlImageProcessor(
|
| 121 |
+
device=self.device,
|
| 122 |
+
canny_low_threshold=self.config.canny_low_threshold,
|
| 123 |
+
canny_high_threshold=self.config.canny_high_threshold
|
| 124 |
+
)
|
| 125 |
+
self._blender = InpaintingBlender(
|
| 126 |
+
min_mask_coverage=self.config.min_mask_coverage,
|
| 127 |
+
max_mask_coverage=self.config.max_mask_coverage
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Pipeline instances
|
| 131 |
+
self._pipeline = None
|
| 132 |
+
self._controlnet = None
|
| 133 |
self._depth_estimator = None
|
| 134 |
self._depth_processor = None
|
| 135 |
|
| 136 |
# State tracking
|
| 137 |
self.is_initialized = False
|
| 138 |
+
self._current_mode = None # "pure" or "controlnet"
|
| 139 |
self._current_conditioning_type = None
|
| 140 |
+
self._current_model_key = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
logger.info(f"InpaintingModule initialized on {self.device}")
|
| 143 |
|
| 144 |
def _setup_device(self, device: str) -> str:
|
| 145 |
+
"""Setup computation device."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
if device == "auto":
|
| 147 |
if torch.cuda.is_available():
|
| 148 |
return "cuda"
|
|
|
|
| 151 |
return "cpu"
|
| 152 |
return device
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def _memory_cleanup(self, aggressive: bool = False) -> None:
|
| 155 |
+
"""Perform memory cleanup."""
|
| 156 |
+
for _ in range(5 if aggressive else 2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
gc.collect()
|
| 158 |
|
|
|
|
|
|
|
| 159 |
is_spaces = os.getenv('SPACE_ID') is not None
|
|
|
|
| 160 |
if not is_spaces and torch.cuda.is_available():
|
| 161 |
torch.cuda.empty_cache()
|
| 162 |
if aggressive:
|
| 163 |
torch.cuda.ipc_collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
def load_pipeline(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
self,
|
| 167 |
+
use_controlnet: bool = False,
|
| 168 |
conditioning_type: str = "canny",
|
| 169 |
+
model_key: str = "sdxl_base",
|
| 170 |
progress_callback: Optional[Callable[[str, int], None]] = None
|
| 171 |
) -> Tuple[bool, str]:
|
| 172 |
"""
|
| 173 |
+
Load the appropriate inpainting pipeline.
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
Parameters
|
| 176 |
----------
|
| 177 |
+
use_controlnet : bool
|
| 178 |
+
If False, use dedicated SDXL Inpainting model (for replacement/removal)
|
| 179 |
+
If True, use ControlNet pipeline (for clothing/color change)
|
| 180 |
conditioning_type : str
|
| 181 |
+
ControlNet type: "canny" or "depth" (only used when use_controlnet=True)
|
| 182 |
+
model_key : str
|
| 183 |
+
Base model for ControlNet mode
|
| 184 |
progress_callback : callable, optional
|
| 185 |
+
Progress update function
|
| 186 |
|
| 187 |
Returns
|
| 188 |
-------
|
| 189 |
tuple
|
| 190 |
(success: bool, error_message: str)
|
| 191 |
"""
|
| 192 |
+
mode = "controlnet" if use_controlnet else "pure"
|
| 193 |
+
|
| 194 |
+
# Check if already loaded with same config
|
| 195 |
+
if (self.is_initialized and
|
| 196 |
+
self._current_mode == mode and
|
| 197 |
+
(not use_controlnet or
|
| 198 |
+
(self._current_conditioning_type == conditioning_type and
|
| 199 |
+
self._current_model_key == model_key))):
|
| 200 |
+
logger.info(f"Pipeline already loaded: mode={mode}")
|
| 201 |
return True, ""
|
| 202 |
|
| 203 |
+
logger.info(f"Loading pipeline: mode={mode}, conditioning={conditioning_type}")
|
| 204 |
|
| 205 |
try:
|
| 206 |
self._memory_cleanup(aggressive=True)
|
| 207 |
|
| 208 |
if progress_callback:
|
| 209 |
+
progress_callback("Preparing pipeline...", 10)
|
| 210 |
+
|
| 211 |
+
# Unload existing pipeline
|
| 212 |
+
self._unload_pipeline()
|
| 213 |
|
| 214 |
+
dtype = torch.float16 if self.device == "cuda" else torch.float32
|
|
|
|
|
|
|
| 215 |
|
| 216 |
+
if not use_controlnet:
|
| 217 |
+
# Mode A: Pure SDXL Inpainting (for replacement/removal)
|
| 218 |
+
if progress_callback:
|
| 219 |
+
progress_callback("Loading SDXL Inpainting model...", 30)
|
| 220 |
|
| 221 |
+
self._pipeline = AutoPipelineForInpainting.from_pretrained(
|
| 222 |
+
SDXL_INPAINTING_MODEL,
|
| 223 |
+
torch_dtype=dtype,
|
| 224 |
+
variant="fp16" if dtype == torch.float16 else None,
|
| 225 |
+
)
|
| 226 |
+
self._current_mode = "pure"
|
| 227 |
+
self._current_conditioning_type = None
|
| 228 |
+
logger.info("Loaded pure SDXL Inpainting pipeline")
|
| 229 |
|
| 230 |
+
else:
|
| 231 |
+
# Mode B: ControlNet Inpainting (for structure-preserving tasks)
|
| 232 |
+
if model_key not in self.SUPPORTED_MODELS:
|
| 233 |
+
model_key = "sdxl_base"
|
| 234 |
+
base_model_id = self.SUPPORTED_MODELS[model_key]
|
| 235 |
|
| 236 |
+
if progress_callback:
|
| 237 |
+
progress_callback("Loading ControlNet model...", 30)
|
| 238 |
+
|
| 239 |
+
# Load ControlNet
|
| 240 |
if conditioning_type == "canny":
|
| 241 |
+
self._controlnet = ControlNetModel.from_pretrained(
|
| 242 |
self.CONTROLNET_CANNY_MODEL,
|
| 243 |
torch_dtype=dtype,
|
| 244 |
use_safetensors=True
|
| 245 |
)
|
|
|
|
|
|
|
|
|
|
| 246 |
elif conditioning_type == "depth":
|
| 247 |
+
self._controlnet = ControlNetModel.from_pretrained(
|
| 248 |
self.CONTROLNET_DEPTH_MODEL,
|
| 249 |
torch_dtype=dtype,
|
| 250 |
use_safetensors=True
|
| 251 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
self._load_depth_estimator()
|
|
|
|
| 253 |
else:
|
| 254 |
raise ValueError(f"Unknown conditioning type: {conditioning_type}")
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
+
if progress_callback:
|
| 257 |
+
progress_callback(f"Loading {model_key}...", 60)
|
| 258 |
+
|
| 259 |
+
# Load pipeline with ControlNet
|
| 260 |
+
use_variant = model_key != "animagine_xl"
|
| 261 |
+
load_kwargs = {
|
| 262 |
+
"controlnet": self._controlnet,
|
| 263 |
+
"torch_dtype": dtype,
|
| 264 |
+
"use_safetensors": True,
|
| 265 |
+
}
|
| 266 |
+
if use_variant and dtype == torch.float16:
|
| 267 |
+
load_kwargs["variant"] = "fp16"
|
| 268 |
|
| 269 |
+
self._pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
|
| 270 |
+
base_model_id,
|
| 271 |
+
**load_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
)
|
| 273 |
+
self._current_mode = "controlnet"
|
| 274 |
+
self._current_conditioning_type = conditioning_type
|
| 275 |
+
self._current_model_key = model_key
|
| 276 |
+
logger.info(f"Loaded ControlNet pipeline: {model_key} + {conditioning_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
if progress_callback:
|
| 279 |
+
progress_callback("Configuring pipeline...", 80)
|
| 280 |
|
| 281 |
+
# Configure scheduler
|
| 282 |
+
self._pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 283 |
+
self._pipeline.scheduler.config
|
| 284 |
)
|
| 285 |
|
| 286 |
+
# Move to device and optimize
|
| 287 |
+
self._pipeline = self._pipeline.to(self.device)
|
| 288 |
+
self._apply_optimizations()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
self.is_initialized = True
|
|
|
|
| 291 |
|
| 292 |
if progress_callback:
|
| 293 |
+
progress_callback("Pipeline ready!", 100)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
return True, ""
|
| 296 |
|
| 297 |
except Exception as e:
|
| 298 |
error_msg = str(e)
|
| 299 |
+
logger.error(f"Failed to load pipeline: {error_msg}")
|
| 300 |
traceback.print_exc()
|
| 301 |
self._unload_pipeline()
|
| 302 |
return False, error_msg
|
| 303 |
|
| 304 |
def _load_depth_estimator(self) -> None:
|
| 305 |
+
"""Load depth estimation model."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
try:
|
|
|
|
|
|
|
| 307 |
self._depth_processor = AutoImageProcessor.from_pretrained(
|
| 308 |
self.DEPTH_MODEL_PRIMARY
|
| 309 |
)
|
|
|
|
| 313 |
)
|
| 314 |
self._depth_estimator.to(self.device)
|
| 315 |
self._depth_estimator.eval()
|
| 316 |
+
logger.info("Loaded Depth-Anything model")
|
|
|
|
|
|
|
| 317 |
except Exception as e:
|
| 318 |
logger.warning(f"Primary depth model failed: {e}, trying fallback...")
|
| 319 |
+
self._depth_processor = DPTImageProcessor.from_pretrained(
|
| 320 |
+
self.DEPTH_MODEL_FALLBACK
|
| 321 |
+
)
|
| 322 |
+
self._depth_estimator = DPTForDepthEstimation.from_pretrained(
|
| 323 |
+
self.DEPTH_MODEL_FALLBACK,
|
| 324 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
| 325 |
+
)
|
| 326 |
+
self._depth_estimator.to(self.device)
|
| 327 |
+
self._depth_estimator.eval()
|
| 328 |
+
logger.info("Loaded MiDaS fallback model")
|
| 329 |
|
| 330 |
+
def _apply_optimizations(self) -> None:
|
| 331 |
+
"""Apply memory and performance optimizations."""
|
| 332 |
+
if self._pipeline is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
return
|
| 334 |
|
|
|
|
| 335 |
try:
|
| 336 |
+
self._pipeline.enable_xformers_memory_efficient_attention()
|
| 337 |
+
logger.info("Enabled xformers attention")
|
| 338 |
except Exception:
|
| 339 |
try:
|
| 340 |
+
self._pipeline.enable_attention_slicing()
|
| 341 |
logger.info("Enabled attention slicing")
|
| 342 |
except Exception:
|
| 343 |
+
pass
|
| 344 |
|
|
|
|
| 345 |
if self.config.enable_vae_tiling:
|
| 346 |
+
if hasattr(self._pipeline, 'enable_vae_tiling'):
|
| 347 |
+
self._pipeline.enable_vae_tiling()
|
| 348 |
+
if hasattr(self._pipeline, 'enable_vae_slicing'):
|
| 349 |
+
self._pipeline.enable_vae_slicing()
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
def _unload_pipeline(self) -> None:
|
| 352 |
+
"""Unload pipeline and free memory."""
|
| 353 |
+
if self._pipeline is not None:
|
| 354 |
+
del self._pipeline
|
| 355 |
+
self._pipeline = None
|
| 356 |
|
| 357 |
+
if self._controlnet is not None:
|
| 358 |
+
del self._controlnet
|
| 359 |
+
self._controlnet = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
if self._depth_estimator is not None:
|
| 362 |
del self._depth_estimator
|
|
|
|
| 367 |
self._depth_processor = None
|
| 368 |
|
| 369 |
self.is_initialized = False
|
| 370 |
+
self._current_mode = None
|
| 371 |
self._current_conditioning_type = None
|
|
|
|
| 372 |
|
| 373 |
self._memory_cleanup(aggressive=True)
|
| 374 |
+
logger.info("Pipeline unloaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
def execute_inpainting(
|
| 377 |
self,
|
| 378 |
image: Image.Image,
|
| 379 |
mask: Image.Image,
|
| 380 |
prompt: str,
|
|
|
|
|
|
|
| 381 |
progress_callback: Optional[Callable[[str, int], None]] = None,
|
| 382 |
**kwargs
|
| 383 |
) -> InpaintingResult:
|
| 384 |
"""
|
| 385 |
+
Execute inpainting operation.
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
Parameters
|
| 388 |
----------
|
| 389 |
image : PIL.Image
|
| 390 |
+
Original image
|
| 391 |
mask : PIL.Image
|
| 392 |
Inpainting mask (white = area to regenerate)
|
| 393 |
prompt : str
|
| 394 |
+
Text description
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
progress_callback : callable, optional
|
| 396 |
+
Progress update function
|
| 397 |
**kwargs
|
| 398 |
+
Additional parameters from template
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
|
| 400 |
Returns
|
| 401 |
-------
|
| 402 |
InpaintingResult
|
| 403 |
+
Result with generated image
|
| 404 |
"""
|
| 405 |
start_time = time.time()
|
| 406 |
|
| 407 |
if not self.is_initialized:
|
| 408 |
return InpaintingResult(
|
| 409 |
success=False,
|
| 410 |
+
error_message="Pipeline not initialized. Call load_pipeline() first."
|
| 411 |
)
|
| 412 |
|
| 413 |
+
logger.info(f"Inpainting: mode={self._current_mode}, prompt='{prompt[:50]}...'")
|
| 414 |
|
| 415 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
if progress_callback:
|
| 417 |
+
progress_callback("Preparing images...", 10)
|
| 418 |
|
| 419 |
# Prepare image
|
| 420 |
if image.mode != 'RGB':
|
| 421 |
image = image.convert('RGB')
|
| 422 |
|
| 423 |
+
# Store original size for later restoration
|
| 424 |
+
original_size = image.size # (width, height)
|
| 425 |
+
|
| 426 |
+
# Ensure dimensions are multiple of 8 for model compatibility
|
| 427 |
width, height = image.size
|
| 428 |
new_width = (width // 8) * 8
|
| 429 |
new_height = (height // 8) * 8
|
|
|
|
| 430 |
if new_width != width or new_height != height:
|
| 431 |
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 432 |
|
| 433 |
+
# Limit resolution for memory efficiency
|
| 434 |
max_res = self.config.max_resolution
|
| 435 |
if max(new_width, new_height) > max_res:
|
| 436 |
scale = max_res / max(new_width, new_height)
|
| 437 |
new_width = int(new_width * scale) // 8 * 8
|
| 438 |
new_height = int(new_height * scale) // 8 * 8
|
| 439 |
image = image.resize((new_width, new_height), Image.LANCZOS)
|
|
|
|
| 440 |
|
| 441 |
+
# Prepare mask with dilation
|
| 442 |
+
mask_dilation = kwargs.get('mask_dilation', 0)
|
| 443 |
+
processed_mask = self._prepare_mask(
|
|
|
|
|
|
|
| 444 |
mask,
|
| 445 |
(new_width, new_height),
|
| 446 |
+
dilation=mask_dilation,
|
| 447 |
+
feather_radius=kwargs.get('feather_radius', self.config.feather_radius)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
)
|
| 449 |
|
| 450 |
+
# Get generation parameters
|
| 451 |
+
strength = kwargs.get('strength', self.config.strength)
|
| 452 |
+
guidance_scale = kwargs.get('guidance_scale', self.config.guidance_scale)
|
| 453 |
+
num_steps = kwargs.get('num_inference_steps', self.config.num_inference_steps)
|
| 454 |
+
negative_prompt = kwargs.get('negative_prompt', "")
|
| 455 |
|
| 456 |
+
# Optimize for HuggingFace Spaces
|
| 457 |
+
is_spaces = os.getenv('SPACE_ID') is not None
|
| 458 |
+
if is_spaces:
|
| 459 |
+
num_steps = min(num_steps, 15)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
|
| 461 |
+
# Setup generator with seed
|
| 462 |
+
# If seed is -1 or None, use random seed based on current time
|
| 463 |
+
input_seed = kwargs.get('seed', -1)
|
| 464 |
+
if input_seed is None or input_seed < 0:
|
| 465 |
seed = int(time.time() * 1000) % (2**32)
|
| 466 |
+
else:
|
| 467 |
+
seed = int(input_seed)
|
| 468 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 469 |
+
logger.info(f"Using seed: {seed}")
|
| 470 |
|
| 471 |
+
# Generate based on mode
|
| 472 |
+
if self._current_mode == "pure":
|
| 473 |
+
# Pure inpainting - no ControlNet
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
if progress_callback:
|
| 475 |
+
progress_callback("Generating (Pure Inpainting)...", 40)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
|
| 477 |
+
result_image = self._generate_pure_inpaint(
|
| 478 |
image=image,
|
| 479 |
mask=processed_mask,
|
| 480 |
+
prompt=prompt,
|
|
|
|
| 481 |
negative_prompt=negative_prompt,
|
| 482 |
+
num_steps=num_steps,
|
| 483 |
+
guidance_scale=guidance_scale,
|
|
|
|
| 484 |
strength=strength,
|
| 485 |
generator=generator
|
| 486 |
)
|
| 487 |
+
control_image = None
|
| 488 |
+
|
| 489 |
else:
|
| 490 |
+
# ControlNet inpainting
|
| 491 |
+
if progress_callback:
|
| 492 |
+
progress_callback("Generating control image...", 30)
|
| 493 |
|
| 494 |
+
# Prepare control image
|
| 495 |
+
preserve_structure = kwargs.get('preserve_structure_in_mask', False)
|
| 496 |
+
edge_guidance_mode = kwargs.get('edge_guidance_mode', 'boundary')
|
| 497 |
|
| 498 |
+
control_image = self._control_processor.prepare_control_image(
|
| 499 |
+
image=image,
|
| 500 |
+
mode=self._current_conditioning_type,
|
| 501 |
+
mask=processed_mask,
|
| 502 |
+
preserve_structure=preserve_structure,
|
| 503 |
+
edge_guidance_mode=edge_guidance_mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
)
|
| 505 |
|
| 506 |
+
if progress_callback:
|
| 507 |
+
progress_callback("Generating (ControlNet)...", 50)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
+
conditioning_scale = kwargs.get(
|
| 510 |
+
'controlnet_conditioning_scale',
|
| 511 |
+
self.config.controlnet_conditioning_scale
|
| 512 |
+
)
|
| 513 |
|
| 514 |
+
result_image = self._generate_controlnet_inpaint(
|
| 515 |
+
image=image,
|
| 516 |
+
mask=processed_mask,
|
| 517 |
+
control_image=control_image,
|
| 518 |
+
prompt=prompt,
|
| 519 |
+
negative_prompt=negative_prompt,
|
| 520 |
+
num_steps=num_steps,
|
| 521 |
+
guidance_scale=guidance_scale,
|
| 522 |
+
conditioning_scale=conditioning_scale,
|
| 523 |
+
strength=strength,
|
| 524 |
+
generator=generator
|
| 525 |
+
)
|
| 526 |
|
| 527 |
generation_time = time.time() - start_time
|
| 528 |
|
| 529 |
+
# Restore original size if it was changed
|
| 530 |
+
if result_image.size != original_size:
|
| 531 |
+
result_image = result_image.resize(original_size, Image.LANCZOS)
|
| 532 |
+
logger.info(f"Restored result to original size: {original_size}")
|
| 533 |
+
|
| 534 |
if progress_callback:
|
| 535 |
progress_callback("Complete!", 100)
|
| 536 |
|
| 537 |
return InpaintingResult(
|
| 538 |
success=True,
|
| 539 |
+
result_image=result_image,
|
| 540 |
+
blended_image=result_image, # Pipeline output is already blended
|
| 541 |
control_image=control_image,
|
|
|
|
| 542 |
generation_time=generation_time,
|
| 543 |
metadata={
|
| 544 |
"seed": seed,
|
| 545 |
+
"prompt": prompt,
|
| 546 |
+
"mode": self._current_mode,
|
| 547 |
+
"num_steps": num_steps,
|
| 548 |
+
"guidance_scale": guidance_scale,
|
| 549 |
"strength": strength,
|
| 550 |
+
"original_size": original_size,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
}
|
| 552 |
)
|
| 553 |
|
| 554 |
except torch.cuda.OutOfMemoryError:
|
| 555 |
+
logger.error("CUDA out of memory")
|
| 556 |
self._memory_cleanup(aggressive=True)
|
| 557 |
return InpaintingResult(
|
| 558 |
success=False,
|
| 559 |
+
error_message="GPU memory exhausted."
|
| 560 |
)
|
|
|
|
| 561 |
except Exception as e:
|
| 562 |
logger.error(f"Inpainting failed: {e}")
|
| 563 |
+
traceback.print_exc()
|
| 564 |
return InpaintingResult(
|
| 565 |
success=False,
|
| 566 |
+
error_message=str(e)
|
| 567 |
)
|
| 568 |
|
| 569 |
+
def _prepare_mask(
|
| 570 |
self,
|
|
|
|
| 571 |
mask: Image.Image,
|
| 572 |
+
target_size: Tuple[int, int],
|
| 573 |
+
dilation: int = 0,
|
| 574 |
+
feather_radius: int = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
) -> Image.Image:
|
| 576 |
+
"""Prepare mask with optional dilation and feathering."""
|
| 577 |
+
# Convert and resize
|
| 578 |
+
if mask.mode != 'L':
|
| 579 |
+
mask = mask.convert('L')
|
| 580 |
+
if mask.size != target_size:
|
| 581 |
+
mask = mask.resize(target_size, Image.LANCZOS)
|
| 582 |
|
| 583 |
+
mask_array = np.array(mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
|
| 585 |
+
# Apply dilation to expand mask
|
| 586 |
+
if dilation > 0:
|
| 587 |
+
kernel = cv2.getStructuringElement(
|
| 588 |
+
cv2.MORPH_ELLIPSE,
|
| 589 |
+
(dilation * 2 + 1, dilation * 2 + 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
)
|
| 591 |
+
mask_array = cv2.dilate(mask_array, kernel, iterations=1)
|
| 592 |
+
logger.debug(f"Applied mask dilation: {dilation}px")
|
| 593 |
|
| 594 |
+
# Apply feathering
|
| 595 |
+
if feather_radius > 0:
|
| 596 |
+
mask_array = cv2.GaussianBlur(
|
| 597 |
+
mask_array,
|
| 598 |
+
(feather_radius * 2 + 1, feather_radius * 2 + 1),
|
| 599 |
+
feather_radius / 2
|
| 600 |
)
|
| 601 |
|
| 602 |
+
return Image.fromarray(mask_array, mode='L')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
+
def _generate_pure_inpaint(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
self,
|
| 606 |
image: Image.Image,
|
| 607 |
mask: Image.Image,
|
| 608 |
prompt: str,
|
| 609 |
+
negative_prompt: str,
|
| 610 |
+
num_steps: int,
|
| 611 |
+
guidance_scale: float,
|
| 612 |
+
strength: float,
|
| 613 |
+
generator: torch.Generator
|
| 614 |
+
) -> Image.Image:
|
| 615 |
+
"""Generate using pure SDXL Inpainting pipeline."""
|
| 616 |
+
with torch.inference_mode():
|
| 617 |
+
result = self._pipeline(
|
| 618 |
+
prompt=prompt,
|
| 619 |
+
negative_prompt=negative_prompt,
|
| 620 |
+
image=image,
|
| 621 |
+
mask_image=mask,
|
| 622 |
+
num_inference_steps=num_steps,
|
| 623 |
+
guidance_scale=guidance_scale,
|
| 624 |
+
strength=strength,
|
| 625 |
+
generator=generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
)
|
| 627 |
+
return result.images[0]
|
| 628 |
|
| 629 |
+
def _generate_controlnet_inpaint(
|
| 630 |
+
self,
|
| 631 |
+
image: Image.Image,
|
| 632 |
+
mask: Image.Image,
|
| 633 |
+
control_image: Image.Image,
|
| 634 |
+
prompt: str,
|
| 635 |
+
negative_prompt: str,
|
| 636 |
+
num_steps: int,
|
| 637 |
+
guidance_scale: float,
|
| 638 |
+
conditioning_scale: float,
|
| 639 |
+
strength: float,
|
| 640 |
+
generator: torch.Generator
|
| 641 |
+
) -> Image.Image:
|
| 642 |
+
"""Generate using ControlNet Inpainting pipeline."""
|
| 643 |
+
with torch.inference_mode():
|
| 644 |
+
result = self._pipeline(
|
| 645 |
+
prompt=prompt,
|
| 646 |
+
negative_prompt=negative_prompt,
|
| 647 |
+
image=image,
|
| 648 |
+
mask_image=mask,
|
| 649 |
+
control_image=control_image,
|
| 650 |
+
num_inference_steps=num_steps,
|
| 651 |
+
guidance_scale=guidance_scale,
|
| 652 |
+
controlnet_conditioning_scale=conditioning_scale,
|
| 653 |
+
strength=strength,
|
| 654 |
+
generator=generator
|
|
|
|
|
|
|
|
|
|
| 655 |
)
|
| 656 |
+
return result.images[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
|
| 658 |
def get_status(self) -> Dict[str, Any]:
|
| 659 |
+
"""Get current module status."""
|
| 660 |
+
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
"initialized": self.is_initialized,
|
| 662 |
"device": self.device,
|
| 663 |
+
"mode": self._current_mode,
|
| 664 |
"conditioning_type": self._current_conditioning_type,
|
| 665 |
+
"model_key": self._current_model_key,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
}
|
|
|
|
|
|
|
|
|
|
|
|
inpainting_templates.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import logging
|
| 2 |
from dataclasses import dataclass, field
|
| 3 |
-
from typing import Dict, List, Optional
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
|
@@ -19,30 +19,31 @@ class InpaintingTemplate:
|
|
| 19 |
prompt_template: str
|
| 20 |
negative_prompt: str
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
guidance_scale: float = 7.5
|
| 26 |
-
num_inference_steps: int = 25
|
| 27 |
|
| 28 |
-
#
|
| 29 |
-
|
| 30 |
-
strength: float = 1.0
|
| 31 |
-
|
| 32 |
-
# Conditioning type preference
|
| 33 |
preferred_conditioning: str = "canny" # "canny" or "depth"
|
| 34 |
-
|
| 35 |
-
# Structure preservation in masked area
|
| 36 |
-
# True = keep edges in mask (for color change), False = clear edges (for replacement/removal)
|
| 37 |
preserve_structure_in_mask: bool = False
|
|
|
|
| 38 |
|
| 39 |
-
#
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
|
| 44 |
|
| 45 |
-
#
|
|
|
|
|
|
|
|
|
|
| 46 |
usage_tips: List[str] = field(default_factory=list)
|
| 47 |
|
| 48 |
|
|
@@ -50,417 +51,338 @@ class InpaintingTemplateManager:
|
|
| 50 |
"""
|
| 51 |
Manages inpainting templates for various use cases.
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
Attributes:
|
| 57 |
-
TEMPLATES: Dictionary of all available templates
|
| 58 |
-
CATEGORIES: List of category names in display order
|
| 59 |
|
| 60 |
Example:
|
| 61 |
>>> manager = InpaintingTemplateManager()
|
| 62 |
>>> template = manager.get_template("object_replacement")
|
| 63 |
-
>>>
|
|
|
|
|
|
|
| 64 |
"""
|
| 65 |
|
| 66 |
TEMPLATES: Dict[str, InpaintingTemplate] = {
|
| 67 |
-
#
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
category="Color",
|
| 76 |
-
icon="🎨",
|
| 77 |
-
description="Change color ONLY - fills the masked area with a solid, flat color",
|
| 78 |
-
prompt_template="{content} color, solid flat {content}, uniform color, no patterns, smooth surface",
|
| 79 |
negative_prompt=(
|
| 80 |
-
"
|
| 81 |
-
"
|
| 82 |
-
"
|
| 83 |
-
"patterns, floral, stripes, plaid, checkered, decorative patterns, "
|
| 84 |
-
"diamond pattern, grid pattern, geometric patterns, "
|
| 85 |
-
"texture, textured, wrinkles, folds, creases, "
|
| 86 |
-
"gradients, shading variations, color variations, "
|
| 87 |
-
"complex patterns, printed patterns, embroidery"
|
| 88 |
),
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
difficulty="easy",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
usage_tips=[
|
| 99 |
-
"🎯 Purpose:
|
| 100 |
"",
|
| 101 |
-
"
|
| 102 |
-
" •
|
| 103 |
-
" •
|
| 104 |
-
" •
|
| 105 |
-
" • 'bright yellow' - eye-catching yellow",
|
| 106 |
-
" • 'pure white' - clean, solid white",
|
| 107 |
"",
|
| 108 |
"💡 Tips:",
|
| 109 |
-
" •
|
| 110 |
-
" •
|
| 111 |
-
" •
|
| 112 |
]
|
| 113 |
),
|
| 114 |
|
| 115 |
-
# 2.
|
| 116 |
-
"
|
| 117 |
-
key="
|
| 118 |
-
name="
|
| 119 |
-
category="
|
| 120 |
-
icon="
|
| 121 |
-
description="
|
| 122 |
-
prompt_template="
|
| 123 |
negative_prompt=(
|
| 124 |
-
"
|
| 125 |
-
"
|
| 126 |
-
"
|
| 127 |
-
"black clothing, dark original color, distorted body, naked, nudity, "
|
| 128 |
-
"cartoon, anime, illustration, drawing, painted"
|
| 129 |
),
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
difficulty="easy",
|
|
|
|
|
|
|
| 138 |
usage_tips=[
|
| 139 |
-
"🎯 Purpose:
|
| 140 |
"",
|
| 141 |
-
"📝
|
| 142 |
-
"
|
| 143 |
-
"
|
| 144 |
-
"
|
| 145 |
-
" • 'white polo shirt with collar' - casual business",
|
| 146 |
-
" • 'cozy cream knit sweater' - warm casual style",
|
| 147 |
-
" • 'vintage denim jacket' - retro fashion",
|
| 148 |
"",
|
| 149 |
"💡 Tips:",
|
| 150 |
-
" •
|
| 151 |
-
" •
|
| 152 |
-
" • Body structure is preserved automatically"
|
| 153 |
]
|
| 154 |
),
|
| 155 |
|
| 156 |
-
#
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
| 160 |
category="Replacement",
|
| 161 |
-
icon="
|
| 162 |
-
description="
|
| 163 |
-
prompt_template="{content}, photorealistic,
|
| 164 |
negative_prompt=(
|
| 165 |
-
"
|
| 166 |
-
"
|
| 167 |
-
"
|
| 168 |
-
"multiple different objects, mixed objects, various items, "
|
| 169 |
-
"cartoon, anime, illustration, drawing, painted"
|
| 170 |
),
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
difficulty="medium",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
usage_tips=[
|
| 180 |
-
"🎯 Purpose:
|
| 181 |
"",
|
| 182 |
-
"
|
| 183 |
-
" •
|
| 184 |
-
" •
|
| 185 |
-
" • 'stack of leather-bound vintage books' - classic decoration",
|
| 186 |
-
" • 'healthy green potted succulent' - natural element",
|
| 187 |
-
" • 'antique brass table lamp with fabric shade' - lighting",
|
| 188 |
"",
|
| 189 |
"💡 Tips:",
|
| 190 |
-
" •
|
| 191 |
-
" •
|
| 192 |
-
" •
|
| 193 |
]
|
| 194 |
),
|
| 195 |
|
| 196 |
-
# 4.
|
| 197 |
-
"
|
| 198 |
-
key="
|
| 199 |
-
name="
|
| 200 |
-
category="
|
| 201 |
-
icon="
|
| 202 |
-
description="
|
| 203 |
-
prompt_template="
|
| 204 |
negative_prompt=(
|
| 205 |
-
"
|
| 206 |
-
"
|
| 207 |
-
"
|
| 208 |
-
"mismatched pattern, color discontinuity, artificial blending, "
|
| 209 |
-
"cartoon, anime, illustration, drawing, painted"
|
| 210 |
),
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
usage_tips=[
|
| 220 |
-
"🎯 Purpose:
|
| 221 |
-
"",
|
| 222 |
-
"📝 Example Prompts:",
|
| 223 |
-
" • 'polished hardwood floor with natural grain' - indoor floors",
|
| 224 |
-
" • 'smooth white painted wall' - wall backgrounds",
|
| 225 |
-
" • 'lush green grass lawn' - outdoor areas",
|
| 226 |
-
" • 'soft beige carpet texture' - carpeted floors",
|
| 227 |
-
" • 'clear blue sky with soft clouds' - sky backgrounds",
|
| 228 |
"",
|
| 229 |
"💡 Tips:",
|
| 230 |
-
" •
|
| 231 |
-
" •
|
| 232 |
-
" •
|
| 233 |
]
|
| 234 |
),
|
| 235 |
}
|
| 236 |
|
| 237 |
-
|
| 238 |
# Category display order
|
| 239 |
-
CATEGORIES = ["Color", "Replacement", "Removal"]
|
| 240 |
|
| 241 |
def __init__(self):
|
| 242 |
"""Initialize the InpaintingTemplateManager."""
|
| 243 |
logger.info(f"InpaintingTemplateManager initialized with {len(self.TEMPLATES)} templates")
|
| 244 |
|
| 245 |
def get_all_templates(self) -> Dict[str, InpaintingTemplate]:
|
| 246 |
-
"""
|
| 247 |
-
Get all available templates.
|
| 248 |
-
|
| 249 |
-
Returns
|
| 250 |
-
-------
|
| 251 |
-
dict
|
| 252 |
-
Dictionary of all templates keyed by template key
|
| 253 |
-
"""
|
| 254 |
return self.TEMPLATES
|
| 255 |
|
| 256 |
def get_template(self, key: str) -> Optional[InpaintingTemplate]:
|
| 257 |
-
"""
|
| 258 |
-
Get a specific template by key.
|
| 259 |
-
|
| 260 |
-
Parameters
|
| 261 |
-
----------
|
| 262 |
-
key : str
|
| 263 |
-
Template identifier
|
| 264 |
-
|
| 265 |
-
Returns
|
| 266 |
-
-------
|
| 267 |
-
InpaintingTemplate or None
|
| 268 |
-
Template if found, None otherwise
|
| 269 |
-
"""
|
| 270 |
return self.TEMPLATES.get(key)
|
| 271 |
|
| 272 |
def get_templates_by_category(self, category: str) -> List[InpaintingTemplate]:
|
| 273 |
-
"""
|
| 274 |
-
Get all templates in a specific category.
|
| 275 |
-
|
| 276 |
-
Parameters
|
| 277 |
-
----------
|
| 278 |
-
category : str
|
| 279 |
-
Category name
|
| 280 |
-
|
| 281 |
-
Returns
|
| 282 |
-
-------
|
| 283 |
-
list
|
| 284 |
-
List of templates in the category
|
| 285 |
-
"""
|
| 286 |
return [t for t in self.TEMPLATES.values() if t.category == category]
|
| 287 |
|
| 288 |
def get_categories(self) -> List[str]:
|
| 289 |
-
"""
|
| 290 |
-
Get list of all categories in display order.
|
| 291 |
-
|
| 292 |
-
Returns
|
| 293 |
-
-------
|
| 294 |
-
list
|
| 295 |
-
Category names
|
| 296 |
-
"""
|
| 297 |
return self.CATEGORIES
|
| 298 |
|
| 299 |
def get_template_choices_sorted(self) -> List[str]:
|
| 300 |
-
"""
|
| 301 |
-
Get template choices formatted for Gradio dropdown.
|
| 302 |
-
|
| 303 |
-
Returns list of display strings sorted by category then A-Z.
|
| 304 |
-
Format: "icon Name"
|
| 305 |
-
|
| 306 |
-
Returns
|
| 307 |
-
-------
|
| 308 |
-
list
|
| 309 |
-
Formatted display strings for dropdown
|
| 310 |
-
"""
|
| 311 |
display_list = []
|
| 312 |
-
|
| 313 |
for category in self.CATEGORIES:
|
| 314 |
templates = self.get_templates_by_category(category)
|
| 315 |
for template in sorted(templates, key=lambda t: t.name):
|
| 316 |
display_name = f"{template.icon} {template.name}"
|
| 317 |
display_list.append(display_name)
|
| 318 |
-
|
| 319 |
return display_list
|
| 320 |
|
| 321 |
def get_template_key_from_display(self, display_name: str) -> Optional[str]:
|
| 322 |
-
"""
|
| 323 |
-
Get template key from display name.
|
| 324 |
-
|
| 325 |
-
Parameters
|
| 326 |
-
----------
|
| 327 |
-
display_name : str
|
| 328 |
-
Display string like "🔄 Object Replacement"
|
| 329 |
-
|
| 330 |
-
Returns
|
| 331 |
-
-------
|
| 332 |
-
str or None
|
| 333 |
-
Template key if found
|
| 334 |
-
"""
|
| 335 |
if not display_name:
|
| 336 |
return None
|
| 337 |
-
|
| 338 |
for key, template in self.TEMPLATES.items():
|
| 339 |
if f"{template.icon} {template.name}" == display_name:
|
| 340 |
return key
|
| 341 |
return None
|
| 342 |
|
| 343 |
-
def get_parameters_for_template(self, key: str) -> Dict[str,
|
| 344 |
-
"""
|
| 345 |
-
Get recommended parameters for a template.
|
| 346 |
-
|
| 347 |
-
Parameters
|
| 348 |
-
----------
|
| 349 |
-
key : str
|
| 350 |
-
Template key
|
| 351 |
-
|
| 352 |
-
Returns
|
| 353 |
-
-------
|
| 354 |
-
dict
|
| 355 |
-
Dictionary of parameter names and values
|
| 356 |
-
"""
|
| 357 |
template = self.get_template(key)
|
| 358 |
if not template:
|
| 359 |
return {}
|
| 360 |
|
| 361 |
return {
|
|
|
|
|
|
|
| 362 |
"controlnet_conditioning_scale": template.controlnet_conditioning_scale,
|
| 363 |
-
"
|
|
|
|
|
|
|
| 364 |
"guidance_scale": template.guidance_scale,
|
| 365 |
"num_inference_steps": template.num_inference_steps,
|
| 366 |
"strength": template.strength,
|
| 367 |
-
"
|
| 368 |
-
"
|
| 369 |
-
"enhance_prompt": template.enhance_prompt
|
| 370 |
}
|
| 371 |
|
| 372 |
def build_prompt(self, key: str, content: str) -> str:
|
| 373 |
-
"""
|
| 374 |
-
Build complete prompt from template and user content.
|
| 375 |
-
|
| 376 |
-
Parameters
|
| 377 |
-
----------
|
| 378 |
-
key : str
|
| 379 |
-
Template key
|
| 380 |
-
content : str
|
| 381 |
-
User-provided content description
|
| 382 |
-
|
| 383 |
-
Returns
|
| 384 |
-
-------
|
| 385 |
-
str
|
| 386 |
-
Formatted prompt with content inserted
|
| 387 |
-
"""
|
| 388 |
template = self.get_template(key)
|
| 389 |
if not template:
|
| 390 |
return content
|
| 391 |
-
|
| 392 |
return template.prompt_template.format(content=content)
|
| 393 |
|
| 394 |
def get_negative_prompt(self, key: str) -> str:
|
| 395 |
-
"""
|
| 396 |
-
Get negative prompt for a template.
|
| 397 |
-
|
| 398 |
-
Parameters
|
| 399 |
-
----------
|
| 400 |
-
key : str
|
| 401 |
-
Template key
|
| 402 |
-
|
| 403 |
-
Returns
|
| 404 |
-
-------
|
| 405 |
-
str
|
| 406 |
-
Negative prompt string
|
| 407 |
-
"""
|
| 408 |
template = self.get_template(key)
|
| 409 |
if not template:
|
| 410 |
return ""
|
| 411 |
return template.negative_prompt
|
| 412 |
|
| 413 |
def get_usage_tips(self, key: str) -> List[str]:
|
| 414 |
-
"""
|
| 415 |
-
Get usage tips for a template.
|
| 416 |
-
|
| 417 |
-
Parameters
|
| 418 |
-
----------
|
| 419 |
-
key : str
|
| 420 |
-
Template key
|
| 421 |
-
|
| 422 |
-
Returns
|
| 423 |
-
-------
|
| 424 |
-
list
|
| 425 |
-
List of tip strings
|
| 426 |
-
"""
|
| 427 |
template = self.get_template(key)
|
| 428 |
if not template:
|
| 429 |
return []
|
| 430 |
return template.usage_tips
|
| 431 |
|
| 432 |
-
def
|
| 433 |
-
"""
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
str
|
| 439 |
-
HTML string for Gradio display
|
| 440 |
-
"""
|
| 441 |
-
html_parts = ['<div class="inpainting-gallery">']
|
| 442 |
-
|
| 443 |
-
for category in self.CATEGORIES:
|
| 444 |
-
templates = self.get_templates_by_category(category)
|
| 445 |
-
if not templates:
|
| 446 |
-
continue
|
| 447 |
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
|
|
|
| 453 |
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
<span class="inpainting-name">{template.name}</span>
|
| 459 |
-
<span class="inpainting-desc">{template.description[:50]}...</span>
|
| 460 |
-
</div>
|
| 461 |
-
''')
|
| 462 |
-
|
| 463 |
-
html_parts.append('</div></div>')
|
| 464 |
-
|
| 465 |
-
html_parts.append('</div>')
|
| 466 |
-
return ''.join(html_parts)
|
|
|
|
| 1 |
import logging
|
| 2 |
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
|
|
|
| 19 |
prompt_template: str
|
| 20 |
negative_prompt: str
|
| 21 |
|
| 22 |
+
# Pipeline mode selection
|
| 23 |
+
use_controlnet: bool = True # False = use pure SDXL Inpainting model (more stable)
|
| 24 |
+
mask_dilation: int = 0 # Pixels to expand mask for better edge blending
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
# ControlNet parameters (only used when use_controlnet=True)
|
| 27 |
+
controlnet_conditioning_scale: float = 0.7
|
|
|
|
|
|
|
|
|
|
| 28 |
preferred_conditioning: str = "canny" # "canny" or "depth"
|
|
|
|
|
|
|
|
|
|
| 29 |
preserve_structure_in_mask: bool = False
|
| 30 |
+
edge_guidance_mode: str = "boundary"
|
| 31 |
|
| 32 |
+
# Generation parameters
|
| 33 |
+
guidance_scale: float = 7.5
|
| 34 |
+
num_inference_steps: int = 25
|
| 35 |
+
strength: float = 0.99 # Use 0.99 instead of 1.0 to avoid noise issues
|
| 36 |
+
|
| 37 |
+
# Mask parameters
|
| 38 |
+
feather_radius: int = 3 # Minimal feathering, let pipeline handle blending
|
| 39 |
|
| 40 |
+
# Prompt enhancement
|
| 41 |
+
enhance_prompt: bool = True
|
| 42 |
|
| 43 |
+
# UI metadata
|
| 44 |
+
difficulty: str = "medium"
|
| 45 |
+
recommended_models: List[str] = field(default_factory=lambda: ["sdxl_base"])
|
| 46 |
+
example_prompts: List[str] = field(default_factory=list)
|
| 47 |
usage_tips: List[str] = field(default_factory=list)
|
| 48 |
|
| 49 |
|
|
|
|
| 51 |
"""
|
| 52 |
Manages inpainting templates for various use cases.
|
| 53 |
|
| 54 |
+
Templates are categorized into two pipeline modes:
|
| 55 |
+
- Pure Inpainting (use_controlnet=False): For replacement/removal tasks
|
| 56 |
+
- ControlNet Inpainting (use_controlnet=True): For structure-preserving tasks
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
Example:
|
| 59 |
>>> manager = InpaintingTemplateManager()
|
| 60 |
>>> template = manager.get_template("object_replacement")
|
| 61 |
+
>>> if not template.use_controlnet:
|
| 62 |
+
... # Use pure SDXL Inpainting pipeline
|
| 63 |
+
... pass
|
| 64 |
"""
|
| 65 |
|
| 66 |
TEMPLATES: Dict[str, InpaintingTemplate] = {
|
| 67 |
+
# 1. OBJECT REPLACEMENT - Replace one object with another
|
| 68 |
+
"object_replacement": InpaintingTemplate(
|
| 69 |
+
key="object_replacement",
|
| 70 |
+
name="Object Replacement",
|
| 71 |
+
category="Replacement",
|
| 72 |
+
icon="🔄",
|
| 73 |
+
description="Replace objects naturally - uses dedicated inpainting model for best results",
|
| 74 |
+
prompt_template="{content}, photorealistic, natural lighting, seamlessly integrated, high quality, detailed",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
negative_prompt=(
|
| 76 |
+
"blurry, low quality, distorted, deformed, "
|
| 77 |
+
"visible seams, harsh edges, unnatural, "
|
| 78 |
+
"cartoon, anime, illustration, drawing"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
),
|
| 80 |
+
# Pipeline mode
|
| 81 |
+
use_controlnet=False, # Pure inpainting for stable results
|
| 82 |
+
mask_dilation=5, # Expand mask for seamless blending
|
| 83 |
+
|
| 84 |
+
# Generation parameters
|
| 85 |
+
guidance_scale=8.0,
|
| 86 |
+
num_inference_steps=25,
|
| 87 |
+
strength=0.99,
|
| 88 |
+
|
| 89 |
+
# Mask parameters
|
| 90 |
+
feather_radius=3,
|
| 91 |
+
|
| 92 |
+
# Not used for Pure Inpainting but kept for compatibility
|
| 93 |
+
controlnet_conditioning_scale=0.0,
|
| 94 |
+
preferred_conditioning="canny", # Placeholder, not used in Pure Inpainting mode
|
| 95 |
+
preserve_structure_in_mask=False,
|
| 96 |
+
edge_guidance_mode="none",
|
| 97 |
+
|
| 98 |
+
enhance_prompt=True,
|
| 99 |
difficulty="easy",
|
| 100 |
+
recommended_models=["realvis_xl", "juggernaut_xl"],
|
| 101 |
+
example_prompts=[
|
| 102 |
+
"elegant ceramic vase with fresh roses",
|
| 103 |
+
"modern minimalist desk lamp, chrome finish",
|
| 104 |
+
"vintage leather-bound book with gold lettering"
|
| 105 |
+
],
|
| 106 |
usage_tips=[
|
| 107 |
+
"🎯 Purpose: Replace an object with something completely different.",
|
| 108 |
"",
|
| 109 |
+
"💡 Example Prompts:",
|
| 110 |
+
" • elegant ceramic vase with fresh roses",
|
| 111 |
+
" • modern minimalist desk lamp, chrome finish",
|
| 112 |
+
" • vintage leather-bound book with gold lettering",
|
|
|
|
|
|
|
| 113 |
"",
|
| 114 |
"💡 Tips:",
|
| 115 |
+
" • Draw mask slightly larger than the object",
|
| 116 |
+
" • Describe the NEW object in detail",
|
| 117 |
+
" • Include material, color, style for better results"
|
| 118 |
]
|
| 119 |
),
|
| 120 |
|
| 121 |
+
# 2. OBJECT REMOVAL - Remove and fill with background (NO PROMPT NEEDED)
|
| 122 |
+
"removal": InpaintingTemplate(
|
| 123 |
+
key="removal",
|
| 124 |
+
name="Remove Object",
|
| 125 |
+
category="Removal",
|
| 126 |
+
icon="🗑️",
|
| 127 |
+
description="Remove unwanted objects - just draw mask, no prompt needed",
|
| 128 |
+
prompt_template="seamless background, natural texture continuation, photorealistic, high quality",
|
| 129 |
negative_prompt=(
|
| 130 |
+
"object, item, thing, foreground element, new object, "
|
| 131 |
+
"visible patch, inconsistent texture, "
|
| 132 |
+
"blurry, low quality, artificial"
|
|
|
|
|
|
|
| 133 |
),
|
| 134 |
+
# Pipeline mode
|
| 135 |
+
use_controlnet=False, # Pure inpainting for clean removal
|
| 136 |
+
mask_dilation=8, # Larger expansion to cover shadows/reflections
|
| 137 |
+
|
| 138 |
+
# Generation parameters
|
| 139 |
+
guidance_scale=7.0, # Lower guidance for natural fill
|
| 140 |
+
num_inference_steps=20,
|
| 141 |
+
strength=0.99,
|
| 142 |
+
|
| 143 |
+
# Mask parameters
|
| 144 |
+
feather_radius=5, # More feathering for seamless blend
|
| 145 |
+
|
| 146 |
+
# Not used for Pure Inpainting but kept for compatibility
|
| 147 |
+
controlnet_conditioning_scale=0.0,
|
| 148 |
+
preferred_conditioning="canny",
|
| 149 |
+
preserve_structure_in_mask=False,
|
| 150 |
+
edge_guidance_mode="none",
|
| 151 |
+
|
| 152 |
+
enhance_prompt=False, # Do NOT enhance - keep it simple
|
| 153 |
difficulty="easy",
|
| 154 |
+
recommended_models=["realvis_xl", "juggernaut_xl"],
|
| 155 |
+
example_prompts=[], # No prompts needed for removal
|
| 156 |
usage_tips=[
|
| 157 |
+
"🎯 Purpose: Remove unwanted objects from image.",
|
| 158 |
"",
|
| 159 |
+
"📝 No prompt needed! Just:",
|
| 160 |
+
" 1. Draw white mask over the object",
|
| 161 |
+
" 2. Include shadows in your mask",
|
| 162 |
+
" 3. Click Generate",
|
|
|
|
|
|
|
|
|
|
| 163 |
"",
|
| 164 |
"💡 Tips:",
|
| 165 |
+
" • Make mask larger than the object",
|
| 166 |
+
" • If artifacts remain, draw a bigger mask and retry"
|
|
|
|
| 167 |
]
|
| 168 |
),
|
| 169 |
|
| 170 |
+
# CONTROLNET TEMPLATES (Structure Preserving)
|
| 171 |
+
# 3. CLOTHING CHANGE - Change clothes while keeping body
|
| 172 |
+
"clothing_change": InpaintingTemplate(
|
| 173 |
+
key="clothing_change",
|
| 174 |
+
name="Clothing Change",
|
| 175 |
category="Replacement",
|
| 176 |
+
icon="👕",
|
| 177 |
+
description="Change clothing style while preserving body structure",
|
| 178 |
+
prompt_template="{content}, photorealistic, realistic fabric, natural fit, high quality",
|
| 179 |
negative_prompt=(
|
| 180 |
+
"wrong proportions, distorted body, floating fabric, "
|
| 181 |
+
"mismatched lighting, naked, nudity, "
|
| 182 |
+
"cartoon, anime, illustration"
|
|
|
|
|
|
|
| 183 |
),
|
| 184 |
+
# Pipeline mode
|
| 185 |
+
use_controlnet=True, # Need ControlNet to preserve body
|
| 186 |
+
mask_dilation=3, # Small expansion for clothing edges
|
| 187 |
+
|
| 188 |
+
# ControlNet parameters
|
| 189 |
+
controlnet_conditioning_scale=0.4,
|
| 190 |
+
preferred_conditioning="depth", # Depth preserves body structure
|
| 191 |
+
preserve_structure_in_mask=False,
|
| 192 |
+
edge_guidance_mode="soft",
|
| 193 |
+
|
| 194 |
+
# Generation parameters
|
| 195 |
+
guidance_scale=8.0,
|
| 196 |
+
num_inference_steps=25,
|
| 197 |
+
strength=1.0, # Full repaint for clothing
|
| 198 |
+
|
| 199 |
+
# Mask parameters
|
| 200 |
+
feather_radius=5,
|
| 201 |
+
|
| 202 |
+
enhance_prompt=True,
|
| 203 |
difficulty="medium",
|
| 204 |
+
recommended_models=["juggernaut_xl", "realvis_xl"],
|
| 205 |
+
example_prompts=[
|
| 206 |
+
"tailored charcoal suit with silk tie",
|
| 207 |
+
"navy blazer with gold buttons",
|
| 208 |
+
"elegant black evening dress",
|
| 209 |
+
"casual white t-shirt",
|
| 210 |
+
"cozy cream sweater",
|
| 211 |
+
"leather motorcycle jacket",
|
| 212 |
+
"formal white dress shirt",
|
| 213 |
+
"vintage denim jacket",
|
| 214 |
+
"red cocktail dress",
|
| 215 |
+
"professional grey blazer"
|
| 216 |
+
],
|
| 217 |
usage_tips=[
|
| 218 |
+
"🎯 Purpose: Change clothing while keeping body shape.",
|
| 219 |
"",
|
| 220 |
+
"🤖 Recommended Models:",
|
| 221 |
+
" • JuggernautXL - Best for formal wear",
|
| 222 |
+
" • RealVisXL - Great for casual clothing",
|
|
|
|
|
|
|
|
|
|
| 223 |
"",
|
| 224 |
"💡 Tips:",
|
| 225 |
+
" • Mask only the clothing area",
|
| 226 |
+
" • Include fabric type: 'silk', 'cotton', 'wool'",
|
| 227 |
+
" • Body proportions are preserved automatically"
|
| 228 |
]
|
| 229 |
),
|
| 230 |
|
| 231 |
+
# 4. COLOR CHANGE - Change color only, keep structure
|
| 232 |
+
"change_color": InpaintingTemplate(
|
| 233 |
+
key="change_color",
|
| 234 |
+
name="Change Color",
|
| 235 |
+
category="Color",
|
| 236 |
+
icon="🎨",
|
| 237 |
+
description="Change color only - strictly preserves shape and texture",
|
| 238 |
+
prompt_template="{content} color, solid uniform {content}, flat color, smooth surface",
|
| 239 |
negative_prompt=(
|
| 240 |
+
"different shape, changed structure, new pattern, "
|
| 241 |
+
"texture change, deformed, distorted, "
|
| 242 |
+
"gradient, multiple colors, pattern"
|
|
|
|
|
|
|
| 243 |
),
|
| 244 |
+
# Pipeline mode
|
| 245 |
+
use_controlnet=True, # Need ControlNet to preserve exact shape
|
| 246 |
+
mask_dilation=0, # No expansion - precise color change
|
| 247 |
+
|
| 248 |
+
# ControlNet parameters
|
| 249 |
+
controlnet_conditioning_scale=0.85, # High: strict structure preservation
|
| 250 |
+
preferred_conditioning="canny", # Canny preserves edges exactly
|
| 251 |
+
preserve_structure_in_mask=True, # Keep all edges
|
| 252 |
+
edge_guidance_mode="boundary",
|
| 253 |
+
|
| 254 |
+
# Generation parameters
|
| 255 |
+
guidance_scale=12.0, # High: force the exact color
|
| 256 |
+
num_inference_steps=15,
|
| 257 |
+
strength=1.0,
|
| 258 |
+
|
| 259 |
+
# Mask parameters
|
| 260 |
+
feather_radius=2, # Very small
|
| 261 |
+
|
| 262 |
+
enhance_prompt=False, # Use color prompt directly
|
| 263 |
+
difficulty="easy",
|
| 264 |
+
recommended_models=["juggernaut_xl", "realvis_xl"],
|
| 265 |
+
example_prompts=[
|
| 266 |
+
"vibrant red",
|
| 267 |
+
"deep navy blue",
|
| 268 |
+
"bright yellow",
|
| 269 |
+
"emerald green",
|
| 270 |
+
"soft pink",
|
| 271 |
+
"pure white",
|
| 272 |
+
"charcoal grey",
|
| 273 |
+
"royal purple",
|
| 274 |
+
"coral orange",
|
| 275 |
+
"golden brown"
|
| 276 |
+
],
|
| 277 |
usage_tips=[
|
| 278 |
+
"🎯 Purpose: Change color only, shape stays exactly the same.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
"",
|
| 280 |
"💡 Tips:",
|
| 281 |
+
" • Enter ONLY the color name",
|
| 282 |
+
" • Use modifiers: 'bright', 'dark', 'pastel'",
|
| 283 |
+
" • Shape and texture are preserved exactly"
|
| 284 |
]
|
| 285 |
),
|
| 286 |
}
|
| 287 |
|
|
|
|
| 288 |
# Category display order
|
| 289 |
+
CATEGORIES = ["Color", "Replacement", "Removal"]
|
| 290 |
|
| 291 |
def __init__(self):
|
| 292 |
"""Initialize the InpaintingTemplateManager."""
|
| 293 |
logger.info(f"InpaintingTemplateManager initialized with {len(self.TEMPLATES)} templates")
|
| 294 |
|
| 295 |
def get_all_templates(self) -> Dict[str, InpaintingTemplate]:
|
| 296 |
+
"""Get all available templates."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
return self.TEMPLATES
|
| 298 |
|
| 299 |
def get_template(self, key: str) -> Optional[InpaintingTemplate]:
|
| 300 |
+
"""Get a specific template by key."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
return self.TEMPLATES.get(key)
|
| 302 |
|
| 303 |
def get_templates_by_category(self, category: str) -> List[InpaintingTemplate]:
|
| 304 |
+
"""Get all templates in a specific category."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
return [t for t in self.TEMPLATES.values() if t.category == category]
|
| 306 |
|
| 307 |
def get_categories(self) -> List[str]:
|
| 308 |
+
"""Get list of all categories in display order."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
return self.CATEGORIES
|
| 310 |
|
| 311 |
def get_template_choices_sorted(self) -> List[str]:
|
| 312 |
+
"""Get template choices formatted for Gradio dropdown."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
display_list = []
|
|
|
|
| 314 |
for category in self.CATEGORIES:
|
| 315 |
templates = self.get_templates_by_category(category)
|
| 316 |
for template in sorted(templates, key=lambda t: t.name):
|
| 317 |
display_name = f"{template.icon} {template.name}"
|
| 318 |
display_list.append(display_name)
|
|
|
|
| 319 |
return display_list
|
| 320 |
|
| 321 |
def get_template_key_from_display(self, display_name: str) -> Optional[str]:
|
| 322 |
+
"""Get template key from display name."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
if not display_name:
|
| 324 |
return None
|
|
|
|
| 325 |
for key, template in self.TEMPLATES.items():
|
| 326 |
if f"{template.icon} {template.name}" == display_name:
|
| 327 |
return key
|
| 328 |
return None
|
| 329 |
|
| 330 |
+
def get_parameters_for_template(self, key: str) -> Dict[str, Any]:
|
| 331 |
+
"""Get recommended parameters for a template."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
template = self.get_template(key)
|
| 333 |
if not template:
|
| 334 |
return {}
|
| 335 |
|
| 336 |
return {
|
| 337 |
+
"use_controlnet": template.use_controlnet,
|
| 338 |
+
"mask_dilation": template.mask_dilation,
|
| 339 |
"controlnet_conditioning_scale": template.controlnet_conditioning_scale,
|
| 340 |
+
"preferred_conditioning": template.preferred_conditioning,
|
| 341 |
+
"preserve_structure_in_mask": template.preserve_structure_in_mask,
|
| 342 |
+
"edge_guidance_mode": template.edge_guidance_mode,
|
| 343 |
"guidance_scale": template.guidance_scale,
|
| 344 |
"num_inference_steps": template.num_inference_steps,
|
| 345 |
"strength": template.strength,
|
| 346 |
+
"feather_radius": template.feather_radius,
|
| 347 |
+
"enhance_prompt": template.enhance_prompt,
|
|
|
|
| 348 |
}
|
| 349 |
|
| 350 |
def build_prompt(self, key: str, content: str) -> str:
|
| 351 |
+
"""Build complete prompt from template and user content."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
template = self.get_template(key)
|
| 353 |
if not template:
|
| 354 |
return content
|
|
|
|
| 355 |
return template.prompt_template.format(content=content)
|
| 356 |
|
| 357 |
def get_negative_prompt(self, key: str) -> str:
|
| 358 |
+
"""Get negative prompt for a template."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
template = self.get_template(key)
|
| 360 |
if not template:
|
| 361 |
return ""
|
| 362 |
return template.negative_prompt
|
| 363 |
|
| 364 |
def get_usage_tips(self, key: str) -> List[str]:
|
| 365 |
+
"""Get usage tips for a template."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
template = self.get_template(key)
|
| 367 |
if not template:
|
| 368 |
return []
|
| 369 |
return template.usage_tips
|
| 370 |
|
| 371 |
+
def get_recommended_models(self, key: str) -> List[str]:
|
| 372 |
+
"""Get recommended models for a template."""
|
| 373 |
+
template = self.get_template(key)
|
| 374 |
+
if not template:
|
| 375 |
+
return ["sdxl_base"]
|
| 376 |
+
return template.recommended_models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
+
def get_example_prompts(self, key: str) -> List[str]:
|
| 379 |
+
"""Get example prompts for a template."""
|
| 380 |
+
template = self.get_template(key)
|
| 381 |
+
if not template:
|
| 382 |
+
return []
|
| 383 |
+
return template.example_prompts
|
| 384 |
|
| 385 |
+
def get_primary_recommended_model(self, key: str) -> str:
|
| 386 |
+
"""Get the primary recommended model for a template."""
|
| 387 |
+
models = self.get_recommended_models(key)
|
| 388 |
+
return models[0] if models else "sdxl_base"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_generator.py
CHANGED
|
@@ -298,7 +298,7 @@ class MaskGenerator:
|
|
| 298 |
# High confidence areas - keep at full opacity
|
| 299 |
final_alpha[high_confidence] = 255
|
| 300 |
|
| 301 |
-
# Medium confidence - boost significantly
|
| 302 |
final_alpha[medium_confidence] = np.clip(alpha_stretched[medium_confidence] * 1.8, 200, 255)
|
| 303 |
|
| 304 |
# Low confidence - moderate boost (catches faint extremities)
|
|
|
|
| 298 |
# High confidence areas - keep at full opacity
|
| 299 |
final_alpha[high_confidence] = 255
|
| 300 |
|
| 301 |
+
# Medium confidence - boost significantly
|
| 302 |
final_alpha[medium_confidence] = np.clip(alpha_stretched[medium_confidence] * 1.8, 200, 255)
|
| 303 |
|
| 304 |
# Low confidence - moderate boost (catches faint extremities)
|
scene_templates.py
CHANGED
|
@@ -24,7 +24,7 @@ class SceneTemplateManager:
|
|
| 24 |
|
| 25 |
# Scene template definitions
|
| 26 |
TEMPLATES: Dict[str, SceneTemplate] = {
|
| 27 |
-
# Professional Category
|
| 28 |
"office_modern": SceneTemplate(
|
| 29 |
key="office_modern",
|
| 30 |
name="Modern Office",
|
|
@@ -71,7 +71,7 @@ class SceneTemplateManager:
|
|
| 71 |
guidance_scale=7.5
|
| 72 |
),
|
| 73 |
|
| 74 |
-
# Nature Category
|
| 75 |
"beach_sunset": SceneTemplate(
|
| 76 |
key="beach_sunset",
|
| 77 |
name="Sunset Beach",
|
|
@@ -127,7 +127,7 @@ class SceneTemplateManager:
|
|
| 127 |
guidance_scale=7.0
|
| 128 |
),
|
| 129 |
|
| 130 |
-
# Urban Category
|
| 131 |
"city_skyline": SceneTemplate(
|
| 132 |
key="city_skyline",
|
| 133 |
name="City Skyline",
|
|
@@ -174,7 +174,7 @@ class SceneTemplateManager:
|
|
| 174 |
guidance_scale=7.5
|
| 175 |
),
|
| 176 |
|
| 177 |
-
# Artistic Category
|
| 178 |
"gradient_soft": SceneTemplate(
|
| 179 |
key="gradient_soft",
|
| 180 |
name="Soft Gradient",
|
|
@@ -212,7 +212,7 @@ class SceneTemplateManager:
|
|
| 212 |
guidance_scale=6.5
|
| 213 |
),
|
| 214 |
|
| 215 |
-
# Seasonal Category
|
| 216 |
"autumn_foliage": SceneTemplate(
|
| 217 |
key="autumn_foliage",
|
| 218 |
name="Autumn Foliage",
|
|
@@ -425,4 +425,4 @@ class SceneTemplateManager:
|
|
| 425 |
grid-template-columns: repeat(3, 1fr);
|
| 426 |
}
|
| 427 |
}
|
| 428 |
-
"""
|
|
|
|
| 24 |
|
| 25 |
# Scene template definitions
|
| 26 |
TEMPLATES: Dict[str, SceneTemplate] = {
|
| 27 |
+
# Professional Category
|
| 28 |
"office_modern": SceneTemplate(
|
| 29 |
key="office_modern",
|
| 30 |
name="Modern Office",
|
|
|
|
| 71 |
guidance_scale=7.5
|
| 72 |
),
|
| 73 |
|
| 74 |
+
# Nature Category
|
| 75 |
"beach_sunset": SceneTemplate(
|
| 76 |
key="beach_sunset",
|
| 77 |
name="Sunset Beach",
|
|
|
|
| 127 |
guidance_scale=7.0
|
| 128 |
),
|
| 129 |
|
| 130 |
+
# Urban Category
|
| 131 |
"city_skyline": SceneTemplate(
|
| 132 |
key="city_skyline",
|
| 133 |
name="City Skyline",
|
|
|
|
| 174 |
guidance_scale=7.5
|
| 175 |
),
|
| 176 |
|
| 177 |
+
# Artistic Category
|
| 178 |
"gradient_soft": SceneTemplate(
|
| 179 |
key="gradient_soft",
|
| 180 |
name="Soft Gradient",
|
|
|
|
| 212 |
guidance_scale=6.5
|
| 213 |
),
|
| 214 |
|
| 215 |
+
# Seasonal Category
|
| 216 |
"autumn_foliage": SceneTemplate(
|
| 217 |
key="autumn_foliage",
|
| 218 |
name="Autumn Foliage",
|
|
|
|
| 425 |
grid-template-columns: repeat(3, 1fr);
|
| 426 |
}
|
| 427 |
}
|
| 428 |
+
"""
|
scene_weaver_core.py
CHANGED
|
@@ -321,7 +321,7 @@ class SceneWeaverCore:
|
|
| 321 |
# Analyze image characteristics
|
| 322 |
img_array = np.array(foreground_image.convert('RGB'))
|
| 323 |
|
| 324 |
-
# Analyze color temperature
|
| 325 |
# Convert to LAB to analyze color temperature
|
| 326 |
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
|
| 327 |
avg_a = np.mean(lab[:, :, 1]) # a channel: green(-) to red(+)
|
|
@@ -330,12 +330,12 @@ class SceneWeaverCore:
|
|
| 330 |
# Determine warm/cool tone
|
| 331 |
is_warm = avg_b > 128 # b > 128 means more yellow/warm
|
| 332 |
|
| 333 |
-
# Analyze brightness
|
| 334 |
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 335 |
avg_brightness = np.mean(gray)
|
| 336 |
is_bright = avg_brightness > 127
|
| 337 |
|
| 338 |
-
# Get subject type from CLIP
|
| 339 |
clip_analysis = self.analyze_image_with_clip(foreground_image)
|
| 340 |
subject_type = "unknown"
|
| 341 |
|
|
@@ -369,7 +369,7 @@ class SceneWeaverCore:
|
|
| 369 |
|
| 370 |
quality_modifiers = "high quality, detailed, sharp focus, photorealistic"
|
| 371 |
|
| 372 |
-
# Select appropriate fragments
|
| 373 |
# Lighting based on color temperature and brightness
|
| 374 |
if is_warm and is_bright:
|
| 375 |
lighting = lighting_options["warm_bright"]
|
|
@@ -383,7 +383,7 @@ class SceneWeaverCore:
|
|
| 383 |
# Atmosphere based on subject type
|
| 384 |
atmosphere = atmosphere_options.get(subject_type, atmosphere_options["unknown"])
|
| 385 |
|
| 386 |
-
# Check for conflicts in user prompt
|
| 387 |
user_prompt_lower = user_prompt.lower()
|
| 388 |
|
| 389 |
# Avoid adding conflicting descriptions
|
|
@@ -392,7 +392,7 @@ class SceneWeaverCore:
|
|
| 392 |
if "dark" in user_prompt_lower or "night" in user_prompt_lower:
|
| 393 |
lighting = lighting.replace("bright", "").replace("daylight", "")
|
| 394 |
|
| 395 |
-
# Combine enhanced prompt
|
| 396 |
fragments = [user_prompt]
|
| 397 |
|
| 398 |
if lighting:
|
|
@@ -864,25 +864,33 @@ class SceneWeaverCore:
|
|
| 864 |
"""
|
| 865 |
if self._inpainting_module is None:
|
| 866 |
self._inpainting_module = InpaintingModule(device=self.device)
|
| 867 |
-
self._inpainting_module.set_model_manager(self._model_manager)
|
| 868 |
logger.info("InpaintingModule created (lazy load)")
|
| 869 |
|
| 870 |
return self._inpainting_module
|
| 871 |
|
| 872 |
def switch_to_inpainting_mode(
|
| 873 |
self,
|
|
|
|
| 874 |
conditioning_type: str = "canny",
|
|
|
|
| 875 |
progress_callback: Optional[Callable[[str, int], None]] = None
|
| 876 |
) -> bool:
|
| 877 |
"""
|
| 878 |
Switch to inpainting mode, unloading background pipeline.
|
| 879 |
|
| 880 |
-
|
|
|
|
|
|
|
| 881 |
|
| 882 |
Parameters
|
| 883 |
----------
|
|
|
|
|
|
|
|
|
|
| 884 |
conditioning_type : str
|
| 885 |
-
ControlNet conditioning type: "canny" or "depth"
|
|
|
|
|
|
|
| 886 |
progress_callback : callable, optional
|
| 887 |
Progress update function(message, percentage)
|
| 888 |
|
|
@@ -891,7 +899,8 @@ class SceneWeaverCore:
|
|
| 891 |
bool
|
| 892 |
True if switch was successful
|
| 893 |
"""
|
| 894 |
-
|
|
|
|
| 895 |
|
| 896 |
try:
|
| 897 |
# Unload background pipeline first
|
|
@@ -912,12 +921,14 @@ class SceneWeaverCore:
|
|
| 912 |
|
| 913 |
def inpaint_progress(msg, pct):
|
| 914 |
if progress_callback:
|
| 915 |
-
# Map inpainting progress (0-100) to (20-90)
|
| 916 |
mapped_pct = 20 + int(pct * 0.7)
|
| 917 |
progress_callback(msg, mapped_pct)
|
| 918 |
|
| 919 |
-
|
|
|
|
|
|
|
| 920 |
conditioning_type=conditioning_type,
|
|
|
|
| 921 |
progress_callback=inpaint_progress
|
| 922 |
)
|
| 923 |
|
|
@@ -997,6 +1008,7 @@ class SceneWeaverCore:
|
|
| 997 |
prompt: str,
|
| 998 |
preview_only: bool = False,
|
| 999 |
template_key: Optional[str] = None,
|
|
|
|
| 1000 |
progress_callback: Optional[Callable[[str, int], None]] = None,
|
| 1001 |
**kwargs
|
| 1002 |
) -> Dict[str, Any]:
|
|
@@ -1017,6 +1029,8 @@ class SceneWeaverCore:
|
|
| 1017 |
If True, generate quick preview only
|
| 1018 |
template_key : str, optional
|
| 1019 |
Inpainting template key to use
|
|
|
|
|
|
|
| 1020 |
progress_callback : callable, optional
|
| 1021 |
Progress update function
|
| 1022 |
**kwargs
|
|
@@ -1027,10 +1041,30 @@ class SceneWeaverCore:
|
|
| 1027 |
dict
|
| 1028 |
Result dictionary with images and metadata
|
| 1029 |
"""
|
| 1030 |
-
#
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1034 |
error_detail = getattr(self, '_last_inpainting_error', 'Unknown error')
|
| 1035 |
return {
|
| 1036 |
"success": False,
|
|
@@ -1038,33 +1072,11 @@ class SceneWeaverCore:
|
|
| 1038 |
}
|
| 1039 |
|
| 1040 |
inpaint_module = self.get_inpainting_module()
|
| 1041 |
-
|
| 1042 |
-
# Apply template if specified
|
| 1043 |
-
if template_key:
|
| 1044 |
-
template_mgr = InpaintingTemplateManager()
|
| 1045 |
-
template = template_mgr.get_template(template_key)
|
| 1046 |
-
|
| 1047 |
-
if template:
|
| 1048 |
-
# Build prompt from template
|
| 1049 |
-
prompt = template_mgr.build_prompt(template_key, prompt)
|
| 1050 |
-
# Apply template parameters as defaults
|
| 1051 |
-
params = template_mgr.get_parameters_for_template(template_key)
|
| 1052 |
-
for key, value in params.items():
|
| 1053 |
-
if key not in kwargs:
|
| 1054 |
-
kwargs[key] = value
|
| 1055 |
-
|
| 1056 |
-
# Pass enhance_prompt flag to inpainting module
|
| 1057 |
-
if 'enhance_prompt' not in kwargs:
|
| 1058 |
-
kwargs['enhance_prompt'] = template.enhance_prompt
|
| 1059 |
-
|
| 1060 |
-
# Execute inpainting
|
| 1061 |
result = inpaint_module.execute_inpainting(
|
| 1062 |
image=image,
|
| 1063 |
mask=mask,
|
| 1064 |
prompt=prompt,
|
| 1065 |
-
preview_only=preview_only,
|
| 1066 |
progress_callback=progress_callback,
|
| 1067 |
-
template_key=template_key, # Pass template_key for conditional prompt enhancement
|
| 1068 |
**kwargs
|
| 1069 |
)
|
| 1070 |
|
|
@@ -1191,4 +1203,4 @@ class SceneWeaverCore:
|
|
| 1191 |
|
| 1192 |
status = self._inpainting_module.get_status()
|
| 1193 |
status["mode"] = self._current_mode
|
| 1194 |
-
return status
|
|
|
|
| 321 |
# Analyze image characteristics
|
| 322 |
img_array = np.array(foreground_image.convert('RGB'))
|
| 323 |
|
| 324 |
+
# Analyze color temperature
|
| 325 |
# Convert to LAB to analyze color temperature
|
| 326 |
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
|
| 327 |
avg_a = np.mean(lab[:, :, 1]) # a channel: green(-) to red(+)
|
|
|
|
| 330 |
# Determine warm/cool tone
|
| 331 |
is_warm = avg_b > 128 # b > 128 means more yellow/warm
|
| 332 |
|
| 333 |
+
# Analyze brightness
|
| 334 |
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 335 |
avg_brightness = np.mean(gray)
|
| 336 |
is_bright = avg_brightness > 127
|
| 337 |
|
| 338 |
+
# Get subject type from CLIP
|
| 339 |
clip_analysis = self.analyze_image_with_clip(foreground_image)
|
| 340 |
subject_type = "unknown"
|
| 341 |
|
|
|
|
| 369 |
|
| 370 |
quality_modifiers = "high quality, detailed, sharp focus, photorealistic"
|
| 371 |
|
| 372 |
+
# Select appropriate fragments
|
| 373 |
# Lighting based on color temperature and brightness
|
| 374 |
if is_warm and is_bright:
|
| 375 |
lighting = lighting_options["warm_bright"]
|
|
|
|
| 383 |
# Atmosphere based on subject type
|
| 384 |
atmosphere = atmosphere_options.get(subject_type, atmosphere_options["unknown"])
|
| 385 |
|
| 386 |
+
# Check for conflicts in user prompt
|
| 387 |
user_prompt_lower = user_prompt.lower()
|
| 388 |
|
| 389 |
# Avoid adding conflicting descriptions
|
|
|
|
| 392 |
if "dark" in user_prompt_lower or "night" in user_prompt_lower:
|
| 393 |
lighting = lighting.replace("bright", "").replace("daylight", "")
|
| 394 |
|
| 395 |
+
# Combine enhanced prompt
|
| 396 |
fragments = [user_prompt]
|
| 397 |
|
| 398 |
if lighting:
|
|
|
|
| 864 |
"""
|
| 865 |
if self._inpainting_module is None:
|
| 866 |
self._inpainting_module = InpaintingModule(device=self.device)
|
|
|
|
| 867 |
logger.info("InpaintingModule created (lazy load)")
|
| 868 |
|
| 869 |
return self._inpainting_module
|
| 870 |
|
| 871 |
def switch_to_inpainting_mode(
|
| 872 |
self,
|
| 873 |
+
use_controlnet: bool = True,
|
| 874 |
conditioning_type: str = "canny",
|
| 875 |
+
model_key: str = "sdxl_base",
|
| 876 |
progress_callback: Optional[Callable[[str, int], None]] = None
|
| 877 |
) -> bool:
|
| 878 |
"""
|
| 879 |
Switch to inpainting mode, unloading background pipeline.
|
| 880 |
|
| 881 |
+
Supports dual-mode inpainting:
|
| 882 |
+
- Pure Inpainting (use_controlnet=False): For object replacement/removal
|
| 883 |
+
- ControlNet Inpainting (use_controlnet=True): For clothing/color change
|
| 884 |
|
| 885 |
Parameters
|
| 886 |
----------
|
| 887 |
+
use_controlnet : bool
|
| 888 |
+
If False, use dedicated SDXL Inpainting model
|
| 889 |
+
If True, use ControlNet + SDXL model
|
| 890 |
conditioning_type : str
|
| 891 |
+
ControlNet conditioning type: "canny" or "depth" (only for ControlNet mode)
|
| 892 |
+
model_key : str
|
| 893 |
+
Model key for ControlNet mode base model
|
| 894 |
progress_callback : callable, optional
|
| 895 |
Progress update function(message, percentage)
|
| 896 |
|
|
|
|
| 899 |
bool
|
| 900 |
True if switch was successful
|
| 901 |
"""
|
| 902 |
+
mode_str = "ControlNet" if use_controlnet else "Pure Inpainting"
|
| 903 |
+
logger.info(f"Switching to inpainting mode: {mode_str} (model: {model_key})")
|
| 904 |
|
| 905 |
try:
|
| 906 |
# Unload background pipeline first
|
|
|
|
| 921 |
|
| 922 |
def inpaint_progress(msg, pct):
|
| 923 |
if progress_callback:
|
|
|
|
| 924 |
mapped_pct = 20 + int(pct * 0.7)
|
| 925 |
progress_callback(msg, mapped_pct)
|
| 926 |
|
| 927 |
+
# Use the new load_pipeline method with dual-mode support
|
| 928 |
+
success, error_msg = inpaint_module.load_pipeline(
|
| 929 |
+
use_controlnet=use_controlnet,
|
| 930 |
conditioning_type=conditioning_type,
|
| 931 |
+
model_key=model_key,
|
| 932 |
progress_callback=inpaint_progress
|
| 933 |
)
|
| 934 |
|
|
|
|
| 1008 |
prompt: str,
|
| 1009 |
preview_only: bool = False,
|
| 1010 |
template_key: Optional[str] = None,
|
| 1011 |
+
model_key: str = "sdxl_base",
|
| 1012 |
progress_callback: Optional[Callable[[str, int], None]] = None,
|
| 1013 |
**kwargs
|
| 1014 |
) -> Dict[str, Any]:
|
|
|
|
| 1029 |
If True, generate quick preview only
|
| 1030 |
template_key : str, optional
|
| 1031 |
Inpainting template key to use
|
| 1032 |
+
model_key : str
|
| 1033 |
+
Model key for the base model (juggernaut_xl, realvis_xl, sdxl_base, animagine_xl)
|
| 1034 |
progress_callback : callable, optional
|
| 1035 |
Progress update function
|
| 1036 |
**kwargs
|
|
|
|
| 1041 |
dict
|
| 1042 |
Result dictionary with images and metadata
|
| 1043 |
"""
|
| 1044 |
+
# Get pipeline mode from kwargs
|
| 1045 |
+
use_controlnet = kwargs.get('use_controlnet', True)
|
| 1046 |
+
conditioning_type = kwargs.get('conditioning_type', 'canny')
|
| 1047 |
+
|
| 1048 |
+
# Check if we need to reinitialize
|
| 1049 |
+
inpaint_module = self.get_inpainting_module()
|
| 1050 |
+
current_mode = getattr(inpaint_module, '_current_mode', None)
|
| 1051 |
+
current_model = getattr(inpaint_module, '_current_model_key', None)
|
| 1052 |
+
|
| 1053 |
+
expected_mode = "controlnet" if use_controlnet else "pure"
|
| 1054 |
+
needs_reinit = (
|
| 1055 |
+
self._current_mode != "inpainting" or
|
| 1056 |
+
not self._inpainting_initialized or
|
| 1057 |
+
current_mode != expected_mode or
|
| 1058 |
+
(use_controlnet and current_model != model_key)
|
| 1059 |
+
)
|
| 1060 |
+
|
| 1061 |
+
if needs_reinit:
|
| 1062 |
+
if not self.switch_to_inpainting_mode(
|
| 1063 |
+
use_controlnet=use_controlnet,
|
| 1064 |
+
conditioning_type=conditioning_type,
|
| 1065 |
+
model_key=model_key,
|
| 1066 |
+
progress_callback=progress_callback
|
| 1067 |
+
):
|
| 1068 |
error_detail = getattr(self, '_last_inpainting_error', 'Unknown error')
|
| 1069 |
return {
|
| 1070 |
"success": False,
|
|
|
|
| 1072 |
}
|
| 1073 |
|
| 1074 |
inpaint_module = self.get_inpainting_module()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1075 |
result = inpaint_module.execute_inpainting(
|
| 1076 |
image=image,
|
| 1077 |
mask=mask,
|
| 1078 |
prompt=prompt,
|
|
|
|
| 1079 |
progress_callback=progress_callback,
|
|
|
|
| 1080 |
**kwargs
|
| 1081 |
)
|
| 1082 |
|
|
|
|
| 1203 |
|
| 1204 |
status = self._inpainting_module.get_status()
|
| 1205 |
status["mode"] = self._current_mode
|
| 1206 |
+
return status
|
ui_manager.py
CHANGED
|
@@ -3,16 +3,17 @@ import time
|
|
| 3 |
import traceback
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Optional, Tuple, Dict, Any, List
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
import cv2
|
| 9 |
import gradio as gr
|
| 10 |
-
import
|
|
|
|
| 11 |
|
| 12 |
-
from scene_weaver_core import SceneWeaverCore
|
| 13 |
from css_styles import CSSStyles
|
| 14 |
from scene_templates import SceneTemplateManager
|
| 15 |
from inpainting_templates import InpaintingTemplateManager
|
|
|
|
|
|
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
logger.setLevel(logging.INFO)
|
|
@@ -29,16 +30,20 @@ class UIManager:
|
|
| 29 |
Gradio UI Manager with support for background generation and inpainting.
|
| 30 |
|
| 31 |
Provides a professional interface with mode switching, template selection,
|
| 32 |
-
and advanced parameter controls.
|
| 33 |
|
| 34 |
Attributes:
|
| 35 |
-
|
| 36 |
template_manager: Scene template manager
|
| 37 |
inpainting_template_manager: Inpainting template manager
|
| 38 |
"""
|
| 39 |
|
| 40 |
def __init__(self):
|
| 41 |
self.sceneweaver = SceneWeaverCore()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
self.template_manager = SceneTemplateManager()
|
| 43 |
self.inpainting_template_manager = InpaintingTemplateManager()
|
| 44 |
self.generation_history = []
|
|
@@ -173,7 +178,6 @@ class UIManager:
|
|
| 173 |
if len(self.generation_history) > max_history:
|
| 174 |
self.generation_history = self.generation_history[-max_history:]
|
| 175 |
|
| 176 |
-
@spaces.GPU(duration=240)
|
| 177 |
def generate_handler(
|
| 178 |
self,
|
| 179 |
uploaded_image: Optional[Image.Image],
|
|
@@ -185,8 +189,33 @@ class UIManager:
|
|
| 185 |
guidance: float,
|
| 186 |
progress=gr.Progress()
|
| 187 |
):
|
| 188 |
-
"""
|
|
|
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
if uploaded_image is None:
|
| 191 |
return None, None, None, "Please upload an image to get started!", gr.update(visible=False)
|
| 192 |
|
|
@@ -194,44 +223,19 @@ class UIManager:
|
|
| 194 |
return None, None, None, "Please describe the background scene you'd like!", gr.update(visible=False)
|
| 195 |
|
| 196 |
try:
|
| 197 |
-
|
| 198 |
-
progress(
|
| 199 |
-
|
| 200 |
-
def init_progress(msg, pct):
|
| 201 |
-
if pct < 30:
|
| 202 |
-
desc = "Loading image analysis models..."
|
| 203 |
-
elif pct < 60:
|
| 204 |
-
desc = "Loading Stable Diffusion XL..."
|
| 205 |
-
elif pct < 90:
|
| 206 |
-
desc = "Applying memory optimizations..."
|
| 207 |
-
else:
|
| 208 |
-
desc = "Almost ready..."
|
| 209 |
-
progress(0.05 + (pct/100) * 0.2, desc=desc)
|
| 210 |
-
|
| 211 |
-
self.sceneweaver.load_models(progress_callback=init_progress)
|
| 212 |
-
|
| 213 |
-
def gen_progress(msg, pct):
|
| 214 |
-
if pct < 20:
|
| 215 |
-
desc = "Analyzing your image..."
|
| 216 |
-
elif pct < 50:
|
| 217 |
-
desc = "Generating background scene..."
|
| 218 |
-
elif pct < 80:
|
| 219 |
-
desc = "Blending foreground and background..."
|
| 220 |
-
elif pct < 95:
|
| 221 |
-
desc = "Applying final touches..."
|
| 222 |
-
else:
|
| 223 |
-
desc = "Complete!"
|
| 224 |
-
progress(0.25 + (pct/100) * 0.75, desc=desc)
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
|
|
|
| 228 |
prompt=prompt,
|
| 229 |
-
combination_mode=combination_mode,
|
| 230 |
-
focus_mode=focus_mode,
|
| 231 |
negative_prompt=negative_prompt,
|
| 232 |
-
|
|
|
|
|
|
|
| 233 |
guidance_scale=float(guidance),
|
| 234 |
-
progress_callback=
|
| 235 |
)
|
| 236 |
|
| 237 |
if result["success"]:
|
|
@@ -547,7 +551,7 @@ class UIManager:
|
|
| 547 |
self,
|
| 548 |
display_name: str,
|
| 549 |
current_prompt: str
|
| 550 |
-
) -> Tuple[str, float, int, str]:
|
| 551 |
"""
|
| 552 |
Apply an inpainting template to the UI fields.
|
| 553 |
|
|
@@ -561,26 +565,76 @@ class UIManager:
|
|
| 561 |
Returns
|
| 562 |
-------
|
| 563 |
tuple
|
| 564 |
-
(prompt, conditioning_scale, feather_radius, conditioning_type
|
|
|
|
| 565 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
if not display_name:
|
| 567 |
-
return
|
| 568 |
|
| 569 |
template_key = self.inpainting_template_manager.get_template_key_from_display(display_name)
|
| 570 |
if not template_key:
|
| 571 |
-
return
|
| 572 |
|
| 573 |
template = self.inpainting_template_manager.get_template(template_key)
|
| 574 |
if template:
|
| 575 |
params = self.inpainting_template_manager.get_parameters_for_template(template_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
return (
|
| 577 |
current_prompt,
|
| 578 |
params.get('controlnet_conditioning_scale', 0.7),
|
| 579 |
params.get('feather_radius', 8),
|
| 580 |
-
params.get('preferred_conditioning', 'canny')
|
|
|
|
|
|
|
|
|
|
| 581 |
)
|
| 582 |
|
| 583 |
-
return
|
| 584 |
|
| 585 |
def extract_mask_from_editor(self, editor_output: Dict[str, Any]) -> Optional[Image.Image]:
|
| 586 |
"""
|
|
@@ -664,22 +718,23 @@ class UIManager:
|
|
| 664 |
logger.error(f"Failed to extract mask from editor: {e}")
|
| 665 |
return None
|
| 666 |
|
| 667 |
-
@spaces.GPU(duration=420)
|
| 668 |
def inpainting_handler(
|
| 669 |
self,
|
| 670 |
image: Optional[Image.Image],
|
| 671 |
mask_editor: Dict[str, Any],
|
| 672 |
prompt: str,
|
| 673 |
template_dropdown: str,
|
|
|
|
| 674 |
conditioning_type: str,
|
| 675 |
conditioning_scale: float,
|
| 676 |
feather_radius: int,
|
| 677 |
guidance_scale: float,
|
| 678 |
num_steps: int,
|
|
|
|
| 679 |
progress: gr.Progress = gr.Progress()
|
| 680 |
-
) -> Tuple[Optional[Image.Image], Optional[Image.Image],
|
| 681 |
"""
|
| 682 |
-
Handle inpainting generation request.
|
| 683 |
|
| 684 |
Parameters
|
| 685 |
----------
|
|
@@ -691,6 +746,8 @@ class UIManager:
|
|
| 691 |
Text description of desired content
|
| 692 |
template_dropdown : str
|
| 693 |
Selected template (optional)
|
|
|
|
|
|
|
| 694 |
conditioning_type : str
|
| 695 |
ControlNet conditioning type
|
| 696 |
conditioning_scale : float
|
|
@@ -701,36 +758,36 @@ class UIManager:
|
|
| 701 |
Guidance scale for generation
|
| 702 |
num_steps : int
|
| 703 |
Number of inference steps
|
|
|
|
|
|
|
| 704 |
progress : gr.Progress
|
| 705 |
Progress callback
|
| 706 |
|
| 707 |
Returns
|
| 708 |
-------
|
| 709 |
tuple
|
| 710 |
-
(result_image, control_image, status_message)
|
| 711 |
"""
|
| 712 |
if image is None:
|
| 713 |
-
return None, None, "⚠️ Please upload an image first"
|
| 714 |
|
| 715 |
# Extract mask
|
| 716 |
mask = self.extract_mask_from_editor(mask_editor)
|
| 717 |
if mask is None:
|
| 718 |
-
return None, None, "⚠️ Please draw a mask on the image"
|
| 719 |
|
| 720 |
# Validate mask
|
| 721 |
mask_array = np.array(mask)
|
| 722 |
coverage = np.count_nonzero(mask_array > 127) / mask_array.size
|
| 723 |
if coverage < 0.01:
|
| 724 |
-
return None, None, "⚠️ Mask too small - please select a larger area"
|
| 725 |
if coverage > 0.95:
|
| 726 |
-
return None, None, "⚠️ Mask too large - consider using background generation instead"
|
| 727 |
|
| 728 |
def progress_callback(msg: str, pct: int):
|
| 729 |
progress(pct / 100, desc=msg)
|
| 730 |
|
| 731 |
try:
|
| 732 |
-
start_time = time.time()
|
| 733 |
-
|
| 734 |
# Get template key if selected
|
| 735 |
template_key = None
|
| 736 |
if template_dropdown:
|
|
@@ -738,53 +795,39 @@ class UIManager:
|
|
| 738 |
template_dropdown
|
| 739 |
)
|
| 740 |
|
| 741 |
-
#
|
| 742 |
-
|
| 743 |
image=image,
|
| 744 |
mask=mask,
|
| 745 |
prompt=prompt,
|
| 746 |
-
preview_only=False,
|
| 747 |
template_key=template_key,
|
|
|
|
| 748 |
conditioning_type=conditioning_type,
|
| 749 |
-
|
| 750 |
feather_radius=feather_radius,
|
| 751 |
guidance_scale=guidance_scale,
|
| 752 |
-
|
|
|
|
| 753 |
progress_callback=progress_callback
|
| 754 |
)
|
| 755 |
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
if result.get('success'):
|
| 759 |
-
# Store in history
|
| 760 |
self.inpainting_history.append({
|
| 761 |
-
'result':
|
| 762 |
'prompt': prompt,
|
| 763 |
-
'
|
|
|
|
| 764 |
})
|
| 765 |
if len(self.inpainting_history) > 3:
|
| 766 |
self.inpainting_history.pop(0)
|
| 767 |
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
# Clean, simple status message
|
| 771 |
-
status = f"✅ Inpainting complete in {elapsed:.1f}s"
|
| 772 |
-
if quality_score > 0:
|
| 773 |
-
status += f" | Quality: {quality_score:.0f}/100"
|
| 774 |
-
|
| 775 |
-
return (
|
| 776 |
-
result.get('combined_image'),
|
| 777 |
-
result.get('control_image'),
|
| 778 |
-
status
|
| 779 |
-
)
|
| 780 |
-
else:
|
| 781 |
-
error_msg = result.get('error', 'Unknown error')
|
| 782 |
-
return None, None, f"❌ Inpainting failed: {error_msg}"
|
| 783 |
|
| 784 |
except Exception as e:
|
| 785 |
logger.error(f"Inpainting handler error: {e}")
|
| 786 |
logger.error(traceback.format_exc())
|
| 787 |
-
return None, None, f"❌ Error: {str(e)}"
|
| 788 |
|
| 789 |
def create_inpainting_tab(self) -> gr.Tab:
|
| 790 |
"""
|
|
@@ -812,17 +855,44 @@ class UIManager:
|
|
| 812 |
</span>
|
| 813 |
</h3>
|
| 814 |
<p style="color: #666; margin-bottom: 12px;">Draw a mask to select the area you want to regenerate</p>
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
</div>
|
| 827 |
</div>
|
| 828 |
""")
|
|
@@ -859,6 +929,9 @@ class UIManager:
|
|
| 859 |
)
|
| 860 |
template_tips = gr.Markdown("")
|
| 861 |
|
|
|
|
|
|
|
|
|
|
| 862 |
# Prompt
|
| 863 |
inpaint_prompt = gr.Textbox(
|
| 864 |
label="Prompt",
|
|
@@ -868,28 +941,49 @@ class UIManager:
|
|
| 868 |
|
| 869 |
# Right column - Settings and Output
|
| 870 |
with gr.Column(scale=1):
|
| 871 |
-
#
|
| 872 |
-
with gr.
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 886 |
|
|
|
|
|
|
|
| 887 |
feather_radius = gr.Slider(
|
| 888 |
minimum=0,
|
| 889 |
maximum=20,
|
| 890 |
value=8,
|
| 891 |
step=1,
|
| 892 |
-
label="Feather Radius (px)"
|
|
|
|
| 893 |
)
|
| 894 |
|
| 895 |
with gr.Accordion("Advanced Settings", open=False):
|
|
@@ -909,6 +1003,14 @@ class UIManager:
|
|
| 909 |
label="Inference Steps"
|
| 910 |
)
|
| 911 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 912 |
# Generate button
|
| 913 |
inpaint_btn = gr.Button(
|
| 914 |
"Generate Inpainting",
|
|
@@ -925,9 +1027,9 @@ class UIManager:
|
|
| 925 |
border-radius: 8px;
|
| 926 |
margin: 12px 0;">
|
| 927 |
<p style="margin: 0; color: #5d4037; font-size: 14px;">
|
| 928 |
-
⏳ <strong>Please be patient!</strong>
|
| 929 |
-
|
| 930 |
-
|
| 931 |
</p>
|
| 932 |
</div>
|
| 933 |
<div style="background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%);
|
|
@@ -943,13 +1045,27 @@ class UIManager:
|
|
| 943 |
"""
|
| 944 |
)
|
| 945 |
|
| 946 |
-
# Status
|
| 947 |
inpaint_status = gr.Textbox(
|
| 948 |
label="Status",
|
| 949 |
value="Ready for inpainting",
|
| 950 |
interactive=False
|
| 951 |
)
|
| 952 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 953 |
# Output row
|
| 954 |
with gr.Row():
|
| 955 |
with gr.Column(scale=1):
|
|
@@ -971,7 +1087,15 @@ class UIManager:
|
|
| 971 |
inpaint_template.change(
|
| 972 |
fn=self.apply_inpainting_template,
|
| 973 |
inputs=[inpaint_template, inpaint_prompt],
|
| 974 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 975 |
)
|
| 976 |
|
| 977 |
inpaint_template.change(
|
|
@@ -980,9 +1104,16 @@ class UIManager:
|
|
| 980 |
outputs=[template_tips]
|
| 981 |
)
|
| 982 |
|
| 983 |
-
# Copy uploaded image to mask editor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 984 |
inpaint_image.change(
|
| 985 |
-
fn=
|
| 986 |
inputs=[inpaint_image],
|
| 987 |
outputs=[mask_editor]
|
| 988 |
)
|
|
@@ -994,19 +1125,29 @@ class UIManager:
|
|
| 994 |
mask_editor,
|
| 995 |
inpaint_prompt,
|
| 996 |
inpaint_template,
|
|
|
|
| 997 |
conditioning_type,
|
| 998 |
conditioning_scale,
|
| 999 |
feather_radius,
|
| 1000 |
inpaint_guidance,
|
| 1001 |
-
inpaint_steps
|
|
|
|
| 1002 |
],
|
| 1003 |
outputs=[
|
| 1004 |
inpaint_result,
|
| 1005 |
inpaint_control,
|
| 1006 |
-
inpaint_status
|
|
|
|
| 1007 |
]
|
| 1008 |
)
|
| 1009 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1010 |
return tab
|
| 1011 |
|
| 1012 |
def _get_template_tips(self, display_name: str) -> str:
|
|
@@ -1021,4 +1162,4 @@ class UIManager:
|
|
| 1021 |
tips = self.inpainting_template_manager.get_usage_tips(template_key)
|
| 1022 |
if tips:
|
| 1023 |
return "**Tips:**\n" + "\n".join(f"- {tip}" for tip in tips)
|
| 1024 |
-
return ""
|
|
|
|
| 3 |
import traceback
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Optional, Tuple, Dict, Any, List
|
| 6 |
+
|
|
|
|
| 7 |
import cv2
|
| 8 |
import gradio as gr
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
|
|
|
|
| 12 |
from css_styles import CSSStyles
|
| 13 |
from scene_templates import SceneTemplateManager
|
| 14 |
from inpainting_templates import InpaintingTemplateManager
|
| 15 |
+
from scene_weaver_core import SceneWeaverCore
|
| 16 |
+
from gpu_handlers import GPUHandlers
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
logger.setLevel(logging.INFO)
|
|
|
|
| 30 |
Gradio UI Manager with support for background generation and inpainting.
|
| 31 |
|
| 32 |
Provides a professional interface with mode switching, template selection,
|
| 33 |
+
and advanced parameter controls. GPU operations are delegated to GPUHandlers.
|
| 34 |
|
| 35 |
Attributes:
|
| 36 |
+
gpu_handlers: GPUHandlers instance for GPU operations
|
| 37 |
template_manager: Scene template manager
|
| 38 |
inpainting_template_manager: Inpainting template manager
|
| 39 |
"""
|
| 40 |
|
| 41 |
def __init__(self):
|
| 42 |
self.sceneweaver = SceneWeaverCore()
|
| 43 |
+
self.gpu_handlers = GPUHandlers(
|
| 44 |
+
core=self.sceneweaver,
|
| 45 |
+
inpainting_template_manager=InpaintingTemplateManager()
|
| 46 |
+
)
|
| 47 |
self.template_manager = SceneTemplateManager()
|
| 48 |
self.inpainting_template_manager = InpaintingTemplateManager()
|
| 49 |
self.generation_history = []
|
|
|
|
| 178 |
if len(self.generation_history) > max_history:
|
| 179 |
self.generation_history = self.generation_history[-max_history:]
|
| 180 |
|
|
|
|
| 181 |
def generate_handler(
|
| 182 |
self,
|
| 183 |
uploaded_image: Optional[Image.Image],
|
|
|
|
| 189 |
guidance: float,
|
| 190 |
progress=gr.Progress()
|
| 191 |
):
|
| 192 |
+
"""
|
| 193 |
+
Generation handler - delegates GPU work to GPUHandlers.
|
| 194 |
|
| 195 |
+
Parameters
|
| 196 |
+
----------
|
| 197 |
+
uploaded_image : PIL.Image
|
| 198 |
+
Input image
|
| 199 |
+
prompt : str
|
| 200 |
+
Background description
|
| 201 |
+
combination_mode : str
|
| 202 |
+
Composition mode
|
| 203 |
+
focus_mode : str
|
| 204 |
+
Focus mode
|
| 205 |
+
negative_prompt : str
|
| 206 |
+
Negative prompt
|
| 207 |
+
steps : int
|
| 208 |
+
Inference steps
|
| 209 |
+
guidance : float
|
| 210 |
+
Guidance scale
|
| 211 |
+
progress : gr.Progress
|
| 212 |
+
Progress callback
|
| 213 |
+
|
| 214 |
+
Returns
|
| 215 |
+
-------
|
| 216 |
+
tuple
|
| 217 |
+
(combined, generated, original, status, download_btn_update)
|
| 218 |
+
"""
|
| 219 |
if uploaded_image is None:
|
| 220 |
return None, None, None, "Please upload an image to get started!", gr.update(visible=False)
|
| 221 |
|
|
|
|
| 223 |
return None, None, None, "Please describe the background scene you'd like!", gr.update(visible=False)
|
| 224 |
|
| 225 |
try:
|
| 226 |
+
def progress_callback(msg: str, pct: int):
|
| 227 |
+
progress(pct / 100, desc=msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
+
# Delegate to GPUHandlers
|
| 230 |
+
result = self.gpu_handlers.background_generate(
|
| 231 |
+
image=uploaded_image,
|
| 232 |
prompt=prompt,
|
|
|
|
|
|
|
| 233 |
negative_prompt=negative_prompt,
|
| 234 |
+
composition_mode=combination_mode,
|
| 235 |
+
focus_mode=focus_mode,
|
| 236 |
+
num_steps=int(steps),
|
| 237 |
guidance_scale=float(guidance),
|
| 238 |
+
progress_callback=progress_callback
|
| 239 |
)
|
| 240 |
|
| 241 |
if result["success"]:
|
|
|
|
| 551 |
self,
|
| 552 |
display_name: str,
|
| 553 |
current_prompt: str
|
| 554 |
+
) -> Tuple[str, float, int, str, Any, Any, Any]:
|
| 555 |
"""
|
| 556 |
Apply an inpainting template to the UI fields.
|
| 557 |
|
|
|
|
| 565 |
Returns
|
| 566 |
-------
|
| 567 |
tuple
|
| 568 |
+
(prompt, conditioning_scale, feather_radius, conditioning_type,
|
| 569 |
+
controlnet_settings_visibility, mode_info_html, model_selection_visibility)
|
| 570 |
"""
|
| 571 |
+
# Default returns for no template selected
|
| 572 |
+
default_return = (
|
| 573 |
+
current_prompt,
|
| 574 |
+
0.7,
|
| 575 |
+
8,
|
| 576 |
+
"canny",
|
| 577 |
+
gr.update(visible=True), # Show ControlNet settings by default
|
| 578 |
+
"", # No mode info
|
| 579 |
+
gr.update(visible=True) # Show model selection by default
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
if not display_name:
|
| 583 |
+
return default_return
|
| 584 |
|
| 585 |
template_key = self.inpainting_template_manager.get_template_key_from_display(display_name)
|
| 586 |
if not template_key:
|
| 587 |
+
return default_return
|
| 588 |
|
| 589 |
template = self.inpainting_template_manager.get_template(template_key)
|
| 590 |
if template:
|
| 591 |
params = self.inpainting_template_manager.get_parameters_for_template(template_key)
|
| 592 |
+
use_controlnet = params.get('use_controlnet', True)
|
| 593 |
+
|
| 594 |
+
# Determine visibility and info based on mode
|
| 595 |
+
if use_controlnet:
|
| 596 |
+
controlnet_visibility = gr.update(visible=True)
|
| 597 |
+
model_visibility = gr.update(visible=True)
|
| 598 |
+
mode_info = """
|
| 599 |
+
<div style="background: linear-gradient(135deg, #e8f5e9 0%, #c8e6c9 100%);
|
| 600 |
+
border-left: 4px solid #4CAF50;
|
| 601 |
+
padding: 10px 14px;
|
| 602 |
+
border-radius: 8px;
|
| 603 |
+
margin: 8px 0;">
|
| 604 |
+
<p style="margin: 0; color: #2e7d32; font-size: 13px;">
|
| 605 |
+
🎛️ <strong>ControlNet Mode</strong> - Structure will be preserved using edge/depth guidance.
|
| 606 |
+
You can adjust ControlNet settings and select model below.
|
| 607 |
+
</p>
|
| 608 |
+
</div>
|
| 609 |
+
"""
|
| 610 |
+
else:
|
| 611 |
+
# Pure Inpainting mode - hide both ControlNet and Model Selection
|
| 612 |
+
controlnet_visibility = gr.update(visible=False)
|
| 613 |
+
model_visibility = gr.update(visible=False)
|
| 614 |
+
mode_info = """
|
| 615 |
+
<div style="background: linear-gradient(135deg, #fff3e0 0%, #ffe0b2 100%);
|
| 616 |
+
border-left: 4px solid #ff9800;
|
| 617 |
+
padding: 10px 14px;
|
| 618 |
+
border-radius: 8px;
|
| 619 |
+
margin: 8px 0;">
|
| 620 |
+
<p style="margin: 0; color: #e65100; font-size: 13px;">
|
| 621 |
+
🚀 <strong>Pure Inpainting Mode</strong> - Using dedicated SDXL Inpainting model.<br>
|
| 622 |
+
Model and ControlNet settings are automatically configured for best results.
|
| 623 |
+
</p>
|
| 624 |
+
</div>
|
| 625 |
+
"""
|
| 626 |
+
|
| 627 |
return (
|
| 628 |
current_prompt,
|
| 629 |
params.get('controlnet_conditioning_scale', 0.7),
|
| 630 |
params.get('feather_radius', 8),
|
| 631 |
+
params.get('preferred_conditioning', 'canny'),
|
| 632 |
+
controlnet_visibility,
|
| 633 |
+
mode_info,
|
| 634 |
+
model_visibility
|
| 635 |
)
|
| 636 |
|
| 637 |
+
return default_return
|
| 638 |
|
| 639 |
def extract_mask_from_editor(self, editor_output: Dict[str, Any]) -> Optional[Image.Image]:
|
| 640 |
"""
|
|
|
|
| 718 |
logger.error(f"Failed to extract mask from editor: {e}")
|
| 719 |
return None
|
| 720 |
|
|
|
|
| 721 |
def inpainting_handler(
|
| 722 |
self,
|
| 723 |
image: Optional[Image.Image],
|
| 724 |
mask_editor: Dict[str, Any],
|
| 725 |
prompt: str,
|
| 726 |
template_dropdown: str,
|
| 727 |
+
model_choice: str,
|
| 728 |
conditioning_type: str,
|
| 729 |
conditioning_scale: float,
|
| 730 |
feather_radius: int,
|
| 731 |
guidance_scale: float,
|
| 732 |
num_steps: int,
|
| 733 |
+
seed: int,
|
| 734 |
progress: gr.Progress = gr.Progress()
|
| 735 |
+
) -> Tuple[Optional[Image.Image], Optional[Image.Image], str, int]:
|
| 736 |
"""
|
| 737 |
+
Handle inpainting generation request - delegates GPU work to GPUHandlers.
|
| 738 |
|
| 739 |
Parameters
|
| 740 |
----------
|
|
|
|
| 746 |
Text description of desired content
|
| 747 |
template_dropdown : str
|
| 748 |
Selected template (optional)
|
| 749 |
+
model_choice : str
|
| 750 |
+
Model key to use (juggernaut_xl, realvis_xl, sdxl_base, animagine_xl)
|
| 751 |
conditioning_type : str
|
| 752 |
ControlNet conditioning type
|
| 753 |
conditioning_scale : float
|
|
|
|
| 758 |
Guidance scale for generation
|
| 759 |
num_steps : int
|
| 760 |
Number of inference steps
|
| 761 |
+
seed : int
|
| 762 |
+
Random seed (-1 for random)
|
| 763 |
progress : gr.Progress
|
| 764 |
Progress callback
|
| 765 |
|
| 766 |
Returns
|
| 767 |
-------
|
| 768 |
tuple
|
| 769 |
+
(result_image, control_image, status_message, used_seed)
|
| 770 |
"""
|
| 771 |
if image is None:
|
| 772 |
+
return None, None, "⚠️ Please upload an image first", -1
|
| 773 |
|
| 774 |
# Extract mask
|
| 775 |
mask = self.extract_mask_from_editor(mask_editor)
|
| 776 |
if mask is None:
|
| 777 |
+
return None, None, "⚠️ Please draw a mask on the image", -1
|
| 778 |
|
| 779 |
# Validate mask
|
| 780 |
mask_array = np.array(mask)
|
| 781 |
coverage = np.count_nonzero(mask_array > 127) / mask_array.size
|
| 782 |
if coverage < 0.01:
|
| 783 |
+
return None, None, "⚠️ Mask too small - please select a larger area", -1
|
| 784 |
if coverage > 0.95:
|
| 785 |
+
return None, None, "⚠️ Mask too large - consider using background generation instead", -1
|
| 786 |
|
| 787 |
def progress_callback(msg: str, pct: int):
|
| 788 |
progress(pct / 100, desc=msg)
|
| 789 |
|
| 790 |
try:
|
|
|
|
|
|
|
| 791 |
# Get template key if selected
|
| 792 |
template_key = None
|
| 793 |
if template_dropdown:
|
|
|
|
| 795 |
template_dropdown
|
| 796 |
)
|
| 797 |
|
| 798 |
+
# Delegate to GPUHandlers
|
| 799 |
+
result_image, control_image, status, used_seed = self.gpu_handlers.inpainting_generate(
|
| 800 |
image=image,
|
| 801 |
mask=mask,
|
| 802 |
prompt=prompt,
|
|
|
|
| 803 |
template_key=template_key,
|
| 804 |
+
model_key=model_choice,
|
| 805 |
conditioning_type=conditioning_type,
|
| 806 |
+
conditioning_scale=conditioning_scale,
|
| 807 |
feather_radius=feather_radius,
|
| 808 |
guidance_scale=guidance_scale,
|
| 809 |
+
num_steps=num_steps,
|
| 810 |
+
seed=int(seed) if seed is not None else -1,
|
| 811 |
progress_callback=progress_callback
|
| 812 |
)
|
| 813 |
|
| 814 |
+
# Store in history if successful
|
| 815 |
+
if result_image is not None:
|
|
|
|
|
|
|
| 816 |
self.inpainting_history.append({
|
| 817 |
+
'result': result_image,
|
| 818 |
'prompt': prompt,
|
| 819 |
+
'seed': used_seed,
|
| 820 |
+
'time': time.time()
|
| 821 |
})
|
| 822 |
if len(self.inpainting_history) > 3:
|
| 823 |
self.inpainting_history.pop(0)
|
| 824 |
|
| 825 |
+
return result_image, control_image, status, used_seed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
|
| 827 |
except Exception as e:
|
| 828 |
logger.error(f"Inpainting handler error: {e}")
|
| 829 |
logger.error(traceback.format_exc())
|
| 830 |
+
return None, None, f"❌ Error: {str(e)}", -1
|
| 831 |
|
| 832 |
def create_inpainting_tab(self) -> gr.Tab:
|
| 833 |
"""
|
|
|
|
| 855 |
</span>
|
| 856 |
</h3>
|
| 857 |
<p style="color: #666; margin-bottom: 12px;">Draw a mask to select the area you want to regenerate</p>
|
| 858 |
+
</div>
|
| 859 |
+
""")
|
| 860 |
+
|
| 861 |
+
# Model Selection Guide
|
| 862 |
+
gr.HTML("""
|
| 863 |
+
<div style="background: linear-gradient(135deg, #f5f7fa 0%, #e4e8ec 100%);
|
| 864 |
+
padding: 16px;
|
| 865 |
+
border-radius: 12px;
|
| 866 |
+
margin: 12px 0;
|
| 867 |
+
border: 1px solid #ddd;">
|
| 868 |
+
<h4 style="margin: 0 0 12px 0; color: #333; font-size: 16px;">
|
| 869 |
+
📸 Model Selection Guide
|
| 870 |
+
</h4>
|
| 871 |
+
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 12px;">
|
| 872 |
+
<div style="background: white; padding: 12px; border-radius: 8px; border-left: 4px solid #4CAF50;">
|
| 873 |
+
<p style="margin: 0 0 8px 0; font-weight: bold; color: #4CAF50;">
|
| 874 |
+
🖼️ Photo Mode (Real Photos)
|
| 875 |
+
</p>
|
| 876 |
+
<p style="margin: 0; font-size: 13px; color: #555;">
|
| 877 |
+
<strong>Best for:</strong> Photographs, portraits, product shots, nature photos
|
| 878 |
+
</p>
|
| 879 |
+
<p style="margin: 8px 0 0 0; font-size: 12px; color: #777;">
|
| 880 |
+
• <strong>JuggernautXL</strong> - Best for portraits and people<br>
|
| 881 |
+
• <strong>RealVisXL</strong> - Best for scenes and objects
|
| 882 |
+
</p>
|
| 883 |
+
</div>
|
| 884 |
+
<div style="background: white; padding: 12px; border-radius: 8px; border-left: 4px solid #9C27B0;">
|
| 885 |
+
<p style="margin: 0 0 8px 0; font-weight: bold; color: #9C27B0;">
|
| 886 |
+
🎨 Anime Mode (Illustrations)
|
| 887 |
+
</p>
|
| 888 |
+
<p style="margin: 0; font-size: 13px; color: #555;">
|
| 889 |
+
<strong>Best for:</strong> Anime, manga, illustrations, digital art, cartoons
|
| 890 |
+
</p>
|
| 891 |
+
<p style="margin: 8px 0 0 0; font-size: 12px; color: #777;">
|
| 892 |
+
• <strong>Animagine XL</strong> - Best for anime/manga style<br>
|
| 893 |
+
• <strong>SDXL Base</strong> - Versatile for general art
|
| 894 |
+
</p>
|
| 895 |
+
</div>
|
| 896 |
</div>
|
| 897 |
</div>
|
| 898 |
""")
|
|
|
|
| 929 |
)
|
| 930 |
template_tips = gr.Markdown("")
|
| 931 |
|
| 932 |
+
# Mode info (dynamically updated based on template)
|
| 933 |
+
mode_info_html = gr.HTML("")
|
| 934 |
+
|
| 935 |
# Prompt
|
| 936 |
inpaint_prompt = gr.Textbox(
|
| 937 |
label="Prompt",
|
|
|
|
| 941 |
|
| 942 |
# Right column - Settings and Output
|
| 943 |
with gr.Column(scale=1):
|
| 944 |
+
# Model Selection (hidden for Pure Inpainting templates)
|
| 945 |
+
with gr.Group(visible=True) as model_selection_group:
|
| 946 |
+
with gr.Accordion("Model Selection", open=True):
|
| 947 |
+
model_choice = gr.Dropdown(
|
| 948 |
+
choices=[
|
| 949 |
+
("🖼️ JuggernautXL v9 - Best for portraits & real photos", "juggernaut_xl"),
|
| 950 |
+
("🖼️ RealVisXL v4 - Best for realistic scenes", "realvis_xl"),
|
| 951 |
+
("🎨 SDXL Base - Versatile for general art", "sdxl_base"),
|
| 952 |
+
("🎨 Animagine XL 3.1 - Best for anime/manga", "animagine_xl"),
|
| 953 |
+
],
|
| 954 |
+
value="juggernaut_xl",
|
| 955 |
+
label="Select Model",
|
| 956 |
+
info="Choose based on your image type (photo vs illustration)"
|
| 957 |
+
)
|
| 958 |
|
| 959 |
+
# ControlNet Settings (hidden for Pure Inpainting templates)
|
| 960 |
+
with gr.Group(visible=True) as controlnet_settings_group:
|
| 961 |
+
with gr.Accordion("ControlNet Settings", open=True):
|
| 962 |
+
conditioning_type = gr.Radio(
|
| 963 |
+
choices=["canny", "depth"],
|
| 964 |
+
value="canny",
|
| 965 |
+
label="ControlNet Mode",
|
| 966 |
+
info="Canny: preserves edges | Depth: preserves 3D structure"
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
conditioning_scale = gr.Slider(
|
| 970 |
+
minimum=0.05,
|
| 971 |
+
maximum=1.0,
|
| 972 |
+
value=0.7,
|
| 973 |
+
step=0.05,
|
| 974 |
+
label="ControlNet Strength",
|
| 975 |
+
info="Higher = more structure preservation"
|
| 976 |
+
)
|
| 977 |
|
| 978 |
+
# General Settings (always visible)
|
| 979 |
+
with gr.Accordion("General Settings", open=True):
|
| 980 |
feather_radius = gr.Slider(
|
| 981 |
minimum=0,
|
| 982 |
maximum=20,
|
| 983 |
value=8,
|
| 984 |
step=1,
|
| 985 |
+
label="Feather Radius (px)",
|
| 986 |
+
info="Edge blending softness"
|
| 987 |
)
|
| 988 |
|
| 989 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
| 1003 |
label="Inference Steps"
|
| 1004 |
)
|
| 1005 |
|
| 1006 |
+
# Seed control for reproducibility
|
| 1007 |
+
seed_input = gr.Number(
|
| 1008 |
+
label="Seed",
|
| 1009 |
+
value=-1,
|
| 1010 |
+
precision=0,
|
| 1011 |
+
info="-1 = random seed, or enter a specific number to reproduce results"
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
# Generate button
|
| 1015 |
inpaint_btn = gr.Button(
|
| 1016 |
"Generate Inpainting",
|
|
|
|
| 1027 |
border-radius: 8px;
|
| 1028 |
margin: 12px 0;">
|
| 1029 |
<p style="margin: 0; color: #5d4037; font-size: 14px;">
|
| 1030 |
+
⏳ <strong>Please be patient!</strong><br>
|
| 1031 |
+
• <strong>First run:</strong> 5-7 minutes (model initialization)<br>
|
| 1032 |
+
• <strong>Subsequent runs:</strong> 2-3 minutes (model cached)
|
| 1033 |
</p>
|
| 1034 |
</div>
|
| 1035 |
<div style="background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%);
|
|
|
|
| 1045 |
"""
|
| 1046 |
)
|
| 1047 |
|
| 1048 |
+
# Status and Seed display
|
| 1049 |
inpaint_status = gr.Textbox(
|
| 1050 |
label="Status",
|
| 1051 |
value="Ready for inpainting",
|
| 1052 |
interactive=False
|
| 1053 |
)
|
| 1054 |
|
| 1055 |
+
# Display used seed for reproducibility
|
| 1056 |
+
with gr.Row():
|
| 1057 |
+
used_seed_display = gr.Number(
|
| 1058 |
+
label="Used Seed (copy this to reproduce)",
|
| 1059 |
+
value=-1,
|
| 1060 |
+
precision=0,
|
| 1061 |
+
interactive=False
|
| 1062 |
+
)
|
| 1063 |
+
copy_seed_btn = gr.Button(
|
| 1064 |
+
"📋 Use This Seed",
|
| 1065 |
+
size="sm",
|
| 1066 |
+
scale=0
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
# Output row
|
| 1070 |
with gr.Row():
|
| 1071 |
with gr.Column(scale=1):
|
|
|
|
| 1087 |
inpaint_template.change(
|
| 1088 |
fn=self.apply_inpainting_template,
|
| 1089 |
inputs=[inpaint_template, inpaint_prompt],
|
| 1090 |
+
outputs=[
|
| 1091 |
+
inpaint_prompt,
|
| 1092 |
+
conditioning_scale,
|
| 1093 |
+
feather_radius,
|
| 1094 |
+
conditioning_type,
|
| 1095 |
+
controlnet_settings_group,
|
| 1096 |
+
mode_info_html,
|
| 1097 |
+
model_selection_group
|
| 1098 |
+
]
|
| 1099 |
)
|
| 1100 |
|
| 1101 |
inpaint_template.change(
|
|
|
|
| 1104 |
outputs=[template_tips]
|
| 1105 |
)
|
| 1106 |
|
| 1107 |
+
# Copy uploaded image to mask editor (as background)
|
| 1108 |
+
def set_mask_editor_background(image):
|
| 1109 |
+
"""Set uploaded image as mask editor background."""
|
| 1110 |
+
if image is None:
|
| 1111 |
+
return None
|
| 1112 |
+
# Return dict format for ImageEditor with background
|
| 1113 |
+
return {"background": image, "layers": [], "composite": None}
|
| 1114 |
+
|
| 1115 |
inpaint_image.change(
|
| 1116 |
+
fn=set_mask_editor_background,
|
| 1117 |
inputs=[inpaint_image],
|
| 1118 |
outputs=[mask_editor]
|
| 1119 |
)
|
|
|
|
| 1125 |
mask_editor,
|
| 1126 |
inpaint_prompt,
|
| 1127 |
inpaint_template,
|
| 1128 |
+
model_choice,
|
| 1129 |
conditioning_type,
|
| 1130 |
conditioning_scale,
|
| 1131 |
feather_radius,
|
| 1132 |
inpaint_guidance,
|
| 1133 |
+
inpaint_steps,
|
| 1134 |
+
seed_input
|
| 1135 |
],
|
| 1136 |
outputs=[
|
| 1137 |
inpaint_result,
|
| 1138 |
inpaint_control,
|
| 1139 |
+
inpaint_status,
|
| 1140 |
+
used_seed_display
|
| 1141 |
]
|
| 1142 |
)
|
| 1143 |
|
| 1144 |
+
# Copy seed button - copies used seed to input
|
| 1145 |
+
copy_seed_btn.click(
|
| 1146 |
+
fn=lambda x: x,
|
| 1147 |
+
inputs=[used_seed_display],
|
| 1148 |
+
outputs=[seed_input]
|
| 1149 |
+
)
|
| 1150 |
+
|
| 1151 |
return tab
|
| 1152 |
|
| 1153 |
def _get_template_tips(self, display_name: str) -> str:
|
|
|
|
| 1162 |
tips = self.inpainting_template_manager.get_usage_tips(template_key)
|
| 1163 |
if tips:
|
| 1164 |
return "**Tips:**\n" + "\n".join(f"- {tip}" for tip in tips)
|
| 1165 |
+
return ""
|