Andrid1's picture
Add PBR material predictor demo (3 curated runs)
94316fa verified
Raw
History Blame Contribute Delete
5.6 kB
"""U-Net model for PBR map prediction: basecolor -> normal + roughness + metallic."""
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from src.height_to_normal import height_to_normal
# Canonical category list — order matters (maps category string -> index).
# "unknown" is index 0 and acts as fallback for unseen categories.
CATEGORIES = [
"unknown",
"Ceramic",
"Concrete",
"Fabric",
"Ground",
"Leather",
"Marble",
"Metal",
"Misc",
"Plaster",
"Plastic",
"Stone",
"Terracotta",
"Wood",
]
CATEGORY_TO_IDX = {name: i for i, name in enumerate(CATEGORIES)}
def category_to_index(category: str) -> int:
"""Convert category string to integer index, falling back to 0 ('unknown')."""
return CATEGORY_TO_IDX.get(category, 0)
class PBRUNet(nn.Module):
"""Predicts normal, roughness, and metallic maps from a basecolor input.
Optionally conditioned on material category via a learned embedding
that is broadcast spatially and concatenated to the input image.
Input: basecolor (B, 3, H, W) float32 [0, 1]
category (B,) int64 indices (optional)
Output: dict with:
"normal" (B, 3, H, W) float32 [0, 1]
"roughness" (B, 1, H, W) float32 [0, 1]
"metallic" (B, 1, H, W) float32 [0, 1]
"""
def __init__(
self,
encoder_name: str = "resnet34",
encoder_weights: str = "imagenet",
n_categories: int = len(CATEGORIES),
category_embed_dim: int = 8,
use_category: bool = False,
normal_xy_only: bool = False,
separate_normal_decoder: bool = False,
predict_height: bool = False,
):
super().__init__()
self.use_category = use_category
self.normal_xy_only = normal_xy_only
self.separate_normal_decoder = separate_normal_decoder
self.predict_height = predict_height
in_channels = 3
if use_category:
self.category_embed = nn.Embedding(n_categories, category_embed_dim)
in_channels = 3 + category_embed_dim
normal_ch = 2 if normal_xy_only else 3
_weights = encoder_weights if encoder_weights != "none" else None
if separate_normal_decoder:
normal_out_ch = 1 if predict_height else normal_ch
self.normal_unet = smp.Unet(
encoder_name=encoder_name,
encoder_weights=_weights,
in_channels=in_channels,
classes=normal_out_ch,
)
self.material_unet = smp.Unet(
encoder_name=encoder_name,
encoder_weights=_weights,
in_channels=in_channels,
classes=2,
)
else:
out_ch = (1 if predict_height else normal_ch) + 2
self.unet = smp.Unet(
encoder_name=encoder_name,
encoder_weights=_weights,
in_channels=in_channels,
classes=out_ch,
)
self.sigmoid = nn.Sigmoid()
def _prepare_input(self, basecolor, category):
"""Build network input: basecolor + optional category embedding."""
x = basecolor
if self.use_category:
if category is None:
category = torch.zeros(
basecolor.shape[0], dtype=torch.long, device=basecolor.device
)
emb = self.category_embed(category)
emb = emb[:, :, None, None].expand(-1, -1, x.shape[2], x.shape[3])
x = torch.cat([x, emb], dim=1)
return x
def _decode_normal(self, raw_normal):
"""Decode raw normal predictions to [0,1] range."""
if self.normal_xy_only:
xy = torch.tanh(raw_normal)
xy_sq = (xy ** 2).sum(dim=1, keepdim=True).clamp(max=1.0 - 1e-6)
z = torch.sqrt(1.0 - xy_sq)
normal = torch.cat([xy * 0.5 + 0.5, z * 0.5 + 0.5], dim=1)
else:
normal = self.sigmoid(raw_normal)
return normal
def forward(
self,
basecolor: torch.Tensor,
category: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
x = self._prepare_input(basecolor, category)
if self.separate_normal_decoder:
raw_normal = self.normal_unet(x)
raw_material = self.material_unet(x)
if self.predict_height:
height = self.sigmoid(raw_normal)
normal = height_to_normal(height, intensity=1.0)
else:
normal = self._decode_normal(raw_normal)
roughness, metallic = self.sigmoid(raw_material).split([1, 1], dim=1)
else:
out = self.unet(x)
if self.predict_height:
height_raw, rest = out.split([1, 2], dim=1)
height = self.sigmoid(height_raw)
normal = height_to_normal(height, intensity=1.0)
roughness, metallic = self.sigmoid(rest).split([1, 1], dim=1)
elif self.normal_xy_only:
normal_raw, rest = out.split([2, 2], dim=1)
normal = self._decode_normal(normal_raw)
roughness, metallic = self.sigmoid(rest).split([1, 1], dim=1)
else:
out = self.sigmoid(out)
normal, roughness, metallic = out.split([3, 1, 1], dim=1)
result = {"normal": normal, "roughness": roughness, "metallic": metallic}
if self.predict_height:
result["height"] = height
return result