Spaces:
Sleeping
Sleeping
File size: 7,055 Bytes
8f383ed 6fcc1b9 8f383ed 6fcc1b9 8f383ed 6fcc1b9 8f383ed 6fcc1b9 8f383ed 6fcc1b9 8f383ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
"""Unit tests for mosaic.inference.data module."""
import numpy as np
import pytest
import torch
from mosaic.inference.data import (
CANCER_TYPE_TO_INT_MAP,
INT_TO_CANCER_TYPE_MAP,
SiteType,
TileFeatureTensorDataset,
)
class TestCancerTypeMaps:
"""Test cancer type mapping constants."""
def test_cancer_type_to_int_map_has_entries(self):
"""Test that CANCER_TYPE_TO_INT_MAP has entries."""
assert len(CANCER_TYPE_TO_INT_MAP) > 0
def test_int_to_cancer_type_map_has_entries(self):
"""Test that INT_TO_CANCER_TYPE_MAP has entries."""
assert len(INT_TO_CANCER_TYPE_MAP) > 0
def test_maps_are_inverse(self):
"""Test that the two maps are inverses of each other."""
assert len(CANCER_TYPE_TO_INT_MAP) == len(INT_TO_CANCER_TYPE_MAP)
for cancer_type, idx in CANCER_TYPE_TO_INT_MAP.items():
assert INT_TO_CANCER_TYPE_MAP[idx] == cancer_type
def test_cancer_type_to_int_map_contains_known_types(self):
"""Test that the map contains some known cancer types."""
known_types = ["LUAD", "BRCA", "PRAD", "COAD"]
for cancer_type in known_types:
assert cancer_type in CANCER_TYPE_TO_INT_MAP
def test_indices_are_unique(self):
"""Test that all indices in CANCER_TYPE_TO_INT_MAP are unique."""
indices = list(CANCER_TYPE_TO_INT_MAP.values())
assert len(indices) == len(set(indices))
class TestSiteType:
"""Test SiteType enum."""
def test_site_type_primary_value(self):
"""Test that PRIMARY has correct value."""
assert SiteType.PRIMARY.value == "Primary"
def test_site_type_metastasis_value(self):
"""Test that METASTASIS has correct value."""
assert SiteType.METASTASIS.value == "Metastasis"
def test_site_type_has_two_members(self):
"""Test that SiteType enum has exactly two members."""
assert len(list(SiteType)) == 2
class TestTileFeatureTensorDataset:
"""Test TileFeatureTensorDataset class."""
@pytest.fixture
def sample_features(self):
"""Create sample features for testing."""
np.random.seed(42)
return np.random.rand(100, 768).astype(np.float32)
@pytest.fixture
def large_features(self):
"""Create large sample features for testing padding/truncation."""
np.random.seed(42)
return np.random.rand(25000, 768).astype(np.float32)
@pytest.fixture
def small_features(self):
"""Create small sample features for testing padding."""
np.random.seed(42)
return np.random.rand(50, 768).astype(np.float32)
def test_dataset_initialization(self, sample_features):
"""Test basic dataset initialization."""
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=sample_features,
n_max_tiles=20000,
)
assert dataset.site_type == SiteType.PRIMARY
assert dataset.n_max_tiles == 20000
assert isinstance(dataset.features, torch.Tensor)
def test_dataset_length(self, sample_features):
"""Test that dataset length is always 1."""
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=sample_features,
)
assert len(dataset) == 1
def test_dataset_getitem_structure(self, sample_features):
"""Test that __getitem__ returns correct structure."""
dataset = TileFeatureTensorDataset(
site_type=SiteType.METASTASIS,
tile_features=sample_features,
)
item = dataset[0]
assert isinstance(item, dict)
assert "site" in item
assert "tile_tensor" in item
assert item["site"] == "Metastasis"
assert isinstance(item["tile_tensor"], torch.Tensor)
def test_features_are_padded_when_small(self, small_features):
"""Test that features are padded when fewer than n_max_tiles."""
n_max_tiles = 1000
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=small_features,
n_max_tiles=n_max_tiles,
)
assert dataset.features.shape[0] == n_max_tiles
assert dataset.features.shape[1] == small_features.shape[1]
def test_features_are_truncated_when_large(self, large_features):
"""Test that features are truncated when more than n_max_tiles."""
n_max_tiles = 20000
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=large_features,
n_max_tiles=n_max_tiles,
)
assert dataset.features.shape[0] == n_max_tiles
assert dataset.features.shape[1] == large_features.shape[1]
def test_features_dtype_is_float32(self, sample_features):
"""Test that features are converted to float32."""
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=sample_features,
)
assert dataset.features.dtype == torch.float32
def test_site_type_primary(self, sample_features):
"""Test dataset with PRIMARY site type."""
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=sample_features,
)
item = dataset[0]
assert item["site"] == "Primary"
def test_site_type_metastasis(self, sample_features):
"""Test dataset with METASTASIS site type."""
dataset = TileFeatureTensorDataset(
site_type=SiteType.METASTASIS,
tile_features=sample_features,
)
item = dataset[0]
assert item["site"] == "Metastasis"
def test_features_exact_size(self):
"""Test that features of exactly n_max_tiles are not modified."""
np.random.seed(42)
n_max_tiles = 100
features = np.random.rand(n_max_tiles, 768).astype(np.float32)
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=features,
n_max_tiles=n_max_tiles,
)
assert dataset.features.shape[0] == n_max_tiles
assert dataset.features.shape[1] == 768
def test_features_shape_preserved(self, sample_features):
"""Test that feature dimensionality is preserved."""
original_dim = sample_features.shape[1]
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=sample_features,
)
assert dataset.features.shape[1] == original_dim
def test_different_feature_dimensions(self):
"""Test dataset with different feature dimensions."""
np.random.seed(42)
for dim in [256, 512, 768, 1024]:
features = np.random.rand(100, dim).astype(np.float32)
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=features,
)
assert dataset.features.shape[1] == dim
|