from escnn import gspaces, nn import torch class EquiResBlock(torch.nn.Module): def __init__( self, group: gspaces.GSpace2D, input_channels: int, hidden_dim: int, kernel_size: int = 3, stride: int = 1, initialize: bool = True, ): super(EquiResBlock, self).__init__() self.group = group rep = self.group.regular_repr feat_type_in = nn.FieldType(self.group, input_channels * [rep]) feat_type_hid = nn.FieldType(self.group, hidden_dim * [rep]) self.layer1 = nn.SequentialModule( nn.R2Conv( feat_type_in, feat_type_hid, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, stride=stride, initialize=initialize, ), nn.ReLU(feat_type_hid, inplace=True), ) self.layer2 = nn.SequentialModule( nn.R2Conv( feat_type_hid, feat_type_hid, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, initialize=initialize, ), ) self.relu = nn.ReLU(feat_type_hid, inplace=True) self.upscale = None if input_channels != hidden_dim or stride != 1: self.upscale = nn.SequentialModule( nn.R2Conv(feat_type_in, feat_type_hid, kernel_size=1, stride=stride, bias=False, initialize=initialize), ) def forward(self, xx: nn.GeometricTensor) -> nn.GeometricTensor: residual = xx out = self.layer1(xx) out = self.layer2(out) if self.upscale: out += self.upscale(residual) else: out += residual out = self.relu(out) return out class EquivariantResEncoder76Cyclic(torch.nn.Module): def __init__(self, obs_channel: int = 2, n_out: int = 128, initialize: bool = True, N=8): super().__init__() self.obs_channel = obs_channel self.group = gspaces.rot2dOnR2(N) self.conv = torch.nn.Sequential( # 76x76 nn.R2Conv( nn.FieldType(self.group, obs_channel * [self.group.trivial_repr]), nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), kernel_size=5, padding=0, initialize=initialize, ), # 72x72 nn.ReLU(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), inplace=True), EquiResBlock(self.group, n_out // 8, n_out // 8, initialize=True), EquiResBlock(self.group, n_out // 8, n_out // 8, initialize=True), nn.PointwiseMaxPool(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), 2), # 36x36 EquiResBlock(self.group, n_out // 8, n_out // 4, initialize=True), EquiResBlock(self.group, n_out // 4, n_out // 4, initialize=True), nn.PointwiseMaxPool(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), 2), # 18x18 EquiResBlock(self.group, n_out // 4, n_out // 2, initialize=True), EquiResBlock(self.group, n_out // 2, n_out // 2, initialize=True), nn.PointwiseMaxPool(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), 2), # 9x9 EquiResBlock(self.group, n_out // 2, n_out, initialize=True), EquiResBlock(self.group, n_out, n_out, initialize=True), nn.PointwiseMaxPool(nn.FieldType(self.group, n_out * [self.group.regular_repr]), 3), # 3x3 nn.R2Conv( nn.FieldType(self.group, n_out * [self.group.regular_repr]), nn.FieldType(self.group, n_out * [self.group.regular_repr]), kernel_size=3, padding=0, initialize=initialize, ), nn.ReLU(nn.FieldType(self.group, n_out * [self.group.regular_repr]), inplace=True), # 1x1 ) def forward(self, x) -> nn.GeometricTensor: if type(x) is torch.Tensor: x = nn.GeometricTensor(x, nn.FieldType(self.group, self.obs_channel * [self.group.trivial_repr])) return self.conv(x) class EquivariantVoxelEncoder58Cyclic(torch.nn.Module): def __init__(self, obs_channel: int = 4, n_out: int = 128, initialize: bool = True, N=8): super().__init__() self.obs_channel = obs_channel self.group = gspaces.rot2dOnR3(N) self.conv = torch.nn.Sequential( # 58 nn.R3Conv( nn.FieldType(self.group, obs_channel * [self.group.trivial_repr]), nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), kernel_size=3, padding=0, initialize=initialize, ), # 56 nn.ReLU(nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), inplace=True), nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), 2), # 28 nn.R3Conv(nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), inplace=True), nn.R3Conv(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), inplace=True), nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), 2), # 14 nn.R3Conv(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), kernel_size=3, padding=0, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), inplace=True), # 12 nn.R3Conv(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), inplace=True), nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), 2), # 6 nn.R3Conv(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), inplace=True), nn.R3Conv(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), inplace=True), nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), 2), # 3 nn.R3Conv( nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), nn.FieldType(self.group, n_out * [self.group.regular_repr]), kernel_size=3, padding=0, initialize=initialize, ), nn.ReLU(nn.FieldType(self.group, n_out * [self.group.regular_repr]), inplace=True), # 1x1 ) def forward(self, x) -> nn.GeometricTensor: if type(x) is torch.Tensor: x = nn.GeometricTensor(x, nn.FieldType(self.group, self.obs_channel * [self.group.trivial_repr])) return self.conv(x) class EquivariantVoxelEncoder64Cyclic(torch.nn.Module): def __init__(self, obs_channel: int = 4, n_out: int = 128, initialize: bool = True, N=8): super().__init__() self.obs_channel = obs_channel self.group = gspaces.rot2dOnR3(N) self.conv = torch.nn.Sequential( # 64 nn.R3Conv( nn.FieldType(self.group, obs_channel * [self.group.trivial_repr]), nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize, ), nn.ReLU(nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), inplace=True), nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), 2), # 32 nn.R3Conv(nn.FieldType(self.group, n_out // 16 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), inplace=True), nn.R3Conv(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), inplace=True), nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), 2), # 16 nn.R3Conv(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), inplace=True), nn.R3Conv(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), inplace=True), nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), 2), # 8 nn.R3Conv(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), kernel_size=3, padding=0, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), inplace=True), nn.R3Conv(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), inplace=True), # 6 nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), 2), # 3 nn.R3Conv( nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), nn.FieldType(self.group, n_out * [self.group.regular_repr]), kernel_size=3, padding=0, initialize=initialize, ), nn.ReLU(nn.FieldType(self.group, n_out * [self.group.regular_repr]), inplace=True), # 1x1 ) def forward(self, x) -> nn.GeometricTensor: if type(x) is torch.Tensor: x = nn.GeometricTensor(x, nn.FieldType(self.group, self.obs_channel * [self.group.trivial_repr])) return self.conv(x) class EquivariantVoxelEncoder16Cyclic(torch.nn.Module): def __init__(self, obs_channel: int = 4, n_out: int = 128, initialize: bool = True, N=8): super().__init__() self.obs_channel = obs_channel self.group = gspaces.rot2dOnR3(N) self.conv = torch.nn.Sequential( # 16 nn.R3Conv( nn.FieldType(self.group, obs_channel * [self.group.trivial_repr]), nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize, ), nn.ReLU(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), inplace=True), nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), 2), # 8 nn.R3Conv(nn.FieldType(self.group, n_out // 8 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), kernel_size=3, padding=0, initialize=initialize), # 6 nn.ReLU(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), inplace=True), nn.R3Conv(nn.FieldType(self.group, n_out // 4 * [self.group.regular_repr]), nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), kernel_size=3, padding=1, initialize=initialize), nn.ReLU(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), inplace=True), nn.PointwiseMaxPool3D(nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), 2), # 3 nn.R3Conv( nn.FieldType(self.group, n_out // 2 * [self.group.regular_repr]), nn.FieldType(self.group, n_out * [self.group.regular_repr]), kernel_size=3, padding=0, initialize=initialize, ), nn.ReLU(nn.FieldType(self.group, n_out * [self.group.regular_repr]), inplace=True), # 1x1 ) def forward(self, x) -> nn.GeometricTensor: if type(x) is torch.Tensor: x = nn.GeometricTensor(x, nn.FieldType(self.group, self.obs_channel * [self.group.trivial_repr])) return self.conv(x)