mosaic-zero / tests /inference /test_data.py
copilot-swe-agent[bot]
Address code review feedback: add deterministic seeds and improve mocks
6fcc1b9
"""Unit tests for mosaic.inference.data module."""
import numpy as np
import pytest
import torch
from mosaic.inference.data import (
CANCER_TYPE_TO_INT_MAP,
INT_TO_CANCER_TYPE_MAP,
SiteType,
TileFeatureTensorDataset,
)
class TestCancerTypeMaps:
"""Test cancer type mapping constants."""
def test_cancer_type_to_int_map_has_entries(self):
"""Test that CANCER_TYPE_TO_INT_MAP has entries."""
assert len(CANCER_TYPE_TO_INT_MAP) > 0
def test_int_to_cancer_type_map_has_entries(self):
"""Test that INT_TO_CANCER_TYPE_MAP has entries."""
assert len(INT_TO_CANCER_TYPE_MAP) > 0
def test_maps_are_inverse(self):
"""Test that the two maps are inverses of each other."""
assert len(CANCER_TYPE_TO_INT_MAP) == len(INT_TO_CANCER_TYPE_MAP)
for cancer_type, idx in CANCER_TYPE_TO_INT_MAP.items():
assert INT_TO_CANCER_TYPE_MAP[idx] == cancer_type
def test_cancer_type_to_int_map_contains_known_types(self):
"""Test that the map contains some known cancer types."""
known_types = ["LUAD", "BRCA", "PRAD", "COAD"]
for cancer_type in known_types:
assert cancer_type in CANCER_TYPE_TO_INT_MAP
def test_indices_are_unique(self):
"""Test that all indices in CANCER_TYPE_TO_INT_MAP are unique."""
indices = list(CANCER_TYPE_TO_INT_MAP.values())
assert len(indices) == len(set(indices))
class TestSiteType:
"""Test SiteType enum."""
def test_site_type_primary_value(self):
"""Test that PRIMARY has correct value."""
assert SiteType.PRIMARY.value == "Primary"
def test_site_type_metastasis_value(self):
"""Test that METASTASIS has correct value."""
assert SiteType.METASTASIS.value == "Metastasis"
def test_site_type_has_two_members(self):
"""Test that SiteType enum has exactly two members."""
assert len(list(SiteType)) == 2
class TestTileFeatureTensorDataset:
"""Test TileFeatureTensorDataset class."""
@pytest.fixture
def sample_features(self):
"""Create sample features for testing."""
np.random.seed(42)
return np.random.rand(100, 768).astype(np.float32)
@pytest.fixture
def large_features(self):
"""Create large sample features for testing padding/truncation."""
np.random.seed(42)
return np.random.rand(25000, 768).astype(np.float32)
@pytest.fixture
def small_features(self):
"""Create small sample features for testing padding."""
np.random.seed(42)
return np.random.rand(50, 768).astype(np.float32)
def test_dataset_initialization(self, sample_features):
"""Test basic dataset initialization."""
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=sample_features,
n_max_tiles=20000,
)
assert dataset.site_type == SiteType.PRIMARY
assert dataset.n_max_tiles == 20000
assert isinstance(dataset.features, torch.Tensor)
def test_dataset_length(self, sample_features):
"""Test that dataset length is always 1."""
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=sample_features,
)
assert len(dataset) == 1
def test_dataset_getitem_structure(self, sample_features):
"""Test that __getitem__ returns correct structure."""
dataset = TileFeatureTensorDataset(
site_type=SiteType.METASTASIS,
tile_features=sample_features,
)
item = dataset[0]
assert isinstance(item, dict)
assert "site" in item
assert "tile_tensor" in item
assert item["site"] == "Metastasis"
assert isinstance(item["tile_tensor"], torch.Tensor)
def test_features_are_padded_when_small(self, small_features):
"""Test that features are padded when fewer than n_max_tiles."""
n_max_tiles = 1000
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=small_features,
n_max_tiles=n_max_tiles,
)
assert dataset.features.shape[0] == n_max_tiles
assert dataset.features.shape[1] == small_features.shape[1]
def test_features_are_truncated_when_large(self, large_features):
"""Test that features are truncated when more than n_max_tiles."""
n_max_tiles = 20000
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=large_features,
n_max_tiles=n_max_tiles,
)
assert dataset.features.shape[0] == n_max_tiles
assert dataset.features.shape[1] == large_features.shape[1]
def test_features_dtype_is_float32(self, sample_features):
"""Test that features are converted to float32."""
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=sample_features,
)
assert dataset.features.dtype == torch.float32
def test_site_type_primary(self, sample_features):
"""Test dataset with PRIMARY site type."""
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=sample_features,
)
item = dataset[0]
assert item["site"] == "Primary"
def test_site_type_metastasis(self, sample_features):
"""Test dataset with METASTASIS site type."""
dataset = TileFeatureTensorDataset(
site_type=SiteType.METASTASIS,
tile_features=sample_features,
)
item = dataset[0]
assert item["site"] == "Metastasis"
def test_features_exact_size(self):
"""Test that features of exactly n_max_tiles are not modified."""
np.random.seed(42)
n_max_tiles = 100
features = np.random.rand(n_max_tiles, 768).astype(np.float32)
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=features,
n_max_tiles=n_max_tiles,
)
assert dataset.features.shape[0] == n_max_tiles
assert dataset.features.shape[1] == 768
def test_features_shape_preserved(self, sample_features):
"""Test that feature dimensionality is preserved."""
original_dim = sample_features.shape[1]
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=sample_features,
)
assert dataset.features.shape[1] == original_dim
def test_different_feature_dimensions(self):
"""Test dataset with different feature dimensions."""
np.random.seed(42)
for dim in [256, 512, 768, 1024]:
features = np.random.rand(100, dim).astype(np.float32)
dataset = TileFeatureTensorDataset(
site_type=SiteType.PRIMARY,
tile_features=features,
)
assert dataset.features.shape[1] == dim