Spaces:
Runtime error
Runtime error
| import argparse | |
| import pytest | |
| import torch | |
| from src.data.custom_transforms import NormalizeIntensity_custom | |
| from src.data.data_loader import get_dataloader | |
| from src.model.cspca_model import CSPCAModel | |
| from src.model.mil import MILModel3D | |
| from src.train import train_cspca, train_pirads | |
| from src.train.train_pirads import get_attention_scores | |
| def mock_args(): | |
| # Mocking argparse for the device | |
| args = argparse.Namespace() | |
| args.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| return args | |
| def test_get_attention_scores_logic(mock_args): | |
| # Setup: 2 samples, 4 patches, images of size 8x8 | |
| batch_size = 2 | |
| num_patches = 4 | |
| # Sample 0: Target = 3 (Cancer), Sample 1: Target = 0 (PI-RADS 2) | |
| data = torch.randn(batch_size, num_patches, 1, 8, 8) | |
| target = torch.tensor([3.0, 0.0]) | |
| # Create heatmaps: Sample 0 has one "hot" patch | |
| heatmap = torch.zeros(batch_size, num_patches, 1, 8, 8) | |
| heatmap[0, 0] = 10.0 # High attention on patch 0 for the first sample | |
| heatmap[1, :] = 5.0 # Should be overridden by PI-RADS 2 logic anyway | |
| att_labels, shuffled_images = get_attention_scores(data, target, heatmap, mock_args) | |
| # --- TEST 1: Normalization --- | |
| sums = att_labels.sum(dim=1) | |
| torch.testing.assert_close(sums, torch.ones(batch_size).to(mock_args.device)) | |
| # --- TEST 2: PI-RADS 2 Uniformity --- | |
| pirads_2_scores = att_labels[1] | |
| expected_uniform = torch.ones(num_patches).to(mock_args.device) / num_patches | |
| torch.testing.assert_close(pirads_2_scores, expected_uniform) | |
| # --- TEST 4: Output Shapes --- | |
| assert att_labels.shape == (batch_size, num_patches) | |
| assert shuffled_images.shape == data.shape | |
| def test_shuffling_consistency(mock_args): | |
| # Verify that the image and label are shuffled with the SAME permutation | |
| num_patches = 10 | |
| # Distinct data per patch: [0, 1, 2, 3...] | |
| data = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float() | |
| target = torch.tensor([3.0]) | |
| # Heatmap matches the data indices so we can track the "label" | |
| heatmap = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float() | |
| att_labels, shuffled_images = get_attention_scores(data, target, heatmap, mock_args) | |
| idx = (shuffled_images[0, :, 0, 0, 0] == 9.0).nonzero(as_tuple=True)[0] | |
| # The attention score at that same index should be the maximum | |
| assert att_labels[0, idx] == att_labels[0].max() | |
| idx = (shuffled_images[0, :, 0, 0, 0] == 0.0).nonzero(as_tuple=True)[0] | |
| # The attention score at that same index should be the minimum | |
| assert att_labels[0, idx] == att_labels[0].min() | |
| shuffled_images = shuffled_images.cpu().squeeze() # Shape [10] | |
| att_labels = att_labels.cpu().squeeze() # Shape [10] | |
| sorted_vals, original_indices = torch.sort(shuffled_images) | |
| sorted_labels = att_labels[original_indices] | |
| for i in range(len(sorted_labels) - 1): | |
| assert sorted_labels[i] <= sorted_labels[i + 1], ( | |
| f"Alignment broken at index {i}: Image val {sorted_vals[i]} has higher label than {sorted_vals[i + 1]}" | |
| ) | |
| def test_normalize_intensity_custom_masked_stats(): | |
| """ | |
| Test that statistics (mean/std) are calculated ONLY from the masked region, | |
| but applied to the whole image. | |
| """ | |
| img = torch.zeros((2, 4, 4), dtype=torch.float32) | |
| mask = torch.zeros((1, 4, 4), dtype=torch.float32) | |
| img[0, :, :] = 100.0 | |
| img[0, 0, 0] = 10.0 | |
| img[0, 0, 1] = 20.0 | |
| img[1, :, :] = 50.0 | |
| img[1, 0, 0] = 2.0 | |
| img[1, 0, 1] = 4.0 | |
| mask[0, 0, 0] = 1 | |
| mask[0, 0, 1] = 1 | |
| normalizer = NormalizeIntensity_custom(nonzero=False, channel_wise=True) | |
| out = normalizer(img, mask) | |
| assert torch.isclose(out[0, 0, 0], torch.tensor(-1.0)), "Ch0 masked value 1 incorrect" | |
| assert torch.isclose(out[0, 0, 1], torch.tensor(1.0)), "Ch0 masked value 2 incorrect" | |
| assert torch.isclose(out[0, 1, 1], torch.tensor(17.0)), "Ch0 background normalization incorrect" | |
| assert torch.isclose(out[1, 0, 0], torch.tensor(-1.0)), "Ch1 masked value 1 incorrect" | |
| assert torch.isclose(out[1, 1, 1], torch.tensor(47.0)), "Ch1 background normalization incorrect" | |
| def test_normalize_intensity_constant_area(): | |
| """ | |
| Test edge case where the area under the mask has 0 variance (constant value). | |
| Std should default to 1.0 to avoid division by zero. | |
| """ | |
| img = torch.ones((1, 4, 4)) * 10.0 # All values are 10 | |
| mask = torch.ones((1, 4, 4)) | |
| normalizer = NormalizeIntensity_custom(channel_wise=True) | |
| out = normalizer(img, mask) | |
| assert torch.allclose(out, torch.zeros_like(out)) | |
| data = torch.rand(1, 10, 10) | |
| mask = torch.randint(0, 2, (1, 10, 10)).float() | |
| normalizer = NormalizeIntensity_custom(nonzero=False, channel_wise=True) | |
| out = normalizer(data, mask) | |
| masked = data[mask != 0] | |
| mean_val = torch.mean(masked.float()) | |
| std_val = torch.std(masked.float(), unbiased=False) | |
| epsilon = 1e-8 | |
| normalized_data = (data - mean_val) / (std_val + epsilon) | |
| torch.testing.assert_close(out, normalized_data) | |
| def test_run_models(): | |
| args = argparse.Namespace() | |
| args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| args.epochs = 1 | |
| args.batch_size = 2 | |
| args.tile_size = 10 | |
| args.tile_count = 5 | |
| args.use_heatmap = True | |
| args.amp = False | |
| args.num_classes = 4 | |
| args.dry_run = True | |
| args.depth = 3 | |
| model = MILModel3D(num_classes=args.num_classes, mil_mode="att_trans") | |
| model.to(args.device) | |
| params = model.parameters() | |
| loader = get_dataloader(args, split="train") | |
| optimizer = torch.optim.AdamW(params, lr=1e-5, weight_decay=1e-5) | |
| scaler = torch.amp.GradScaler(device=str(args.device), enabled=args.amp) | |
| _ = train_pirads.train_epoch(model, loader, optimizer, scaler=scaler, epoch=0, args=args) | |
| _ = train_pirads.val_epoch(model, loader, epoch=0, args=args) | |
| cspca_model = CSPCAModel(backbone=model).to(args.device) | |
| optimizer_cspca = torch.optim.AdamW(cspca_model.parameters(), lr=1e-5) | |
| _ = train_cspca.train_epoch(cspca_model, loader, optimizer_cspca, epoch=0, args=args) | |
| _ = train_cspca.val_epoch(cspca_model, loader, epoch=0, args=args) | |