| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
|
|
| from se3_transformer.model import SE3Transformer |
| from se3_transformer.model.fiber import Fiber |
| from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot |
|
|
| |
| TOL = 1e-3 |
| CHANNELS, NODES = 32, 512 |
|
|
|
|
| def _get_outputs(model, R): |
| feats0 = torch.randn(NODES, CHANNELS, 1) |
| feats1 = torch.randn(NODES, CHANNELS, 3) |
|
|
| coords = torch.randn(NODES, 3) |
| graph = get_random_graph(NODES) |
| if torch.cuda.is_available(): |
| feats0 = feats0.cuda() |
| feats1 = feats1.cuda() |
| R = R.cuda() |
| coords = coords.cuda() |
| graph = graph.to('cuda') |
| model.cuda() |
|
|
| graph1 = assign_relative_pos(graph, coords) |
| out1 = model(graph1, {'0': feats0, '1': feats1}, {}) |
| graph2 = assign_relative_pos(graph, coords @ R) |
| out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {}) |
|
|
| return out1, out2 |
|
|
|
|
| def _get_model(**kwargs): |
| return SE3Transformer( |
| num_layers=4, |
| fiber_in=Fiber.create(2, CHANNELS), |
| fiber_hidden=Fiber.create(3, CHANNELS), |
| fiber_out=Fiber.create(2, CHANNELS), |
| fiber_edge=Fiber({}), |
| num_heads=8, |
| channels_div=2, |
| **kwargs |
| ) |
|
|
|
|
| def test_equivariance(): |
| model = _get_model() |
| R = rot(*torch.rand(3)) |
| if torch.cuda.is_available(): |
| R = R.cuda() |
| out1, out2 = _get_outputs(model, R) |
|
|
| assert torch.allclose(out2['0'], out1['0'], atol=TOL), \ |
| f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}' |
| assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \ |
| f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}' |
|
|
|
|
| def test_equivariance_pooled(): |
| model = _get_model(pooling='avg', return_type=1) |
| R = rot(*torch.rand(3)) |
| if torch.cuda.is_available(): |
| R = R.cuda() |
| out1, out2 = _get_outputs(model, R) |
|
|
| assert torch.allclose(out2, (out1 @ R), atol=TOL), \ |
| f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}' |
|
|
|
|
| def test_invariance_pooled(): |
| model = _get_model(pooling='avg', return_type=0) |
| R = rot(*torch.rand(3)) |
| if torch.cuda.is_available(): |
| R = R.cuda() |
| out1, out2 = _get_outputs(model, R) |
|
|
| assert torch.allclose(out2, out1, atol=TOL), \ |
| f'type-0 features should be invariant {get_max_diff(out1, out2)}' |
|
|