SceneWeaver / inpainting_module.py
DawnC's picture
Upload 15 files
991a517 verified
import gc
import logging
import os
import time
import traceback
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional, Tuple
import cv2
import numpy as np
import torch
from PIL import Image
from diffusers import AutoPipelineForInpainting
from diffusers import ControlNetModel
from diffusers import DPMSolverMultistepScheduler
from diffusers import StableDiffusionXLControlNetInpaintPipeline
from transformers import AutoImageProcessor
from transformers import AutoModelForDepthEstimation
from transformers import DPTForDepthEstimation
from transformers import DPTImageProcessor
from control_image_processor import ControlImageProcessor
from inpainting_blender import InpaintingBlender
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Dedicated SDXL Inpainting model - trained specifically for inpainting
SDXL_INPAINTING_MODEL = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
@dataclass
class InpaintingConfig:
"""Configuration for inpainting operations."""
# ControlNet settings (for ControlNet mode only)
controlnet_conditioning_scale: float = 0.7
conditioning_type: str = "canny"
# Canny edge detection parameters
canny_low_threshold: int = 100
canny_high_threshold: int = 200
# Mask settings
feather_radius: int = 3
min_mask_coverage: float = 0.01
max_mask_coverage: float = 0.95
# Generation settings
num_inference_steps: int = 25
guidance_scale: float = 7.5
strength: float = 0.99 # Use 0.99 to avoid noise issues with 1.0
# Memory settings
enable_vae_tiling: bool = True
max_resolution: int = 1024
@dataclass
class InpaintingResult:
"""Result container for inpainting operations."""
success: bool
result_image: Optional[Image.Image] = None
preview_image: Optional[Image.Image] = None
control_image: Optional[Image.Image] = None
blended_image: Optional[Image.Image] = None
quality_score: float = 0.0
generation_time: float = 0.0
error_message: str = ""
metadata: Dict[str, Any] = field(default_factory=dict)
class InpaintingModule:
"""
Dual-mode Inpainting Module for SceneWeaver.
Supports two modes:
1. Pure Inpainting (use_controlnet=False): Uses dedicated SDXL Inpainting model
- Best for: Object replacement, Object removal
- More stable, better edge blending
2. ControlNet Inpainting (use_controlnet=True): Uses ControlNet + SDXL
- Best for: Clothing change (depth), Color change (canny)
- Preserves structure in masked region
Example:
>>> module = InpaintingModule(device="cuda")
>>> # For object replacement (no ControlNet)
>>> module.load_pipeline(use_controlnet=False)
>>> result = module.execute_inpainting(image, mask, "a vase with flowers")
"""
# ControlNet model identifiers
CONTROLNET_CANNY_MODEL = "diffusers/controlnet-canny-sdxl-1.0"
CONTROLNET_DEPTH_MODEL = "diffusers/controlnet-depth-sdxl-1.0"
DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf"
DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas"
# Base models for ControlNet mode
SUPPORTED_MODELS = {
"juggernaut_xl": "RunDiffusion/Juggernaut-XL-v9",
"realvis_xl": "SG161222/RealVisXL_V4.0",
"sdxl_base": "stabilityai/stable-diffusion-xl-base-1.0",
"animagine_xl": "cagliostrolab/animagine-xl-3.1",
}
def __init__(
self,
device: str = "auto",
config: Optional[InpaintingConfig] = None
):
"""Initialize the InpaintingModule."""
self.device = self._setup_device(device)
self.config = config or InpaintingConfig()
# Sub-modules
self._control_processor = ControlImageProcessor(
device=self.device,
canny_low_threshold=self.config.canny_low_threshold,
canny_high_threshold=self.config.canny_high_threshold
)
self._blender = InpaintingBlender(
min_mask_coverage=self.config.min_mask_coverage,
max_mask_coverage=self.config.max_mask_coverage
)
# Pipeline instances
self._pipeline = None
self._controlnet = None
self._depth_estimator = None
self._depth_processor = None
# State tracking
self.is_initialized = False
self._current_mode = None # "pure" or "controlnet"
self._current_conditioning_type = None
self._current_model_key = None
logger.info(f"InpaintingModule initialized on {self.device}")
def _setup_device(self, device: str) -> str:
"""Setup computation device."""
if device == "auto":
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps"
return "cpu"
return device
def _memory_cleanup(self, aggressive: bool = False) -> None:
"""Perform memory cleanup."""
for _ in range(5 if aggressive else 2):
gc.collect()
is_spaces = os.getenv('SPACE_ID') is not None
if not is_spaces and torch.cuda.is_available():
torch.cuda.empty_cache()
if aggressive:
torch.cuda.ipc_collect()
def load_pipeline(
self,
use_controlnet: bool = False,
conditioning_type: str = "canny",
model_key: str = "sdxl_base",
progress_callback: Optional[Callable[[str, int], None]] = None
) -> Tuple[bool, str]:
"""
Load the appropriate inpainting pipeline.
Parameters
----------
use_controlnet : bool
If False, use dedicated SDXL Inpainting model (for replacement/removal)
If True, use ControlNet pipeline (for clothing/color change)
conditioning_type : str
ControlNet type: "canny" or "depth" (only used when use_controlnet=True)
model_key : str
Base model for ControlNet mode
progress_callback : callable, optional
Progress update function
Returns
-------
tuple
(success: bool, error_message: str)
"""
mode = "controlnet" if use_controlnet else "pure"
# Check if already loaded with same config
if (self.is_initialized and
self._current_mode == mode and
(not use_controlnet or
(self._current_conditioning_type == conditioning_type and
self._current_model_key == model_key))):
logger.info(f"Pipeline already loaded: mode={mode}")
return True, ""
logger.info(f"Loading pipeline: mode={mode}, conditioning={conditioning_type}")
try:
self._memory_cleanup(aggressive=True)
if progress_callback:
progress_callback("Preparing pipeline...", 10)
# Unload existing pipeline
self._unload_pipeline()
dtype = torch.float16 if self.device == "cuda" else torch.float32
if not use_controlnet:
# Mode A: Pure SDXL Inpainting (for replacement/removal)
if progress_callback:
progress_callback("Loading SDXL Inpainting model...", 30)
self._pipeline = AutoPipelineForInpainting.from_pretrained(
SDXL_INPAINTING_MODEL,
torch_dtype=dtype,
variant="fp16" if dtype == torch.float16 else None,
)
self._current_mode = "pure"
self._current_conditioning_type = None
logger.info("Loaded pure SDXL Inpainting pipeline")
else:
# Mode B: ControlNet Inpainting (for structure-preserving tasks)
if model_key not in self.SUPPORTED_MODELS:
model_key = "sdxl_base"
base_model_id = self.SUPPORTED_MODELS[model_key]
if progress_callback:
progress_callback("Loading ControlNet model...", 30)
# Load ControlNet
if conditioning_type == "canny":
self._controlnet = ControlNetModel.from_pretrained(
self.CONTROLNET_CANNY_MODEL,
torch_dtype=dtype,
use_safetensors=True
)
elif conditioning_type == "depth":
self._controlnet = ControlNetModel.from_pretrained(
self.CONTROLNET_DEPTH_MODEL,
torch_dtype=dtype,
use_safetensors=True
)
self._load_depth_estimator()
else:
raise ValueError(f"Unknown conditioning type: {conditioning_type}")
if progress_callback:
progress_callback(f"Loading {model_key}...", 60)
# Load pipeline with ControlNet
use_variant = model_key != "animagine_xl"
load_kwargs = {
"controlnet": self._controlnet,
"torch_dtype": dtype,
"use_safetensors": True,
}
if use_variant and dtype == torch.float16:
load_kwargs["variant"] = "fp16"
self._pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
base_model_id,
**load_kwargs
)
self._current_mode = "controlnet"
self._current_conditioning_type = conditioning_type
self._current_model_key = model_key
logger.info(f"Loaded ControlNet pipeline: {model_key} + {conditioning_type}")
if progress_callback:
progress_callback("Configuring pipeline...", 80)
# Configure scheduler
self._pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
self._pipeline.scheduler.config
)
# Move to device and optimize
self._pipeline = self._pipeline.to(self.device)
self._apply_optimizations()
self.is_initialized = True
if progress_callback:
progress_callback("Pipeline ready!", 100)
return True, ""
except Exception as e:
error_msg = str(e)
logger.error(f"Failed to load pipeline: {error_msg}")
traceback.print_exc()
self._unload_pipeline()
return False, error_msg
def _load_depth_estimator(self) -> None:
"""Load depth estimation model."""
try:
self._depth_processor = AutoImageProcessor.from_pretrained(
self.DEPTH_MODEL_PRIMARY
)
self._depth_estimator = AutoModelForDepthEstimation.from_pretrained(
self.DEPTH_MODEL_PRIMARY,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
self._depth_estimator.to(self.device)
self._depth_estimator.eval()
logger.info("Loaded Depth-Anything model")
except Exception as e:
logger.warning(f"Primary depth model failed: {e}, trying fallback...")
self._depth_processor = DPTImageProcessor.from_pretrained(
self.DEPTH_MODEL_FALLBACK
)
self._depth_estimator = DPTForDepthEstimation.from_pretrained(
self.DEPTH_MODEL_FALLBACK,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
self._depth_estimator.to(self.device)
self._depth_estimator.eval()
logger.info("Loaded MiDaS fallback model")
def _apply_optimizations(self) -> None:
"""Apply memory and performance optimizations."""
if self._pipeline is None:
return
try:
self._pipeline.enable_xformers_memory_efficient_attention()
logger.info("Enabled xformers attention")
except Exception:
try:
self._pipeline.enable_attention_slicing()
logger.info("Enabled attention slicing")
except Exception:
pass
if self.config.enable_vae_tiling:
if hasattr(self._pipeline, 'enable_vae_tiling'):
self._pipeline.enable_vae_tiling()
if hasattr(self._pipeline, 'enable_vae_slicing'):
self._pipeline.enable_vae_slicing()
def _unload_pipeline(self) -> None:
"""Unload pipeline and free memory."""
if self._pipeline is not None:
del self._pipeline
self._pipeline = None
if self._controlnet is not None:
del self._controlnet
self._controlnet = None
if self._depth_estimator is not None:
del self._depth_estimator
self._depth_estimator = None
if self._depth_processor is not None:
del self._depth_processor
self._depth_processor = None
self.is_initialized = False
self._current_mode = None
self._current_conditioning_type = None
self._memory_cleanup(aggressive=True)
logger.info("Pipeline unloaded")
def execute_inpainting(
self,
image: Image.Image,
mask: Image.Image,
prompt: str,
progress_callback: Optional[Callable[[str, int], None]] = None,
**kwargs
) -> InpaintingResult:
"""
Execute inpainting operation.
Parameters
----------
image : PIL.Image
Original image
mask : PIL.Image
Inpainting mask (white = area to regenerate)
prompt : str
Text description
progress_callback : callable, optional
Progress update function
**kwargs
Additional parameters from template
Returns
-------
InpaintingResult
Result with generated image
"""
start_time = time.time()
if not self.is_initialized:
return InpaintingResult(
success=False,
error_message="Pipeline not initialized. Call load_pipeline() first."
)
logger.info(f"Inpainting: mode={self._current_mode}, prompt='{prompt[:50]}...'")
try:
if progress_callback:
progress_callback("Preparing images...", 10)
# Prepare image
if image.mode != 'RGB':
image = image.convert('RGB')
# Store original size for later restoration
original_size = image.size # (width, height)
# Ensure dimensions are multiple of 8 for model compatibility
width, height = image.size
new_width = (width // 8) * 8
new_height = (height // 8) * 8
if new_width != width or new_height != height:
image = image.resize((new_width, new_height), Image.LANCZOS)
# Limit resolution for memory efficiency
max_res = self.config.max_resolution
if max(new_width, new_height) > max_res:
scale = max_res / max(new_width, new_height)
new_width = int(new_width * scale) // 8 * 8
new_height = int(new_height * scale) // 8 * 8
image = image.resize((new_width, new_height), Image.LANCZOS)
# Prepare mask with dilation
mask_dilation = kwargs.get('mask_dilation', 0)
processed_mask = self._prepare_mask(
mask,
(new_width, new_height),
dilation=mask_dilation,
feather_radius=kwargs.get('feather_radius', self.config.feather_radius)
)
# Get generation parameters
strength = kwargs.get('strength', self.config.strength)
guidance_scale = kwargs.get('guidance_scale', self.config.guidance_scale)
num_steps = kwargs.get('num_inference_steps', self.config.num_inference_steps)
negative_prompt = kwargs.get('negative_prompt', "")
# Optimize for HuggingFace Spaces
is_spaces = os.getenv('SPACE_ID') is not None
if is_spaces:
num_steps = min(num_steps, 15)
# Setup generator with seed
# If seed is -1 or None, use random seed based on current time
input_seed = kwargs.get('seed', -1)
if input_seed is None or input_seed < 0:
seed = int(time.time() * 1000) % (2**32)
else:
seed = int(input_seed)
generator = torch.Generator(device=self.device).manual_seed(seed)
logger.info(f"Using seed: {seed}")
# Generate based on mode
if self._current_mode == "pure":
# Pure inpainting - no ControlNet
if progress_callback:
progress_callback("Generating (Pure Inpainting)...", 40)
result_image = self._generate_pure_inpaint(
image=image,
mask=processed_mask,
prompt=prompt,
negative_prompt=negative_prompt,
num_steps=num_steps,
guidance_scale=guidance_scale,
strength=strength,
generator=generator
)
control_image = None
else:
# ControlNet inpainting
if progress_callback:
progress_callback("Generating control image...", 30)
# Prepare control image
preserve_structure = kwargs.get('preserve_structure_in_mask', False)
edge_guidance_mode = kwargs.get('edge_guidance_mode', 'boundary')
control_image = self._control_processor.prepare_control_image(
image=image,
mode=self._current_conditioning_type,
mask=processed_mask,
preserve_structure=preserve_structure,
edge_guidance_mode=edge_guidance_mode
)
if progress_callback:
progress_callback("Generating (ControlNet)...", 50)
conditioning_scale = kwargs.get(
'controlnet_conditioning_scale',
self.config.controlnet_conditioning_scale
)
result_image = self._generate_controlnet_inpaint(
image=image,
mask=processed_mask,
control_image=control_image,
prompt=prompt,
negative_prompt=negative_prompt,
num_steps=num_steps,
guidance_scale=guidance_scale,
conditioning_scale=conditioning_scale,
strength=strength,
generator=generator
)
generation_time = time.time() - start_time
# Restore original size if it was changed
if result_image.size != original_size:
result_image = result_image.resize(original_size, Image.LANCZOS)
logger.info(f"Restored result to original size: {original_size}")
if progress_callback:
progress_callback("Complete!", 100)
return InpaintingResult(
success=True,
result_image=result_image,
blended_image=result_image, # Pipeline output is already blended
control_image=control_image,
generation_time=generation_time,
metadata={
"seed": seed,
"prompt": prompt,
"mode": self._current_mode,
"num_steps": num_steps,
"guidance_scale": guidance_scale,
"strength": strength,
"original_size": original_size,
}
)
except torch.cuda.OutOfMemoryError:
logger.error("CUDA out of memory")
self._memory_cleanup(aggressive=True)
return InpaintingResult(
success=False,
error_message="GPU memory exhausted."
)
except Exception as e:
logger.error(f"Inpainting failed: {e}")
traceback.print_exc()
return InpaintingResult(
success=False,
error_message=str(e)
)
def _prepare_mask(
self,
mask: Image.Image,
target_size: Tuple[int, int],
dilation: int = 0,
feather_radius: int = 3
) -> Image.Image:
"""Prepare mask with optional dilation and feathering."""
# Convert and resize
if mask.mode != 'L':
mask = mask.convert('L')
if mask.size != target_size:
mask = mask.resize(target_size, Image.LANCZOS)
mask_array = np.array(mask)
# Apply dilation to expand mask
if dilation > 0:
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE,
(dilation * 2 + 1, dilation * 2 + 1)
)
mask_array = cv2.dilate(mask_array, kernel, iterations=1)
logger.debug(f"Applied mask dilation: {dilation}px")
# Apply feathering
if feather_radius > 0:
mask_array = cv2.GaussianBlur(
mask_array,
(feather_radius * 2 + 1, feather_radius * 2 + 1),
feather_radius / 2
)
return Image.fromarray(mask_array, mode='L')
def _generate_pure_inpaint(
self,
image: Image.Image,
mask: Image.Image,
prompt: str,
negative_prompt: str,
num_steps: int,
guidance_scale: float,
strength: float,
generator: torch.Generator
) -> Image.Image:
"""Generate using pure SDXL Inpainting pipeline."""
with torch.inference_mode():
result = self._pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
mask_image=mask,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
strength=strength,
generator=generator
)
return result.images[0]
def _generate_controlnet_inpaint(
self,
image: Image.Image,
mask: Image.Image,
control_image: Image.Image,
prompt: str,
negative_prompt: str,
num_steps: int,
guidance_scale: float,
conditioning_scale: float,
strength: float,
generator: torch.Generator
) -> Image.Image:
"""Generate using ControlNet Inpainting pipeline."""
with torch.inference_mode():
result = self._pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
mask_image=mask,
control_image=control_image,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=conditioning_scale,
strength=strength,
generator=generator
)
return result.images[0]
def get_status(self) -> Dict[str, Any]:
"""Get current module status."""
return {
"initialized": self.is_initialized,
"device": self.device,
"mode": self._current_mode,
"conditioning_type": self._current_conditioning_type,
"model_key": self._current_model_key,
}