File size: 7,384 Bytes
159500c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
Utility functions for TransNormal pipeline.

Includes image processing utilities for preprocessing and postprocessing.
"""

from typing import List, Union
from PIL import Image
import numpy as np
import torch
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize


def resize_max_res(
    img: torch.Tensor,
    max_edge_resolution: int,
    resample_method: InterpolationMode = InterpolationMode.BILINEAR,
) -> torch.Tensor:
    """
    Resize image to limit maximum edge length while keeping aspect ratio.

    Args:
        img: Image tensor to be resized. Expected shape: [B, C, H, W]
        max_edge_resolution: Maximum edge length (pixels)
        resample_method: Resampling method used to resize images

    Returns:
        Resized image tensor
    """
    assert img.dim() == 4, f"Invalid input shape {img.shape}, expected [B, C, H, W]"

    original_height, original_width = img.shape[-2:]
    downscale_factor = min(
        max_edge_resolution / original_width,
        max_edge_resolution / original_height
    )

    new_width = int(original_width * downscale_factor)
    new_height = int(original_height * downscale_factor)

    resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
    return resized_img


def resize_back(
    img: Union[torch.Tensor, np.ndarray, Image.Image, List[Image.Image]],
    target_size: Union[int, tuple],
    resample_method: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> Union[torch.Tensor, np.ndarray, Image.Image, List[Image.Image]]:
    """
    Resize image back to target size.

    Args:
        img: Image to be resized (tensor, numpy, PIL, or list of PIL)
        target_size: Target size (H, W) or single int for square
        resample_method: Resampling method for resizing

    Returns:
        Resized image in the same format as input
    """
    if isinstance(target_size, int):
        target_size = (target_size, target_size)
    
    if isinstance(img, torch.Tensor):
        resized_img = resize(img, target_size, resample_method, antialias=True)
    elif isinstance(img, np.ndarray):
        # Convert to tensor
        if img.ndim == 3:  # HWC
            img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
        else:  # BHWC
            img_tensor = torch.from_numpy(img).permute(0, 3, 1, 2)
        
        resized_tensor = resize(img_tensor, target_size, resample_method, antialias=True)
        
        # Convert back
        if img.ndim == 3:
            resized_img = resized_tensor.squeeze(0).permute(1, 2, 0).numpy()
        else:
            resized_img = resized_tensor.permute(0, 2, 3, 1).numpy()
    elif isinstance(img, Image.Image):
        # PIL uses (width, height)
        pil_size = (target_size[1], target_size[0])
        resized_img = img.resize(pil_size, resample_method)
    elif isinstance(img, list) and all(isinstance(i, Image.Image) for i in img):
        pil_size = (target_size[1], target_size[0])
        resized_img = [i.resize(pil_size, resample_method) for i in img]
    else:
        raise TypeError(f"Unsupported image type: {type(img)}")
    
    return resized_img


def get_tv_resample_method(method_str: str) -> InterpolationMode:
    """
    Get torchvision interpolation mode from string.

    Args:
        method_str: Resampling method name ("bilinear", "bicubic", "nearest")

    Returns:
        Corresponding InterpolationMode
    """
    resample_method_dict = {
        "bilinear": InterpolationMode.BILINEAR,
        "bicubic": InterpolationMode.BICUBIC,
        "nearest": InterpolationMode.NEAREST_EXACT,
        "nearest-exact": InterpolationMode.NEAREST_EXACT,
    }
    resample_method = resample_method_dict.get(method_str.lower())
    if resample_method is None:
        raise ValueError(f"Unknown resampling method: {method_str}")
    return resample_method


def get_pil_resample_method(method_str: str) -> int:
    """
    Get PIL resampling method from string.

    Args:
        method_str: Resampling method name ("bilinear", "bicubic", "nearest")

    Returns:
        Corresponding PIL resampling constant
    """
    resample_method_dict = {
        "bilinear": Image.BILINEAR,
        "bicubic": Image.BICUBIC,
        "nearest": Image.NEAREST,
    }
    resample_method = resample_method_dict.get(method_str.lower())
    if resample_method is None:
        raise ValueError(f"Unknown resampling method: {method_str}")
    return resample_method


def normal_to_rgb(normal: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
    """
    Convert normal map to RGB visualization.
    
    Normal vectors are assumed to be in range [-1, 1] or [0, 1].
    Output is RGB image in range [0, 255].
    
    Args:
        normal: Normal map tensor/array, shape (H, W, 3) or (B, H, W, 3) or (B, 3, H, W)
    
    Returns:
        RGB visualization as uint8 numpy array
    """
    if isinstance(normal, torch.Tensor):
        normal = normal.cpu().numpy()
    
    # Handle different formats
    if normal.ndim == 4:
        if normal.shape[1] == 3:  # BCHW
            normal = np.transpose(normal, (0, 2, 3, 1))  # BHWC
        normal = normal[0]  # Take first batch
    
    # Convert from [-1, 1] to [0, 1] if needed
    if normal.min() < 0:
        normal = (normal + 1.0) / 2.0
    
    # Clamp and convert to uint8
    normal = np.clip(normal, 0, 1)
    rgb = (normal * 255).astype(np.uint8)
    
    return rgb


def save_normal_map(
    normal: Union[torch.Tensor, np.ndarray],
    output_path: str,
    as_rgb: bool = True,
):
    """
    Save normal map to file.
    
    Args:
        normal: Normal map tensor/array
        output_path: Output file path
        as_rgb: If True, save as RGB visualization; if False, save raw values as NPZ
    """
    if as_rgb:
        rgb = normal_to_rgb(normal)
        Image.fromarray(rgb).save(output_path)
    else:
        if isinstance(normal, torch.Tensor):
            normal = normal.cpu().numpy()
        np.savez_compressed(output_path, normal=normal)


def load_image(image_path: str) -> Image.Image:
    """
    Load image from file path.
    
    Args:
        image_path: Path to image file
    
    Returns:
        PIL Image in RGB mode
    """
    return Image.open(image_path).convert("RGB")


def concatenate_images(*image_lists) -> Image.Image:
    """
    Concatenate multiple rows of images into a single image.
    
    Args:
        *image_lists: Variable number of image lists, each list is a row
    
    Returns:
        Concatenated PIL Image
    """
    if not image_lists or not image_lists[0]:
        raise ValueError("At least one non-empty image list must be provided")
    
    max_width = 0
    total_height = 0
    row_heights = []
    
    for image_list in image_lists:
        if image_list:
            width = sum(img.width for img in image_list)
            height = image_list[0].height
            max_width = max(max_width, width)
            total_height += height
            row_heights.append(height)
    
    new_image = Image.new('RGB', (max_width, total_height))
    
    y_offset = 0
    for i, image_list in enumerate(image_lists):
        x_offset = 0
        for img in image_list:
            new_image.paste(img, (x_offset, y_offset))
            x_offset += img.width
        y_offset += row_heights[i]
    
    return new_image