Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| import logging | |
| import os | |
| import time | |
| import traceback | |
| from dataclasses import dataclass, field | |
| from typing import Any, Callable, Dict, Optional, Tuple | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from diffusers import AutoPipelineForInpainting | |
| from diffusers import ControlNetModel | |
| from diffusers import DPMSolverMultistepScheduler | |
| from diffusers import StableDiffusionXLControlNetInpaintPipeline | |
| from transformers import AutoImageProcessor | |
| from transformers import AutoModelForDepthEstimation | |
| from transformers import DPTForDepthEstimation | |
| from transformers import DPTImageProcessor | |
| from control_image_processor import ControlImageProcessor | |
| from inpainting_blender import InpaintingBlender | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| # Dedicated SDXL Inpainting model - trained specifically for inpainting | |
| SDXL_INPAINTING_MODEL = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1" | |
| class InpaintingConfig: | |
| """Configuration for inpainting operations.""" | |
| # ControlNet settings (for ControlNet mode only) | |
| controlnet_conditioning_scale: float = 0.7 | |
| conditioning_type: str = "canny" | |
| # Canny edge detection parameters | |
| canny_low_threshold: int = 100 | |
| canny_high_threshold: int = 200 | |
| # Mask settings | |
| feather_radius: int = 3 | |
| min_mask_coverage: float = 0.01 | |
| max_mask_coverage: float = 0.95 | |
| # Generation settings | |
| num_inference_steps: int = 25 | |
| guidance_scale: float = 7.5 | |
| strength: float = 0.99 # Use 0.99 to avoid noise issues with 1.0 | |
| # Memory settings | |
| enable_vae_tiling: bool = True | |
| max_resolution: int = 1024 | |
| class InpaintingResult: | |
| """Result container for inpainting operations.""" | |
| success: bool | |
| result_image: Optional[Image.Image] = None | |
| preview_image: Optional[Image.Image] = None | |
| control_image: Optional[Image.Image] = None | |
| blended_image: Optional[Image.Image] = None | |
| quality_score: float = 0.0 | |
| generation_time: float = 0.0 | |
| error_message: str = "" | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| class InpaintingModule: | |
| """ | |
| Dual-mode Inpainting Module for SceneWeaver. | |
| Supports two modes: | |
| 1. Pure Inpainting (use_controlnet=False): Uses dedicated SDXL Inpainting model | |
| - Best for: Object replacement, Object removal | |
| - More stable, better edge blending | |
| 2. ControlNet Inpainting (use_controlnet=True): Uses ControlNet + SDXL | |
| - Best for: Clothing change (depth), Color change (canny) | |
| - Preserves structure in masked region | |
| Example: | |
| >>> module = InpaintingModule(device="cuda") | |
| >>> # For object replacement (no ControlNet) | |
| >>> module.load_pipeline(use_controlnet=False) | |
| >>> result = module.execute_inpainting(image, mask, "a vase with flowers") | |
| """ | |
| # ControlNet model identifiers | |
| CONTROLNET_CANNY_MODEL = "diffusers/controlnet-canny-sdxl-1.0" | |
| CONTROLNET_DEPTH_MODEL = "diffusers/controlnet-depth-sdxl-1.0" | |
| DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf" | |
| DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas" | |
| # Base models for ControlNet mode | |
| SUPPORTED_MODELS = { | |
| "juggernaut_xl": "RunDiffusion/Juggernaut-XL-v9", | |
| "realvis_xl": "SG161222/RealVisXL_V4.0", | |
| "sdxl_base": "stabilityai/stable-diffusion-xl-base-1.0", | |
| "animagine_xl": "cagliostrolab/animagine-xl-3.1", | |
| } | |
| def __init__( | |
| self, | |
| device: str = "auto", | |
| config: Optional[InpaintingConfig] = None | |
| ): | |
| """Initialize the InpaintingModule.""" | |
| self.device = self._setup_device(device) | |
| self.config = config or InpaintingConfig() | |
| # Sub-modules | |
| self._control_processor = ControlImageProcessor( | |
| device=self.device, | |
| canny_low_threshold=self.config.canny_low_threshold, | |
| canny_high_threshold=self.config.canny_high_threshold | |
| ) | |
| self._blender = InpaintingBlender( | |
| min_mask_coverage=self.config.min_mask_coverage, | |
| max_mask_coverage=self.config.max_mask_coverage | |
| ) | |
| # Pipeline instances | |
| self._pipeline = None | |
| self._controlnet = None | |
| self._depth_estimator = None | |
| self._depth_processor = None | |
| # State tracking | |
| self.is_initialized = False | |
| self._current_mode = None # "pure" or "controlnet" | |
| self._current_conditioning_type = None | |
| self._current_model_key = None | |
| logger.info(f"InpaintingModule initialized on {self.device}") | |
| def _setup_device(self, device: str) -> str: | |
| """Setup computation device.""" | |
| if device == "auto": | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| return device | |
| def _memory_cleanup(self, aggressive: bool = False) -> None: | |
| """Perform memory cleanup.""" | |
| for _ in range(5 if aggressive else 2): | |
| gc.collect() | |
| is_spaces = os.getenv('SPACE_ID') is not None | |
| if not is_spaces and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if aggressive: | |
| torch.cuda.ipc_collect() | |
| def load_pipeline( | |
| self, | |
| use_controlnet: bool = False, | |
| conditioning_type: str = "canny", | |
| model_key: str = "sdxl_base", | |
| progress_callback: Optional[Callable[[str, int], None]] = None | |
| ) -> Tuple[bool, str]: | |
| """ | |
| Load the appropriate inpainting pipeline. | |
| Parameters | |
| ---------- | |
| use_controlnet : bool | |
| If False, use dedicated SDXL Inpainting model (for replacement/removal) | |
| If True, use ControlNet pipeline (for clothing/color change) | |
| conditioning_type : str | |
| ControlNet type: "canny" or "depth" (only used when use_controlnet=True) | |
| model_key : str | |
| Base model for ControlNet mode | |
| progress_callback : callable, optional | |
| Progress update function | |
| Returns | |
| ------- | |
| tuple | |
| (success: bool, error_message: str) | |
| """ | |
| mode = "controlnet" if use_controlnet else "pure" | |
| # Check if already loaded with same config | |
| if (self.is_initialized and | |
| self._current_mode == mode and | |
| (not use_controlnet or | |
| (self._current_conditioning_type == conditioning_type and | |
| self._current_model_key == model_key))): | |
| logger.info(f"Pipeline already loaded: mode={mode}") | |
| return True, "" | |
| logger.info(f"Loading pipeline: mode={mode}, conditioning={conditioning_type}") | |
| try: | |
| self._memory_cleanup(aggressive=True) | |
| if progress_callback: | |
| progress_callback("Preparing pipeline...", 10) | |
| # Unload existing pipeline | |
| self._unload_pipeline() | |
| dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| if not use_controlnet: | |
| # Mode A: Pure SDXL Inpainting (for replacement/removal) | |
| if progress_callback: | |
| progress_callback("Loading SDXL Inpainting model...", 30) | |
| self._pipeline = AutoPipelineForInpainting.from_pretrained( | |
| SDXL_INPAINTING_MODEL, | |
| torch_dtype=dtype, | |
| variant="fp16" if dtype == torch.float16 else None, | |
| ) | |
| self._current_mode = "pure" | |
| self._current_conditioning_type = None | |
| logger.info("Loaded pure SDXL Inpainting pipeline") | |
| else: | |
| # Mode B: ControlNet Inpainting (for structure-preserving tasks) | |
| if model_key not in self.SUPPORTED_MODELS: | |
| model_key = "sdxl_base" | |
| base_model_id = self.SUPPORTED_MODELS[model_key] | |
| if progress_callback: | |
| progress_callback("Loading ControlNet model...", 30) | |
| # Load ControlNet | |
| if conditioning_type == "canny": | |
| self._controlnet = ControlNetModel.from_pretrained( | |
| self.CONTROLNET_CANNY_MODEL, | |
| torch_dtype=dtype, | |
| use_safetensors=True | |
| ) | |
| elif conditioning_type == "depth": | |
| self._controlnet = ControlNetModel.from_pretrained( | |
| self.CONTROLNET_DEPTH_MODEL, | |
| torch_dtype=dtype, | |
| use_safetensors=True | |
| ) | |
| self._load_depth_estimator() | |
| else: | |
| raise ValueError(f"Unknown conditioning type: {conditioning_type}") | |
| if progress_callback: | |
| progress_callback(f"Loading {model_key}...", 60) | |
| # Load pipeline with ControlNet | |
| use_variant = model_key != "animagine_xl" | |
| load_kwargs = { | |
| "controlnet": self._controlnet, | |
| "torch_dtype": dtype, | |
| "use_safetensors": True, | |
| } | |
| if use_variant and dtype == torch.float16: | |
| load_kwargs["variant"] = "fp16" | |
| self._pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( | |
| base_model_id, | |
| **load_kwargs | |
| ) | |
| self._current_mode = "controlnet" | |
| self._current_conditioning_type = conditioning_type | |
| self._current_model_key = model_key | |
| logger.info(f"Loaded ControlNet pipeline: {model_key} + {conditioning_type}") | |
| if progress_callback: | |
| progress_callback("Configuring pipeline...", 80) | |
| # Configure scheduler | |
| self._pipeline.scheduler = DPMSolverMultistepScheduler.from_config( | |
| self._pipeline.scheduler.config | |
| ) | |
| # Move to device and optimize | |
| self._pipeline = self._pipeline.to(self.device) | |
| self._apply_optimizations() | |
| self.is_initialized = True | |
| if progress_callback: | |
| progress_callback("Pipeline ready!", 100) | |
| return True, "" | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.error(f"Failed to load pipeline: {error_msg}") | |
| traceback.print_exc() | |
| self._unload_pipeline() | |
| return False, error_msg | |
| def _load_depth_estimator(self) -> None: | |
| """Load depth estimation model.""" | |
| try: | |
| self._depth_processor = AutoImageProcessor.from_pretrained( | |
| self.DEPTH_MODEL_PRIMARY | |
| ) | |
| self._depth_estimator = AutoModelForDepthEstimation.from_pretrained( | |
| self.DEPTH_MODEL_PRIMARY, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ) | |
| self._depth_estimator.to(self.device) | |
| self._depth_estimator.eval() | |
| logger.info("Loaded Depth-Anything model") | |
| except Exception as e: | |
| logger.warning(f"Primary depth model failed: {e}, trying fallback...") | |
| self._depth_processor = DPTImageProcessor.from_pretrained( | |
| self.DEPTH_MODEL_FALLBACK | |
| ) | |
| self._depth_estimator = DPTForDepthEstimation.from_pretrained( | |
| self.DEPTH_MODEL_FALLBACK, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ) | |
| self._depth_estimator.to(self.device) | |
| self._depth_estimator.eval() | |
| logger.info("Loaded MiDaS fallback model") | |
| def _apply_optimizations(self) -> None: | |
| """Apply memory and performance optimizations.""" | |
| if self._pipeline is None: | |
| return | |
| try: | |
| self._pipeline.enable_xformers_memory_efficient_attention() | |
| logger.info("Enabled xformers attention") | |
| except Exception: | |
| try: | |
| self._pipeline.enable_attention_slicing() | |
| logger.info("Enabled attention slicing") | |
| except Exception: | |
| pass | |
| if self.config.enable_vae_tiling: | |
| if hasattr(self._pipeline, 'enable_vae_tiling'): | |
| self._pipeline.enable_vae_tiling() | |
| if hasattr(self._pipeline, 'enable_vae_slicing'): | |
| self._pipeline.enable_vae_slicing() | |
| def _unload_pipeline(self) -> None: | |
| """Unload pipeline and free memory.""" | |
| if self._pipeline is not None: | |
| del self._pipeline | |
| self._pipeline = None | |
| if self._controlnet is not None: | |
| del self._controlnet | |
| self._controlnet = None | |
| if self._depth_estimator is not None: | |
| del self._depth_estimator | |
| self._depth_estimator = None | |
| if self._depth_processor is not None: | |
| del self._depth_processor | |
| self._depth_processor = None | |
| self.is_initialized = False | |
| self._current_mode = None | |
| self._current_conditioning_type = None | |
| self._memory_cleanup(aggressive=True) | |
| logger.info("Pipeline unloaded") | |
| def execute_inpainting( | |
| self, | |
| image: Image.Image, | |
| mask: Image.Image, | |
| prompt: str, | |
| progress_callback: Optional[Callable[[str, int], None]] = None, | |
| **kwargs | |
| ) -> InpaintingResult: | |
| """ | |
| Execute inpainting operation. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Original image | |
| mask : PIL.Image | |
| Inpainting mask (white = area to regenerate) | |
| prompt : str | |
| Text description | |
| progress_callback : callable, optional | |
| Progress update function | |
| **kwargs | |
| Additional parameters from template | |
| Returns | |
| ------- | |
| InpaintingResult | |
| Result with generated image | |
| """ | |
| start_time = time.time() | |
| if not self.is_initialized: | |
| return InpaintingResult( | |
| success=False, | |
| error_message="Pipeline not initialized. Call load_pipeline() first." | |
| ) | |
| logger.info(f"Inpainting: mode={self._current_mode}, prompt='{prompt[:50]}...'") | |
| try: | |
| if progress_callback: | |
| progress_callback("Preparing images...", 10) | |
| # Prepare image | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Store original size for later restoration | |
| original_size = image.size # (width, height) | |
| # Ensure dimensions are multiple of 8 for model compatibility | |
| width, height = image.size | |
| new_width = (width // 8) * 8 | |
| new_height = (height // 8) * 8 | |
| if new_width != width or new_height != height: | |
| image = image.resize((new_width, new_height), Image.LANCZOS) | |
| # Limit resolution for memory efficiency | |
| max_res = self.config.max_resolution | |
| if max(new_width, new_height) > max_res: | |
| scale = max_res / max(new_width, new_height) | |
| new_width = int(new_width * scale) // 8 * 8 | |
| new_height = int(new_height * scale) // 8 * 8 | |
| image = image.resize((new_width, new_height), Image.LANCZOS) | |
| # Prepare mask with dilation | |
| mask_dilation = kwargs.get('mask_dilation', 0) | |
| processed_mask = self._prepare_mask( | |
| mask, | |
| (new_width, new_height), | |
| dilation=mask_dilation, | |
| feather_radius=kwargs.get('feather_radius', self.config.feather_radius) | |
| ) | |
| # Get generation parameters | |
| strength = kwargs.get('strength', self.config.strength) | |
| guidance_scale = kwargs.get('guidance_scale', self.config.guidance_scale) | |
| num_steps = kwargs.get('num_inference_steps', self.config.num_inference_steps) | |
| negative_prompt = kwargs.get('negative_prompt', "") | |
| # Optimize for HuggingFace Spaces | |
| is_spaces = os.getenv('SPACE_ID') is not None | |
| if is_spaces: | |
| num_steps = min(num_steps, 15) | |
| # Setup generator with seed | |
| # If seed is -1 or None, use random seed based on current time | |
| input_seed = kwargs.get('seed', -1) | |
| if input_seed is None or input_seed < 0: | |
| seed = int(time.time() * 1000) % (2**32) | |
| else: | |
| seed = int(input_seed) | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| logger.info(f"Using seed: {seed}") | |
| # Generate based on mode | |
| if self._current_mode == "pure": | |
| # Pure inpainting - no ControlNet | |
| if progress_callback: | |
| progress_callback("Generating (Pure Inpainting)...", 40) | |
| result_image = self._generate_pure_inpaint( | |
| image=image, | |
| mask=processed_mask, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| strength=strength, | |
| generator=generator | |
| ) | |
| control_image = None | |
| else: | |
| # ControlNet inpainting | |
| if progress_callback: | |
| progress_callback("Generating control image...", 30) | |
| # Prepare control image | |
| preserve_structure = kwargs.get('preserve_structure_in_mask', False) | |
| edge_guidance_mode = kwargs.get('edge_guidance_mode', 'boundary') | |
| control_image = self._control_processor.prepare_control_image( | |
| image=image, | |
| mode=self._current_conditioning_type, | |
| mask=processed_mask, | |
| preserve_structure=preserve_structure, | |
| edge_guidance_mode=edge_guidance_mode | |
| ) | |
| if progress_callback: | |
| progress_callback("Generating (ControlNet)...", 50) | |
| conditioning_scale = kwargs.get( | |
| 'controlnet_conditioning_scale', | |
| self.config.controlnet_conditioning_scale | |
| ) | |
| result_image = self._generate_controlnet_inpaint( | |
| image=image, | |
| mask=processed_mask, | |
| control_image=control_image, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| conditioning_scale=conditioning_scale, | |
| strength=strength, | |
| generator=generator | |
| ) | |
| generation_time = time.time() - start_time | |
| # Restore original size if it was changed | |
| if result_image.size != original_size: | |
| result_image = result_image.resize(original_size, Image.LANCZOS) | |
| logger.info(f"Restored result to original size: {original_size}") | |
| if progress_callback: | |
| progress_callback("Complete!", 100) | |
| return InpaintingResult( | |
| success=True, | |
| result_image=result_image, | |
| blended_image=result_image, # Pipeline output is already blended | |
| control_image=control_image, | |
| generation_time=generation_time, | |
| metadata={ | |
| "seed": seed, | |
| "prompt": prompt, | |
| "mode": self._current_mode, | |
| "num_steps": num_steps, | |
| "guidance_scale": guidance_scale, | |
| "strength": strength, | |
| "original_size": original_size, | |
| } | |
| ) | |
| except torch.cuda.OutOfMemoryError: | |
| logger.error("CUDA out of memory") | |
| self._memory_cleanup(aggressive=True) | |
| return InpaintingResult( | |
| success=False, | |
| error_message="GPU memory exhausted." | |
| ) | |
| except Exception as e: | |
| logger.error(f"Inpainting failed: {e}") | |
| traceback.print_exc() | |
| return InpaintingResult( | |
| success=False, | |
| error_message=str(e) | |
| ) | |
| def _prepare_mask( | |
| self, | |
| mask: Image.Image, | |
| target_size: Tuple[int, int], | |
| dilation: int = 0, | |
| feather_radius: int = 3 | |
| ) -> Image.Image: | |
| """Prepare mask with optional dilation and feathering.""" | |
| # Convert and resize | |
| if mask.mode != 'L': | |
| mask = mask.convert('L') | |
| if mask.size != target_size: | |
| mask = mask.resize(target_size, Image.LANCZOS) | |
| mask_array = np.array(mask) | |
| # Apply dilation to expand mask | |
| if dilation > 0: | |
| kernel = cv2.getStructuringElement( | |
| cv2.MORPH_ELLIPSE, | |
| (dilation * 2 + 1, dilation * 2 + 1) | |
| ) | |
| mask_array = cv2.dilate(mask_array, kernel, iterations=1) | |
| logger.debug(f"Applied mask dilation: {dilation}px") | |
| # Apply feathering | |
| if feather_radius > 0: | |
| mask_array = cv2.GaussianBlur( | |
| mask_array, | |
| (feather_radius * 2 + 1, feather_radius * 2 + 1), | |
| feather_radius / 2 | |
| ) | |
| return Image.fromarray(mask_array, mode='L') | |
| def _generate_pure_inpaint( | |
| self, | |
| image: Image.Image, | |
| mask: Image.Image, | |
| prompt: str, | |
| negative_prompt: str, | |
| num_steps: int, | |
| guidance_scale: float, | |
| strength: float, | |
| generator: torch.Generator | |
| ) -> Image.Image: | |
| """Generate using pure SDXL Inpainting pipeline.""" | |
| with torch.inference_mode(): | |
| result = self._pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| image=image, | |
| mask_image=mask, | |
| num_inference_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| strength=strength, | |
| generator=generator | |
| ) | |
| return result.images[0] | |
| def _generate_controlnet_inpaint( | |
| self, | |
| image: Image.Image, | |
| mask: Image.Image, | |
| control_image: Image.Image, | |
| prompt: str, | |
| negative_prompt: str, | |
| num_steps: int, | |
| guidance_scale: float, | |
| conditioning_scale: float, | |
| strength: float, | |
| generator: torch.Generator | |
| ) -> Image.Image: | |
| """Generate using ControlNet Inpainting pipeline.""" | |
| with torch.inference_mode(): | |
| result = self._pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| image=image, | |
| mask_image=mask, | |
| control_image=control_image, | |
| num_inference_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| controlnet_conditioning_scale=conditioning_scale, | |
| strength=strength, | |
| generator=generator | |
| ) | |
| return result.images[0] | |
| def get_status(self) -> Dict[str, Any]: | |
| """Get current module status.""" | |
| return { | |
| "initialized": self.is_initialized, | |
| "device": self.device, | |
| "mode": self._current_mode, | |
| "conditioning_type": self._current_conditioning_type, | |
| "model_key": self._current_model_key, | |
| } | |