| """Tests for config loading, including GlobalHydra conflict regression.""" |
|
|
| from omegaconf import OmegaConf |
|
|
| from colipri.checkpoint import _load_config |
| from colipri.checkpoint import load_model_config |
| from colipri.checkpoint import load_processor_config |
|
|
|
|
| def test_load_config(): |
| config = _load_config() |
| assert config.input_size == 192 |
| assert config.spacing == 2 |
| assert config.model._target_ == "colipri.model.multimodal.Model" |
| assert config.processor._target_ == "colipri.processor.Processor" |
|
|
|
|
| def test_load_model_config(): |
| config = load_model_config() |
| assert config._target_ == "colipri.model.multimodal.Model" |
| assert "image_encoder" in config |
| assert "text_encoder" in config |
|
|
|
|
| def test_load_processor_config(): |
| config = load_processor_config() |
| assert config._target_ == "colipri.processor.Processor" |
| assert "image_transform" in config |
| assert "tokenizer" in config |
|
|
|
|
| def test_overrides(): |
| config = _load_config(overrides=["input_size=256", "spacing=3"]) |
| assert config.input_size == 256 |
| assert config.spacing == 3 |
|
|
|
|
| def test_interpolation(): |
| config = _load_config() |
| resolved = OmegaConf.to_container(config, resolve=True) |
| assert isinstance(resolved, dict) |
| backbone = resolved["model"]["image_encoder"]["backbone"] |
| assert backbone["embed_dim"] == 864 |
| assert backbone["input_shape"] == [192, 192, 192] |
|
|
|
|
| def test_config_loading_with_hydra_preinitialized(): |
| """Regression test: COLIPRI must work when GlobalHydra is already initialized. |
| |
| See https://huggingface.co/microsoft/colipri/discussions/3 |
| """ |
| from hydra import initialize |
| from hydra.core.global_hydra import GlobalHydra |
|
|
| with initialize(config_path=None, version_base=None): |
| assert GlobalHydra.instance().is_initialized() |
|
|
| model_cfg = load_model_config() |
| proc_cfg = load_processor_config() |
| assert model_cfg._target_ == "colipri.model.multimodal.Model" |
| assert proc_cfg._target_ == "colipri.processor.Processor" |
|
|
| |
| config = _load_config(overrides=["input_size=256"]) |
| assert config.input_size == 256 |
|
|
| |
| assert not GlobalHydra.instance().is_initialized() |
|
|