|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
|
|
import torch |
|
|
from pytorch3d.common.linear_with_repeat import LinearWithRepeat |
|
|
|
|
|
from .common_testing import TestCaseMixin |
|
|
|
|
|
|
|
|
class TestLinearWithRepeat(TestCaseMixin, unittest.TestCase): |
|
|
def setUp(self) -> None: |
|
|
super().setUp() |
|
|
torch.manual_seed(42) |
|
|
|
|
|
def test_simple(self): |
|
|
x = torch.rand(4, 6, 7, 3) |
|
|
y = torch.rand(4, 6, 4) |
|
|
|
|
|
linear = torch.nn.Linear(7, 8) |
|
|
torch.nn.init.xavier_uniform_(linear.weight.data) |
|
|
linear.bias.data.uniform_() |
|
|
equivalent = torch.cat([x, y.unsqueeze(-2).expand(4, 6, 7, 4)], dim=-1) |
|
|
expected = linear.forward(equivalent) |
|
|
|
|
|
linear_with_repeat = LinearWithRepeat(7, 8) |
|
|
linear_with_repeat.load_state_dict(linear.state_dict()) |
|
|
actual = linear_with_repeat.forward((x, y)) |
|
|
self.assertClose(actual, expected, rtol=1e-4) |
|
|
|