""" TransNormal Pipeline for Surface Normal Estimation This pipeline is designed for transparent object surface normal estimation, using DINOv3 encoder for semantic-guided geometry estimation. Based on the Lotus-D deterministic pipeline architecture. """ import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import numpy as np from diffusers import DiffusionPipeline, StableDiffusionMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.image_processor import VaeImageProcessor from diffusers.utils import logging from transformers import CLIPTextModel, CLIPTokenizer from .utils import resize_max_res, resize_back, get_tv_resample_method from torchvision.transforms import InterpolationMode logger = logging.get_logger(__name__) def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, **kwargs, ): """ Get timesteps from scheduler. Args: scheduler: The scheduler to get timesteps from num_inference_steps: Number of diffusion steps device: Device to move timesteps to timesteps: Custom timesteps (optional) Returns: Tuple of (timesteps, num_inference_steps) """ if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__} does not support custom " f"timestep schedules." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class TransNormalPipeline(DiffusionPipeline, StableDiffusionMixin): """ TransNormal Pipeline for Surface Normal Estimation This pipeline uses DINOv3 encoder for semantic-guided geometry estimation, particularly effective for transparent objects where traditional methods fail. Args: vae: Variational Autoencoder for encoding/decoding images text_encoder: CLIP text encoder (kept for compatibility) tokenizer: CLIP tokenizer (kept for compatibility) unet: UNet2DConditionModel for denoising scheduler: Noise scheduler dino_encoder: Optional DINOv3 encoder for semantic features """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["text_encoder", "tokenizer", "dino_encoder"] # Default processing resolution default_processing_resolution = 768 def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, dino_encoder: Optional[nn.Module] = None, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, dino_encoder=dino_encoder, ) # VAE scale factor (typically 8 for SD) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # DINOv3 encoder usage flag self._use_dino_for_cross_attention = dino_encoder is not None def set_dino_encoder(self, dino_encoder: Optional[nn.Module], device: torch.device = None): """ Set or remove the DINOv3 encoder. Args: dino_encoder: DINOv3 encoder module, or None to disable device: Target device for the encoder """ if dino_encoder is not None and device is not None: dino_encoder = dino_encoder.to(device) if hasattr(dino_encoder, 'dino_backbone') and dino_encoder.dino_backbone is not None: dino_encoder.dino_backbone = dino_encoder.dino_backbone.to(device) # Update registered module self.register_modules(dino_encoder=dino_encoder) self._use_dino_for_cross_attention = dino_encoder is not None def encode_prompt( self, prompt: str, device: torch.device, num_images_per_prompt: int = 1, ) -> torch.Tensor: """ Encode text prompt using CLIP text encoder. Args: prompt: Text prompt device: Target device num_images_per_prompt: Number of images per prompt Returns: Text embeddings tensor """ text_inputs = self.tokenizer( prompt, padding="do_not_pad", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) return prompt_embeds def _get_encoder_hidden_states( self, rgb_in: torch.Tensor, prompt: str, device: torch.device, ) -> torch.Tensor: """ Get encoder hidden states for cross-attention. Uses DINOv3 features if encoder is available, otherwise uses CLIP text embeddings. Args: rgb_in: Input RGB image tensor, shape (B, 3, H, W), range [-1, 1] prompt: Text prompt (used only if DINO encoder is not available) device: Target device Returns: Encoder hidden states for cross-attention """ if self._use_dino_for_cross_attention and self.dino_encoder is not None: # Use DINOv3 to extract semantic features encoder_hidden_states = self.dino_encoder.get_cross_attention_features(rgb_in) # Ensure dtype matches UNet if self.unet is not None: encoder_hidden_states = encoder_hidden_states.to(dtype=self.unet.dtype) return encoder_hidden_states else: # Fallback to CLIP text encoder return self.encode_prompt(prompt, device) def preprocess_image( self, image: Union[torch.Tensor, Image.Image, np.ndarray, str], device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: """ Preprocess input image to tensor format. Args: image: Input image (PIL, numpy, tensor, or path) device: Target device dtype: Target dtype Returns: Preprocessed image tensor, shape (1, 3, H, W), range [-1, 1] """ # Load image if path is provided if isinstance(image, str): image = Image.open(image).convert("RGB") # Convert PIL to numpy if isinstance(image, Image.Image): image = np.array(image) # Convert numpy to tensor if isinstance(image, np.ndarray): # Ensure HWC format if image.ndim == 2: image = np.stack([image] * 3, axis=-1) elif image.shape[0] == 3: # CHW format image = np.transpose(image, (1, 2, 0)) # Normalize to [0, 1] if image.dtype == np.uint8: image = image.astype(np.float32) / 255.0 # Convert to tensor (B, C, H, W) image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) # Ensure batch dimension if image.dim() == 3: image = image.unsqueeze(0) # Normalize to [-1, 1] if image.min() >= 0 and image.max() <= 1: image = image * 2.0 - 1.0 return image.to(device=device, dtype=dtype) @torch.no_grad() def __call__( self, image: Union[torch.Tensor, Image.Image, np.ndarray, str], prompt: str = "", timestep: int = 1, processing_res: Optional[int] = None, match_input_res: bool = True, resample_method: str = "bilinear", output_type: str = "np", return_dict: bool = False, **kwargs, ): """ Run surface normal estimation on input image. Args: image: Input RGB image (PIL, numpy, tensor, or file path) prompt: Text prompt (optional, used only if DINO encoder is not available) timestep: Diffusion timestep for deterministic prediction (default: 1) processing_res: Processing resolution (default: 768) match_input_res: Whether to resize output to match input resolution resample_method: Resampling method for resizing output_type: Output format - "np" (numpy), "pt" (tensor), or "pil" (PIL Image) return_dict: Whether to return a dict with additional info Returns: Normal map in specified format. Normal vectors are in camera coordinates: - X: right (positive = right) - Y: down (positive = down) - Z: forward (positive = into screen) Output range is [0, 1] where 0.5 represents zero in each axis. """ # Set default processing resolution if processing_res is None: processing_res = self.default_processing_resolution device = self._execution_device dtype = self.unet.dtype if self.unet is not None else torch.float32 # Preprocess input image rgb_in = self.preprocess_image(image, device, dtype) input_size = rgb_in.shape[-2:] # Resize to processing resolution resample_method_tv = get_tv_resample_method(resample_method) if processing_res > 0: rgb_in = resize_max_res( rgb_in, max_edge_resolution=processing_res, resample_method=resample_method_tv, ) # Get encoder hidden states (DINO or CLIP) encoder_hidden_states = self._get_encoder_hidden_states( rgb_in=rgb_in, prompt=prompt, device=device, ) # Prepare timestep timesteps = torch.tensor([timestep], device=device).long() # Encode RGB to latent space rgb_latents = self.vae.encode(rgb_in).latent_dist.sample() rgb_latents = rgb_latents * self.vae.config.scaling_factor # Task embedding for normal estimation task_emb = torch.tensor([1, 0], dtype=dtype, device=device).unsqueeze(0) task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1) # Single-step deterministic prediction t = timesteps[0] pred = self.unet( rgb_latents, t, encoder_hidden_states=encoder_hidden_states, return_dict=False, class_labels=task_emb, )[0] # Decode prediction normal_latent = pred / self.vae.config.scaling_factor normal_image = self.vae.decode(normal_latent, return_dict=False)[0] # Post-process to [0, 1] range normal_image = (normal_image / 2 + 0.5).clamp(0, 1) # Resize back to input resolution if requested if match_input_res and processing_res > 0: normal_image = F.interpolate( normal_image, size=input_size, mode='bilinear', align_corners=False, ) # Convert to output format if output_type == "pt": output = normal_image # (B, 3, H, W), range [0, 1] elif output_type == "np": # Convert to float32 first (bfloat16 not supported by numpy) output = normal_image.float().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3) if output.shape[0] == 1: output = output[0] # (H, W, 3) elif output_type == "pil": # Convert to float32 first (bfloat16 not supported by numpy) output = normal_image.float().cpu().permute(0, 2, 3, 1).numpy() output = (output * 255).astype(np.uint8) if output.shape[0] == 1: output = Image.fromarray(output[0]) else: output = [Image.fromarray(img) for img in output] else: raise ValueError(f"Unknown output_type: {output_type}") if return_dict: return {"normal": output, "resolution": normal_image.shape[-2:]} return output @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, dino_encoder: Optional[nn.Module] = None, **kwargs, ): """ Load TransNormalPipeline from pretrained weights. Args: pretrained_model_name_or_path: Path to pretrained model or HuggingFace model ID dino_encoder: Optional pre-loaded DINO encoder **kwargs: Additional arguments passed to DiffusionPipeline.from_pretrained Returns: TransNormalPipeline instance """ # Load base pipeline components pipeline = super().from_pretrained(pretrained_model_name_or_path, **kwargs) # Set DINO encoder if provided if dino_encoder is not None: pipeline.set_dino_encoder(dino_encoder) return pipeline