"""U-Net model for PBR map prediction: basecolor -> normal + roughness + metallic.""" import torch import torch.nn as nn import segmentation_models_pytorch as smp from src.height_to_normal import height_to_normal # Canonical category list — order matters (maps category string -> index). # "unknown" is index 0 and acts as fallback for unseen categories. CATEGORIES = [ "unknown", "Ceramic", "Concrete", "Fabric", "Ground", "Leather", "Marble", "Metal", "Misc", "Plaster", "Plastic", "Stone", "Terracotta", "Wood", ] CATEGORY_TO_IDX = {name: i for i, name in enumerate(CATEGORIES)} def category_to_index(category: str) -> int: """Convert category string to integer index, falling back to 0 ('unknown').""" return CATEGORY_TO_IDX.get(category, 0) class PBRUNet(nn.Module): """Predicts normal, roughness, and metallic maps from a basecolor input. Optionally conditioned on material category via a learned embedding that is broadcast spatially and concatenated to the input image. Input: basecolor (B, 3, H, W) float32 [0, 1] category (B,) int64 indices (optional) Output: dict with: "normal" (B, 3, H, W) float32 [0, 1] "roughness" (B, 1, H, W) float32 [0, 1] "metallic" (B, 1, H, W) float32 [0, 1] """ def __init__( self, encoder_name: str = "resnet34", encoder_weights: str = "imagenet", n_categories: int = len(CATEGORIES), category_embed_dim: int = 8, use_category: bool = False, normal_xy_only: bool = False, separate_normal_decoder: bool = False, predict_height: bool = False, ): super().__init__() self.use_category = use_category self.normal_xy_only = normal_xy_only self.separate_normal_decoder = separate_normal_decoder self.predict_height = predict_height in_channels = 3 if use_category: self.category_embed = nn.Embedding(n_categories, category_embed_dim) in_channels = 3 + category_embed_dim normal_ch = 2 if normal_xy_only else 3 _weights = encoder_weights if encoder_weights != "none" else None if separate_normal_decoder: normal_out_ch = 1 if predict_height else normal_ch self.normal_unet = smp.Unet( encoder_name=encoder_name, encoder_weights=_weights, in_channels=in_channels, classes=normal_out_ch, ) self.material_unet = smp.Unet( encoder_name=encoder_name, encoder_weights=_weights, in_channels=in_channels, classes=2, ) else: out_ch = (1 if predict_height else normal_ch) + 2 self.unet = smp.Unet( encoder_name=encoder_name, encoder_weights=_weights, in_channels=in_channels, classes=out_ch, ) self.sigmoid = nn.Sigmoid() def _prepare_input(self, basecolor, category): """Build network input: basecolor + optional category embedding.""" x = basecolor if self.use_category: if category is None: category = torch.zeros( basecolor.shape[0], dtype=torch.long, device=basecolor.device ) emb = self.category_embed(category) emb = emb[:, :, None, None].expand(-1, -1, x.shape[2], x.shape[3]) x = torch.cat([x, emb], dim=1) return x def _decode_normal(self, raw_normal): """Decode raw normal predictions to [0,1] range.""" if self.normal_xy_only: xy = torch.tanh(raw_normal) xy_sq = (xy ** 2).sum(dim=1, keepdim=True).clamp(max=1.0 - 1e-6) z = torch.sqrt(1.0 - xy_sq) normal = torch.cat([xy * 0.5 + 0.5, z * 0.5 + 0.5], dim=1) else: normal = self.sigmoid(raw_normal) return normal def forward( self, basecolor: torch.Tensor, category: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: x = self._prepare_input(basecolor, category) if self.separate_normal_decoder: raw_normal = self.normal_unet(x) raw_material = self.material_unet(x) if self.predict_height: height = self.sigmoid(raw_normal) normal = height_to_normal(height, intensity=1.0) else: normal = self._decode_normal(raw_normal) roughness, metallic = self.sigmoid(raw_material).split([1, 1], dim=1) else: out = self.unet(x) if self.predict_height: height_raw, rest = out.split([1, 2], dim=1) height = self.sigmoid(height_raw) normal = height_to_normal(height, intensity=1.0) roughness, metallic = self.sigmoid(rest).split([1, 1], dim=1) elif self.normal_xy_only: normal_raw, rest = out.split([2, 2], dim=1) normal = self._decode_normal(normal_raw) roughness, metallic = self.sigmoid(rest).split([1, 1], dim=1) else: out = self.sigmoid(out) normal, roughness, metallic = out.split([3, 1, 1], dim=1) result = {"normal": normal, "roughness": roughness, "metallic": metallic} if self.predict_height: result["height"] = height return result