File size: 4,839 Bytes
a345d50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Tests for custom config support in get_model and get_processor."""

from __future__ import annotations

from pathlib import Path

import pytest
import torchio as tio
from omegaconf import DictConfig
from omegaconf import OmegaConf

from colipri.checkpoint import load_model_config
from colipri.checkpoint import load_processor_config
from colipri.processor import Processor
from colipri.processor import get_processor

TRANSFORM_YAML_CONTENT = """\
_target_: torchio.transforms.augmentation.composition.Compose
transforms:
  - _target_: torchio.transforms.preprocessing.intensity.clamp.Clamp
    out_min: -500
    out_max: 500
"""


@pytest.fixture
def transform_yaml(tmp_path: Path) -> Path:
    """A self-contained transform YAML with no interpolation variables."""
    path = tmp_path / "transform.yaml"
    path.write_text(TRANSFORM_YAML_CONTENT)
    return path


@pytest.fixture
def resolved_processor_config() -> DictConfig:
    """Default processor config with all interpolations resolved."""
    config = load_processor_config()
    resolved = OmegaConf.to_container(config, resolve=True)
    assert isinstance(resolved, dict)
    return OmegaConf.create(resolved)


@pytest.fixture
def resolved_model_config() -> DictConfig:
    """Default model config with all interpolations resolved."""
    config = load_model_config()
    resolved = OmegaConf.to_container(config, resolve=True)
    assert isinstance(resolved, dict)
    return OmegaConf.create(resolved)


class TestGetProcessorCustomConfig:
    def test_with_transform_yaml_path(self, transform_yaml: Path) -> None:
        """Transform YAML path → Processor with custom transform."""
        processor = get_processor(config=transform_yaml, image_only=True)
        assert isinstance(processor, Processor)
        transform = processor._image_transform
        assert isinstance(transform, tio.Compose)
        assert len(transform.transforms) == 1
        clamp = transform.transforms[0]
        assert clamp.out_min == -500
        assert clamp.out_max == 500

    def test_with_transform_dictconfig(self) -> None:
        """Transform DictConfig object → Processor with custom transform."""
        config = OmegaConf.create(TRANSFORM_YAML_CONTENT)
        processor = get_processor(config=config, image_only=True)
        assert isinstance(processor, Processor)
        transform = processor._image_transform
        assert isinstance(transform, tio.Compose)
        assert transform.transforms[0].out_min == -500

    def test_with_full_processor_config(
        self,
        resolved_processor_config: DictConfig,
    ) -> None:
        """Full processor DictConfig → Processor matching that config."""
        # Remove all but the first transform to distinguish from default (5 transforms)
        resolved_processor_config.image_transform.transforms = (
            resolved_processor_config.image_transform.transforms[:1]
        )
        processor = get_processor(
            config=resolved_processor_config,
            image_only=True,
        )
        assert isinstance(processor, Processor)
        assert isinstance(processor._image_transform, tio.Compose)
        assert len(processor._image_transform.transforms) == 1

    def test_transform_yaml_wraps_with_default_tokenizer(
        self,
        transform_yaml: Path,
    ) -> None:
        """Transform-only config is wrapped with default tokenizer config."""
        # Without image_only, the transform YAML should be wrapped into a full
        # processor config that includes the default tokenizer.
        processor = get_processor(config=transform_yaml)
        assert isinstance(processor, Processor)
        # Should have both custom transform and default tokenizer
        assert isinstance(processor._image_transform, tio.Compose)
        assert processor._text_tokenizer is not None

    def test_default_unchanged(self) -> None:
        """get_processor() without config still works (backward compat)."""
        processor = get_processor(image_only=True)
        assert isinstance(processor, Processor)
        assert isinstance(processor._image_transform, tio.Compose)


class TestGetModelCustomConfig:
    def test_with_config(self, resolved_model_config: DictConfig) -> None:
        """Pass a model DictConfig → Model with that config."""
        from colipri.model.multimodal import Model
        from colipri.model.multimodal import get_model

        model = get_model(pretrained=False, config=resolved_model_config)
        assert isinstance(model, Model)

    def test_default_unchanged(self) -> None:
        """get_model(pretrained=False) without config still works."""
        from colipri.model.multimodal import Model
        from colipri.model.multimodal import get_model

        model = get_model(pretrained=False)
        assert isinstance(model, Model)