File size: 7,055 Bytes
8f383ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fcc1b9
8f383ed
 
 
 
 
6fcc1b9
8f383ed
 
 
 
 
6fcc1b9
8f383ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fcc1b9
8f383ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fcc1b9
8f383ed
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""Unit tests for mosaic.inference.data module."""

import numpy as np
import pytest
import torch

from mosaic.inference.data import (
    CANCER_TYPE_TO_INT_MAP,
    INT_TO_CANCER_TYPE_MAP,
    SiteType,
    TileFeatureTensorDataset,
)


class TestCancerTypeMaps:
    """Test cancer type mapping constants."""

    def test_cancer_type_to_int_map_has_entries(self):
        """Test that CANCER_TYPE_TO_INT_MAP has entries."""
        assert len(CANCER_TYPE_TO_INT_MAP) > 0

    def test_int_to_cancer_type_map_has_entries(self):
        """Test that INT_TO_CANCER_TYPE_MAP has entries."""
        assert len(INT_TO_CANCER_TYPE_MAP) > 0

    def test_maps_are_inverse(self):
        """Test that the two maps are inverses of each other."""
        assert len(CANCER_TYPE_TO_INT_MAP) == len(INT_TO_CANCER_TYPE_MAP)
        for cancer_type, idx in CANCER_TYPE_TO_INT_MAP.items():
            assert INT_TO_CANCER_TYPE_MAP[idx] == cancer_type

    def test_cancer_type_to_int_map_contains_known_types(self):
        """Test that the map contains some known cancer types."""
        known_types = ["LUAD", "BRCA", "PRAD", "COAD"]
        for cancer_type in known_types:
            assert cancer_type in CANCER_TYPE_TO_INT_MAP

    def test_indices_are_unique(self):
        """Test that all indices in CANCER_TYPE_TO_INT_MAP are unique."""
        indices = list(CANCER_TYPE_TO_INT_MAP.values())
        assert len(indices) == len(set(indices))


class TestSiteType:
    """Test SiteType enum."""

    def test_site_type_primary_value(self):
        """Test that PRIMARY has correct value."""
        assert SiteType.PRIMARY.value == "Primary"

    def test_site_type_metastasis_value(self):
        """Test that METASTASIS has correct value."""
        assert SiteType.METASTASIS.value == "Metastasis"

    def test_site_type_has_two_members(self):
        """Test that SiteType enum has exactly two members."""
        assert len(list(SiteType)) == 2


class TestTileFeatureTensorDataset:
    """Test TileFeatureTensorDataset class."""

    @pytest.fixture
    def sample_features(self):
        """Create sample features for testing."""
        np.random.seed(42)
        return np.random.rand(100, 768).astype(np.float32)

    @pytest.fixture
    def large_features(self):
        """Create large sample features for testing padding/truncation."""
        np.random.seed(42)
        return np.random.rand(25000, 768).astype(np.float32)

    @pytest.fixture
    def small_features(self):
        """Create small sample features for testing padding."""
        np.random.seed(42)
        return np.random.rand(50, 768).astype(np.float32)

    def test_dataset_initialization(self, sample_features):
        """Test basic dataset initialization."""
        dataset = TileFeatureTensorDataset(
            site_type=SiteType.PRIMARY,
            tile_features=sample_features,
            n_max_tiles=20000,
        )
        assert dataset.site_type == SiteType.PRIMARY
        assert dataset.n_max_tiles == 20000
        assert isinstance(dataset.features, torch.Tensor)

    def test_dataset_length(self, sample_features):
        """Test that dataset length is always 1."""
        dataset = TileFeatureTensorDataset(
            site_type=SiteType.PRIMARY,
            tile_features=sample_features,
        )
        assert len(dataset) == 1

    def test_dataset_getitem_structure(self, sample_features):
        """Test that __getitem__ returns correct structure."""
        dataset = TileFeatureTensorDataset(
            site_type=SiteType.METASTASIS,
            tile_features=sample_features,
        )
        item = dataset[0]
        assert isinstance(item, dict)
        assert "site" in item
        assert "tile_tensor" in item
        assert item["site"] == "Metastasis"
        assert isinstance(item["tile_tensor"], torch.Tensor)

    def test_features_are_padded_when_small(self, small_features):
        """Test that features are padded when fewer than n_max_tiles."""
        n_max_tiles = 1000
        dataset = TileFeatureTensorDataset(
            site_type=SiteType.PRIMARY,
            tile_features=small_features,
            n_max_tiles=n_max_tiles,
        )
        assert dataset.features.shape[0] == n_max_tiles
        assert dataset.features.shape[1] == small_features.shape[1]

    def test_features_are_truncated_when_large(self, large_features):
        """Test that features are truncated when more than n_max_tiles."""
        n_max_tiles = 20000
        dataset = TileFeatureTensorDataset(
            site_type=SiteType.PRIMARY,
            tile_features=large_features,
            n_max_tiles=n_max_tiles,
        )
        assert dataset.features.shape[0] == n_max_tiles
        assert dataset.features.shape[1] == large_features.shape[1]

    def test_features_dtype_is_float32(self, sample_features):
        """Test that features are converted to float32."""
        dataset = TileFeatureTensorDataset(
            site_type=SiteType.PRIMARY,
            tile_features=sample_features,
        )
        assert dataset.features.dtype == torch.float32

    def test_site_type_primary(self, sample_features):
        """Test dataset with PRIMARY site type."""
        dataset = TileFeatureTensorDataset(
            site_type=SiteType.PRIMARY,
            tile_features=sample_features,
        )
        item = dataset[0]
        assert item["site"] == "Primary"

    def test_site_type_metastasis(self, sample_features):
        """Test dataset with METASTASIS site type."""
        dataset = TileFeatureTensorDataset(
            site_type=SiteType.METASTASIS,
            tile_features=sample_features,
        )
        item = dataset[0]
        assert item["site"] == "Metastasis"

    def test_features_exact_size(self):
        """Test that features of exactly n_max_tiles are not modified."""
        np.random.seed(42)
        n_max_tiles = 100
        features = np.random.rand(n_max_tiles, 768).astype(np.float32)
        dataset = TileFeatureTensorDataset(
            site_type=SiteType.PRIMARY,
            tile_features=features,
            n_max_tiles=n_max_tiles,
        )
        assert dataset.features.shape[0] == n_max_tiles
        assert dataset.features.shape[1] == 768

    def test_features_shape_preserved(self, sample_features):
        """Test that feature dimensionality is preserved."""
        original_dim = sample_features.shape[1]
        dataset = TileFeatureTensorDataset(
            site_type=SiteType.PRIMARY,
            tile_features=sample_features,
        )
        assert dataset.features.shape[1] == original_dim

    def test_different_feature_dimensions(self):
        """Test dataset with different feature dimensions."""
        np.random.seed(42)
        for dim in [256, 512, 768, 1024]:
            features = np.random.rand(100, dim).astype(np.float32)
            dataset = TileFeatureTensorDataset(
                site_type=SiteType.PRIMARY,
                tile_features=features,
            )
            assert dataset.features.shape[1] == dim