Longxiang-ai commited on
Commit
159500c
·
0 Parent(s):

Initial release: TransNormal with Zero GPU support

Browse files
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TransNormal
3
+ emoji: 🔮
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ suggested_hardware: zero-a10g
12
+ ---
13
+
14
+ # TransNormal
15
+
16
+ Surface Normal Estimation for Transparent Objects using Dense Visual Semantics.
17
+
18
+ **Paper:** [TransNormal: Dense Visual Semantics for Diffusion-based Transparent Object Normal Estimation](https://longxiang-ai.github.io/TransNormal/)
19
+
20
+ **Authors:** Mingwei Li, Hehe Fan, Yi Yang (Zhejiang University)
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ TransNormal - Hugging Face Spaces Zero GPU Version
4
+
5
+ Surface Normal Estimation for Transparent Objects
6
+ """
7
+
8
+ import os
9
+ import spaces
10
+ import torch
11
+ import gradio as gr
12
+ from PIL import Image
13
+ from huggingface_hub import snapshot_download
14
+
15
+ from transnormal import TransNormalPipeline, create_dino_encoder
16
+
17
+ # ============== Model Paths ==============
18
+ TRANSNORMAL_REPO = "Longxiang-ai/TransNormal"
19
+ DINO_REPO = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
20
+ # =========================================
21
+
22
+ # Global pipeline
23
+ pipe = None
24
+ weights_downloaded = False
25
+
26
+
27
+ def download_weights():
28
+ """Download model weights from HuggingFace Hub."""
29
+ global weights_downloaded
30
+
31
+ if weights_downloaded:
32
+ return "./weights/transnormal", "./weights/dinov3_vith16plus"
33
+
34
+ print("[TransNormal] Downloading TransNormal weights...")
35
+ transnormal_path = snapshot_download(
36
+ TRANSNORMAL_REPO,
37
+ local_dir="./weights/transnormal"
38
+ )
39
+
40
+ print("[TransNormal] Downloading DINOv3 weights...")
41
+ dino_path = snapshot_download(
42
+ DINO_REPO,
43
+ local_dir="./weights/dinov3_vith16plus"
44
+ )
45
+
46
+ weights_downloaded = True
47
+ print("[TransNormal] Weights downloaded successfully!")
48
+ return transnormal_path, dino_path
49
+
50
+
51
+ def load_pipeline():
52
+ """Load the TransNormal pipeline."""
53
+ global pipe
54
+
55
+ if pipe is not None:
56
+ return pipe
57
+
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
60
+
61
+ print(f"[TransNormal] Loading model on {device} with {dtype}...")
62
+
63
+ # Download weights
64
+ transnormal_path, dino_path = download_weights()
65
+ projector_path = os.path.join(transnormal_path, "cross_attention_projector.pt")
66
+
67
+ # Load DINO encoder
68
+ dino_encoder = create_dino_encoder(
69
+ model_name="dinov3_vith16plus",
70
+ cross_attention_dim=1024,
71
+ weights_path=dino_path,
72
+ projector_path=projector_path,
73
+ device=device,
74
+ dtype=dtype,
75
+ freeze_encoder=True,
76
+ )
77
+
78
+ # Load pipeline
79
+ pipe = TransNormalPipeline.from_pretrained(
80
+ transnormal_path,
81
+ dino_encoder=dino_encoder,
82
+ torch_dtype=dtype,
83
+ )
84
+ pipe = pipe.to(device)
85
+
86
+ print("[TransNormal] Model loaded successfully!")
87
+ return pipe
88
+
89
+
90
+ @spaces.GPU(duration=120)
91
+ def predict_normal(image: Image.Image, processing_res: int = 768) -> Image.Image:
92
+ """
93
+ Predict surface normal from input image using Zero GPU.
94
+
95
+ Args:
96
+ image: Input RGB image
97
+ processing_res: Processing resolution
98
+
99
+ Returns:
100
+ Normal map as PIL Image
101
+ """
102
+ if image is None:
103
+ return None
104
+
105
+ # Load pipeline (will use GPU allocated by @spaces.GPU)
106
+ pipeline = load_pipeline()
107
+
108
+ # Run inference
109
+ with torch.no_grad():
110
+ normal_map = pipeline(
111
+ image=image,
112
+ processing_res=processing_res,
113
+ output_type="pil",
114
+ )
115
+
116
+ return normal_map
117
+
118
+
119
+ # ============== Gradio Interface ==============
120
+
121
+ custom_css = """
122
+ .gradio-container {
123
+ font-family: 'Segoe UI', 'Helvetica Neue', Arial, sans-serif !important;
124
+ }
125
+ h1 {
126
+ font-weight: 600 !important;
127
+ }
128
+ """
129
+
130
+ with gr.Blocks(
131
+ title="TransNormal",
132
+ theme=gr.themes.Soft(),
133
+ css=custom_css,
134
+ ) as demo:
135
+
136
+ gr.Markdown(
137
+ """
138
+ # 🔮 TransNormal
139
+ ### Surface Normal Estimation for Transparent Objects
140
+
141
+ Upload an image to estimate surface normals. Particularly effective for **transparent objects** like glass and plastic.
142
+
143
+ **Normal Convention:** Red=X (Left) | Green=Y (Up) | Blue=Z (Out)
144
+
145
+ > ⏱️ First inference may take ~1-2 minutes to load model weights.
146
+ """
147
+ )
148
+
149
+ with gr.Row():
150
+ with gr.Column():
151
+ input_image = gr.Image(
152
+ label="Input Image",
153
+ type="pil",
154
+ height=400,
155
+ )
156
+
157
+ processing_res = gr.Slider(
158
+ minimum=256,
159
+ maximum=1024,
160
+ value=768,
161
+ step=64,
162
+ label="Processing Resolution",
163
+ info="Higher resolution = better quality but slower"
164
+ )
165
+
166
+ submit_btn = gr.Button("🚀 Estimate Normal", variant="primary", size="lg")
167
+
168
+ with gr.Column():
169
+ output_image = gr.Image(
170
+ label="Normal Map",
171
+ type="pil",
172
+ height=400,
173
+ )
174
+
175
+ # Event handlers
176
+ submit_btn.click(
177
+ fn=predict_normal,
178
+ inputs=[input_image, processing_res],
179
+ outputs=output_image,
180
+ )
181
+
182
+ # Footer
183
+ gr.Markdown(
184
+ """
185
+ ---
186
+
187
+ **Paper:** [TransNormal: Dense Visual Semantics for Diffusion-based Transparent Object Normal Estimation](https://longxiang-ai.github.io/TransNormal/)
188
+
189
+ **Authors:** Mingwei Li, Hehe Fan, Yi Yang (Zhejiang University)
190
+
191
+ **Code:** [GitHub](https://github.com/longxiang-ai/TransNormal)
192
+ """
193
+ )
194
+
195
+ # Launch
196
+ if __name__ == "__main__":
197
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TransNormal HuggingFace Space Dependencies
2
+
3
+ # PyTorch
4
+ torch>=2.0.0
5
+ torchvision>=0.15.0
6
+
7
+ # Diffusers and Transformers
8
+ diffusers>=0.28.0
9
+ transformers>=4.56.0
10
+ accelerate>=0.24.0
11
+ safetensors>=0.4.0
12
+
13
+ # Image processing
14
+ Pillow>=9.0.0
15
+ numpy>=1.23.0
16
+
17
+ # HuggingFace
18
+ huggingface_hub
19
+
20
+ # Gradio and Spaces
21
+ gradio>=5.0.0
22
+ spaces
transnormal/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TransNormal: Surface Normal Estimation for Transparent Objects
3
+
4
+ This package provides a diffusion-based pipeline for estimating surface normals
5
+ from RGB images, with particular effectiveness on transparent objects.
6
+
7
+ Example usage:
8
+ from transnormal import TransNormalPipeline, create_dino_encoder
9
+ import torch
10
+
11
+ # Create DINO encoder
12
+ dino_encoder = create_dino_encoder(
13
+ model_name="dinov3_vith16plus",
14
+ weights_path="path/to/dinov3_weights",
15
+ projector_path="path/to/projector.pt",
16
+ device="cuda",
17
+ )
18
+
19
+ # Load pipeline
20
+ pipe = TransNormalPipeline.from_pretrained(
21
+ "path/to/transnormal_model",
22
+ dino_encoder=dino_encoder,
23
+ torch_dtype=torch.float16,
24
+ )
25
+ pipe = pipe.to("cuda")
26
+
27
+ # Run inference
28
+ normal_map = pipe("path/to/image.jpg", output_type="np")
29
+ """
30
+
31
+ __version__ = "1.0.0"
32
+ __author__ = "TransNormal Team"
33
+
34
+ from .pipeline import TransNormalPipeline
35
+ from .dino_encoder import DINOv3Encoder, create_dino_encoder
36
+ from .utils import (
37
+ resize_max_res,
38
+ resize_back,
39
+ get_tv_resample_method,
40
+ get_pil_resample_method,
41
+ normal_to_rgb,
42
+ save_normal_map,
43
+ load_image,
44
+ concatenate_images,
45
+ )
46
+
47
+ __all__ = [
48
+ "TransNormalPipeline",
49
+ "DINOv3Encoder",
50
+ "create_dino_encoder",
51
+ "resize_max_res",
52
+ "resize_back",
53
+ "get_tv_resample_method",
54
+ "get_pil_resample_method",
55
+ "normal_to_rgb",
56
+ "save_normal_map",
57
+ "load_image",
58
+ "concatenate_images",
59
+ ]
transnormal/dino_encoder.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DINOv3 Encoder for Semantic-Guided Surface Normal Estimation
3
+
4
+ This module provides a simplified DINOv3 encoder that extracts semantic features
5
+ from RGB images for cross-attention in the TransNormal pipeline.
6
+
7
+ The encoder is particularly effective for transparent objects, as DINOv3's
8
+ strong semantic features can "see through" refraction artifacts.
9
+ """
10
+
11
+ import os
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from typing import Optional, Dict
16
+
17
+
18
+ # DINOv3 model configurations
19
+ DINOV3_CONFIGS = {
20
+ "dinov3_vits16": {
21
+ "embed_dim": 384,
22
+ "patch_size": 16,
23
+ "n_storage_tokens": 4,
24
+ },
25
+ "dinov3_vitb16": {
26
+ "embed_dim": 768,
27
+ "patch_size": 16,
28
+ "n_storage_tokens": 4,
29
+ },
30
+ "dinov3_vitl16": {
31
+ "embed_dim": 1024,
32
+ "patch_size": 16,
33
+ "n_storage_tokens": 4,
34
+ },
35
+ "dinov3_vith16plus": {
36
+ "embed_dim": 1280,
37
+ "patch_size": 16,
38
+ "n_storage_tokens": 4,
39
+ },
40
+ }
41
+
42
+
43
+ class DINOv3Encoder(nn.Module):
44
+ """
45
+ DINOv3 Encoder for extracting semantic features from RGB images.
46
+
47
+ This encoder provides projected patch tokens for cross-attention in the UNet,
48
+ replacing CLIP text embeddings with visual semantic features.
49
+
50
+ Args:
51
+ model_name: DINOv3 model name (e.g., "dinov3_vith16plus")
52
+ cross_attention_dim: Target dimension for cross-attention (1024 for SD 2.x)
53
+ weights_path: Path to DINOv3 pretrained weights (HuggingFace format)
54
+ freeze_encoder: Whether to freeze the DINOv3 backbone
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ model_name: str = "dinov3_vith16plus",
60
+ cross_attention_dim: int = 1024,
61
+ weights_path: Optional[str] = None,
62
+ freeze_encoder: bool = True,
63
+ ):
64
+ super().__init__()
65
+
66
+ self.model_name = model_name
67
+ self.cross_attention_dim = cross_attention_dim
68
+ self.weights_path = weights_path
69
+ self.freeze_encoder = freeze_encoder
70
+
71
+ # Get model configuration
72
+ if model_name not in DINOV3_CONFIGS:
73
+ raise ValueError(f"Unknown DINOv3 model: {model_name}. Available: {list(DINOV3_CONFIGS.keys())}")
74
+
75
+ self.config = DINOV3_CONFIGS[model_name]
76
+ self.dino_hidden_dim = self.config["embed_dim"]
77
+ self.patch_size = self.config["patch_size"]
78
+ self.n_storage_tokens = self.config["n_storage_tokens"]
79
+
80
+ # DINOv3 backbone (loaded later)
81
+ self.dino_backbone = None
82
+ self._use_hf_interface = False
83
+ self._is_loaded = False
84
+
85
+ # Cross-attention projector: DINO hidden_dim -> SD cross_attention_dim
86
+ self.cross_attention_projector = nn.Linear(self.dino_hidden_dim, cross_attention_dim)
87
+ self._init_projector()
88
+
89
+ # ImageNet normalization for DINOv3
90
+ self.register_buffer(
91
+ "imagenet_mean",
92
+ torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1),
93
+ persistent=False
94
+ )
95
+ self.register_buffer(
96
+ "imagenet_std",
97
+ torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1),
98
+ persistent=False
99
+ )
100
+
101
+ @property
102
+ def dtype(self) -> torch.dtype:
103
+ """Return the dtype of the encoder (for diffusers compatibility)."""
104
+ return self.cross_attention_projector.weight.dtype
105
+
106
+ @property
107
+ def device(self) -> torch.device:
108
+ """Return the device of the encoder."""
109
+ return self.cross_attention_projector.weight.device
110
+
111
+ def _init_projector(self):
112
+ """Initialize the cross-attention projector with Xavier initialization."""
113
+ nn.init.xavier_uniform_(self.cross_attention_projector.weight)
114
+ nn.init.zeros_(self.cross_attention_projector.bias)
115
+
116
+ def _preprocess_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
117
+ """
118
+ Preprocess image from [-1, 1] to ImageNet normalized format.
119
+
120
+ Args:
121
+ pixel_values: Input images, shape (B, 3, H, W), normalized to [-1, 1]
122
+
123
+ Returns:
124
+ Preprocessed images with ImageNet normalization
125
+ """
126
+ # Convert from [-1, 1] to [0, 1]
127
+ pixel_values = (pixel_values + 1.0) / 2.0
128
+
129
+ # Ensure mean/std are on the same device and dtype
130
+ mean = self.imagenet_mean.to(device=pixel_values.device, dtype=pixel_values.dtype)
131
+ std = self.imagenet_std.to(device=pixel_values.device, dtype=pixel_values.dtype)
132
+
133
+ # Apply ImageNet normalization
134
+ pixel_values = (pixel_values - mean) / std
135
+
136
+ return pixel_values
137
+
138
+ def load_dino_model(self, device: torch.device = None, dtype: torch.dtype = None):
139
+ """
140
+ Load the DINOv3 model from HuggingFace format.
141
+
142
+ Args:
143
+ device: Device to load the model on
144
+ dtype: Data type for the model weights
145
+ """
146
+ if self._is_loaded:
147
+ return
148
+
149
+ if self.weights_path is None:
150
+ raise ValueError("weights_path must be provided to load DINOv3 model")
151
+
152
+ try:
153
+ from transformers import AutoModel
154
+
155
+ print(f"[DINOv3] Loading from: {self.weights_path}")
156
+ self.dino_backbone = AutoModel.from_pretrained(
157
+ self.weights_path,
158
+ trust_remote_code=True,
159
+ )
160
+
161
+ # Update config from loaded model
162
+ hf_config = getattr(self.dino_backbone, "config", None)
163
+ if hf_config is not None:
164
+ self.dino_hidden_dim = getattr(hf_config, "hidden_size", self.dino_hidden_dim)
165
+ self.patch_size = getattr(hf_config, "patch_size", self.patch_size)
166
+ self.n_storage_tokens = getattr(hf_config, "num_register_tokens", self.n_storage_tokens)
167
+
168
+ # Reinitialize projector if hidden dim changed
169
+ if self.cross_attention_projector.in_features != self.dino_hidden_dim:
170
+ self.cross_attention_projector = nn.Linear(
171
+ self.dino_hidden_dim, self.cross_attention_dim
172
+ )
173
+ self._init_projector()
174
+
175
+ self._use_hf_interface = True
176
+
177
+ # Move to device/dtype
178
+ if device is not None:
179
+ self.dino_backbone = self.dino_backbone.to(device)
180
+ self.cross_attention_projector = self.cross_attention_projector.to(device)
181
+
182
+ if dtype is not None:
183
+ self.dino_backbone = self.dino_backbone.to(dtype)
184
+ self.cross_attention_projector = self.cross_attention_projector.to(dtype)
185
+
186
+ # Freeze backbone
187
+ if self.freeze_encoder:
188
+ self.dino_backbone.requires_grad_(False)
189
+ self.dino_backbone.eval()
190
+
191
+ self._is_loaded = True
192
+ print(f"[DINOv3] Successfully loaded {self.model_name}")
193
+ print(f" - Hidden dim: {self.dino_hidden_dim}")
194
+ print(f" - Patch size: {self.patch_size}")
195
+ print(f" - Cross-attention dim: {self.cross_attention_dim}")
196
+
197
+ except Exception as e:
198
+ raise RuntimeError(
199
+ f"Failed to load DINOv3 model from {self.weights_path}.\n"
200
+ f"Error: {e}"
201
+ )
202
+
203
+ def _ensure_loaded(self):
204
+ """Ensure the model is loaded before forward pass."""
205
+ if not self._is_loaded:
206
+ raise RuntimeError(
207
+ "DINOv3 model not loaded. Call load_dino_model() first."
208
+ )
209
+
210
+ def extract_patch_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
211
+ """
212
+ Extract patch tokens from DINOv3.
213
+
214
+ Args:
215
+ pixel_values: Input images, shape (B, 3, H, W), normalized to [-1, 1]
216
+
217
+ Returns:
218
+ patch_tokens: Shape (B, N, D) where N is number of patches, D is hidden_dim
219
+ """
220
+ self._ensure_loaded()
221
+
222
+ # Preprocess image
223
+ preprocessed = self._preprocess_image(pixel_values)
224
+
225
+ # Ensure dimensions are multiples of patch_size
226
+ _, _, H, W = preprocessed.shape
227
+ new_H = (H // self.patch_size) * self.patch_size
228
+ new_W = (W // self.patch_size) * self.patch_size
229
+ if new_H != H or new_W != W:
230
+ preprocessed = F.interpolate(
231
+ preprocessed,
232
+ size=(new_H, new_W),
233
+ mode='bilinear',
234
+ align_corners=False
235
+ )
236
+
237
+ # Forward through DINOv3
238
+ with torch.no_grad() if self.freeze_encoder else torch.enable_grad():
239
+ if self._use_hf_interface:
240
+ outputs = self.dino_backbone(
241
+ pixel_values=preprocessed,
242
+ output_hidden_states=True
243
+ )
244
+ last_hidden = outputs.last_hidden_state
245
+ # Remove CLS and register tokens
246
+ n_special = 1 + self.n_storage_tokens
247
+ patch_tokens = last_hidden[:, n_special:, :]
248
+ else:
249
+ outputs = self.dino_backbone.forward_features(preprocessed, masks=None)
250
+ patch_tokens = outputs['x_norm_patchtokens']
251
+
252
+ return patch_tokens
253
+
254
+ def forward(self, pixel_values: torch.Tensor) -> Dict[str, torch.Tensor]:
255
+ """
256
+ Forward pass to extract features for cross-attention.
257
+
258
+ Args:
259
+ pixel_values: Input images, shape (B, 3, H, W), normalized to [-1, 1]
260
+
261
+ Returns:
262
+ dict with 'cross_attention_features': Projected features, shape (B, N, cross_attention_dim)
263
+ """
264
+ self._ensure_loaded()
265
+
266
+ # Extract patch tokens
267
+ patch_tokens = self.extract_patch_tokens(pixel_values)
268
+
269
+ # Project to cross-attention dimension
270
+ projector_dtype = next(self.cross_attention_projector.parameters()).dtype
271
+ if patch_tokens.dtype != projector_dtype:
272
+ patch_tokens = patch_tokens.to(dtype=projector_dtype)
273
+
274
+ cross_attention_features = self.cross_attention_projector(patch_tokens)
275
+
276
+ return {'cross_attention_features': cross_attention_features}
277
+
278
+ def get_cross_attention_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
279
+ """
280
+ Convenience method to get only cross-attention features.
281
+
282
+ Args:
283
+ pixel_values: Input images, shape (B, 3, H, W), normalized to [-1, 1]
284
+
285
+ Returns:
286
+ cross_attention_features: Shape (B, N, cross_attention_dim)
287
+ """
288
+ return self.forward(pixel_values)['cross_attention_features']
289
+
290
+ def load_projector(self, projector_path: str, device: torch.device = None):
291
+ """
292
+ Load pretrained projector weights.
293
+
294
+ Args:
295
+ projector_path: Path to projector weights file (.pt)
296
+ device: Device to load weights on
297
+ """
298
+ if not os.path.exists(projector_path):
299
+ raise FileNotFoundError(f"Projector weights not found: {projector_path}")
300
+
301
+ state_dict = torch.load(projector_path, map_location=device or "cpu")
302
+ self.cross_attention_projector.load_state_dict(state_dict)
303
+ print(f"[DINOv3] Loaded projector weights from {projector_path}")
304
+
305
+
306
+ def create_dino_encoder(
307
+ model_name: str = "dinov3_vith16plus",
308
+ cross_attention_dim: int = 1024,
309
+ weights_path: Optional[str] = None,
310
+ projector_path: Optional[str] = None,
311
+ device: torch.device = None,
312
+ dtype: torch.dtype = None,
313
+ freeze_encoder: bool = True,
314
+ ) -> DINOv3Encoder:
315
+ """
316
+ Factory function to create and initialize a DINOv3 encoder.
317
+
318
+ Args:
319
+ model_name: DINOv3 model name
320
+ cross_attention_dim: Target dimension for cross-attention
321
+ weights_path: Path to DINOv3 pretrained weights
322
+ projector_path: Path to projector weights (optional)
323
+ device: Device to load the model on
324
+ dtype: Data type for the model
325
+ freeze_encoder: Whether to freeze the backbone
326
+
327
+ Returns:
328
+ Initialized DINOv3Encoder
329
+ """
330
+ encoder = DINOv3Encoder(
331
+ model_name=model_name,
332
+ cross_attention_dim=cross_attention_dim,
333
+ weights_path=weights_path,
334
+ freeze_encoder=freeze_encoder,
335
+ )
336
+
337
+ # Load DINO backbone
338
+ if weights_path is not None:
339
+ encoder.load_dino_model(device=device, dtype=dtype)
340
+
341
+ # Load projector weights if provided
342
+ if projector_path is not None:
343
+ encoder.load_projector(projector_path, device=device)
344
+
345
+ # Move to device
346
+ if device is not None:
347
+ encoder = encoder.to(device)
348
+
349
+ if dtype is not None:
350
+ encoder = encoder.to(dtype)
351
+
352
+ return encoder
transnormal/pipeline.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TransNormal Pipeline for Surface Normal Estimation
3
+
4
+ This pipeline is designed for transparent object surface normal estimation,
5
+ using DINOv3 encoder for semantic-guided geometry estimation.
6
+
7
+ Based on the Lotus-D deterministic pipeline architecture.
8
+ """
9
+
10
+ import inspect
11
+ from typing import Any, Callable, Dict, List, Optional, Union
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from PIL import Image
17
+ import numpy as np
18
+
19
+ from diffusers import DiffusionPipeline, StableDiffusionMixin
20
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
21
+ from diffusers.schedulers import KarrasDiffusionSchedulers
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from diffusers.utils import logging
24
+ from transformers import CLIPTextModel, CLIPTokenizer
25
+
26
+ from .utils import resize_max_res, resize_back, get_tv_resample_method
27
+ from torchvision.transforms import InterpolationMode
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ def retrieve_timesteps(
33
+ scheduler,
34
+ num_inference_steps: Optional[int] = None,
35
+ device: Optional[Union[str, torch.device]] = None,
36
+ timesteps: Optional[List[int]] = None,
37
+ **kwargs,
38
+ ):
39
+ """
40
+ Get timesteps from scheduler.
41
+
42
+ Args:
43
+ scheduler: The scheduler to get timesteps from
44
+ num_inference_steps: Number of diffusion steps
45
+ device: Device to move timesteps to
46
+ timesteps: Custom timesteps (optional)
47
+
48
+ Returns:
49
+ Tuple of (timesteps, num_inference_steps)
50
+ """
51
+ if timesteps is not None:
52
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
53
+ if not accepts_timesteps:
54
+ raise ValueError(
55
+ f"The current scheduler class {scheduler.__class__} does not support custom "
56
+ f"timestep schedules."
57
+ )
58
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
59
+ timesteps = scheduler.timesteps
60
+ num_inference_steps = len(timesteps)
61
+ else:
62
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
63
+ timesteps = scheduler.timesteps
64
+ return timesteps, num_inference_steps
65
+
66
+
67
+ class TransNormalPipeline(DiffusionPipeline, StableDiffusionMixin):
68
+ """
69
+ TransNormal Pipeline for Surface Normal Estimation
70
+
71
+ This pipeline uses DINOv3 encoder for semantic-guided geometry estimation,
72
+ particularly effective for transparent objects where traditional methods fail.
73
+
74
+ Args:
75
+ vae: Variational Autoencoder for encoding/decoding images
76
+ text_encoder: CLIP text encoder (kept for compatibility)
77
+ tokenizer: CLIP tokenizer (kept for compatibility)
78
+ unet: UNet2DConditionModel for denoising
79
+ scheduler: Noise scheduler
80
+ dino_encoder: Optional DINOv3 encoder for semantic features
81
+ """
82
+
83
+ model_cpu_offload_seq = "text_encoder->unet->vae"
84
+ _optional_components = ["text_encoder", "tokenizer", "dino_encoder"]
85
+
86
+ # Default processing resolution
87
+ default_processing_resolution = 768
88
+
89
+ def __init__(
90
+ self,
91
+ vae: AutoencoderKL,
92
+ text_encoder: CLIPTextModel,
93
+ tokenizer: CLIPTokenizer,
94
+ unet: UNet2DConditionModel,
95
+ scheduler: KarrasDiffusionSchedulers,
96
+ dino_encoder: Optional[nn.Module] = None,
97
+ ):
98
+ super().__init__()
99
+
100
+ self.register_modules(
101
+ vae=vae,
102
+ text_encoder=text_encoder,
103
+ tokenizer=tokenizer,
104
+ unet=unet,
105
+ scheduler=scheduler,
106
+ dino_encoder=dino_encoder,
107
+ )
108
+
109
+ # VAE scale factor (typically 8 for SD)
110
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
111
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
112
+
113
+ # DINOv3 encoder usage flag
114
+ self._use_dino_for_cross_attention = dino_encoder is not None
115
+
116
+ def set_dino_encoder(self, dino_encoder: Optional[nn.Module], device: torch.device = None):
117
+ """
118
+ Set or remove the DINOv3 encoder.
119
+
120
+ Args:
121
+ dino_encoder: DINOv3 encoder module, or None to disable
122
+ device: Target device for the encoder
123
+ """
124
+ if dino_encoder is not None and device is not None:
125
+ dino_encoder = dino_encoder.to(device)
126
+ if hasattr(dino_encoder, 'dino_backbone') and dino_encoder.dino_backbone is not None:
127
+ dino_encoder.dino_backbone = dino_encoder.dino_backbone.to(device)
128
+
129
+ # Update registered module
130
+ self.register_modules(dino_encoder=dino_encoder)
131
+ self._use_dino_for_cross_attention = dino_encoder is not None
132
+
133
+ def encode_prompt(
134
+ self,
135
+ prompt: str,
136
+ device: torch.device,
137
+ num_images_per_prompt: int = 1,
138
+ ) -> torch.Tensor:
139
+ """
140
+ Encode text prompt using CLIP text encoder.
141
+
142
+ Args:
143
+ prompt: Text prompt
144
+ device: Target device
145
+ num_images_per_prompt: Number of images per prompt
146
+
147
+ Returns:
148
+ Text embeddings tensor
149
+ """
150
+ text_inputs = self.tokenizer(
151
+ prompt,
152
+ padding="do_not_pad",
153
+ max_length=self.tokenizer.model_max_length,
154
+ truncation=True,
155
+ return_tensors="pt",
156
+ )
157
+ text_input_ids = text_inputs.input_ids
158
+
159
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
160
+
161
+ bs_embed, seq_len, _ = prompt_embeds.shape
162
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
163
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
164
+
165
+ return prompt_embeds
166
+
167
+ def _get_encoder_hidden_states(
168
+ self,
169
+ rgb_in: torch.Tensor,
170
+ prompt: str,
171
+ device: torch.device,
172
+ ) -> torch.Tensor:
173
+ """
174
+ Get encoder hidden states for cross-attention.
175
+
176
+ Uses DINOv3 features if encoder is available, otherwise uses CLIP text embeddings.
177
+
178
+ Args:
179
+ rgb_in: Input RGB image tensor, shape (B, 3, H, W), range [-1, 1]
180
+ prompt: Text prompt (used only if DINO encoder is not available)
181
+ device: Target device
182
+
183
+ Returns:
184
+ Encoder hidden states for cross-attention
185
+ """
186
+ if self._use_dino_for_cross_attention and self.dino_encoder is not None:
187
+ # Use DINOv3 to extract semantic features
188
+ encoder_hidden_states = self.dino_encoder.get_cross_attention_features(rgb_in)
189
+
190
+ # Ensure dtype matches UNet
191
+ if self.unet is not None:
192
+ encoder_hidden_states = encoder_hidden_states.to(dtype=self.unet.dtype)
193
+ return encoder_hidden_states
194
+ else:
195
+ # Fallback to CLIP text encoder
196
+ return self.encode_prompt(prompt, device)
197
+
198
+ def preprocess_image(
199
+ self,
200
+ image: Union[torch.Tensor, Image.Image, np.ndarray, str],
201
+ device: torch.device,
202
+ dtype: torch.dtype,
203
+ ) -> torch.Tensor:
204
+ """
205
+ Preprocess input image to tensor format.
206
+
207
+ Args:
208
+ image: Input image (PIL, numpy, tensor, or path)
209
+ device: Target device
210
+ dtype: Target dtype
211
+
212
+ Returns:
213
+ Preprocessed image tensor, shape (1, 3, H, W), range [-1, 1]
214
+ """
215
+ # Load image if path is provided
216
+ if isinstance(image, str):
217
+ image = Image.open(image).convert("RGB")
218
+
219
+ # Convert PIL to numpy
220
+ if isinstance(image, Image.Image):
221
+ image = np.array(image)
222
+
223
+ # Convert numpy to tensor
224
+ if isinstance(image, np.ndarray):
225
+ # Ensure HWC format
226
+ if image.ndim == 2:
227
+ image = np.stack([image] * 3, axis=-1)
228
+ elif image.shape[0] == 3: # CHW format
229
+ image = np.transpose(image, (1, 2, 0))
230
+
231
+ # Normalize to [0, 1]
232
+ if image.dtype == np.uint8:
233
+ image = image.astype(np.float32) / 255.0
234
+
235
+ # Convert to tensor (B, C, H, W)
236
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
237
+
238
+ # Ensure batch dimension
239
+ if image.dim() == 3:
240
+ image = image.unsqueeze(0)
241
+
242
+ # Normalize to [-1, 1]
243
+ if image.min() >= 0 and image.max() <= 1:
244
+ image = image * 2.0 - 1.0
245
+
246
+ return image.to(device=device, dtype=dtype)
247
+
248
+ @torch.no_grad()
249
+ def __call__(
250
+ self,
251
+ image: Union[torch.Tensor, Image.Image, np.ndarray, str],
252
+ prompt: str = "",
253
+ timestep: int = 1,
254
+ processing_res: Optional[int] = None,
255
+ match_input_res: bool = True,
256
+ resample_method: str = "bilinear",
257
+ output_type: str = "np",
258
+ return_dict: bool = False,
259
+ **kwargs,
260
+ ):
261
+ """
262
+ Run surface normal estimation on input image.
263
+
264
+ Args:
265
+ image: Input RGB image (PIL, numpy, tensor, or file path)
266
+ prompt: Text prompt (optional, used only if DINO encoder is not available)
267
+ timestep: Diffusion timestep for deterministic prediction (default: 1)
268
+ processing_res: Processing resolution (default: 768)
269
+ match_input_res: Whether to resize output to match input resolution
270
+ resample_method: Resampling method for resizing
271
+ output_type: Output format - "np" (numpy), "pt" (tensor), or "pil" (PIL Image)
272
+ return_dict: Whether to return a dict with additional info
273
+
274
+ Returns:
275
+ Normal map in specified format. Normal vectors are in camera coordinates:
276
+ - X: right (positive = right)
277
+ - Y: down (positive = down)
278
+ - Z: forward (positive = into screen)
279
+
280
+ Output range is [0, 1] where 0.5 represents zero in each axis.
281
+ """
282
+ # Set default processing resolution
283
+ if processing_res is None:
284
+ processing_res = self.default_processing_resolution
285
+
286
+ device = self._execution_device
287
+ dtype = self.unet.dtype if self.unet is not None else torch.float32
288
+
289
+ # Preprocess input image
290
+ rgb_in = self.preprocess_image(image, device, dtype)
291
+ input_size = rgb_in.shape[-2:]
292
+
293
+ # Resize to processing resolution
294
+ resample_method_tv = get_tv_resample_method(resample_method)
295
+ if processing_res > 0:
296
+ rgb_in = resize_max_res(
297
+ rgb_in,
298
+ max_edge_resolution=processing_res,
299
+ resample_method=resample_method_tv,
300
+ )
301
+
302
+ # Get encoder hidden states (DINO or CLIP)
303
+ encoder_hidden_states = self._get_encoder_hidden_states(
304
+ rgb_in=rgb_in,
305
+ prompt=prompt,
306
+ device=device,
307
+ )
308
+
309
+ # Prepare timestep
310
+ timesteps = torch.tensor([timestep], device=device).long()
311
+
312
+ # Encode RGB to latent space
313
+ rgb_latents = self.vae.encode(rgb_in).latent_dist.sample()
314
+ rgb_latents = rgb_latents * self.vae.config.scaling_factor
315
+
316
+ # Task embedding for normal estimation
317
+ task_emb = torch.tensor([1, 0], dtype=dtype, device=device).unsqueeze(0)
318
+ task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
319
+
320
+ # Single-step deterministic prediction
321
+ t = timesteps[0]
322
+ pred = self.unet(
323
+ rgb_latents,
324
+ t,
325
+ encoder_hidden_states=encoder_hidden_states,
326
+ return_dict=False,
327
+ class_labels=task_emb,
328
+ )[0]
329
+
330
+ # Decode prediction
331
+ normal_latent = pred / self.vae.config.scaling_factor
332
+ normal_image = self.vae.decode(normal_latent, return_dict=False)[0]
333
+
334
+ # Post-process to [0, 1] range
335
+ normal_image = (normal_image / 2 + 0.5).clamp(0, 1)
336
+
337
+ # Resize back to input resolution if requested
338
+ if match_input_res and processing_res > 0:
339
+ normal_image = F.interpolate(
340
+ normal_image,
341
+ size=input_size,
342
+ mode='bilinear',
343
+ align_corners=False,
344
+ )
345
+
346
+ # Convert to output format
347
+ if output_type == "pt":
348
+ output = normal_image # (B, 3, H, W), range [0, 1]
349
+ elif output_type == "np":
350
+ # Convert to float32 first (bfloat16 not supported by numpy)
351
+ output = normal_image.float().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3)
352
+ if output.shape[0] == 1:
353
+ output = output[0] # (H, W, 3)
354
+ elif output_type == "pil":
355
+ # Convert to float32 first (bfloat16 not supported by numpy)
356
+ output = normal_image.float().cpu().permute(0, 2, 3, 1).numpy()
357
+ output = (output * 255).astype(np.uint8)
358
+ if output.shape[0] == 1:
359
+ output = Image.fromarray(output[0])
360
+ else:
361
+ output = [Image.fromarray(img) for img in output]
362
+ else:
363
+ raise ValueError(f"Unknown output_type: {output_type}")
364
+
365
+ if return_dict:
366
+ return {"normal": output, "resolution": normal_image.shape[-2:]}
367
+ return output
368
+
369
+ @classmethod
370
+ def from_pretrained(
371
+ cls,
372
+ pretrained_model_name_or_path: str,
373
+ dino_encoder: Optional[nn.Module] = None,
374
+ **kwargs,
375
+ ):
376
+ """
377
+ Load TransNormalPipeline from pretrained weights.
378
+
379
+ Args:
380
+ pretrained_model_name_or_path: Path to pretrained model or HuggingFace model ID
381
+ dino_encoder: Optional pre-loaded DINO encoder
382
+ **kwargs: Additional arguments passed to DiffusionPipeline.from_pretrained
383
+
384
+ Returns:
385
+ TransNormalPipeline instance
386
+ """
387
+ # Load base pipeline components
388
+ pipeline = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
389
+
390
+ # Set DINO encoder if provided
391
+ if dino_encoder is not None:
392
+ pipeline.set_dino_encoder(dino_encoder)
393
+
394
+ return pipeline
transnormal/utils.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for TransNormal pipeline.
3
+
4
+ Includes image processing utilities for preprocessing and postprocessing.
5
+ """
6
+
7
+ from typing import List, Union
8
+ from PIL import Image
9
+ import numpy as np
10
+ import torch
11
+ from torchvision.transforms import InterpolationMode
12
+ from torchvision.transforms.functional import resize
13
+
14
+
15
+ def resize_max_res(
16
+ img: torch.Tensor,
17
+ max_edge_resolution: int,
18
+ resample_method: InterpolationMode = InterpolationMode.BILINEAR,
19
+ ) -> torch.Tensor:
20
+ """
21
+ Resize image to limit maximum edge length while keeping aspect ratio.
22
+
23
+ Args:
24
+ img: Image tensor to be resized. Expected shape: [B, C, H, W]
25
+ max_edge_resolution: Maximum edge length (pixels)
26
+ resample_method: Resampling method used to resize images
27
+
28
+ Returns:
29
+ Resized image tensor
30
+ """
31
+ assert img.dim() == 4, f"Invalid input shape {img.shape}, expected [B, C, H, W]"
32
+
33
+ original_height, original_width = img.shape[-2:]
34
+ downscale_factor = min(
35
+ max_edge_resolution / original_width,
36
+ max_edge_resolution / original_height
37
+ )
38
+
39
+ new_width = int(original_width * downscale_factor)
40
+ new_height = int(original_height * downscale_factor)
41
+
42
+ resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
43
+ return resized_img
44
+
45
+
46
+ def resize_back(
47
+ img: Union[torch.Tensor, np.ndarray, Image.Image, List[Image.Image]],
48
+ target_size: Union[int, tuple],
49
+ resample_method: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
50
+ ) -> Union[torch.Tensor, np.ndarray, Image.Image, List[Image.Image]]:
51
+ """
52
+ Resize image back to target size.
53
+
54
+ Args:
55
+ img: Image to be resized (tensor, numpy, PIL, or list of PIL)
56
+ target_size: Target size (H, W) or single int for square
57
+ resample_method: Resampling method for resizing
58
+
59
+ Returns:
60
+ Resized image in the same format as input
61
+ """
62
+ if isinstance(target_size, int):
63
+ target_size = (target_size, target_size)
64
+
65
+ if isinstance(img, torch.Tensor):
66
+ resized_img = resize(img, target_size, resample_method, antialias=True)
67
+ elif isinstance(img, np.ndarray):
68
+ # Convert to tensor
69
+ if img.ndim == 3: # HWC
70
+ img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
71
+ else: # BHWC
72
+ img_tensor = torch.from_numpy(img).permute(0, 3, 1, 2)
73
+
74
+ resized_tensor = resize(img_tensor, target_size, resample_method, antialias=True)
75
+
76
+ # Convert back
77
+ if img.ndim == 3:
78
+ resized_img = resized_tensor.squeeze(0).permute(1, 2, 0).numpy()
79
+ else:
80
+ resized_img = resized_tensor.permute(0, 2, 3, 1).numpy()
81
+ elif isinstance(img, Image.Image):
82
+ # PIL uses (width, height)
83
+ pil_size = (target_size[1], target_size[0])
84
+ resized_img = img.resize(pil_size, resample_method)
85
+ elif isinstance(img, list) and all(isinstance(i, Image.Image) for i in img):
86
+ pil_size = (target_size[1], target_size[0])
87
+ resized_img = [i.resize(pil_size, resample_method) for i in img]
88
+ else:
89
+ raise TypeError(f"Unsupported image type: {type(img)}")
90
+
91
+ return resized_img
92
+
93
+
94
+ def get_tv_resample_method(method_str: str) -> InterpolationMode:
95
+ """
96
+ Get torchvision interpolation mode from string.
97
+
98
+ Args:
99
+ method_str: Resampling method name ("bilinear", "bicubic", "nearest")
100
+
101
+ Returns:
102
+ Corresponding InterpolationMode
103
+ """
104
+ resample_method_dict = {
105
+ "bilinear": InterpolationMode.BILINEAR,
106
+ "bicubic": InterpolationMode.BICUBIC,
107
+ "nearest": InterpolationMode.NEAREST_EXACT,
108
+ "nearest-exact": InterpolationMode.NEAREST_EXACT,
109
+ }
110
+ resample_method = resample_method_dict.get(method_str.lower())
111
+ if resample_method is None:
112
+ raise ValueError(f"Unknown resampling method: {method_str}")
113
+ return resample_method
114
+
115
+
116
+ def get_pil_resample_method(method_str: str) -> int:
117
+ """
118
+ Get PIL resampling method from string.
119
+
120
+ Args:
121
+ method_str: Resampling method name ("bilinear", "bicubic", "nearest")
122
+
123
+ Returns:
124
+ Corresponding PIL resampling constant
125
+ """
126
+ resample_method_dict = {
127
+ "bilinear": Image.BILINEAR,
128
+ "bicubic": Image.BICUBIC,
129
+ "nearest": Image.NEAREST,
130
+ }
131
+ resample_method = resample_method_dict.get(method_str.lower())
132
+ if resample_method is None:
133
+ raise ValueError(f"Unknown resampling method: {method_str}")
134
+ return resample_method
135
+
136
+
137
+ def normal_to_rgb(normal: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
138
+ """
139
+ Convert normal map to RGB visualization.
140
+
141
+ Normal vectors are assumed to be in range [-1, 1] or [0, 1].
142
+ Output is RGB image in range [0, 255].
143
+
144
+ Args:
145
+ normal: Normal map tensor/array, shape (H, W, 3) or (B, H, W, 3) or (B, 3, H, W)
146
+
147
+ Returns:
148
+ RGB visualization as uint8 numpy array
149
+ """
150
+ if isinstance(normal, torch.Tensor):
151
+ normal = normal.cpu().numpy()
152
+
153
+ # Handle different formats
154
+ if normal.ndim == 4:
155
+ if normal.shape[1] == 3: # BCHW
156
+ normal = np.transpose(normal, (0, 2, 3, 1)) # BHWC
157
+ normal = normal[0] # Take first batch
158
+
159
+ # Convert from [-1, 1] to [0, 1] if needed
160
+ if normal.min() < 0:
161
+ normal = (normal + 1.0) / 2.0
162
+
163
+ # Clamp and convert to uint8
164
+ normal = np.clip(normal, 0, 1)
165
+ rgb = (normal * 255).astype(np.uint8)
166
+
167
+ return rgb
168
+
169
+
170
+ def save_normal_map(
171
+ normal: Union[torch.Tensor, np.ndarray],
172
+ output_path: str,
173
+ as_rgb: bool = True,
174
+ ):
175
+ """
176
+ Save normal map to file.
177
+
178
+ Args:
179
+ normal: Normal map tensor/array
180
+ output_path: Output file path
181
+ as_rgb: If True, save as RGB visualization; if False, save raw values as NPZ
182
+ """
183
+ if as_rgb:
184
+ rgb = normal_to_rgb(normal)
185
+ Image.fromarray(rgb).save(output_path)
186
+ else:
187
+ if isinstance(normal, torch.Tensor):
188
+ normal = normal.cpu().numpy()
189
+ np.savez_compressed(output_path, normal=normal)
190
+
191
+
192
+ def load_image(image_path: str) -> Image.Image:
193
+ """
194
+ Load image from file path.
195
+
196
+ Args:
197
+ image_path: Path to image file
198
+
199
+ Returns:
200
+ PIL Image in RGB mode
201
+ """
202
+ return Image.open(image_path).convert("RGB")
203
+
204
+
205
+ def concatenate_images(*image_lists) -> Image.Image:
206
+ """
207
+ Concatenate multiple rows of images into a single image.
208
+
209
+ Args:
210
+ *image_lists: Variable number of image lists, each list is a row
211
+
212
+ Returns:
213
+ Concatenated PIL Image
214
+ """
215
+ if not image_lists or not image_lists[0]:
216
+ raise ValueError("At least one non-empty image list must be provided")
217
+
218
+ max_width = 0
219
+ total_height = 0
220
+ row_heights = []
221
+
222
+ for image_list in image_lists:
223
+ if image_list:
224
+ width = sum(img.width for img in image_list)
225
+ height = image_list[0].height
226
+ max_width = max(max_width, width)
227
+ total_height += height
228
+ row_heights.append(height)
229
+
230
+ new_image = Image.new('RGB', (max_width, total_height))
231
+
232
+ y_offset = 0
233
+ for i, image_list in enumerate(image_lists):
234
+ x_offset = 0
235
+ for img in image_list:
236
+ new_image.paste(img, (x_offset, y_offset))
237
+ x_offset += img.width
238
+ y_offset += row_heights[i]
239
+
240
+ return new_image