| | |
| | |
| | |
| | |
| | |
| |
|
| | import unittest |
| |
|
| | from omegaconf import OmegaConf |
| | from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( |
| | ResNetFeatureExtractor, |
| | ) |
| | from pytorch3d.implicitron.models.generic_model import GenericModel |
| | from pytorch3d.implicitron.models.global_encoder.global_encoder import ( |
| | SequenceAutodecoder, |
| | ) |
| | from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( |
| | IdrFeatureField, |
| | ) |
| | from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( |
| | NeuralRadianceFieldImplicitFunction, |
| | ) |
| | from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer |
| | from pytorch3d.implicitron.models.renderer.multipass_ea import ( |
| | MultiPassEmissionAbsorptionRenderer, |
| | ) |
| | from pytorch3d.implicitron.models.view_pooler.feature_aggregator import ( |
| | AngleWeightedIdentityFeatureAggregator, |
| | ) |
| | from pytorch3d.implicitron.tools.config import ( |
| | get_default_args, |
| | remove_unused_components, |
| | ) |
| | from tests.common_testing import get_tests_dir |
| |
|
| | from .common_resources import provide_resnet34 |
| |
|
| | DATA_DIR = get_tests_dir() / "implicitron/data" |
| | DEBUG: bool = False |
| |
|
| | |
| |
|
| |
|
| | class TestGenericModel(unittest.TestCase): |
| | def setUp(self): |
| | self.maxDiff = None |
| |
|
| | def test_create_gm(self): |
| | args = get_default_args(GenericModel) |
| | gm = GenericModel(**args) |
| | self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer) |
| | self.assertIsInstance( |
| | gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction |
| | ) |
| | self.assertIsNone(gm.global_encoder) |
| | self.assertFalse(hasattr(gm, "implicit_function")) |
| | self.assertIsNone(gm.view_pooler) |
| | self.assertIsNone(gm.image_feature_extractor) |
| |
|
| | def test_create_gm_overrides(self): |
| | provide_resnet34() |
| | args = get_default_args(GenericModel) |
| | args.view_pooler_enabled = True |
| | args.view_pooler_args.feature_aggregator_class_type = ( |
| | "AngleWeightedIdentityFeatureAggregator" |
| | ) |
| | args.image_feature_extractor_class_type = "ResNetFeatureExtractor" |
| | args.implicit_function_class_type = "IdrFeatureField" |
| | args.global_encoder_class_type = "SequenceAutodecoder" |
| | idr_args = args.implicit_function_IdrFeatureField_args |
| | idr_args.n_harmonic_functions_xyz = 1729 |
| |
|
| | args.renderer_class_type = "LSTMRenderer" |
| | gm = GenericModel(**args) |
| | self.assertIsInstance(gm.renderer, LSTMRenderer) |
| | self.assertIsInstance( |
| | gm.view_pooler.feature_aggregator, |
| | AngleWeightedIdentityFeatureAggregator, |
| | ) |
| | self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField) |
| | self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729) |
| | self.assertIsInstance(gm.global_encoder, SequenceAutodecoder) |
| | self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor) |
| | self.assertFalse(hasattr(gm, "implicit_function")) |
| |
|
| | instance_args = OmegaConf.structured(gm) |
| | if DEBUG: |
| | full_yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) |
| | (DATA_DIR / "overrides_full.yaml").write_text(full_yaml) |
| | remove_unused_components(instance_args) |
| | yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) |
| | if DEBUG: |
| | (DATA_DIR / "overrides_.yaml").write_text(yaml) |
| | self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text()) |
| |
|