DeOldify / tests /test_fastai_compat.py
thookham's picture
Initial commit for Hugging Face sync (Clean History)
e9f9fd3
import unittest
import torch
import torch.nn as nn
import sys
import os
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from deoldify.fastai_compat import (
conv_layer,
NormType,
relu,
res_block,
MergeLayer,
SequentialEx
)
class TestFastAICompat(unittest.TestCase):
def test_conv_layer(self):
# Test basic convolution layer creation
l = conv_layer(3, 64, ks=3, stride=1)
self.assertIsInstance(l, nn.Sequential)
# Expect Conv2d, ReLU, BatchNorm2d
self.assertEqual(len(l), 3)
self.IsInstance(l[0], nn.Conv2d)
def test_relu(self):
r = relu(leaky=0.1)
self.IsInstance(r, nn.LeakyReLU)
self.assertEqual(r.negative_slope, 0.1)
def test_merge_layer(self):
m = MergeLayer(dense=False)
x = torch.randn(1, 10)
x.orig = torch.randn(1, 10)
out = m(x)
self.assertEqual(out.shape, (1, 10)) # Added
m_dense = MergeLayer(dense=True)
out_dense = m_dense(x)
self.assertEqual(out_dense.shape, (1, 20)) # Concatenated
def test_sequential_ex(self):
# Test that SequentialEx handles .orig attribute
l1 = nn.Identity()
l2 = nn.Identity()
seq = SequentialEx(l1, l2)
x = torch.randn(1, 10)
out = seq(x)
self.assertEqual(out.shape, x.shape)
def IsInstance(self, obj, cls):
self.assertTrue(isinstance(obj, cls))
if __name__ == '__main__':
unittest.main()