Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""ControlNet processor for structure-preserving image generation.
Standalone implementation that doesn't require ComfyUI imports.
Provides Canny edge detection and ControlNet model loading using LightDiffusion's infrastructure.
"""
import os
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Tuple, Any, Dict, List, Callable
from PIL import Image
import logging
from src.Utilities import util
from src.Device import Device
logger = logging.getLogger(__name__)
class CannyPreprocessor:
"""Canny edge detection preprocessor for ControlNet."""
@staticmethod
def detect(image: torch.Tensor, low_threshold: int = 100, high_threshold: int = 200) -> torch.Tensor:
"""Detect edges in an image using Canny algorithm.
Args:
image: Input image tensor [B, H, W, C] in range [0, 1]
low_threshold: Lower threshold for edge detection
high_threshold: Upper threshold for edge detection
Returns:
Edge map tensor [B, H, W, C] in range [0, 1]
"""
try:
import cv2
except ImportError:
raise ImportError("OpenCV (cv2) is required for Canny edge detection. Install with: pip install opencv-python")
# Handle batch dimension
if image.dim() == 3:
image = image.unsqueeze(0)
batch_size = image.shape[0]
results = []
for i in range(batch_size):
# Convert to numpy [H, W, C] in range [0, 255]
img_np = (image[i].cpu().numpy() * 255).astype(np.uint8)
# Convert to grayscale if color
if img_np.shape[-1] == 3:
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
else:
gray = img_np[..., 0]
# Apply Canny edge detection
edges = cv2.Canny(gray, low_threshold, high_threshold)
# Convert back to [H, W, C] format with 3 channels
edges_rgb = np.stack([edges, edges, edges], axis=-1)
# Normalize to [0, 1]
edges_tensor = torch.from_numpy(edges_rgb.astype(np.float32) / 255.0)
results.append(edges_tensor)
return torch.stack(results)
class ControlNetConditioner:
"""Lightweight ControlNet conditioner that applies preprocessing to conditioning.
This implementation doesn't load a full ControlNet model - instead it prepares
the control image for use with img2img at high denoise, which achieves similar
structure-preserving effects.
"""
def __init__(self, control_image: torch.Tensor, strength: float = 1.0):
"""Initialize the conditioner.
Args:
control_image: Preprocessed control image [B, H, W, C] or [B, C, H, W]
strength: Control strength (0-2)
"""
self.control_image = control_image
self.strength = strength
self._models = []
def get_models(self) -> List:
"""Return list of models to load (for cond_util compatibility)."""
return self._models
def inference_memory_requirements(self, dtype: torch.dtype) -> int:
"""Return memory requirements (for cond_util compatibility)."""
return 0
def cleanup(self):
"""Clean up resources."""
self.control_image = None
class ControlNetProcessor:
"""ControlNet processor using img2img with preprocessed edges.
Since full ControlNet model loading requires ComfyUI dependencies,
this implementation uses a hybrid approach: Canny edge detection +
high-denoise img2img, which achieves similar structure-preserving results.
"""
@classmethod
def preprocess_image(
cls,
image: torch.Tensor,
preprocessor: str = "canny",
**kwargs
) -> torch.Tensor:
"""Preprocess an image for structure guidance.
Args:
image: Input image tensor [B, H, W, C]
preprocessor: Preprocessor type ("canny", "none")
**kwargs: Preprocessor-specific arguments
Returns:
Preprocessed image tensor
"""
if preprocessor == "canny":
low = kwargs.get("low_threshold", 100)
high = kwargs.get("high_threshold", 200)
return CannyPreprocessor.detect(image, low, high)
elif preprocessor == "none":
return image
else:
logger.warning(f"Unknown preprocessor '{preprocessor}', returning original image")
return image
@classmethod
def create_conditioner(
cls,
control_image: torch.Tensor,
strength: float = 1.0,
) -> ControlNetConditioner:
"""Create a ControlNet conditioner for the control image.
Args:
control_image: Preprocessed control image
strength: Control strength
Returns:
ControlNetConditioner object
"""
return ControlNetConditioner(control_image, strength)
def apply_controlnet_to_img2img(
ctx,
model,
positive,
negative,
control_image: torch.Tensor,
strength: float = 1.0,
original_image: Optional[torch.Tensor] = None,
last_step: Optional[int] = None,
callback: Optional[Callable] = None,
) -> Tuple[torch.Tensor, Any]:
"""Apply ControlNet-style generation using img2img with edge guidance.
This simplified ControlNet uses edge detection + img2img with controlled denoise
to preserve input structure while allowing content changes.
Key insight: We need LOW denoise to preserve structure, and blend edges with
original image to provide both structure AND content guidance.
Args:
ctx: Pipeline context
model: Loaded model
positive: Positive conditioning
negative: Negative conditioning
control_image: Preprocessed control image (e.g., Canny edges)
strength: How much to preserve structure (higher = more preservation)
original_image: Original input image (required for proper guidance)
last_step: Optional step to stop at (for refiner handoff)
callback: Optional callback for live previews
Returns:
Generated latents and context
"""
from src.Processors.Img2Img import Img2Img
# Detect model type
is_flux2 = getattr(model.capabilities, "is_flux2", False)
is_flux = getattr(model.capabilities, "is_flux", False)
# CRITICAL: Use LOW denoise to preserve input structure
# ControlNet should modify the image, not regenerate from scratch
if is_flux2 or is_flux:
# Flux: Don't use edges at all - they cause artifacts
# Just use original image with moderate denoise for structure preservation
denoise = 0.55 + (strength * 0.15) # Range: 0.55-0.7
edge_blend = 0.0 # No edges for Flux - use original image only
else:
# SD1.5/SDXL: Balanced denoise - preserve structure but allow prompt changes
denoise = 0.45 + (strength * 0.2) # Range: 0.45-0.65
# Blend: Balanced mix allowing both structure and color changes
edge_blend = strength * 0.3 # Range: 0.0-0.3 for edges
# Always blend edges with original for proper guidance
if original_image is not None:
# Blend: edges provide structure, original provides content/color reference
input_image = control_image * edge_blend + original_image * (1.0 - edge_blend)
logger.info(
f"ControlNet {'Flux' if is_flux or is_flux2 else 'SD'}: "
f"strength={strength:.2f}, denoise={denoise:.2f}, edge_blend={edge_blend:.2f}"
+ (f", last_step={last_step}" if last_step else "")
)
else:
# Fallback: use edges only (not recommended)
input_image = control_image
logger.warning("ControlNet: No original image provided, using edges only (may not work well)")
logger.info(
f"ControlNet {'Flux' if is_flux or is_flux2 else 'SD'}: "
f"strength={strength:.2f}, denoise={denoise:.2f}, edges only"
)
# Run img2img with moderate denoise to preserve structure
latents = Img2Img.simple_img2img(
ctx, model, positive, negative,
image_tensor=input_image,
denoise=denoise,
last_step=last_step,
callback=callback,
)
return latents, ctx
def find_controlnet_models(search_dir: str = None) -> list:
"""Find ControlNet models in the specified directory.
Args:
search_dir: Directory to search (default: ./include/controlnets)
Returns:
List of ControlNet model paths
"""
if search_dir is None:
search_dir = "./include/controlnets"
if not os.path.exists(search_dir):
return []
models = []
for f in os.listdir(search_dir):
if f.endswith((".safetensors", ".pth", ".pt")):
models.append(os.path.join(search_dir, f))
return sorted(models)