|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
from se3_transformer.model import SE3Transformer |
|
|
from se3_transformer.model.fiber import Fiber |
|
|
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot |
|
|
|
|
|
|
|
|
TOL = 1e-3 |
|
|
CHANNELS, NODES = 32, 512 |
|
|
|
|
|
|
|
|
def _get_outputs(model, R): |
|
|
feats0 = torch.randn(NODES, CHANNELS, 1) |
|
|
feats1 = torch.randn(NODES, CHANNELS, 3) |
|
|
|
|
|
coords = torch.randn(NODES, 3) |
|
|
graph = get_random_graph(NODES) |
|
|
if torch.cuda.is_available(): |
|
|
feats0 = feats0.cuda() |
|
|
feats1 = feats1.cuda() |
|
|
R = R.cuda() |
|
|
coords = coords.cuda() |
|
|
graph = graph.to('cuda') |
|
|
model.cuda() |
|
|
|
|
|
graph1 = assign_relative_pos(graph, coords) |
|
|
out1 = model(graph1, {'0': feats0, '1': feats1}, {}) |
|
|
graph2 = assign_relative_pos(graph, coords @ R) |
|
|
out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {}) |
|
|
|
|
|
return out1, out2 |
|
|
|
|
|
|
|
|
def _get_model(**kwargs): |
|
|
return SE3Transformer( |
|
|
num_layers=4, |
|
|
fiber_in=Fiber.create(2, CHANNELS), |
|
|
fiber_hidden=Fiber.create(3, CHANNELS), |
|
|
fiber_out=Fiber.create(2, CHANNELS), |
|
|
fiber_edge=Fiber({}), |
|
|
num_heads=8, |
|
|
channels_div=2, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
def test_equivariance(): |
|
|
model = _get_model() |
|
|
R = rot(*torch.rand(3)) |
|
|
if torch.cuda.is_available(): |
|
|
R = R.cuda() |
|
|
out1, out2 = _get_outputs(model, R) |
|
|
|
|
|
assert torch.allclose(out2['0'], out1['0'], atol=TOL), \ |
|
|
f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}' |
|
|
assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \ |
|
|
f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}' |
|
|
|
|
|
|
|
|
def test_equivariance_pooled(): |
|
|
model = _get_model(pooling='avg', return_type=1) |
|
|
R = rot(*torch.rand(3)) |
|
|
if torch.cuda.is_available(): |
|
|
R = R.cuda() |
|
|
out1, out2 = _get_outputs(model, R) |
|
|
|
|
|
assert torch.allclose(out2, (out1 @ R), atol=TOL), \ |
|
|
f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}' |
|
|
|
|
|
|
|
|
def test_invariance_pooled(): |
|
|
model = _get_model(pooling='avg', return_type=0) |
|
|
R = rot(*torch.rand(3)) |
|
|
if torch.cuda.is_available(): |
|
|
R = R.cuda() |
|
|
out1, out2 = _get_outputs(model, R) |
|
|
|
|
|
assert torch.allclose(out2, out1, atol=TOL), \ |
|
|
f'type-0 features should be invariant {get_max_diff(out1, out2)}' |
|
|
|