| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import models |
|
|
|
|
| class ConvMemoryLayer(nn.Module): |
| def __init__(self, channels: int): |
| super().__init__() |
| self.memory = nn.Parameter(torch.zeros(1, channels, 1, 1)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x + self.memory |
|
|
|
|
| class ConvBlock(nn.Module): |
| def __init__(self, in_ch: int, out_ch: int): |
| super().__init__() |
| self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) |
| num_groups = min(32, out_ch) |
| while out_ch % num_groups != 0: |
| num_groups -= 1 |
| self.bn1 = nn.GroupNorm(num_groups, out_ch) |
| self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
| self.bn2 = nn.GroupNorm(num_groups, out_ch) |
|
|
| self.memory = ConvMemoryLayer(out_ch) |
|
|
| mid = max(1, out_ch // 16) |
| self.channel_att = nn.Sequential( |
| nn.AdaptiveAvgPool2d(1), |
| nn.Flatten(), |
| nn.Linear(out_ch, mid), |
| nn.ReLU(), |
| nn.Linear(mid, out_ch), |
| nn.Sigmoid(), |
| ) |
|
|
| self.spatial_att = nn.Sequential( |
| nn.Conv2d(out_ch, 1, kernel_size=3, padding="same"), |
| nn.Sigmoid() |
| ) |
| nn.init.xavier_uniform_(self.spatial_att[0].weight, gain=0.01) |
| nn.init.constant_(self.spatial_att[0].bias, 0.0) |
|
|
| if in_ch != out_ch: |
| self.shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False) |
| else: |
| self.shortcut = nn.Identity() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| identity = self.shortcut(x) |
|
|
| x = F.elu(self.bn1(self.conv1(x))) |
| x = self.bn2(self.conv2(x)) |
|
|
| b, c, _, _ = x.shape |
| ch_att = self.channel_att(x).view(b, c, 1, 1) |
| x = x * ch_att |
|
|
| sp_att = self.spatial_att(x) |
| x = x * sp_att |
|
|
| x = x + identity |
| x = F.elu(x) |
| x = self.memory(x) |
| return x |
|
|
|
|
| class UpBlock(nn.Module): |
| def __init__(self, in_ch: int, skip_ch: int, out_ch: int): |
| super().__init__() |
| self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2) |
| self.block1 = ConvBlock(out_ch + skip_ch, out_ch) |
| self.block2 = ConvBlock(out_ch, out_ch) |
|
|
| def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: |
| x = self.up(x) |
| x = torch.cat([x, skip], dim=1) |
| x = self.block1(x) |
| x = self.block2(x) |
| return x |
|
|
|
|
| class ParallelEncoder(nn.Module): |
| def __init__(self, in_ch: int = 3, width_mult: float = 1.0): |
| super().__init__() |
| c0 = max(8, int(16 * width_mult / 8) * 8) |
| c1 = max(8, int(32 * width_mult / 8) * 8) |
| c2 = max(8, int(64 * width_mult / 8) * 8) |
| c3 = max(8, int(128 * width_mult / 8) * 8) |
| c4 = max(8, int(256 * width_mult / 8) * 8) |
| self.s1 = nn.Sequential( |
| ConvBlock(in_ch, c0), |
| nn.Conv2d(c0, c0, kernel_size=3, stride=2, padding=1), |
| ) |
| self.s2 = nn.Sequential( |
| ConvBlock(c0, c1), |
| nn.Conv2d(c1, c1, kernel_size=3, stride=2, padding=1), |
| ) |
| self.s3 = nn.Sequential( |
| ConvBlock(c1, c2), |
| nn.Conv2d(c2, c2, kernel_size=3, stride=2, padding=1), |
| ) |
| self.s4 = nn.Sequential( |
| ConvBlock(c2, c3), |
| nn.Conv2d(c3, c3, kernel_size=3, stride=2, padding=1), |
| ) |
| self.s5 = nn.Sequential( |
| ConvBlock(c3, c4), |
| nn.Conv2d(c4, c4, kernel_size=3, stride=2, padding=1), |
| ) |
|
|
| def forward(self, x): |
| c0 = self.s1(x) |
| c1 = self.s2(c0) |
| c2 = self.s3(c1) |
| c3 = self.s4(c2) |
| c4 = self.s5(c3) |
| return c0, c1, c2, c3, c4 |
|
|
|
|
| class MobileNetUNet(nn.Module): |
| def __init__(self, freeze_backbone: bool = True, width_mult: float = 1.0): |
| super().__init__() |
| backbone = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1) |
| self.features = backbone.features |
| if freeze_backbone: |
| for p in self.features.parameters(): |
| p.requires_grad = False |
| for idx, layer in enumerate(self.features): |
| if idx >= 10: |
| for p in layer.parameters(): |
| p.requires_grad = True |
|
|
| self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) |
| self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) |
|
|
| def _ch(base: int) -> int: |
| return max(8, int(base * width_mult / 8) * 8) |
|
|
| self.custom_enc = ParallelEncoder(width_mult=width_mult) |
|
|
| enc_c = [_ch(16), _ch(32), _ch(64), _ch(128), _ch(256)] |
| mobile_c = [16, 24, 32, 96, 1280] |
|
|
| inv_target = list(enc_c) |
|
|
| self.inv_proj = nn.Sequential( |
| nn.Conv2d(9, 3, kernel_size=1), |
| nn.GroupNorm(3, 3), |
| ) |
|
|
| def _proj_norm(in_ch, out_ch): |
| g = min(32, out_ch) |
| while out_ch % g != 0: |
| g -= 1 |
| return nn.Sequential( |
| nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False), |
| nn.GroupNorm(g, out_ch), |
| ) |
|
|
| self.compress0 = _proj_norm(mobile_c[0] + enc_c[0], inv_target[0]) |
| self.compress1 = _proj_norm(mobile_c[1] + enc_c[1], inv_target[1]) |
| self.compress2 = _proj_norm(mobile_c[2] + enc_c[2], inv_target[2]) |
| self.compress3 = _proj_norm(mobile_c[3] + enc_c[3], inv_target[3]) |
| self.compress4 = _proj_norm(mobile_c[4] + enc_c[4], inv_target[4]) |
|
|
| self.up1 = UpBlock(in_ch=inv_target[4], skip_ch=inv_target[3], out_ch=enc_c[4]) |
| self.up2 = UpBlock(in_ch=enc_c[4], skip_ch=inv_target[2], out_ch=enc_c[3]) |
| self.up3 = UpBlock(in_ch=enc_c[3], skip_ch=inv_target[1], out_ch=enc_c[2]) |
| self.up4 = UpBlock(in_ch=enc_c[2], skip_ch=inv_target[0], out_ch=enc_c[1]) |
| self.up5 = nn.Sequential( |
| nn.ConvTranspose2d(enc_c[1], enc_c[0], kernel_size=2, stride=2), |
| ConvBlock(enc_c[0], enc_c[0]), |
| ConvBlock(enc_c[0], enc_c[0]), |
| ) |
|
|
| self.shared_head = ConvBlock(enc_c[0] + 3, enc_c[1]) |
| self.basecolor_proj = nn.Conv2d(enc_c[1], 3, kernel_size=1) |
| self.normal_proj = nn.Conv2d(enc_c[1], 3, kernel_size=1) |
| self.rmd_proj = nn.Conv2d(enc_c[1], 3, kernel_size=1) |
| self.rgb_proj = nn.Conv2d(enc_c[1], 3, kernel_size=1) |
| nn.init.zeros_(self.basecolor_proj.weight) |
| nn.init.zeros_(self.basecolor_proj.bias) |
| nn.init.zeros_(self.rgb_proj.weight) |
| nn.init.zeros_(self.rgb_proj.bias) |
|
|
| def train(self, mode: bool = True): |
| super().train(mode) |
| any_frozen = False |
| for p in self.features.parameters(): |
| if not p.requires_grad: |
| any_frozen = True |
| break |
| if any_frozen: |
| self.features.eval() |
| return self |
|
|
| def forward(self, x: torch.Tensor, mode: int = 0) -> dict: |
| if mode == 0: |
| out = (x + 1.0) / 2.0 |
| out = (out - self.mean) / self.std |
|
|
| mobile_out = out |
| mobile_feats = [] |
| for idx, layer in enumerate(self.features): |
| mobile_out = layer(mobile_out) |
| if idx in (1, 3, 6, 13, 18): |
| mobile_feats.append(mobile_out) |
|
|
| custom_feats = self.custom_enc(out) |
| c0 = self.compress0(torch.cat([mobile_feats[0], custom_feats[0]], dim=1)) |
| c1 = self.compress1(torch.cat([mobile_feats[1], custom_feats[1]], dim=1)) |
| c2 = self.compress2(torch.cat([mobile_feats[2], custom_feats[2]], dim=1)) |
| c3 = self.compress3(torch.cat([mobile_feats[3], custom_feats[3]], dim=1)) |
| c4 = self.compress4(torch.cat([mobile_feats[4], custom_feats[4]], dim=1)) |
| x_res = x |
| else: |
| x_proj = self.inv_proj(x) |
| c0, c1, c2, c3, c4 = self.custom_enc(x_proj) |
| x_res = x_proj |
|
|
| x = self.up1(c4, c3) |
| x = self.up2(x, c2) |
| x = self.up3(x, c1) |
| x = self.up4(x, c0) |
| x = self.up5(x) |
|
|
| x_cat = torch.cat([x, x_res], dim=1) |
|
|
| h = self.shared_head(x_cat) |
| basecolor = torch.clamp(self.basecolor_proj(h), -1, 1) |
| normal = torch.clamp(self.normal_proj(h), -1, 1) |
| rmd = torch.clamp(self.rmd_proj(h), -1, 1) |
| rgb = torch.clamp(self.rgb_proj(h), -1, 1) |
|
|
| return {'basecolor': basecolor, 'normal': normal, 'rmd': rmd, 'rgb': rgb} |
|
|
|
|
| def create_model(image_size=128, width_mult=1.0): |
| return MobileNetUNet(freeze_backbone=True, width_mult=width_mult) |
|
|
|
|
| def load_from_checkpoint(ckpt_path, device="cpu", width_mult=2.0): |
| model = MobileNetUNet(freeze_backbone=True, width_mult=width_mult) |
| state = torch.load(ckpt_path, map_location=device, weights_only=True) |
| if "state_dict" in state: |
| sd = state["state_dict"] |
| else: |
| sd = state |
|
|
| renamed = {} |
| for k, v in sd.items(): |
| if k.startswith("model."): |
| renamed[k[6:]] = v |
| else: |
| renamed[k] = v |
|
|
| missing, unexpected = model.load_state_dict(renamed, strict=False) |
| if missing: |
| print(f"Missing keys: {missing}") |
| if unexpected: |
| print(f"Unexpected keys: {unexpected}") |
| model.eval() |
| return model |
|
|