| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Tests for registry.""" |
| |
|
| | from __future__ import absolute_import |
| | from __future__ import division |
| | from __future__ import print_function |
| |
|
| | from unittest import mock |
| |
|
| | from absl.testing import absltest |
| | from big_vision.pp import registry |
| |
|
| |
|
| | class RegistryTest(absltest.TestCase): |
| |
|
| | def setUp(self): |
| | super(RegistryTest, self).setUp() |
| | |
| | |
| | self.addCleanup(mock.patch.stopall) |
| | self.global_registry = dict() |
| | self.mocked_method = mock.patch.object( |
| | registry.Registry, "global_registry", |
| | return_value=self.global_registry).start() |
| |
|
| | def test_parse_name(self): |
| | name, args, kwargs = registry.parse_name("f") |
| | self.assertEqual(name, "f") |
| | self.assertEqual(args, ()) |
| | self.assertEqual(kwargs, {}) |
| |
|
| | name, args, kwargs = registry.parse_name("f()") |
| | self.assertEqual(name, "f") |
| | self.assertEqual(args, ()) |
| | self.assertEqual(kwargs, {}) |
| |
|
| | name, args, kwargs = registry.parse_name("func(a=0,b=1,c='s')") |
| | self.assertEqual(name, "func") |
| | self.assertEqual(args, ()) |
| | self.assertEqual(kwargs, {"a": 0, "b": 1, "c": "s"}) |
| |
|
| | name, args, kwargs = registry.parse_name("func(1,'foo',3)") |
| | self.assertEqual(name, "func") |
| | self.assertEqual(args, (1, "foo", 3)) |
| | self.assertEqual(kwargs, {}) |
| |
|
| | name, args, kwargs = registry.parse_name("func(1,'2',a=3,foo='bar')") |
| | self.assertEqual(name, "func") |
| | self.assertEqual(args, (1, "2")) |
| | self.assertEqual(kwargs, {"a": 3, "foo": "bar"}) |
| |
|
| | name, args, kwargs = registry.parse_name("foo.bar.func(a=0,b=(1),c='s')") |
| | self.assertEqual(name, "foo.bar.func") |
| | self.assertEqual(kwargs, dict(a=0, b=1, c="s")) |
| |
|
| | with self.assertRaises(SyntaxError): |
| | registry.parse_name("func(0") |
| | with self.assertRaises(SyntaxError): |
| | registry.parse_name("func(a=0,,b=0)") |
| | with self.assertRaises(SyntaxError): |
| | registry.parse_name("func(a=0,b==1,c='s')") |
| | with self.assertRaises(ValueError): |
| | registry.parse_name("func(a=0,b=undefined_name,c='s')") |
| |
|
| | def test_register(self): |
| | |
| | @registry.Registry.register("func1") |
| | def func1(): |
| | pass |
| |
|
| | self.assertLen(registry.Registry.global_registry(), 1) |
| |
|
| | def test_lookup_function(self): |
| |
|
| | @registry.Registry.register("func1") |
| | def func1(arg1, arg2, arg3): |
| | return arg1, arg2, arg3 |
| |
|
| | self.assertTrue(callable(registry.Registry.lookup("func1"))) |
| | self.assertEqual(registry.Registry.lookup("func1")(1, 2, 3), (1, 2, 3)) |
| | self.assertEqual( |
| | registry.Registry.lookup("func1(arg3=9)")(1, 2), (1, 2, 9)) |
| | self.assertEqual( |
| | registry.Registry.lookup("func1(arg2=9,arg1=99)")(arg3=3), (99, 9, 3)) |
| | self.assertEqual( |
| | registry.Registry.lookup("func1(arg2=9,arg1=99)")(arg1=1, arg3=3), |
| | (1, 9, 3)) |
| |
|
| | self.assertEqual( |
| | registry.Registry.lookup("func1(1)")(1, 2), (1, 1, 2)) |
| | self.assertEqual( |
| | registry.Registry.lookup("func1(1)")(arg3=3, arg2=2), (1, 2, 3)) |
| | self.assertEqual( |
| | registry.Registry.lookup("func1(1, 2)")(3), (1, 2, 3)) |
| | self.assertEqual( |
| | registry.Registry.lookup("func1(1, 2)")(arg3=3), (1, 2, 3)) |
| | self.assertEqual( |
| | registry.Registry.lookup("func1(1, arg2=2)")(arg3=3), (1, 2, 3)) |
| | self.assertEqual( |
| | registry.Registry.lookup("func1(1, arg3=2)")(arg2=3), (1, 3, 2)) |
| | self.assertEqual( |
| | registry.Registry.lookup("func1(1, arg3=2)")(3), (1, 3, 2)) |
| |
|
| | with self.assertRaises(TypeError): |
| | registry.Registry.lookup("func1(1, arg2=2)")(3) |
| | with self.assertRaises(TypeError): |
| | registry.Registry.lookup("func1(1, arg3=3)")(arg3=3) |
| | with self.assertRaises(TypeError): |
| | registry.Registry.lookup("func1(1, arg3=3)")(arg1=3) |
| | with self.assertRaises(SyntaxError): |
| | registry.Registry.lookup("func1(arg1=1, 3)")(arg2=3) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | absltest.main() |
| |
|