colipri / tests /test_config.py
mmrech's picture
Duplicate from microsoft/colipri
4886d4e
"""Tests for config loading, including GlobalHydra conflict regression."""
from omegaconf import OmegaConf
from colipri.checkpoint import _load_config
from colipri.checkpoint import load_model_config
from colipri.checkpoint import load_processor_config
def test_load_config():
config = _load_config()
assert config.input_size == 192
assert config.spacing == 2
assert config.model._target_ == "colipri.model.multimodal.Model"
assert config.processor._target_ == "colipri.processor.Processor"
def test_load_model_config():
config = load_model_config()
assert config._target_ == "colipri.model.multimodal.Model"
assert "image_encoder" in config
assert "text_encoder" in config
def test_load_processor_config():
config = load_processor_config()
assert config._target_ == "colipri.processor.Processor"
assert "image_transform" in config
assert "tokenizer" in config
def test_overrides():
config = _load_config(overrides=["input_size=256", "spacing=3"])
assert config.input_size == 256
assert config.spacing == 3
def test_interpolation():
config = _load_config()
resolved = OmegaConf.to_container(config, resolve=True)
assert isinstance(resolved, dict)
backbone = resolved["model"]["image_encoder"]["backbone"]
assert backbone["embed_dim"] == 864 # ${image_embed_dim}
assert backbone["input_shape"] == [192, 192, 192] # ${input_size}
def test_config_loading_with_hydra_preinitialized():
"""Regression test: COLIPRI must work when GlobalHydra is already initialized.
See https://huggingface.co/microsoft/colipri/discussions/3
"""
from hydra import initialize
from hydra.core.global_hydra import GlobalHydra
with initialize(config_path=None, version_base=None):
assert GlobalHydra.instance().is_initialized()
model_cfg = load_model_config()
proc_cfg = load_processor_config()
assert model_cfg._target_ == "colipri.model.multimodal.Model"
assert proc_cfg._target_ == "colipri.processor.Processor"
# Overrides must also work
config = _load_config(overrides=["input_size=256"])
assert config.input_size == 256
# GlobalHydra state must not be corrupted
assert not GlobalHydra.instance().is_initialized()