File size: 14,282 Bytes
159500c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
"""
TransNormal Pipeline for Surface Normal Estimation

This pipeline is designed for transparent object surface normal estimation,
using DINOv3 encoder for semantic-guided geometry estimation.

Based on the Lotus-D deterministic pipeline architecture.
"""

import inspect
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np

from diffusers import DiffusionPipeline, StableDiffusionMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import logging
from transformers import CLIPTextModel, CLIPTokenizer

from .utils import resize_max_res, resize_back, get_tv_resample_method
from torchvision.transforms import InterpolationMode

logger = logging.get_logger(__name__)


def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    **kwargs,
):
    """
    Get timesteps from scheduler.
    
    Args:
        scheduler: The scheduler to get timesteps from
        num_inference_steps: Number of diffusion steps
        device: Device to move timesteps to
        timesteps: Custom timesteps (optional)
    
    Returns:
        Tuple of (timesteps, num_inference_steps)
    """
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__} does not support custom "
                f"timestep schedules."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps


class TransNormalPipeline(DiffusionPipeline, StableDiffusionMixin):
    """
    TransNormal Pipeline for Surface Normal Estimation
    
    This pipeline uses DINOv3 encoder for semantic-guided geometry estimation,
    particularly effective for transparent objects where traditional methods fail.
    
    Args:
        vae: Variational Autoencoder for encoding/decoding images
        text_encoder: CLIP text encoder (kept for compatibility)
        tokenizer: CLIP tokenizer (kept for compatibility)
        unet: UNet2DConditionModel for denoising
        scheduler: Noise scheduler
        dino_encoder: Optional DINOv3 encoder for semantic features
    """
    
    model_cpu_offload_seq = "text_encoder->unet->vae"
    _optional_components = ["text_encoder", "tokenizer", "dino_encoder"]
    
    # Default processing resolution
    default_processing_resolution = 768
    
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        dino_encoder: Optional[nn.Module] = None,
    ):
        super().__init__()
        
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            dino_encoder=dino_encoder,
        )
        
        # VAE scale factor (typically 8 for SD)
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
        
        # DINOv3 encoder usage flag
        self._use_dino_for_cross_attention = dino_encoder is not None
    
    def set_dino_encoder(self, dino_encoder: Optional[nn.Module], device: torch.device = None):
        """
        Set or remove the DINOv3 encoder.
        
        Args:
            dino_encoder: DINOv3 encoder module, or None to disable
            device: Target device for the encoder
        """
        if dino_encoder is not None and device is not None:
            dino_encoder = dino_encoder.to(device)
            if hasattr(dino_encoder, 'dino_backbone') and dino_encoder.dino_backbone is not None:
                dino_encoder.dino_backbone = dino_encoder.dino_backbone.to(device)
        
        # Update registered module
        self.register_modules(dino_encoder=dino_encoder)
        self._use_dino_for_cross_attention = dino_encoder is not None
    
    def encode_prompt(
        self,
        prompt: str,
        device: torch.device,
        num_images_per_prompt: int = 1,
    ) -> torch.Tensor:
        """
        Encode text prompt using CLIP text encoder.
        
        Args:
            prompt: Text prompt
            device: Target device
            num_images_per_prompt: Number of images per prompt
        
        Returns:
            Text embeddings tensor
        """
        text_inputs = self.tokenizer(
            prompt,
            padding="do_not_pad",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
        
        prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
        
        bs_embed, seq_len, _ = prompt_embeds.shape
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
        
        return prompt_embeds
    
    def _get_encoder_hidden_states(
        self,
        rgb_in: torch.Tensor,
        prompt: str,
        device: torch.device,
    ) -> torch.Tensor:
        """
        Get encoder hidden states for cross-attention.
        
        Uses DINOv3 features if encoder is available, otherwise uses CLIP text embeddings.
        
        Args:
            rgb_in: Input RGB image tensor, shape (B, 3, H, W), range [-1, 1]
            prompt: Text prompt (used only if DINO encoder is not available)
            device: Target device
        
        Returns:
            Encoder hidden states for cross-attention
        """
        if self._use_dino_for_cross_attention and self.dino_encoder is not None:
            # Use DINOv3 to extract semantic features
            encoder_hidden_states = self.dino_encoder.get_cross_attention_features(rgb_in)
            
            # Ensure dtype matches UNet
            if self.unet is not None:
                encoder_hidden_states = encoder_hidden_states.to(dtype=self.unet.dtype)
            return encoder_hidden_states
        else:
            # Fallback to CLIP text encoder
            return self.encode_prompt(prompt, device)
    
    def preprocess_image(
        self,
        image: Union[torch.Tensor, Image.Image, np.ndarray, str],
        device: torch.device,
        dtype: torch.dtype,
    ) -> torch.Tensor:
        """
        Preprocess input image to tensor format.
        
        Args:
            image: Input image (PIL, numpy, tensor, or path)
            device: Target device
            dtype: Target dtype
        
        Returns:
            Preprocessed image tensor, shape (1, 3, H, W), range [-1, 1]
        """
        # Load image if path is provided
        if isinstance(image, str):
            image = Image.open(image).convert("RGB")
        
        # Convert PIL to numpy
        if isinstance(image, Image.Image):
            image = np.array(image)
        
        # Convert numpy to tensor
        if isinstance(image, np.ndarray):
            # Ensure HWC format
            if image.ndim == 2:
                image = np.stack([image] * 3, axis=-1)
            elif image.shape[0] == 3:  # CHW format
                image = np.transpose(image, (1, 2, 0))
            
            # Normalize to [0, 1]
            if image.dtype == np.uint8:
                image = image.astype(np.float32) / 255.0
            
            # Convert to tensor (B, C, H, W)
            image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
        
        # Ensure batch dimension
        if image.dim() == 3:
            image = image.unsqueeze(0)
        
        # Normalize to [-1, 1]
        if image.min() >= 0 and image.max() <= 1:
            image = image * 2.0 - 1.0
        
        return image.to(device=device, dtype=dtype)
    
    @torch.no_grad()
    def __call__(
        self,
        image: Union[torch.Tensor, Image.Image, np.ndarray, str],
        prompt: str = "",
        timestep: int = 1,
        processing_res: Optional[int] = None,
        match_input_res: bool = True,
        resample_method: str = "bilinear",
        output_type: str = "np",
        return_dict: bool = False,
        **kwargs,
    ):
        """
        Run surface normal estimation on input image.
        
        Args:
            image: Input RGB image (PIL, numpy, tensor, or file path)
            prompt: Text prompt (optional, used only if DINO encoder is not available)
            timestep: Diffusion timestep for deterministic prediction (default: 1)
            processing_res: Processing resolution (default: 768)
            match_input_res: Whether to resize output to match input resolution
            resample_method: Resampling method for resizing
            output_type: Output format - "np" (numpy), "pt" (tensor), or "pil" (PIL Image)
            return_dict: Whether to return a dict with additional info
        
        Returns:
            Normal map in specified format. Normal vectors are in camera coordinates:
            - X: right (positive = right)
            - Y: down (positive = down)  
            - Z: forward (positive = into screen)
            
            Output range is [0, 1] where 0.5 represents zero in each axis.
        """
        # Set default processing resolution
        if processing_res is None:
            processing_res = self.default_processing_resolution
        
        device = self._execution_device
        dtype = self.unet.dtype if self.unet is not None else torch.float32
        
        # Preprocess input image
        rgb_in = self.preprocess_image(image, device, dtype)
        input_size = rgb_in.shape[-2:]
        
        # Resize to processing resolution
        resample_method_tv = get_tv_resample_method(resample_method)
        if processing_res > 0:
            rgb_in = resize_max_res(
                rgb_in,
                max_edge_resolution=processing_res,
                resample_method=resample_method_tv,
            )
        
        # Get encoder hidden states (DINO or CLIP)
        encoder_hidden_states = self._get_encoder_hidden_states(
            rgb_in=rgb_in,
            prompt=prompt,
            device=device,
        )
        
        # Prepare timestep
        timesteps = torch.tensor([timestep], device=device).long()
        
        # Encode RGB to latent space
        rgb_latents = self.vae.encode(rgb_in).latent_dist.sample()
        rgb_latents = rgb_latents * self.vae.config.scaling_factor
        
        # Task embedding for normal estimation
        task_emb = torch.tensor([1, 0], dtype=dtype, device=device).unsqueeze(0)
        task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
        
        # Single-step deterministic prediction
        t = timesteps[0]
        pred = self.unet(
            rgb_latents,
            t,
            encoder_hidden_states=encoder_hidden_states,
            return_dict=False,
            class_labels=task_emb,
        )[0]
        
        # Decode prediction
        normal_latent = pred / self.vae.config.scaling_factor
        normal_image = self.vae.decode(normal_latent, return_dict=False)[0]
        
        # Post-process to [0, 1] range
        normal_image = (normal_image / 2 + 0.5).clamp(0, 1)
        
        # Resize back to input resolution if requested
        if match_input_res and processing_res > 0:
            normal_image = F.interpolate(
                normal_image,
                size=input_size,
                mode='bilinear',
                align_corners=False,
            )
        
        # Convert to output format
        if output_type == "pt":
            output = normal_image  # (B, 3, H, W), range [0, 1]
        elif output_type == "np":
            # Convert to float32 first (bfloat16 not supported by numpy)
            output = normal_image.float().cpu().permute(0, 2, 3, 1).numpy()  # (B, H, W, 3)
            if output.shape[0] == 1:
                output = output[0]  # (H, W, 3)
        elif output_type == "pil":
            # Convert to float32 first (bfloat16 not supported by numpy)
            output = normal_image.float().cpu().permute(0, 2, 3, 1).numpy()
            output = (output * 255).astype(np.uint8)
            if output.shape[0] == 1:
                output = Image.fromarray(output[0])
            else:
                output = [Image.fromarray(img) for img in output]
        else:
            raise ValueError(f"Unknown output_type: {output_type}")
        
        if return_dict:
            return {"normal": output, "resolution": normal_image.shape[-2:]}
        return output
    
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        dino_encoder: Optional[nn.Module] = None,
        **kwargs,
    ):
        """
        Load TransNormalPipeline from pretrained weights.
        
        Args:
            pretrained_model_name_or_path: Path to pretrained model or HuggingFace model ID
            dino_encoder: Optional pre-loaded DINO encoder
            **kwargs: Additional arguments passed to DiffusionPipeline.from_pretrained
        
        Returns:
            TransNormalPipeline instance
        """
        # Load base pipeline components
        pipeline = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
        
        # Set DINO encoder if provided
        if dino_encoder is not None:
            pipeline.set_dino_encoder(dino_encoder)
        
        return pipeline