Spaces:
Sleeping
Sleeping
| """Image IO and resize helpers for TransNormal-2 inference.""" | |
| from typing import List, Optional, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| def load_image(path: str) -> torch.Tensor: | |
| """Load an image file as a (1, 3, H, W) float tensor in [-1, 1].""" | |
| img = Image.open(path).convert("RGB") | |
| arr = np.asarray(img).astype(np.float32) | |
| ts = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) | |
| return ts / 127.5 - 1.0 | |
| def resize_to_multiple_of_16(image_tensor: torch.Tensor) -> torch.Tensor: | |
| """Rescale (B, C, H, W) so both sides are multiples of 16 (aspect kept).""" | |
| h, w = image_tensor.shape[2], image_tensor.shape[3] | |
| min_side = min(h, w) | |
| scale = (min_side // 16) * 16 / min_side | |
| new_h = (int(h * scale) // 16) * 16 | |
| new_w = (int(w * scale) // 16) * 16 | |
| if (new_h, new_w) == (h, w): | |
| return image_tensor | |
| return F.interpolate( | |
| image_tensor, size=(new_h, new_w), mode="bilinear", align_corners=False | |
| ) | |
| def resize_image_first(image_tensor: torch.Tensor, process_res: Optional[int] = None) -> torch.Tensor: | |
| """Optionally cap the max edge at ``process_res``, then snap to /16.""" | |
| if process_res: | |
| max_edge = max(image_tensor.shape[2], image_tensor.shape[3]) | |
| if max_edge > process_res: | |
| scale = process_res / max_edge | |
| new_height = int(image_tensor.shape[2] * scale) | |
| new_width = int(image_tensor.shape[3] * scale) | |
| image_tensor = F.interpolate( | |
| image_tensor, size=(new_height, new_width), mode="bilinear", align_corners=False | |
| ) | |
| return resize_to_multiple_of_16(image_tensor) | |
| def tensor_to_output( | |
| normal_01: torch.Tensor, output_type: str = "pt" | |
| ) -> Union[torch.Tensor, np.ndarray, List[Image.Image]]: | |
| """Convert a (B, 3, H, W) [0, 1] tensor to the requested output format.""" | |
| if output_type == "pt": | |
| return normal_01 | |
| arr = normal_01.float().clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy() | |
| if output_type == "np": | |
| return arr | |
| if output_type == "pil": | |
| return [Image.fromarray((a * 255).round().astype(np.uint8)) for a in arr] | |
| raise ValueError(f"Unsupported output_type: {output_type} (use 'pt', 'np' or 'pil')") | |
| def save_normal_map( | |
| normal: Union[torch.Tensor, np.ndarray], | |
| path: str, | |
| save_npy: Optional[str] = None, | |
| ) -> None: | |
| """Save a normal map prediction as a PNG (and optionally raw .npy). | |
| Accepts (3, H, W) / (1, 3, H, W) tensors or (H, W, 3) arrays in [0, 1] | |
| (the ``(n + 1) / 2`` encoding: camera-space X right, Y up, Z toward | |
| camera mapped to RGB). | |
| """ | |
| if isinstance(normal, torch.Tensor): | |
| t = normal.detach().float().cpu() | |
| if t.dim() == 4: | |
| t = t[0] | |
| if t.dim() == 3 and t.shape[0] == 3: | |
| t = t.permute(1, 2, 0) | |
| arr = t.numpy() | |
| else: | |
| arr = np.asarray(normal, dtype=np.float32) | |
| if arr.ndim == 4: | |
| arr = arr[0] | |
| arr = np.clip(arr, 0.0, 1.0) | |
| Image.fromarray((arr * 255).round().astype(np.uint8)).save(path) | |
| if save_npy: | |
| np.save(save_npy, arr) | |