Spaces:
Running
Running
| """Shared image transforms for PBR map processing.""" | |
| import random | |
| import torch | |
| from torchvision import transforms | |
| # Maps that use standard [0,1] range | |
| STANDARD_MAPS = ("basecolor", "roughness", "metallic") | |
| # Normal maps need special handling ([-1,1] range) | |
| NORMAL_MAP = "normal" | |
| MAP_NAMES = ("basecolor", "normal", "roughness", "metallic") | |
| def get_resize_transform(size: int = 256): | |
| """Basic resize + to-tensor for all map types.""" | |
| return transforms.Compose([ | |
| transforms.Resize((size, size), interpolation=transforms.InterpolationMode.LANCZOS), | |
| transforms.ToTensor(), # [0,255] -> [0,1], HWC -> CHW | |
| ]) | |
| def get_train_transform(size: int = 256): | |
| """Training transform with augmentation (shared crop/flip for all maps).""" | |
| return transforms.Compose([ | |
| transforms.Resize(size, interpolation=transforms.InterpolationMode.LANCZOS), | |
| transforms.RandomCrop(size), | |
| transforms.ToTensor(), | |
| ]) | |
| def get_preview_transform(size: int = 512): | |
| """Larger resize for visual inspection.""" | |
| return transforms.Compose([ | |
| transforms.Resize((size, size), interpolation=transforms.InterpolationMode.LANCZOS), | |
| transforms.ToTensor(), | |
| ]) | |
| class PBRAugmentation: | |
| """Augmentation for PBR map samples with correct normal map handling. | |
| Applies identical spatial transforms to all maps, then corrects normal | |
| map X/Y channels for flips and rotations. | |
| Normal map convention (tangent space): | |
| R (ch 0) = X (right), stored as [0,1], 0.5 = zero | |
| G (ch 1) = Y (up), stored as [0,1], 0.5 = zero | |
| B (ch 2) = Z (out), always positive | |
| Corrections after spatial transforms: | |
| H-flip: negate X -> R = 1.0 - R | |
| V-flip: negate Y -> G = 1.0 - G | |
| 90° CW: (X, Y) -> (Y, -X) -> R_new = G_old, G_new = 1.0 - R_old | |
| """ | |
| def __init__(self, hflip=True, vflip=True, rot90=True): | |
| self.hflip = hflip | |
| self.vflip = vflip | |
| self.rot90 = rot90 | |
| def __call__(self, sample: dict) -> dict: | |
| """Augment a sample dict in-place. All map tensors must be (C, H, W).""" | |
| do_hflip = self.hflip and random.random() < 0.5 | |
| do_vflip = self.vflip and random.random() < 0.5 | |
| n_rot90 = random.randint(0, 3) if self.rot90 else 0 | |
| # Apply spatial transforms to all maps | |
| for key in MAP_NAMES: | |
| if key not in sample: | |
| continue | |
| t = sample[key] | |
| if do_hflip: | |
| t = t.flip(-1) # flip W | |
| if do_vflip: | |
| t = t.flip(-2) # flip H | |
| if n_rot90 > 0: | |
| t = torch.rot90(t, n_rot90, dims=(-2, -1)) | |
| sample[key] = t | |
| # Correct normal map channels | |
| if "normal" in sample: | |
| n = sample["normal"] | |
| # Apply flip corrections | |
| if do_hflip: | |
| n[0] = 1.0 - n[0] # negate X | |
| if do_vflip: | |
| n[1] = 1.0 - n[1] # negate Y | |
| # Apply rotation corrections | |
| if n_rot90 % 4 == 1: # 90° CW | |
| r_old, g_old = n[0].clone(), n[1].clone() | |
| n[0] = g_old # X_new = Y_old | |
| n[1] = 1.0 - r_old # Y_new = -X_old | |
| elif n_rot90 % 4 == 2: # 180° | |
| n[0] = 1.0 - n[0] # negate X | |
| n[1] = 1.0 - n[1] # negate Y | |
| elif n_rot90 % 4 == 3: # 270° CW (= 90° CCW) | |
| r_old, g_old = n[0].clone(), n[1].clone() | |
| n[0] = 1.0 - g_old # X_new = -Y_old | |
| n[1] = r_old # Y_new = X_old | |
| sample["normal"] = n | |
| return sample | |