ShadeNet / model.py
singam96's picture
Initial
12510fb
Raw
History Blame Contribute Delete
9.37 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class ConvMemoryLayer(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.memory = nn.Parameter(torch.zeros(1, channels, 1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.memory
class ConvBlock(nn.Module):
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
num_groups = min(32, out_ch)
while out_ch % num_groups != 0:
num_groups -= 1
self.bn1 = nn.GroupNorm(num_groups, out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.bn2 = nn.GroupNorm(num_groups, out_ch)
self.memory = ConvMemoryLayer(out_ch)
mid = max(1, out_ch // 16)
self.channel_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(out_ch, mid),
nn.ReLU(),
nn.Linear(mid, out_ch),
nn.Sigmoid(),
)
self.spatial_att = nn.Sequential(
nn.Conv2d(out_ch, 1, kernel_size=3, padding="same"),
nn.Sigmoid()
)
nn.init.xavier_uniform_(self.spatial_att[0].weight, gain=0.01)
nn.init.constant_(self.spatial_att[0].bias, 0.0)
if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False)
else:
self.shortcut = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.shortcut(x)
x = F.elu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
b, c, _, _ = x.shape
ch_att = self.channel_att(x).view(b, c, 1, 1)
x = x * ch_att
sp_att = self.spatial_att(x)
x = x * sp_att
x = x + identity
x = F.elu(x)
x = self.memory(x)
return x
class UpBlock(nn.Module):
def __init__(self, in_ch: int, skip_ch: int, out_ch: int):
super().__init__()
self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
self.block1 = ConvBlock(out_ch + skip_ch, out_ch)
self.block2 = ConvBlock(out_ch, out_ch)
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
x = self.up(x)
x = torch.cat([x, skip], dim=1)
x = self.block1(x)
x = self.block2(x)
return x
class ParallelEncoder(nn.Module):
def __init__(self, in_ch: int = 3, width_mult: float = 1.0):
super().__init__()
c0 = max(8, int(16 * width_mult / 8) * 8)
c1 = max(8, int(32 * width_mult / 8) * 8)
c2 = max(8, int(64 * width_mult / 8) * 8)
c3 = max(8, int(128 * width_mult / 8) * 8)
c4 = max(8, int(256 * width_mult / 8) * 8)
self.s1 = nn.Sequential(
ConvBlock(in_ch, c0),
nn.Conv2d(c0, c0, kernel_size=3, stride=2, padding=1),
)
self.s2 = nn.Sequential(
ConvBlock(c0, c1),
nn.Conv2d(c1, c1, kernel_size=3, stride=2, padding=1),
)
self.s3 = nn.Sequential(
ConvBlock(c1, c2),
nn.Conv2d(c2, c2, kernel_size=3, stride=2, padding=1),
)
self.s4 = nn.Sequential(
ConvBlock(c2, c3),
nn.Conv2d(c3, c3, kernel_size=3, stride=2, padding=1),
)
self.s5 = nn.Sequential(
ConvBlock(c3, c4),
nn.Conv2d(c4, c4, kernel_size=3, stride=2, padding=1),
)
def forward(self, x):
c0 = self.s1(x)
c1 = self.s2(c0)
c2 = self.s3(c1)
c3 = self.s4(c2)
c4 = self.s5(c3)
return c0, c1, c2, c3, c4
class MobileNetUNet(nn.Module):
def __init__(self, freeze_backbone: bool = True, width_mult: float = 1.0):
super().__init__()
backbone = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
self.features = backbone.features
if freeze_backbone:
for p in self.features.parameters():
p.requires_grad = False
for idx, layer in enumerate(self.features):
if idx >= 10:
for p in layer.parameters():
p.requires_grad = True
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def _ch(base: int) -> int:
return max(8, int(base * width_mult / 8) * 8)
self.custom_enc = ParallelEncoder(width_mult=width_mult)
enc_c = [_ch(16), _ch(32), _ch(64), _ch(128), _ch(256)]
mobile_c = [16, 24, 32, 96, 1280]
inv_target = list(enc_c)
self.inv_proj = nn.Sequential(
nn.Conv2d(9, 3, kernel_size=1),
nn.GroupNorm(3, 3),
)
def _proj_norm(in_ch, out_ch):
g = min(32, out_ch)
while out_ch % g != 0:
g -= 1
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False),
nn.GroupNorm(g, out_ch),
)
self.compress0 = _proj_norm(mobile_c[0] + enc_c[0], inv_target[0])
self.compress1 = _proj_norm(mobile_c[1] + enc_c[1], inv_target[1])
self.compress2 = _proj_norm(mobile_c[2] + enc_c[2], inv_target[2])
self.compress3 = _proj_norm(mobile_c[3] + enc_c[3], inv_target[3])
self.compress4 = _proj_norm(mobile_c[4] + enc_c[4], inv_target[4])
self.up1 = UpBlock(in_ch=inv_target[4], skip_ch=inv_target[3], out_ch=enc_c[4])
self.up2 = UpBlock(in_ch=enc_c[4], skip_ch=inv_target[2], out_ch=enc_c[3])
self.up3 = UpBlock(in_ch=enc_c[3], skip_ch=inv_target[1], out_ch=enc_c[2])
self.up4 = UpBlock(in_ch=enc_c[2], skip_ch=inv_target[0], out_ch=enc_c[1])
self.up5 = nn.Sequential(
nn.ConvTranspose2d(enc_c[1], enc_c[0], kernel_size=2, stride=2),
ConvBlock(enc_c[0], enc_c[0]),
ConvBlock(enc_c[0], enc_c[0]),
)
self.shared_head = ConvBlock(enc_c[0] + 3, enc_c[1])
self.basecolor_proj = nn.Conv2d(enc_c[1], 3, kernel_size=1)
self.normal_proj = nn.Conv2d(enc_c[1], 3, kernel_size=1)
self.rmd_proj = nn.Conv2d(enc_c[1], 3, kernel_size=1)
self.rgb_proj = nn.Conv2d(enc_c[1], 3, kernel_size=1)
nn.init.zeros_(self.basecolor_proj.weight)
nn.init.zeros_(self.basecolor_proj.bias)
nn.init.zeros_(self.rgb_proj.weight)
nn.init.zeros_(self.rgb_proj.bias)
def train(self, mode: bool = True):
super().train(mode)
any_frozen = False
for p in self.features.parameters():
if not p.requires_grad:
any_frozen = True
break
if any_frozen:
self.features.eval()
return self
def forward(self, x: torch.Tensor, mode: int = 0) -> dict:
if mode == 0:
out = (x + 1.0) / 2.0
out = (out - self.mean) / self.std
mobile_out = out
mobile_feats = []
for idx, layer in enumerate(self.features):
mobile_out = layer(mobile_out)
if idx in (1, 3, 6, 13, 18):
mobile_feats.append(mobile_out)
custom_feats = self.custom_enc(out)
c0 = self.compress0(torch.cat([mobile_feats[0], custom_feats[0]], dim=1))
c1 = self.compress1(torch.cat([mobile_feats[1], custom_feats[1]], dim=1))
c2 = self.compress2(torch.cat([mobile_feats[2], custom_feats[2]], dim=1))
c3 = self.compress3(torch.cat([mobile_feats[3], custom_feats[3]], dim=1))
c4 = self.compress4(torch.cat([mobile_feats[4], custom_feats[4]], dim=1))
x_res = x
else:
x_proj = self.inv_proj(x)
c0, c1, c2, c3, c4 = self.custom_enc(x_proj)
x_res = x_proj
x = self.up1(c4, c3)
x = self.up2(x, c2)
x = self.up3(x, c1)
x = self.up4(x, c0)
x = self.up5(x)
x_cat = torch.cat([x, x_res], dim=1)
h = self.shared_head(x_cat)
basecolor = torch.clamp(self.basecolor_proj(h), -1, 1)
normal = torch.clamp(self.normal_proj(h), -1, 1)
rmd = torch.clamp(self.rmd_proj(h), -1, 1)
rgb = torch.clamp(self.rgb_proj(h), -1, 1)
return {'basecolor': basecolor, 'normal': normal, 'rmd': rmd, 'rgb': rgb}
def create_model(image_size=128, width_mult=1.0):
return MobileNetUNet(freeze_backbone=True, width_mult=width_mult)
def load_from_checkpoint(ckpt_path, device="cpu", width_mult=2.0):
model = MobileNetUNet(freeze_backbone=True, width_mult=width_mult)
state = torch.load(ckpt_path, map_location=device, weights_only=True)
if "state_dict" in state:
sd = state["state_dict"]
else:
sd = state
renamed = {}
for k, v in sd.items():
if k.startswith("model."):
renamed[k[6:]] = v
else:
renamed[k] = v
missing, unexpected = model.load_state_dict(renamed, strict=False)
if missing:
print(f"Missing keys: {missing}")
if unexpected:
print(f"Unexpected keys: {unexpected}")
model.eval()
return model