| | |
| | |
| | |
| | |
| | |
| |
|
| | import pickle |
| | import textwrap |
| | import unittest |
| | from dataclasses import dataclass, field, is_dataclass |
| | from enum import Enum |
| | from typing import Any, Dict, List, Optional, Tuple |
| | from unittest.mock import Mock |
| |
|
| | from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError |
| | from pytorch3d.implicitron.tools.config import ( |
| | _get_type_to_process, |
| | _is_actually_dataclass, |
| | _ProcessType, |
| | _Registry, |
| | Configurable, |
| | enable_get_default_args, |
| | expand_args_fields, |
| | get_default_args, |
| | get_default_args_field, |
| | registry, |
| | remove_unused_components, |
| | ReplaceableBase, |
| | run_auto_creation, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class Animal(ReplaceableBase): |
| | pass |
| |
|
| |
|
| | class Fruit(ReplaceableBase): |
| | pass |
| |
|
| |
|
| | @registry.register |
| | class Banana(Fruit): |
| | pips: int |
| | spots: int |
| | bananame: str |
| |
|
| |
|
| | @registry.register |
| | class Pear(Fruit): |
| | n_pips: int = 13 |
| |
|
| |
|
| | class Pineapple(Fruit): |
| | pass |
| |
|
| |
|
| | @registry.register |
| | class Orange(Fruit): |
| | pass |
| |
|
| |
|
| | @registry.register |
| | class Kiwi(Fruit): |
| | pass |
| |
|
| |
|
| | @registry.register |
| | class LargePear(Pear): |
| | pass |
| |
|
| |
|
| | class BoringConfigurable(Configurable): |
| | pass |
| |
|
| |
|
| | class MainTest(Configurable): |
| | the_fruit: Fruit |
| | n_ids: int |
| | n_reps: int = 8 |
| | the_second_fruit: Fruit |
| |
|
| | def create_the_second_fruit(self): |
| | expand_args_fields(Pineapple) |
| | self.the_second_fruit = Pineapple() |
| |
|
| | def __post_init__(self): |
| | run_auto_creation(self) |
| |
|
| |
|
| | class TestConfig(unittest.TestCase): |
| | def test_is_actually_dataclass(self): |
| | @dataclass |
| | class A: |
| | pass |
| |
|
| | self.assertTrue(_is_actually_dataclass(A)) |
| | self.assertTrue(is_dataclass(A)) |
| |
|
| | class B(A): |
| | a: int |
| |
|
| | self.assertFalse(_is_actually_dataclass(B)) |
| | self.assertTrue(is_dataclass(B)) |
| |
|
| | def test_get_type_to_process(self): |
| | gt = _get_type_to_process |
| | self.assertIsNone(gt(int)) |
| | self.assertEqual(gt(Fruit), (Fruit, _ProcessType.REPLACEABLE)) |
| | self.assertEqual( |
| | gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE) |
| | ) |
| | self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE)) |
| | self.assertEqual( |
| | gt(Optional[MainTest]), (MainTest, _ProcessType.OPTIONAL_CONFIGURABLE) |
| | ) |
| | self.assertIsNone(gt(Optional[int])) |
| | self.assertIsNone(gt(Tuple[Fruit])) |
| | self.assertIsNone(gt(Tuple[Fruit, Animal])) |
| | self.assertIsNone(gt(Optional[List[int]])) |
| |
|
| | def test_simple_replacement(self): |
| | struct = get_default_args(MainTest) |
| | struct.n_ids = 9780 |
| | struct.the_fruit_Pear_args.n_pips = 3 |
| | struct.the_fruit_class_type = "Pear" |
| | struct.the_second_fruit_class_type = "Pear" |
| |
|
| | main = MainTest(**struct) |
| | self.assertIsInstance(main.the_fruit, Pear) |
| | self.assertEqual(main.n_reps, 8) |
| | self.assertEqual(main.n_ids, 9780) |
| | self.assertEqual(main.the_fruit.n_pips, 3) |
| | self.assertIsInstance(main.the_second_fruit, Pineapple) |
| |
|
| | struct2 = get_default_args(MainTest) |
| | self.assertEqual(struct2.the_fruit_Pear_args.n_pips, 13) |
| |
|
| | self.assertEqual( |
| | MainTest._creation_functions, |
| | ("create_the_fruit", "create_the_second_fruit"), |
| | ) |
| |
|
| | def test_detect_bases(self): |
| | |
| | self.assertIsNone(_Registry._base_class_from_class(ReplaceableBase)) |
| | self.assertIsNone(_Registry._base_class_from_class(MainTest)) |
| | self.assertIs(_Registry._base_class_from_class(Fruit), Fruit) |
| | self.assertIs(_Registry._base_class_from_class(Pear), Fruit) |
| |
|
| | class PricklyPear(Pear): |
| | pass |
| |
|
| | self.assertIs(_Registry._base_class_from_class(PricklyPear), Fruit) |
| |
|
| | def test_registry_entries(self): |
| | self.assertIs(registry.get(Fruit, "Banana"), Banana) |
| | with self.assertRaisesRegex(ValueError, "Banana has not been registered."): |
| | registry.get(Animal, "Banana") |
| | with self.assertRaisesRegex(ValueError, "PricklyPear has not been registered."): |
| | registry.get(Fruit, "PricklyPear") |
| |
|
| | self.assertIs(registry.get(Pear, "Pear"), Pear) |
| | self.assertIs(registry.get(Pear, "LargePear"), LargePear) |
| | with self.assertRaisesRegex(ValueError, "Banana resolves to"): |
| | registry.get(Pear, "Banana") |
| |
|
| | all_fruit = set(registry.get_all(Fruit)) |
| | self.assertIn(Banana, all_fruit) |
| | self.assertIn(Pear, all_fruit) |
| | self.assertIn(LargePear, all_fruit) |
| | self.assertEqual(registry.get_all(Pear), [LargePear]) |
| |
|
| | @registry.register |
| | class Apple(Fruit): |
| | pass |
| |
|
| | @registry.register |
| | class CrabApple(Apple): |
| | pass |
| |
|
| | self.assertEqual(registry.get_all(Apple), [CrabApple]) |
| |
|
| | self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple) |
| |
|
| | with self.assertRaisesRegex(ValueError, "Cannot tell what it is."): |
| |
|
| | @registry.register |
| | class NotAFruit: |
| | pass |
| |
|
| | def test_recursion(self): |
| | class Shape(ReplaceableBase): |
| | pass |
| |
|
| | @registry.register |
| | class Triangle(Shape): |
| | a: float = 5.0 |
| |
|
| | @registry.register |
| | class Square(Shape): |
| | a: float = 3.0 |
| |
|
| | @registry.register |
| | class LargeShape(Shape): |
| | inner: Shape |
| |
|
| | def __post_init__(self): |
| | run_auto_creation(self) |
| |
|
| | class ShapeContainer(Configurable): |
| | shape: Shape |
| |
|
| | container = ShapeContainer(**get_default_args(ShapeContainer)) |
| | |
| | with self.assertRaises(AttributeError): |
| | container.shape |
| |
|
| | class ShapeContainer2(Configurable): |
| | x: Shape |
| | x_class_type: str = "LargeShape" |
| |
|
| | def __post_init__(self): |
| | self.x_LargeShape_args.inner_class_type = "Triangle" |
| | run_auto_creation(self) |
| |
|
| | container2_args = get_default_args(ShapeContainer2) |
| | container2_args.x_LargeShape_args.inner_Triangle_args.a += 10 |
| | self.assertIn("inner_Square_args", container2_args.x_LargeShape_args) |
| | |
| | |
| | self.assertNotIn("inner_LargeShape_args", container2_args.x_LargeShape_args) |
| | container2_args.x_LargeShape_args.inner_Square_args.a += 100 |
| | container2 = ShapeContainer2(**container2_args) |
| | self.assertIsInstance(container2.x, LargeShape) |
| | self.assertIsInstance(container2.x.inner, Triangle) |
| | self.assertEqual(container2.x.inner.a, 15.0) |
| |
|
| | def test_simpleclass_member(self): |
| | |
| | |
| | |
| | class Foo: |
| | def __init__(self, a: Any = 1, b: Any = 2): |
| | self.a, self.b = a, b |
| |
|
| | enable_get_default_args(Foo) |
| |
|
| | @dataclass() |
| | class Bar: |
| | aa: int = 9 |
| | bb: int = 9 |
| |
|
| | class Container(Configurable): |
| | bar: Bar = Bar() |
| | |
| | |
| | fruit: Fruit |
| | fruit_class_type: str = "Orange" |
| |
|
| | def __post_init__(self): |
| | run_auto_creation(self) |
| |
|
| | self.assertEqual(get_default_args(Foo), {"a": 1, "b": 2}) |
| | container_args = get_default_args(Container) |
| | container = Container(**container_args) |
| | self.assertIsInstance(container.fruit, Orange) |
| | self.assertEqual(Container._processed_members, {"fruit": Fruit}) |
| | self.assertEqual(container._processed_members, {"fruit": Fruit}) |
| |
|
| | container_defaulted = Container() |
| | container_defaulted.fruit_Pear_args.n_pips += 4 |
| |
|
| | container_args2 = get_default_args(Container) |
| | container = Container(**container_args2) |
| | self.assertEqual(container.fruit_Pear_args.n_pips, 13) |
| |
|
| | def test_inheritance(self): |
| | |
| | class FruitBowl(ReplaceableBase): |
| | main_fruit: Fruit |
| | main_fruit_class_type: str = "Orange" |
| |
|
| | def __post_init__(self): |
| | raise ValueError("This doesn't get called") |
| |
|
| | class LargeFruitBowl(FruitBowl): |
| | extra_fruit: Optional[Fruit] |
| | extra_fruit_class_type: str = "Kiwi" |
| | no_fruit: Optional[Fruit] |
| | no_fruit_class_type: Optional[str] = None |
| |
|
| | def __post_init__(self): |
| | run_auto_creation(self) |
| |
|
| | large_args = get_default_args(LargeFruitBowl) |
| | self.assertNotIn("extra_fruit", large_args) |
| | self.assertNotIn("main_fruit", large_args) |
| | large = LargeFruitBowl(**large_args) |
| | self.assertIsInstance(large.main_fruit, Orange) |
| | self.assertIsInstance(large.extra_fruit, Kiwi) |
| | self.assertIsNone(large.no_fruit) |
| | self.assertIn("no_fruit_Kiwi_args", large_args) |
| |
|
| | remove_unused_components(large_args) |
| | large2 = LargeFruitBowl(**large_args) |
| | self.assertIsInstance(large2.main_fruit, Orange) |
| | self.assertIsInstance(large2.extra_fruit, Kiwi) |
| | self.assertIsNone(large2.no_fruit) |
| | needed_args = [ |
| | "extra_fruit_Kiwi_args", |
| | "extra_fruit_class_type", |
| | "main_fruit_Orange_args", |
| | "main_fruit_class_type", |
| | "no_fruit_class_type", |
| | ] |
| | self.assertEqual(sorted(large_args.keys()), needed_args) |
| |
|
| | with self.assertRaisesRegex(ValueError, "NotAFruit has not been registered."): |
| | LargeFruitBowl(extra_fruit_class_type="NotAFruit") |
| |
|
| | def test_inheritance2(self): |
| | |
| | |
| | class Parent(ReplaceableBase): |
| | pass |
| |
|
| | class Main(Configurable): |
| | parent: Parent |
| | |
| |
|
| | @registry.register |
| | class Derived(Parent, Main): |
| | pass |
| |
|
| | args = get_default_args(Main) |
| | |
| | self.assertCountEqual(args.keys(), ["parent_class_type"]) |
| |
|
| | main = Main(**args) |
| |
|
| | with self.assertRaisesRegex(ValueError, "UNDEFAULTED has not been registered."): |
| | run_auto_creation(main) |
| |
|
| | main.parent_class_type = "Derived" |
| | |
| | main.parent_Derived_args = {} |
| | with self.assertRaises(AttributeError): |
| | main.parent |
| | run_auto_creation(main) |
| | self.assertIsInstance(main.parent, Derived) |
| |
|
| | def test_redefine(self): |
| | class FruitBowl(ReplaceableBase): |
| | main_fruit: Fruit |
| | main_fruit_class_type: str = "Grape" |
| |
|
| | def __post_init__(self): |
| | run_auto_creation(self) |
| |
|
| | @registry.register |
| | @dataclass |
| | class Grape(Fruit): |
| | large: bool = False |
| |
|
| | def get_color(self): |
| | return "red" |
| |
|
| | def __post_init__(self): |
| | raise ValueError("This doesn't get called") |
| |
|
| | bowl_args = get_default_args(FruitBowl) |
| |
|
| | @registry.register |
| | @dataclass |
| | class Grape(Fruit): |
| | large: bool = True |
| |
|
| | def get_color(self): |
| | return "green" |
| |
|
| | with self.assertWarnsRegex( |
| | UserWarning, "New implementation of Grape is being chosen." |
| | ): |
| | defaulted_bowl = FruitBowl() |
| | self.assertIsInstance(defaulted_bowl.main_fruit, Grape) |
| | self.assertEqual(defaulted_bowl.main_fruit.large, True) |
| | self.assertEqual(defaulted_bowl.main_fruit.get_color(), "green") |
| |
|
| | with self.assertWarnsRegex( |
| | UserWarning, "New implementation of Grape is being chosen." |
| | ): |
| | args_bowl = FruitBowl(**bowl_args) |
| | self.assertIsInstance(args_bowl.main_fruit, Grape) |
| | |
| | self.assertEqual(args_bowl.main_fruit.large, False) |
| | |
| | self.assertEqual(args_bowl.main_fruit.get_color(), "green") |
| |
|
| | |
| | |
| | |
| | @registry.register |
| | class Grape(Fruit): |
| | large: bool = True |
| |
|
| | with self.assertWarnsRegex( |
| | UserWarning, "New implementation of Grape is being chosen." |
| | ): |
| | FruitBowl(**bowl_args) |
| |
|
| | |
| | |
| | |
| | |
| | @registry.register |
| | class Fig(Fruit): |
| | pass |
| |
|
| | bowl_args2 = get_default_args(FruitBowl) |
| | self.assertIn("main_fruit_Grape_args", bowl_args2) |
| | self.assertNotIn("main_fruit_Fig_args", bowl_args2) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | def test_no_replacement(self): |
| | |
| | class A(Configurable): |
| | n: int = 9 |
| |
|
| | class B(Configurable): |
| | a: A |
| |
|
| | def __post_init__(self): |
| | run_auto_creation(self) |
| |
|
| | class C(Configurable): |
| | b1: B |
| | b2: Optional[B] |
| | b3: Optional[B] |
| | b2_enabled: bool = True |
| | b3_enabled: bool = False |
| |
|
| | def __post_init__(self): |
| | run_auto_creation(self) |
| |
|
| | c_args = get_default_args(C) |
| | c = C(**c_args) |
| | self.assertIsInstance(c.b1.a, A) |
| | self.assertEqual(c.b1.a.n, 9) |
| | self.assertFalse(hasattr(c, "b1_enabled")) |
| | self.assertIsInstance(c.b2.a, A) |
| | self.assertEqual(c.b2.a.n, 9) |
| | self.assertTrue(c.b2_enabled) |
| | self.assertIsNone(c.b3) |
| | self.assertFalse(c.b3_enabled) |
| |
|
| | def test_doc(self): |
| | |
| | class A(ReplaceableBase): |
| | k: int = 1 |
| |
|
| | @registry.register |
| | class A1(A): |
| | m: int = 3 |
| |
|
| | @registry.register |
| | class A2(A): |
| | n: str = "2" |
| |
|
| | class B(Configurable): |
| | a: A |
| | a_class_type: str = "A2" |
| |
|
| | def __post_init__(self): |
| | run_auto_creation(self) |
| |
|
| | b_args = get_default_args(B) |
| | self.assertNotIn("a", b_args) |
| | b = B(**b_args) |
| | self.assertEqual(b.a.n, "2") |
| |
|
| | def test_raw_types(self): |
| | @dataclass |
| | class MyDataclass: |
| | int_field: int = 0 |
| | none_field: Optional[int] = None |
| | float_field: float = 9.3 |
| | bool_field: bool = True |
| | tuple_field: Tuple[int, ...] = (3,) |
| |
|
| | class SimpleClass: |
| | def __init__( |
| | self, |
| | tuple_member_: Tuple[int, int] = (3, 4), |
| | ): |
| | self.tuple_member = tuple_member_ |
| |
|
| | def get_tuple(self): |
| | return self.tuple_member |
| |
|
| | enable_get_default_args(SimpleClass) |
| |
|
| | def f(*, a: int = 3, b: str = "kj"): |
| | self.assertEqual(a, 3) |
| | self.assertEqual(b, "kj") |
| |
|
| | enable_get_default_args(f) |
| |
|
| | class C(Configurable): |
| | simple: DictConfig = get_default_args_field(SimpleClass) |
| | |
| | mydata: DictConfig = get_default_args_field(MyDataclass) |
| | a_tuple: Tuple[float] = (4.0, 3.0) |
| | f_args: DictConfig = get_default_args_field(f) |
| |
|
| | args = get_default_args(C) |
| | c = C(**args) |
| | self.assertCountEqual(args.keys(), ["simple", "mydata", "a_tuple", "f_args"]) |
| |
|
| | mydata = MyDataclass(**c.mydata) |
| | simple = SimpleClass(**c.simple) |
| |
|
| | |
| | self.assertEqual(simple.get_tuple(), [3, 4]) |
| | self.assertTrue(isinstance(simple.get_tuple(), ListConfig)) |
| | |
| | self.assertEqual(c.a_tuple, [4.0, 3.0]) |
| | self.assertTrue(isinstance(c.a_tuple, ListConfig)) |
| | self.assertEqual(mydata.tuple_field, (3,)) |
| | self.assertTrue(isinstance(mydata.tuple_field, ListConfig)) |
| | f(**c.f_args) |
| |
|
| | def test_irrelevant_bases(self): |
| | class NotADataclass: |
| | |
| | |
| | |
| | |
| | a: int = 9 |
| | b: int |
| |
|
| | class LeftConfigured(Configurable, NotADataclass): |
| | left: int = 1 |
| |
|
| | class RightConfigured(NotADataclass, Configurable): |
| | right: int = 2 |
| |
|
| | class Outer(Configurable): |
| | left: LeftConfigured |
| | right: RightConfigured |
| |
|
| | def __post_init__(self): |
| | run_auto_creation(self) |
| |
|
| | outer = Outer(**get_default_args(Outer)) |
| | self.assertEqual(outer.left.left, 1) |
| | self.assertEqual(outer.right.right, 2) |
| | with self.assertRaisesRegex(TypeError, "non-default argument"): |
| | dataclass(NotADataclass) |
| |
|
| | def test_unprocessed(self): |
| | |
| | class UnprocessedConfigurable(Configurable): |
| | a: int = 9 |
| |
|
| | class UnprocessedReplaceable(ReplaceableBase): |
| | a: int = 9 |
| |
|
| | for Unprocessed in [UnprocessedConfigurable, UnprocessedReplaceable]: |
| |
|
| | self.assertFalse(_is_actually_dataclass(Unprocessed)) |
| | unprocessed = Unprocessed() |
| | self.assertTrue(_is_actually_dataclass(Unprocessed)) |
| | self.assertTrue(isinstance(unprocessed, Unprocessed)) |
| | self.assertEqual(unprocessed.a, 9) |
| |
|
| | def test_enum(self): |
| | |
| | |
| |
|
| | class A(Enum): |
| | B1 = "b1" |
| | B2 = "b2" |
| |
|
| | |
| | class C(Configurable): |
| | a: A = A.B1 |
| |
|
| | |
| | def C_fn(a: A = A.B1): |
| | pass |
| |
|
| | enable_get_default_args(C_fn) |
| |
|
| | class C_cl: |
| | def __init__(self, a: A = A.B1) -> None: |
| | pass |
| |
|
| | enable_get_default_args(C_cl) |
| |
|
| | for C_ in [C, C_fn, C_cl]: |
| | base = get_default_args(C_) |
| | self.assertEqual(OmegaConf.to_yaml(base), "a: B1\n") |
| | self.assertEqual(base.a, A.B1) |
| | replaced = OmegaConf.merge(base, {"a": "B2"}) |
| | self.assertEqual(replaced.a, A.B2) |
| | with self.assertRaises(ValidationError): |
| | |
| | |
| | |
| | OmegaConf.merge(base, {"a": "b2"}) |
| |
|
| | remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base))) |
| | self.assertEqual(remerged.a, A.B1) |
| |
|
| | def test_pickle(self): |
| | def func(a: int = 1, b: str = "3"): |
| | pass |
| |
|
| | enable_get_default_args(func) |
| |
|
| | args = get_default_args(func) |
| | args2 = pickle.loads(pickle.dumps(args)) |
| | self.assertEqual(args2.a, 1) |
| | self.assertEqual(args2.b, "3") |
| |
|
| | args_regenerated = get_default_args(func) |
| | pickle.dumps(args_regenerated) |
| | pickle.dumps(args) |
| |
|
| | def test_remove_unused_components(self): |
| | struct = get_default_args(MainTest) |
| | struct.n_ids = 32 |
| | struct.the_fruit_class_type = "Pear" |
| | struct.the_second_fruit_class_type = "Banana" |
| | remove_unused_components(struct) |
| | expected_keys = [ |
| | "n_ids", |
| | "n_reps", |
| | "the_fruit_Pear_args", |
| | "the_fruit_class_type", |
| | "the_second_fruit_Banana_args", |
| | "the_second_fruit_class_type", |
| | ] |
| | expected_yaml = textwrap.dedent( |
| | """\ |
| | n_ids: 32 |
| | n_reps: 8 |
| | the_fruit_class_type: Pear |
| | the_fruit_Pear_args: |
| | n_pips: 13 |
| | the_second_fruit_class_type: Banana |
| | the_second_fruit_Banana_args: |
| | pips: ??? |
| | spots: ??? |
| | bananame: ??? |
| | """ |
| | ) |
| | self.assertEqual(sorted(struct.keys()), expected_keys) |
| |
|
| | |
| | expected = OmegaConf.create(expected_yaml) |
| | self.assertEqual(struct, expected) |
| |
|
| | |
| | self.assertEqual(OmegaConf.to_yaml(struct, sort_keys=False), expected_yaml) |
| |
|
| | main = MainTest(**struct) |
| | instance_data = OmegaConf.structured(main) |
| | remove_unused_components(instance_data) |
| | self.assertEqual(sorted(instance_data.keys()), expected_keys) |
| | self.assertEqual(instance_data, expected) |
| |
|
| | def test_remove_unused_components_optional(self): |
| | class MainTestWrapper(Configurable): |
| | mt: Optional[MainTest] |
| | mt_enabled: bool = False |
| |
|
| | args = get_default_args(MainTestWrapper) |
| | self.assertEqual(list(args.keys()), ["mt_enabled", "mt_args"]) |
| | remove_unused_components(args) |
| | self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n") |
| |
|
| | def test_get_instance_args(self): |
| | mt1, mt2 = [ |
| | MainTest( |
| | n_ids=0, |
| | n_reps=909, |
| | the_fruit_class_type="Pear", |
| | the_second_fruit_class_type="Pear", |
| | the_fruit_Pear_args=DictConfig({}), |
| | the_second_fruit_Pear_args={}, |
| | ) |
| | for _ in range(2) |
| | ] |
| | |
| | cfg1 = OmegaConf.structured(mt1) |
| | cfg2 = get_default_args(mt2) |
| | self.assertEqual(cfg1, cfg2) |
| | self.assertEqual(len(cfg1.the_second_fruit_Pear_args), 0) |
| | self.assertEqual(len(mt2.the_second_fruit_Pear_args), 0) |
| |
|
| | from_cfg = MainTest(**cfg2) |
| | self.assertEqual(len(from_cfg.the_second_fruit_Pear_args), 0) |
| |
|
| | |
| | merged_args = OmegaConf.merge(get_default_args(MainTest), cfg2) |
| | from_merged = MainTest(**merged_args) |
| | self.assertEqual(len(from_merged.the_second_fruit_Pear_args), 1) |
| | self.assertEqual(from_merged.n_reps, 909) |
| |
|
| | def test_tweak_hook(self): |
| | class A(Configurable): |
| | n: int = 9 |
| |
|
| | class Wrapper(Configurable): |
| | fruit: Fruit |
| | fruit_class_type: str = "Pear" |
| | fruit2: Fruit |
| | fruit2_class_type: str = "Pear" |
| | a: A |
| | a2: A |
| | a3: A |
| |
|
| | @classmethod |
| | def a_tweak_args(cls, type, args): |
| | assert type == A |
| | args.n = 993 |
| |
|
| | @classmethod |
| | def a3_tweak_args(cls, type, args): |
| | del args["n"] |
| |
|
| | @classmethod |
| | def fruit_tweak_args(cls, type, args): |
| | assert issubclass(type, Fruit) |
| | if type == Pear: |
| | assert args.n_pips == 13 |
| | args.n_pips = 19 |
| |
|
| | args = get_default_args(Wrapper) |
| | self.assertEqual(args.a_args.n, 993) |
| | self.assertEqual(args.a2_args.n, 9) |
| | self.assertEqual(args.a3_args, {}) |
| | self.assertEqual(args.fruit_Pear_args.n_pips, 19) |
| | self.assertEqual(args.fruit2_Pear_args.n_pips, 13) |
| |
|
| | def test_impls(self): |
| | |
| | |
| | |
| | |
| | |
| | control_args = [] |
| |
|
| | def fake_impl(self, control, args): |
| | control_args.append(control) |
| |
|
| | for fake in [False, True]: |
| |
|
| | class MyClass(Configurable): |
| | fruit: Fruit |
| | fruit_class_type: str = "Orange" |
| | fruit_o: Optional[Fruit] |
| | fruit_o_class_type: str = "Orange" |
| | fruit_0: Optional[Fruit] |
| | fruit_0_class_type: Optional[str] = None |
| | boring: BoringConfigurable |
| | boring_o: Optional[BoringConfigurable] |
| | boring_o_enabled: bool = True |
| | boring_0: Optional[BoringConfigurable] |
| | boring_0_enabled: bool = False |
| |
|
| | def __post_init__(self): |
| | run_auto_creation(self) |
| |
|
| | if fake: |
| | MyClass.create_fruit_impl = fake_impl |
| | MyClass.create_fruit_o_impl = fake_impl |
| | MyClass.create_boring_impl = fake_impl |
| | MyClass.create_boring_o_impl = fake_impl |
| |
|
| | expand_args_fields(MyClass) |
| | instance = MyClass() |
| | for name in ["fruit", "fruit_o", "boring", "boring_o"]: |
| | self.assertEqual( |
| | hasattr(instance, name), not fake, msg=f"{name} {fake}" |
| | ) |
| |
|
| | self.assertIsNone(instance.fruit_0) |
| | self.assertIsNone(instance.boring_0) |
| | if not fake: |
| | self.assertIsInstance(instance.fruit, Orange) |
| | self.assertIsInstance(instance.fruit_o, Orange) |
| | self.assertIsInstance(instance.boring, BoringConfigurable) |
| | self.assertIsInstance(instance.boring_o, BoringConfigurable) |
| |
|
| | self.assertEqual(control_args, ["Orange", "Orange", True, True]) |
| |
|
| | def test_pre_expand(self): |
| | |
| | |
| |
|
| | class A(Configurable): |
| | n: int = 9 |
| |
|
| | @classmethod |
| | def pre_expand(cls): |
| | pass |
| |
|
| | A.pre_expand = Mock() |
| | expand_args_fields(A) |
| | A.pre_expand.assert_called() |
| |
|
| | def test_pre_expand_replaceable(self): |
| | |
| | |
| |
|
| | class A(ReplaceableBase): |
| | pass |
| |
|
| | @classmethod |
| | def pre_expand(cls): |
| | pass |
| |
|
| | class A1(A): |
| | n: 9 |
| |
|
| | A.pre_expand = Mock() |
| | expand_args_fields(A1) |
| | A.pre_expand.assert_called() |
| |
|
| |
|
| | @dataclass(eq=False) |
| | class MockDataclass: |
| | field_no_default: int |
| | field_primitive_type: int = 42 |
| | field_optional_none: Optional[int] = None |
| | field_optional_dict_none: Optional[Dict] = None |
| | field_optional_with_value: Optional[int] = 42 |
| | field_list_type: List[int] = field(default_factory=lambda: []) |
| |
|
| |
|
| | class RefObject: |
| | pass |
| |
|
| |
|
| | REF_OBJECT = RefObject() |
| |
|
| |
|
| | class MockClassWithInit: |
| | def __init__( |
| | self, |
| | field_no_nothing, |
| | field_no_default: int, |
| | field_primitive_type: int = 42, |
| | field_optional_none: Optional[int] = None, |
| | field_optional_dict_none: Optional[Dict] = None, |
| | field_optional_with_value: Optional[int] = 42, |
| | field_list_type: List[int] = [], |
| | field_reference_type: RefObject = REF_OBJECT, |
| | ): |
| | self.field_no_nothing = field_no_nothing |
| | self.field_no_default = field_no_default |
| | self.field_primitive_type = field_primitive_type |
| | self.field_optional_none = field_optional_none |
| | self.field_optional_dict_none = field_optional_dict_none |
| | self.field_optional_with_value = field_optional_with_value |
| | self.field_list_type = field_list_type |
| | self.field_reference_type = field_reference_type |
| |
|
| |
|
| | enable_get_default_args(MockClassWithInit) |
| |
|
| |
|
| | class TestRawClasses(unittest.TestCase): |
| | def setUp(self) -> None: |
| | self._instances = { |
| | MockDataclass: MockDataclass(field_no_default=0), |
| | MockClassWithInit: MockClassWithInit( |
| | field_no_nothing="tratata", field_no_default=0 |
| | ), |
| | } |
| |
|
| | def test_get_default_args(self): |
| | for cls in [MockDataclass, MockClassWithInit]: |
| | dataclass_defaults = get_default_args(cls) |
| | |
| | self.assertNotIn("field_no_default", dataclass_defaults) |
| | self.assertNotIn("field_no_nothing", dataclass_defaults) |
| | self.assertNotIn("field_reference_type", dataclass_defaults) |
| | expected_defaults = [ |
| | "field_primitive_type", |
| | "field_optional_none", |
| | "field_optional_dict_none", |
| | "field_optional_with_value", |
| | "field_list_type", |
| | ] |
| |
|
| | if cls == MockDataclass: |
| | dataclass_defaults.field_no_default = 0 |
| | expected_defaults.insert(0, "field_no_default") |
| | self.assertEqual(list(dataclass_defaults), expected_defaults) |
| | for name, val in dataclass_defaults.items(): |
| | self.assertTrue(hasattr(self._instances[cls], name)) |
| | self.assertEqual(val, getattr(self._instances[cls], name)) |
| |
|
| | def test_get_default_args_readonly(self): |
| | for cls in [MockDataclass, MockClassWithInit]: |
| | dataclass_defaults = get_default_args(cls) |
| | dataclass_defaults["field_list_type"].append(13) |
| | self.assertEqual(self._instances[cls].field_list_type, []) |
| |
|