|
|
import unittest |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from deoldify.fastai_compat import ( |
|
|
conv_layer, |
|
|
NormType, |
|
|
relu, |
|
|
res_block, |
|
|
MergeLayer, |
|
|
SequentialEx |
|
|
) |
|
|
|
|
|
class TestFastAICompat(unittest.TestCase): |
|
|
def test_conv_layer(self): |
|
|
|
|
|
l = conv_layer(3, 64, ks=3, stride=1) |
|
|
self.assertIsInstance(l, nn.Sequential) |
|
|
|
|
|
self.assertEqual(len(l), 3) |
|
|
self.IsInstance(l[0], nn.Conv2d) |
|
|
|
|
|
def test_relu(self): |
|
|
r = relu(leaky=0.1) |
|
|
self.IsInstance(r, nn.LeakyReLU) |
|
|
self.assertEqual(r.negative_slope, 0.1) |
|
|
|
|
|
def test_merge_layer(self): |
|
|
m = MergeLayer(dense=False) |
|
|
x = torch.randn(1, 10) |
|
|
x.orig = torch.randn(1, 10) |
|
|
out = m(x) |
|
|
self.assertEqual(out.shape, (1, 10)) |
|
|
|
|
|
m_dense = MergeLayer(dense=True) |
|
|
out_dense = m_dense(x) |
|
|
self.assertEqual(out_dense.shape, (1, 20)) |
|
|
|
|
|
def test_sequential_ex(self): |
|
|
|
|
|
l1 = nn.Identity() |
|
|
l2 = nn.Identity() |
|
|
seq = SequentialEx(l1, l2) |
|
|
|
|
|
x = torch.randn(1, 10) |
|
|
out = seq(x) |
|
|
self.assertEqual(out.shape, x.shape) |
|
|
|
|
|
def IsInstance(self, obj, cls): |
|
|
self.assertTrue(isinstance(obj, cls)) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
unittest.main() |
|
|
|