pbr-material-predictor / src /transforms.py
Andrid1's picture
Add PBR material predictor demo (3 curated runs)
94316fa verified
Raw
History Blame Contribute Delete
3.66 kB
"""Shared image transforms for PBR map processing."""
import random
import torch
from torchvision import transforms
# Maps that use standard [0,1] range
STANDARD_MAPS = ("basecolor", "roughness", "metallic")
# Normal maps need special handling ([-1,1] range)
NORMAL_MAP = "normal"
MAP_NAMES = ("basecolor", "normal", "roughness", "metallic")
def get_resize_transform(size: int = 256):
"""Basic resize + to-tensor for all map types."""
return transforms.Compose([
transforms.Resize((size, size), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.ToTensor(), # [0,255] -> [0,1], HWC -> CHW
])
def get_train_transform(size: int = 256):
"""Training transform with augmentation (shared crop/flip for all maps)."""
return transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.LANCZOS),
transforms.RandomCrop(size),
transforms.ToTensor(),
])
def get_preview_transform(size: int = 512):
"""Larger resize for visual inspection."""
return transforms.Compose([
transforms.Resize((size, size), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.ToTensor(),
])
class PBRAugmentation:
"""Augmentation for PBR map samples with correct normal map handling.
Applies identical spatial transforms to all maps, then corrects normal
map X/Y channels for flips and rotations.
Normal map convention (tangent space):
R (ch 0) = X (right), stored as [0,1], 0.5 = zero
G (ch 1) = Y (up), stored as [0,1], 0.5 = zero
B (ch 2) = Z (out), always positive
Corrections after spatial transforms:
H-flip: negate X -> R = 1.0 - R
V-flip: negate Y -> G = 1.0 - G
90° CW: (X, Y) -> (Y, -X) -> R_new = G_old, G_new = 1.0 - R_old
"""
def __init__(self, hflip=True, vflip=True, rot90=True):
self.hflip = hflip
self.vflip = vflip
self.rot90 = rot90
def __call__(self, sample: dict) -> dict:
"""Augment a sample dict in-place. All map tensors must be (C, H, W)."""
do_hflip = self.hflip and random.random() < 0.5
do_vflip = self.vflip and random.random() < 0.5
n_rot90 = random.randint(0, 3) if self.rot90 else 0
# Apply spatial transforms to all maps
for key in MAP_NAMES:
if key not in sample:
continue
t = sample[key]
if do_hflip:
t = t.flip(-1) # flip W
if do_vflip:
t = t.flip(-2) # flip H
if n_rot90 > 0:
t = torch.rot90(t, n_rot90, dims=(-2, -1))
sample[key] = t
# Correct normal map channels
if "normal" in sample:
n = sample["normal"]
# Apply flip corrections
if do_hflip:
n[0] = 1.0 - n[0] # negate X
if do_vflip:
n[1] = 1.0 - n[1] # negate Y
# Apply rotation corrections
if n_rot90 % 4 == 1: # 90° CW
r_old, g_old = n[0].clone(), n[1].clone()
n[0] = g_old # X_new = Y_old
n[1] = 1.0 - r_old # Y_new = -X_old
elif n_rot90 % 4 == 2: # 180°
n[0] = 1.0 - n[0] # negate X
n[1] = 1.0 - n[1] # negate Y
elif n_rot90 % 4 == 3: # 270° CW (= 90° CCW)
r_old, g_old = n[0].clone(), n[1].clone()
n[0] = 1.0 - g_old # X_new = -Y_old
n[1] = r_old # Y_new = X_old
sample["normal"] = n
return sample