import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models class ConvMemoryLayer(nn.Module): def __init__(self, channels: int): super().__init__() self.memory = nn.Parameter(torch.zeros(1, channels, 1, 1)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.memory class ConvBlock(nn.Module): def __init__(self, in_ch: int, out_ch: int): super().__init__() self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) num_groups = min(32, out_ch) while out_ch % num_groups != 0: num_groups -= 1 self.bn1 = nn.GroupNorm(num_groups, out_ch) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.bn2 = nn.GroupNorm(num_groups, out_ch) self.memory = ConvMemoryLayer(out_ch) mid = max(1, out_ch // 16) self.channel_att = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(out_ch, mid), nn.ReLU(), nn.Linear(mid, out_ch), nn.Sigmoid(), ) self.spatial_att = nn.Sequential( nn.Conv2d(out_ch, 1, kernel_size=3, padding="same"), nn.Sigmoid() ) nn.init.xavier_uniform_(self.spatial_att[0].weight, gain=0.01) nn.init.constant_(self.spatial_att[0].bias, 0.0) if in_ch != out_ch: self.shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False) else: self.shortcut = nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: identity = self.shortcut(x) x = F.elu(self.bn1(self.conv1(x))) x = self.bn2(self.conv2(x)) b, c, _, _ = x.shape ch_att = self.channel_att(x).view(b, c, 1, 1) x = x * ch_att sp_att = self.spatial_att(x) x = x * sp_att x = x + identity x = F.elu(x) x = self.memory(x) return x class UpBlock(nn.Module): def __init__(self, in_ch: int, skip_ch: int, out_ch: int): super().__init__() self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2) self.block1 = ConvBlock(out_ch + skip_ch, out_ch) self.block2 = ConvBlock(out_ch, out_ch) def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: x = self.up(x) x = torch.cat([x, skip], dim=1) x = self.block1(x) x = self.block2(x) return x class ParallelEncoder(nn.Module): def __init__(self, in_ch: int = 3, width_mult: float = 1.0): super().__init__() c0 = max(8, int(16 * width_mult / 8) * 8) c1 = max(8, int(32 * width_mult / 8) * 8) c2 = max(8, int(64 * width_mult / 8) * 8) c3 = max(8, int(128 * width_mult / 8) * 8) c4 = max(8, int(256 * width_mult / 8) * 8) self.s1 = nn.Sequential( ConvBlock(in_ch, c0), nn.Conv2d(c0, c0, kernel_size=3, stride=2, padding=1), ) self.s2 = nn.Sequential( ConvBlock(c0, c1), nn.Conv2d(c1, c1, kernel_size=3, stride=2, padding=1), ) self.s3 = nn.Sequential( ConvBlock(c1, c2), nn.Conv2d(c2, c2, kernel_size=3, stride=2, padding=1), ) self.s4 = nn.Sequential( ConvBlock(c2, c3), nn.Conv2d(c3, c3, kernel_size=3, stride=2, padding=1), ) self.s5 = nn.Sequential( ConvBlock(c3, c4), nn.Conv2d(c4, c4, kernel_size=3, stride=2, padding=1), ) def forward(self, x): c0 = self.s1(x) c1 = self.s2(c0) c2 = self.s3(c1) c3 = self.s4(c2) c4 = self.s5(c3) return c0, c1, c2, c3, c4 class MobileNetUNet(nn.Module): def __init__(self, freeze_backbone: bool = True, width_mult: float = 1.0): super().__init__() backbone = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1) self.features = backbone.features if freeze_backbone: for p in self.features.parameters(): p.requires_grad = False for idx, layer in enumerate(self.features): if idx >= 10: for p in layer.parameters(): p.requires_grad = True self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) def _ch(base: int) -> int: return max(8, int(base * width_mult / 8) * 8) self.custom_enc = ParallelEncoder(width_mult=width_mult) enc_c = [_ch(16), _ch(32), _ch(64), _ch(128), _ch(256)] mobile_c = [16, 24, 32, 96, 1280] inv_target = list(enc_c) self.inv_proj = nn.Sequential( nn.Conv2d(9, 3, kernel_size=1), nn.GroupNorm(3, 3), ) def _proj_norm(in_ch, out_ch): g = min(32, out_ch) while out_ch % g != 0: g -= 1 return nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False), nn.GroupNorm(g, out_ch), ) self.compress0 = _proj_norm(mobile_c[0] + enc_c[0], inv_target[0]) self.compress1 = _proj_norm(mobile_c[1] + enc_c[1], inv_target[1]) self.compress2 = _proj_norm(mobile_c[2] + enc_c[2], inv_target[2]) self.compress3 = _proj_norm(mobile_c[3] + enc_c[3], inv_target[3]) self.compress4 = _proj_norm(mobile_c[4] + enc_c[4], inv_target[4]) self.up1 = UpBlock(in_ch=inv_target[4], skip_ch=inv_target[3], out_ch=enc_c[4]) self.up2 = UpBlock(in_ch=enc_c[4], skip_ch=inv_target[2], out_ch=enc_c[3]) self.up3 = UpBlock(in_ch=enc_c[3], skip_ch=inv_target[1], out_ch=enc_c[2]) self.up4 = UpBlock(in_ch=enc_c[2], skip_ch=inv_target[0], out_ch=enc_c[1]) self.up5 = nn.Sequential( nn.ConvTranspose2d(enc_c[1], enc_c[0], kernel_size=2, stride=2), ConvBlock(enc_c[0], enc_c[0]), ConvBlock(enc_c[0], enc_c[0]), ) self.shared_head = ConvBlock(enc_c[0] + 3, enc_c[1]) self.basecolor_proj = nn.Conv2d(enc_c[1], 3, kernel_size=1) self.normal_proj = nn.Conv2d(enc_c[1], 3, kernel_size=1) self.rmd_proj = nn.Conv2d(enc_c[1], 3, kernel_size=1) self.rgb_proj = nn.Conv2d(enc_c[1], 3, kernel_size=1) nn.init.zeros_(self.basecolor_proj.weight) nn.init.zeros_(self.basecolor_proj.bias) nn.init.zeros_(self.rgb_proj.weight) nn.init.zeros_(self.rgb_proj.bias) def train(self, mode: bool = True): super().train(mode) any_frozen = False for p in self.features.parameters(): if not p.requires_grad: any_frozen = True break if any_frozen: self.features.eval() return self def forward(self, x: torch.Tensor, mode: int = 0) -> dict: if mode == 0: out = (x + 1.0) / 2.0 out = (out - self.mean) / self.std mobile_out = out mobile_feats = [] for idx, layer in enumerate(self.features): mobile_out = layer(mobile_out) if idx in (1, 3, 6, 13, 18): mobile_feats.append(mobile_out) custom_feats = self.custom_enc(out) c0 = self.compress0(torch.cat([mobile_feats[0], custom_feats[0]], dim=1)) c1 = self.compress1(torch.cat([mobile_feats[1], custom_feats[1]], dim=1)) c2 = self.compress2(torch.cat([mobile_feats[2], custom_feats[2]], dim=1)) c3 = self.compress3(torch.cat([mobile_feats[3], custom_feats[3]], dim=1)) c4 = self.compress4(torch.cat([mobile_feats[4], custom_feats[4]], dim=1)) x_res = x else: x_proj = self.inv_proj(x) c0, c1, c2, c3, c4 = self.custom_enc(x_proj) x_res = x_proj x = self.up1(c4, c3) x = self.up2(x, c2) x = self.up3(x, c1) x = self.up4(x, c0) x = self.up5(x) x_cat = torch.cat([x, x_res], dim=1) h = self.shared_head(x_cat) basecolor = torch.clamp(self.basecolor_proj(h), -1, 1) normal = torch.clamp(self.normal_proj(h), -1, 1) rmd = torch.clamp(self.rmd_proj(h), -1, 1) rgb = torch.clamp(self.rgb_proj(h), -1, 1) return {'basecolor': basecolor, 'normal': normal, 'rmd': rmd, 'rgb': rgb} def create_model(image_size=128, width_mult=1.0): return MobileNetUNet(freeze_backbone=True, width_mult=width_mult) def load_from_checkpoint(ckpt_path, device="cpu", width_mult=2.0): model = MobileNetUNet(freeze_backbone=True, width_mult=width_mult) state = torch.load(ckpt_path, map_location=device, weights_only=True) if "state_dict" in state: sd = state["state_dict"] else: sd = state renamed = {} for k, v in sd.items(): if k.startswith("model."): renamed[k[6:]] = v else: renamed[k] = v missing, unexpected = model.load_state_dict(renamed, strict=False) if missing: print(f"Missing keys: {missing}") if unexpected: print(f"Unexpected keys: {unexpected}") model.eval() return model