| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from packaging import version |
| |
|
| | class Model(nn.Module): |
| | def __init__(self): |
| | super(Model, self).__init__() |
| |
|
| | self.w2 = nn.Parameter(torch.rand(12, 6, 4, 4, 4)) |
| | self.b2 = nn.Parameter(torch.rand(12)) |
| | self.w3 = nn.Parameter(torch.rand(6, 4, 3, 3, 3)) |
| |
|
| | def forward(self, x, w0, w1, b1, y): |
| | x = F.conv3d(x, w0, None, stride=(2,2,2), padding=(1,0,1)) |
| | if version.parse(torch.__version__) < version.parse('1.9'): |
| | x = F.conv3d(x, w1, b1, stride=(1,1,1), padding=(1,1,1), dilation=(2,2,1), groups=2) |
| | else: |
| | x = F.conv3d(x, w1, b1, stride=(1,1,1), padding='same', dilation=(2,2,1), groups=2) |
| |
|
| | y = F.conv3d(y, self.w2, self.b2, stride=(2,2,2), padding=(2,2,2)) |
| | y = F.conv3d(y, self.w3, None, stride=(2,2,2), padding=(1,1,1), groups=3) |
| | return x, y |
| |
|
| | def test(): |
| | net = Model() |
| | net.eval() |
| |
|
| | torch.manual_seed(0) |
| | x = torch.rand(1, 12, 20, 32, 40) |
| | w0 = torch.rand(16, 12, 3, 2, 3) |
| | w1 = torch.rand(16, 8, 5, 4, 5) |
| | b1 = torch.rand(16) |
| | y = torch.rand(1, 6, 12, 11, 10) |
| |
|
| | a0, a1 = net(x, w0, w1, b1, y) |
| |
|
| | |
| | mod = torch.jit.trace(net, (x, w0, w1, b1, y)) |
| | mod.save("test_F_conv3d.pt") |
| |
|
| | |
| | import os |
| | os.system("../src/pnnx test_F_conv3d.pt inputshape=[1,12,20,32,40],[16,12,3,2,3],[16,8,5,4,5],[16],[1,6,12,11,10]") |
| |
|
| | |
| | import test_F_conv3d_pnnx |
| | b0, b1 = test_F_conv3d_pnnx.test_inference() |
| |
|
| | return torch.equal(a0, b0) and torch.equal(a1, b1) |
| |
|
| | if __name__ == "__main__": |
| | if test(): |
| | exit(0) |
| | else: |
| | exit(1) |
| |
|