File size: 9,138 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""ControlNet processor for structure-preserving image generation.

Standalone implementation that doesn't require ComfyUI imports.
Provides Canny edge detection and ControlNet model loading using LightDiffusion's infrastructure.
"""

import os
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Tuple, Any, Dict, List, Callable
from PIL import Image
import logging

from src.Utilities import util
from src.Device import Device

logger = logging.getLogger(__name__)


class CannyPreprocessor:
    """Canny edge detection preprocessor for ControlNet."""
    
    @staticmethod
    def detect(image: torch.Tensor, low_threshold: int = 100, high_threshold: int = 200) -> torch.Tensor:
        """Detect edges in an image using Canny algorithm.
        
        Args:
            image: Input image tensor [B, H, W, C] in range [0, 1]
            low_threshold: Lower threshold for edge detection
            high_threshold: Upper threshold for edge detection
            
        Returns:
            Edge map tensor [B, H, W, C] in range [0, 1]
        """
        try:
            import cv2
        except ImportError:
            raise ImportError("OpenCV (cv2) is required for Canny edge detection. Install with: pip install opencv-python")
        
        # Handle batch dimension
        if image.dim() == 3:
            image = image.unsqueeze(0)
        
        batch_size = image.shape[0]
        results = []
        
        for i in range(batch_size):
            # Convert to numpy [H, W, C] in range [0, 255]
            img_np = (image[i].cpu().numpy() * 255).astype(np.uint8)
            
            # Convert to grayscale if color
            if img_np.shape[-1] == 3:
                gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
            else:
                gray = img_np[..., 0]
            
            # Apply Canny edge detection
            edges = cv2.Canny(gray, low_threshold, high_threshold)
            
            # Convert back to [H, W, C] format with 3 channels
            edges_rgb = np.stack([edges, edges, edges], axis=-1)
            
            # Normalize to [0, 1]
            edges_tensor = torch.from_numpy(edges_rgb.astype(np.float32) / 255.0)
            results.append(edges_tensor)
        
        return torch.stack(results)


class ControlNetConditioner:
    """Lightweight ControlNet conditioner that applies preprocessing to conditioning.
    
    This implementation doesn't load a full ControlNet model - instead it prepares
    the control image for use with img2img at high denoise, which achieves similar
    structure-preserving effects.
    """
    
    def __init__(self, control_image: torch.Tensor, strength: float = 1.0):
        """Initialize the conditioner.
        
        Args:
            control_image: Preprocessed control image [B, H, W, C] or [B, C, H, W]
            strength: Control strength (0-2)
        """
        self.control_image = control_image
        self.strength = strength
        self._models = []
        
    def get_models(self) -> List:
        """Return list of models to load (for cond_util compatibility)."""
        return self._models
    
    def inference_memory_requirements(self, dtype: torch.dtype) -> int:
        """Return memory requirements (for cond_util compatibility)."""
        return 0
    
    def cleanup(self):
        """Clean up resources."""
        self.control_image = None


class ControlNetProcessor:
    """ControlNet processor using img2img with preprocessed edges.
    
    Since full ControlNet model loading requires ComfyUI dependencies,
    this implementation uses a hybrid approach: Canny edge detection + 
    high-denoise img2img, which achieves similar structure-preserving results.
    """
    
    @classmethod
    def preprocess_image(
        cls,
        image: torch.Tensor,
        preprocessor: str = "canny",
        **kwargs
    ) -> torch.Tensor:
        """Preprocess an image for structure guidance.
        
        Args:
            image: Input image tensor [B, H, W, C]
            preprocessor: Preprocessor type ("canny", "none")
            **kwargs: Preprocessor-specific arguments
            
        Returns:
            Preprocessed image tensor
        """
        if preprocessor == "canny":
            low = kwargs.get("low_threshold", 100)
            high = kwargs.get("high_threshold", 200)
            return CannyPreprocessor.detect(image, low, high)
        elif preprocessor == "none":
            return image
        else:
            logger.warning(f"Unknown preprocessor '{preprocessor}', returning original image")
            return image

    @classmethod
    def create_conditioner(
        cls,
        control_image: torch.Tensor,
        strength: float = 1.0,
    ) -> ControlNetConditioner:
        """Create a ControlNet conditioner for the control image.
        
        Args:
            control_image: Preprocessed control image
            strength: Control strength
            
        Returns:
            ControlNetConditioner object
        """
        return ControlNetConditioner(control_image, strength)


def apply_controlnet_to_img2img(
    ctx,
    model,
    positive,
    negative,
    control_image: torch.Tensor,
    strength: float = 1.0,
    original_image: Optional[torch.Tensor] = None,
    last_step: Optional[int] = None,
    callback: Optional[Callable] = None,
) -> Tuple[torch.Tensor, Any]:
    """Apply ControlNet-style generation using img2img with edge guidance.
    
    This simplified ControlNet uses edge detection + img2img with controlled denoise
    to preserve input structure while allowing content changes.
    
    Key insight: We need LOW denoise to preserve structure, and blend edges with
    original image to provide both structure AND content guidance.
    
    Args:
        ctx: Pipeline context
        model: Loaded model
        positive: Positive conditioning
        negative: Negative conditioning
        control_image: Preprocessed control image (e.g., Canny edges)
        strength: How much to preserve structure (higher = more preservation)
        original_image: Original input image (required for proper guidance)
        last_step: Optional step to stop at (for refiner handoff)
        callback: Optional callback for live previews
        
    Returns:
        Generated latents and context
    """
    from src.Processors.Img2Img import Img2Img
    
    # Detect model type
    is_flux2 = getattr(model.capabilities, "is_flux2", False)
    is_flux = getattr(model.capabilities, "is_flux", False)
    
    # CRITICAL: Use LOW denoise to preserve input structure
    # ControlNet should modify the image, not regenerate from scratch
    if is_flux2 or is_flux:
        # Flux: Don't use edges at all - they cause artifacts
        # Just use original image with moderate denoise for structure preservation
        denoise = 0.55 + (strength * 0.15)  # Range: 0.55-0.7
        edge_blend = 0.0  # No edges for Flux - use original image only
    else:
        # SD1.5/SDXL: Balanced denoise - preserve structure but allow prompt changes
        denoise = 0.45 + (strength * 0.2)  # Range: 0.45-0.65
        # Blend: Balanced mix allowing both structure and color changes
        edge_blend = strength * 0.3  # Range: 0.0-0.3 for edges
    
    # Always blend edges with original for proper guidance
    if original_image is not None:
        # Blend: edges provide structure, original provides content/color reference
        input_image = control_image * edge_blend + original_image * (1.0 - edge_blend)
        logger.info(
            f"ControlNet {'Flux' if is_flux or is_flux2 else 'SD'}: "
            f"strength={strength:.2f}, denoise={denoise:.2f}, edge_blend={edge_blend:.2f}"
            + (f", last_step={last_step}" if last_step else "")
        )
    else:
        # Fallback: use edges only (not recommended)
        input_image = control_image
        logger.warning("ControlNet: No original image provided, using edges only (may not work well)")
        logger.info(
            f"ControlNet {'Flux' if is_flux or is_flux2 else 'SD'}: "
            f"strength={strength:.2f}, denoise={denoise:.2f}, edges only"
        )
    
    # Run img2img with moderate denoise to preserve structure
    latents = Img2Img.simple_img2img(
        ctx, model, positive, negative,
        image_tensor=input_image,
        denoise=denoise,
        last_step=last_step,
        callback=callback,
    )
    
    return latents, ctx


def find_controlnet_models(search_dir: str = None) -> list:
    """Find ControlNet models in the specified directory.
    
    Args:
        search_dir: Directory to search (default: ./include/controlnets)
        
    Returns:
        List of ControlNet model paths
    """
    if search_dir is None:
        search_dir = "./include/controlnets"
    
    if not os.path.exists(search_dir):
        return []
    
    models = []
    for f in os.listdir(search_dir):
        if f.endswith((".safetensors", ".pth", ".pt")):
            models.append(os.path.join(search_dir, f))
    
    return sorted(models)