File size: 2,297 Bytes
6eb076f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Tests for config loading, including GlobalHydra conflict regression."""

from omegaconf import OmegaConf

from colipri.checkpoint import _load_config
from colipri.checkpoint import load_model_config
from colipri.checkpoint import load_processor_config


def test_load_config():
    config = _load_config()
    assert config.input_size == 192
    assert config.spacing == 2
    assert config.model._target_ == "colipri.model.multimodal.Model"
    assert config.processor._target_ == "colipri.processor.Processor"


def test_load_model_config():
    config = load_model_config()
    assert config._target_ == "colipri.model.multimodal.Model"
    assert "image_encoder" in config
    assert "text_encoder" in config


def test_load_processor_config():
    config = load_processor_config()
    assert config._target_ == "colipri.processor.Processor"
    assert "image_transform" in config
    assert "tokenizer" in config


def test_overrides():
    config = _load_config(overrides=["input_size=256", "spacing=3"])
    assert config.input_size == 256
    assert config.spacing == 3


def test_interpolation():
    config = _load_config()
    resolved = OmegaConf.to_container(config, resolve=True)
    assert isinstance(resolved, dict)
    backbone = resolved["model"]["image_encoder"]["backbone"]
    assert backbone["embed_dim"] == 864  # ${image_embed_dim}
    assert backbone["input_shape"] == [192, 192, 192]  # ${input_size}


def test_config_loading_with_hydra_preinitialized():
    """Regression test: COLIPRI must work when GlobalHydra is already initialized.

    See https://huggingface.co/microsoft/colipri/discussions/3
    """
    from hydra import initialize
    from hydra.core.global_hydra import GlobalHydra

    with initialize(config_path=None, version_base=None):
        assert GlobalHydra.instance().is_initialized()

        model_cfg = load_model_config()
        proc_cfg = load_processor_config()
        assert model_cfg._target_ == "colipri.model.multimodal.Model"
        assert proc_cfg._target_ == "colipri.processor.Processor"

        # Overrides must also work
        config = _load_config(overrides=["input_size=256"])
        assert config.input_size == 256

    # GlobalHydra state must not be corrupted
    assert not GlobalHydra.instance().is_initialized()