Longxiang-ai's picture
Add TransNormal-2 Gradio demo
4a8f4e7 verified
Raw
History Blame Contribute Delete
3.16 kB
"""Image IO and resize helpers for TransNormal-2 inference."""
from typing import List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
def load_image(path: str) -> torch.Tensor:
"""Load an image file as a (1, 3, H, W) float tensor in [-1, 1]."""
img = Image.open(path).convert("RGB")
arr = np.asarray(img).astype(np.float32)
ts = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
return ts / 127.5 - 1.0
def resize_to_multiple_of_16(image_tensor: torch.Tensor) -> torch.Tensor:
"""Rescale (B, C, H, W) so both sides are multiples of 16 (aspect kept)."""
h, w = image_tensor.shape[2], image_tensor.shape[3]
min_side = min(h, w)
scale = (min_side // 16) * 16 / min_side
new_h = (int(h * scale) // 16) * 16
new_w = (int(w * scale) // 16) * 16
if (new_h, new_w) == (h, w):
return image_tensor
return F.interpolate(
image_tensor, size=(new_h, new_w), mode="bilinear", align_corners=False
)
def resize_image_first(image_tensor: torch.Tensor, process_res: Optional[int] = None) -> torch.Tensor:
"""Optionally cap the max edge at ``process_res``, then snap to /16."""
if process_res:
max_edge = max(image_tensor.shape[2], image_tensor.shape[3])
if max_edge > process_res:
scale = process_res / max_edge
new_height = int(image_tensor.shape[2] * scale)
new_width = int(image_tensor.shape[3] * scale)
image_tensor = F.interpolate(
image_tensor, size=(new_height, new_width), mode="bilinear", align_corners=False
)
return resize_to_multiple_of_16(image_tensor)
def tensor_to_output(
normal_01: torch.Tensor, output_type: str = "pt"
) -> Union[torch.Tensor, np.ndarray, List[Image.Image]]:
"""Convert a (B, 3, H, W) [0, 1] tensor to the requested output format."""
if output_type == "pt":
return normal_01
arr = normal_01.float().clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy()
if output_type == "np":
return arr
if output_type == "pil":
return [Image.fromarray((a * 255).round().astype(np.uint8)) for a in arr]
raise ValueError(f"Unsupported output_type: {output_type} (use 'pt', 'np' or 'pil')")
def save_normal_map(
normal: Union[torch.Tensor, np.ndarray],
path: str,
save_npy: Optional[str] = None,
) -> None:
"""Save a normal map prediction as a PNG (and optionally raw .npy).
Accepts (3, H, W) / (1, 3, H, W) tensors or (H, W, 3) arrays in [0, 1]
(the ``(n + 1) / 2`` encoding: camera-space X right, Y up, Z toward
camera mapped to RGB).
"""
if isinstance(normal, torch.Tensor):
t = normal.detach().float().cpu()
if t.dim() == 4:
t = t[0]
if t.dim() == 3 and t.shape[0] == 3:
t = t.permute(1, 2, 0)
arr = t.numpy()
else:
arr = np.asarray(normal, dtype=np.float32)
if arr.ndim == 4:
arr = arr[0]
arr = np.clip(arr, 0.0, 1.0)
Image.fromarray((arr * 255).round().astype(np.uint8)).save(path)
if save_npy:
np.save(save_npy, arr)