Image-GS / utils /saliency /decoder.py
Julien Blanchon
Deploy optimized Image-GS with dynamic dependencies
d62394f
raw
history blame
1.86 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
def __init__(self, shape, num_img_feat, num_pla_feat):
super(Decoder, self).__init__()
self.shape = shape
self.img_model = self._make_layer(num_img_feat)
self.pla_model = self._make_layer(num_pla_feat)
self.combined = self._make_output(num_img_feat + num_pla_feat)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, num_feat):
ans = nn.ModuleList()
for _ in range(num_feat):
m = nn.Sequential(
nn.Conv2d(1, 1, 3, padding=1), nn.BatchNorm2d(1), nn.ReLU(inplace=True)
)
ans.append(m)
return ans
def _make_output(self, planes, readout=1):
return nn.Sequential(
nn.Conv2d(planes, readout, 3, stride=1, padding=1),
nn.BatchNorm2d(readout),
nn.Sigmoid(),
)
def forward(self, x):
img_feat, pla_feat = x
feat = []
for a, b in zip(img_feat, self.img_model):
f = F.interpolate(b(a), self.shape)
feat.append(f)
for a, b in zip(pla_feat, self.pla_model):
f = F.interpolate(b(a), self.shape)
feat.append(f)
feat = torch.cat(feat, dim=1)
feat = self.combined(feat)
return feat
def build_decoder(model_path, *args):
decoder = Decoder(*args)
loaded = torch.load(model_path, weights_only=True)["state_dict"]
decoder.load_state_dict(loaded)
return decoder