PEAR / pytorch3d /implicitron /third_party /hyperlayers.py
BestWJH's picture
Upload 455 files
94dc344 verified
# a copy-paste from https://github.com/vsitzmann/scene-representation-networks/blob/master/hyperlayers.py
# fmt: off
# flake8: noqa
# pyre-unsafe
'''Pytorch implementations of hyper-network modules.
'''
import functools
import torch
import torch.nn as nn
from . import pytorch_prototyping
def partialclass(cls, *args, **kwds):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwds)
return NewCls
class LookupLayer(nn.Module):
def __init__(self, in_ch, out_ch, num_objects):
super().__init__()
self.out_ch = out_ch
self.lookup_lin = LookupLinear(in_ch, out_ch, num_objects=num_objects)
self.norm_nl = nn.Sequential(
nn.LayerNorm([self.out_ch], elementwise_affine=False), nn.ReLU(inplace=True)
)
def forward(self, obj_idx):
net = nn.Sequential(self.lookup_lin(obj_idx), self.norm_nl)
return net
class LookupFC(nn.Module):
def __init__(
self,
hidden_ch,
num_hidden_layers,
num_objects,
in_ch,
out_ch,
outermost_linear=False,
):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(
LookupLayer(in_ch=in_ch, out_ch=hidden_ch, num_objects=num_objects)
)
for i in range(num_hidden_layers):
self.layers.append(
LookupLayer(in_ch=hidden_ch, out_ch=hidden_ch, num_objects=num_objects)
)
if outermost_linear:
self.layers.append(
LookupLinear(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)
)
else:
self.layers.append(
LookupLayer(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)
)
def forward(self, obj_idx):
net = []
for i in range(len(self.layers)):
net.append(self.layers[i](obj_idx))
return nn.Sequential(*net)
class LookupLinear(nn.Module):
def __init__(self, in_ch, out_ch, num_objects):
super().__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.hypo_params = nn.Embedding(num_objects, in_ch * out_ch + out_ch)
for i in range(num_objects):
nn.init.kaiming_normal_(
self.hypo_params.weight.data[i, : self.in_ch * self.out_ch].view(
self.out_ch, self.in_ch
),
a=0.0,
nonlinearity="relu",
mode="fan_in",
)
self.hypo_params.weight.data[i, self.in_ch * self.out_ch :].fill_(0.0)
def forward(self, obj_idx):
hypo_params = self.hypo_params(obj_idx)
# Indices explicit to catch erros in shape of output layer
weights = hypo_params[..., : self.in_ch * self.out_ch]
biases = hypo_params[
..., self.in_ch * self.out_ch : (self.in_ch * self.out_ch) + self.out_ch
]
biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch)
weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch)
return BatchLinear(weights=weights, biases=biases)
class HyperLayer(nn.Module):
"""A hypernetwork that predicts a single Dense Layer, including LayerNorm and a ReLU."""
def __init__(
self, in_ch, out_ch, hyper_in_ch, hyper_num_hidden_layers, hyper_hidden_ch
):
super().__init__()
self.hyper_linear = HyperLinear(
in_ch=in_ch,
out_ch=out_ch,
hyper_in_ch=hyper_in_ch,
hyper_num_hidden_layers=hyper_num_hidden_layers,
hyper_hidden_ch=hyper_hidden_ch,
)
self.norm_nl = nn.Sequential(
nn.LayerNorm([out_ch], elementwise_affine=False), nn.ReLU(inplace=True)
)
def forward(self, hyper_input):
"""
:param hyper_input: input to hypernetwork.
:return: nn.Module; predicted fully connected network.
"""
return nn.Sequential(self.hyper_linear(hyper_input), self.norm_nl)
class HyperFC(nn.Module):
"""Builds a hypernetwork that predicts a fully connected neural network."""
def __init__(
self,
hyper_in_ch,
hyper_num_hidden_layers,
hyper_hidden_ch,
hidden_ch,
num_hidden_layers,
in_ch,
out_ch,
outermost_linear=False,
):
super().__init__()
PreconfHyperLinear = partialclass(
HyperLinear,
hyper_in_ch=hyper_in_ch,
hyper_num_hidden_layers=hyper_num_hidden_layers,
hyper_hidden_ch=hyper_hidden_ch,
)
PreconfHyperLayer = partialclass(
HyperLayer,
hyper_in_ch=hyper_in_ch,
hyper_num_hidden_layers=hyper_num_hidden_layers,
hyper_hidden_ch=hyper_hidden_ch,
)
self.layers = nn.ModuleList()
self.layers.append(PreconfHyperLayer(in_ch=in_ch, out_ch=hidden_ch))
for i in range(num_hidden_layers):
self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=hidden_ch))
if outermost_linear:
self.layers.append(PreconfHyperLinear(in_ch=hidden_ch, out_ch=out_ch))
else:
self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=out_ch))
def forward(self, hyper_input):
"""
:param hyper_input: Input to hypernetwork.
:return: nn.Module; Predicted fully connected neural network.
"""
net = []
for i in range(len(self.layers)):
net.append(self.layers[i](hyper_input))
return nn.Sequential(*net)
class BatchLinear(nn.Module):
def __init__(self, weights, biases):
"""Implements a batch linear layer.
:param weights: Shape: (batch, out_ch, in_ch)
:param biases: Shape: (batch, 1, out_ch)
"""
super().__init__()
self.weights = weights
self.biases = biases
def __repr__(self):
return "BatchLinear(in_ch=%d, out_ch=%d)" % (
self.weights.shape[-1],
self.weights.shape[-2],
)
def forward(self, input):
output = input.matmul(
self.weights.permute(
*[i for i in range(len(self.weights.shape) - 2)], -1, -2
)
)
output += self.biases
return output
def last_hyper_layer_init(m) -> None:
if type(m) == nn.Linear:
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity="relu", mode="fan_in")
m.weight.data *= 1e-1
class HyperLinear(nn.Module):
"""A hypernetwork that predicts a single linear layer (weights & biases)."""
def __init__(
self, in_ch, out_ch, hyper_in_ch, hyper_num_hidden_layers, hyper_hidden_ch
):
super().__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.hypo_params = pytorch_prototyping.FCBlock(
in_features=hyper_in_ch,
hidden_ch=hyper_hidden_ch,
num_hidden_layers=hyper_num_hidden_layers,
out_features=(in_ch * out_ch) + out_ch,
outermost_linear=True,
)
self.hypo_params[-1].apply(last_hyper_layer_init)
def forward(self, hyper_input):
hypo_params = self.hypo_params(hyper_input)
# Indices explicit to catch erros in shape of output layer
weights = hypo_params[..., : self.in_ch * self.out_ch]
biases = hypo_params[
..., self.in_ch * self.out_ch : (self.in_ch * self.out_ch) + self.out_ch
]
biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch)
weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch)
return BatchLinear(weights=weights, biases=biases)