|
|
from escnn import gspaces, nn |
|
|
import torch |
|
|
|
|
|
class EquiResBlock(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
group: gspaces.GSpace2D, |
|
|
input_channels: int, |
|
|
hidden_dim: int, |
|
|
kernel_size: int = 3, |
|
|
stride: int = 1, |
|
|
initialize: bool = True, |
|
|
): |
|
|
super(EquiResBlock, self).__init__() |
|
|
self.group = group |
|
|
rep = self.group.regular_repr |
|
|
|
|
|
feat_type_in = nn.FieldType(self.group, input_channels * [rep]) |
|
|
feat_type_hid = nn.FieldType(self.group, hidden_dim * [rep]) |
|
|
|
|
|
self.layer1 = nn.SequentialModule( |
|
|
nn.R2Conv( |
|
|
feat_type_in, |
|
|
feat_type_hid, |
|
|
kernel_size=kernel_size, |
|
|
padding=(kernel_size - 1) // 2, |
|
|
stride=stride, |
|
|
initialize=initialize, |
|
|
), |
|
|
nn.ReLU(feat_type_hid, inplace=True), |
|
|
) |
|
|
|
|
|
self.layer2 = nn.SequentialModule( |
|
|
nn.R2Conv( |
|
|
feat_type_hid, |
|
|
feat_type_hid, |
|
|
kernel_size=kernel_size, |
|
|
padding=(kernel_size - 1) // 2, |
|
|
initialize=initialize, |
|
|
), |
|
|
) |
|
|
self.relu = nn.ReLU(feat_type_hid, inplace=True) |
|
|
|
|
|
self.upscale = None |
|
|
if input_channels != hidden_dim or stride != 1: |
|
|
self.upscale = nn.SequentialModule( |
|
|
nn.R2Conv(feat_type_in, feat_type_hid, kernel_size=1, stride=stride, bias=False, initialize=initialize), |
|
|
) |
|
|
|
|
|
def forward(self, xx: nn.GeometricTensor) -> nn.GeometricTensor: |
|
|
residual = xx |
|
|
out = self.layer1(xx) |
|
|
out = self.layer2(out) |
|
|
if self.upscale: |
|
|
out += self.upscale(residual) |
|
|
else: |
|
|
out += residual |
|
|
out = self.relu(out) |
|
|
|
|
|
return out |
|
|
|
|
|
class EquivariantResEncoder76Cyclic(torch.nn.Module): |
|
|
def __init__(self, obs_channel: int = 2, n_out: int = 128, initialize: bool = True, N=8): |
|
|
super().__init__() |
|
|
self.obs_channel = obs_channel |
|
|
self.group = gspaces.rot2dOnR2(N) |
|
|
self.conv = torch.nn.Sequential( |
|
|
|
|
|
nn.R2Conv( |
|
|
nn.FieldType(self.group, obs_channel * [self.group.trivial_repr]), |
|
|
nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), |
|
|
kernel_size=5, |
|
|
padding=0, |
|
|
initialize=initialize, |
|
|
), |
|
|
|
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), inplace=True), |
|
|
EquiResBlock(self.group, n_out // 8, n_out // 8, initialize=True), |
|
|
EquiResBlock(self.group, n_out // 8, n_out // 8, initialize=True), |
|
|
nn.PointwiseMaxPool(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), 2), |
|
|
|
|
|
EquiResBlock(self.group, n_out // 8, n_out // 4, initialize=True), |
|
|
EquiResBlock(self.group, n_out // 4, n_out // 4, initialize=True), |
|
|
nn.PointwiseMaxPool(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), 2), |
|
|
|
|
|
EquiResBlock(self.group, n_out // 4, n_out // 2, initialize=True), |
|
|
EquiResBlock(self.group, n_out // 2, n_out // 2, initialize=True), |
|
|
nn.PointwiseMaxPool(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), 2), |
|
|
|
|
|
EquiResBlock(self.group, n_out // 2, n_out, initialize=True), |
|
|
EquiResBlock(self.group, n_out, n_out, initialize=True), |
|
|
nn.PointwiseMaxPool(nn.FieldType(self.group, n_out * [self.group.regular_repr]), 3), |
|
|
|
|
|
nn.R2Conv( |
|
|
nn.FieldType(self.group, n_out * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out * [self.group.regular_repr]), |
|
|
kernel_size=3, |
|
|
padding=0, |
|
|
initialize=initialize, |
|
|
), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out * [self.group.regular_repr]), inplace=True), |
|
|
|
|
|
) |
|
|
|
|
|
def forward(self, x) -> nn.GeometricTensor: |
|
|
if type(x) is torch.Tensor: |
|
|
x = nn.GeometricTensor(x, nn.FieldType(self.group, self.obs_channel * [self.group.trivial_repr])) |
|
|
return self.conv(x) |
|
|
|
|
|
class EquivariantVoxelEncoder58Cyclic(torch.nn.Module): |
|
|
def __init__(self, obs_channel: int = 4, n_out: int = 128, initialize: bool = True, N=8): |
|
|
super().__init__() |
|
|
self.obs_channel = obs_channel |
|
|
self.group = gspaces.rot2dOnR3(N) |
|
|
self.conv = torch.nn.Sequential( |
|
|
|
|
|
nn.R3Conv( |
|
|
nn.FieldType(self.group, obs_channel * [self.group.trivial_repr]), |
|
|
nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), |
|
|
kernel_size=3, |
|
|
padding=0, |
|
|
initialize=initialize, |
|
|
), |
|
|
|
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), inplace=True), |
|
|
nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), 2), |
|
|
|
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=1, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), inplace=True), |
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=1, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), inplace=True), |
|
|
nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), 2), |
|
|
|
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=0, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), inplace=True), |
|
|
|
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=1, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), inplace=True), |
|
|
nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), 2), |
|
|
|
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=1, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), inplace=True), |
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=1, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), inplace=True), |
|
|
nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), 2), |
|
|
|
|
|
nn.R3Conv( |
|
|
nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out * [self.group.regular_repr]), |
|
|
kernel_size=3, |
|
|
padding=0, |
|
|
initialize=initialize, |
|
|
), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out * [self.group.regular_repr]), inplace=True), |
|
|
|
|
|
) |
|
|
def forward(self, x) -> nn.GeometricTensor: |
|
|
if type(x) is torch.Tensor: |
|
|
x = nn.GeometricTensor(x, nn.FieldType(self.group, self.obs_channel * [self.group.trivial_repr])) |
|
|
return self.conv(x) |
|
|
|
|
|
class EquivariantVoxelEncoder64Cyclic(torch.nn.Module): |
|
|
def __init__(self, obs_channel: int = 4, n_out: int = 128, initialize: bool = True, N=8): |
|
|
super().__init__() |
|
|
self.obs_channel = obs_channel |
|
|
self.group = gspaces.rot2dOnR3(N) |
|
|
self.conv = torch.nn.Sequential( |
|
|
|
|
|
nn.R3Conv( |
|
|
nn.FieldType(self.group, obs_channel * [self.group.trivial_repr]), |
|
|
nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), |
|
|
kernel_size=3, |
|
|
padding=1, |
|
|
initialize=initialize, |
|
|
), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), inplace=True), |
|
|
nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), 2), |
|
|
|
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=1, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), inplace=True), |
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=1, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), inplace=True), |
|
|
nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), 2), |
|
|
|
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=1, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), inplace=True), |
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=1, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), inplace=True), |
|
|
nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), 2), |
|
|
|
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=0, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), inplace=True), |
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=1, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), inplace=True), |
|
|
|
|
|
nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), 2), |
|
|
|
|
|
nn.R3Conv( |
|
|
nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out * [self.group.regular_repr]), |
|
|
kernel_size=3, |
|
|
padding=0, |
|
|
initialize=initialize, |
|
|
), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out * [self.group.regular_repr]), inplace=True), |
|
|
|
|
|
) |
|
|
def forward(self, x) -> nn.GeometricTensor: |
|
|
if type(x) is torch.Tensor: |
|
|
x = nn.GeometricTensor(x, nn.FieldType(self.group, self.obs_channel * [self.group.trivial_repr])) |
|
|
return self.conv(x) |
|
|
|
|
|
class EquivariantVoxelEncoder16Cyclic(torch.nn.Module): |
|
|
def __init__(self, obs_channel: int = 4, n_out: int = 128, initialize: bool = True, N=8): |
|
|
super().__init__() |
|
|
self.obs_channel = obs_channel |
|
|
self.group = gspaces.rot2dOnR3(N) |
|
|
self.conv = torch.nn.Sequential( |
|
|
|
|
|
nn.R3Conv( |
|
|
nn.FieldType(self.group, obs_channel * [self.group.trivial_repr]), |
|
|
nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), |
|
|
kernel_size=3, |
|
|
padding=1, |
|
|
initialize=initialize, |
|
|
), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), inplace=True), |
|
|
nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), 2), |
|
|
|
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=0, initialize=initialize), |
|
|
|
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), inplace=True), |
|
|
nn.R3Conv(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), |
|
|
kernel_size=3, padding=1, initialize=initialize), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), inplace=True), |
|
|
nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), 2), |
|
|
|
|
|
nn.R3Conv( |
|
|
nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), |
|
|
nn.FieldType(self.group, n_out * [self.group.regular_repr]), |
|
|
kernel_size=3, |
|
|
padding=0, |
|
|
initialize=initialize, |
|
|
), |
|
|
nn.ReLU(nn.FieldType(self.group, n_out * [self.group.regular_repr]), inplace=True), |
|
|
|
|
|
) |
|
|
def forward(self, x) -> nn.GeometricTensor: |
|
|
if type(x) is torch.Tensor: |
|
|
x = nn.GeometricTensor(x, nn.FieldType(self.group, self.obs_channel * [self.group.trivial_repr])) |
|
|
return self.conv(x) |