File size: 6,223 Bytes
95dc457
1baebae
95dc457
 
1baebae
95dc457
 
 
 
 
 
1baebae
 
95dc457
 
 
 
 
 
1baebae
 
95dc457
 
 
 
1baebae
95dc457
 
 
1baebae
95dc457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caf6ee7
1baebae
95dc457
 
 
1baebae
95dc457
 
 
1baebae
95dc457
 
1baebae
95dc457
1baebae
95dc457
 
 
caf6ee7
95dc457
 
 
caf6ee7
95dc457
 
 
 
 
 
 
 
 
 
 
 
 
1baebae
95dc457
 
1baebae
 
95dc457
 
1baebae
95dc457
 
 
1baebae
95dc457
 
 
1baebae
95dc457
 
1baebae
95dc457
 
caf6ee7
95dc457
 
 
caf6ee7
95dc457
 
 
 
 
1baebae
95dc457
 
1baebae
95dc457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import argparse

import pytest
import torch

from src.data.custom_transforms import NormalizeIntensity_custom
from src.data.data_loader import get_dataloader
from src.model.cspca_model import CSPCAModel
from src.model.mil import MILModel3D
from src.train import train_cspca, train_pirads
from src.train.train_pirads import get_attention_scores


@pytest.fixture
def mock_args():
    # Mocking argparse for the device
    args = argparse.Namespace()
    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    return args


def test_get_attention_scores_logic(mock_args):
    # Setup: 2 samples, 4 patches, images of size 8x8
    batch_size = 2
    num_patches = 4

    # Sample 0: Target = 3 (Cancer), Sample 1: Target = 0 (PI-RADS 2)
    data = torch.randn(batch_size, num_patches, 1, 8, 8)
    target = torch.tensor([3.0, 0.0])

    # Create heatmaps: Sample 0 has one "hot" patch
    heatmap = torch.zeros(batch_size, num_patches, 1, 8, 8)
    heatmap[0, 0] = 10.0  # High attention on patch 0 for the first sample
    heatmap[1, :] = 5.0  # Should be overridden by PI-RADS 2 logic anyway

    att_labels, shuffled_images = get_attention_scores(data, target, heatmap, mock_args)

    # --- TEST 1: Normalization ---
    sums = att_labels.sum(dim=1)
    torch.testing.assert_close(sums, torch.ones(batch_size).to(mock_args.device))

    # --- TEST 2: PI-RADS 2 Uniformity ---
    pirads_2_scores = att_labels[1]
    expected_uniform = torch.ones(num_patches).to(mock_args.device) / num_patches
    torch.testing.assert_close(pirads_2_scores, expected_uniform)

    # --- TEST 4: Output Shapes ---
    assert att_labels.shape == (batch_size, num_patches)
    assert shuffled_images.shape == data.shape


def test_shuffling_consistency(mock_args):
    # Verify that the image and label are shuffled with the SAME permutation
    num_patches = 10

    # Distinct data per patch: [0, 1, 2, 3...]
    data = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float()
    target = torch.tensor([3.0])

    # Heatmap matches the data indices so we can track the "label"
    heatmap = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float()

    att_labels, shuffled_images = get_attention_scores(data, target, heatmap, mock_args)

    idx = (shuffled_images[0, :, 0, 0, 0] == 9.0).nonzero(as_tuple=True)[0]
    # The attention score at that same index should be the maximum
    assert att_labels[0, idx] == att_labels[0].max()

    idx = (shuffled_images[0, :, 0, 0, 0] == 0.0).nonzero(as_tuple=True)[0]
    # The attention score at that same index should be the minimum
    assert att_labels[0, idx] == att_labels[0].min()

    shuffled_images = shuffled_images.cpu().squeeze()  # Shape [10]
    att_labels = att_labels.cpu().squeeze()  # Shape [10]

    sorted_vals, original_indices = torch.sort(shuffled_images)
    sorted_labels = att_labels[original_indices]

    for i in range(len(sorted_labels) - 1):
        assert sorted_labels[i] <= sorted_labels[i + 1], (
            f"Alignment broken at index {i}: Image val {sorted_vals[i]} has higher label than {sorted_vals[i + 1]}"
        )


def test_normalize_intensity_custom_masked_stats():
    """
    Test that statistics (mean/std) are calculated ONLY from the masked region,
    but applied to the whole image.
    """

    img = torch.zeros((2, 4, 4), dtype=torch.float32)
    mask = torch.zeros((1, 4, 4), dtype=torch.float32)

    img[0, :, :] = 100.0
    img[0, 0, 0] = 10.0
    img[0, 0, 1] = 20.0

    img[1, :, :] = 50.0
    img[1, 0, 0] = 2.0
    img[1, 0, 1] = 4.0

    mask[0, 0, 0] = 1
    mask[0, 0, 1] = 1

    normalizer = NormalizeIntensity_custom(nonzero=False, channel_wise=True)
    out = normalizer(img, mask)

    assert torch.isclose(out[0, 0, 0], torch.tensor(-1.0)), "Ch0 masked value 1 incorrect"
    assert torch.isclose(out[0, 0, 1], torch.tensor(1.0)), "Ch0 masked value 2 incorrect"
    assert torch.isclose(out[0, 1, 1], torch.tensor(17.0)), "Ch0 background normalization incorrect"

    assert torch.isclose(out[1, 0, 0], torch.tensor(-1.0)), "Ch1 masked value 1 incorrect"
    assert torch.isclose(out[1, 1, 1], torch.tensor(47.0)), "Ch1 background normalization incorrect"


def test_normalize_intensity_constant_area():
    """
    Test edge case where the area under the mask has 0 variance (constant value).
    Std should default to 1.0 to avoid division by zero.
    """
    img = torch.ones((1, 4, 4)) * 10.0  # All values are 10
    mask = torch.ones((1, 4, 4))

    normalizer = NormalizeIntensity_custom(channel_wise=True)
    out = normalizer(img, mask)
    assert torch.allclose(out, torch.zeros_like(out))

    data = torch.rand(1, 10, 10)
    mask = torch.randint(0, 2, (1, 10, 10)).float()
    normalizer = NormalizeIntensity_custom(nonzero=False, channel_wise=True)
    out = normalizer(data, mask)

    masked = data[mask != 0]
    mean_val = torch.mean(masked.float())
    std_val = torch.std(masked.float(), unbiased=False)

    epsilon = 1e-8
    normalized_data = (data - mean_val) / (std_val + epsilon)

    torch.testing.assert_close(out, normalized_data)


def test_run_models():
    args = argparse.Namespace()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.epochs = 1
    args.batch_size = 2
    args.tile_size = 10
    args.tile_count = 5
    args.use_heatmap = True
    args.amp = False
    args.num_classes = 4
    args.dry_run = True
    args.depth = 3

    model = MILModel3D(num_classes=args.num_classes, mil_mode="att_trans")
    model.to(args.device)
    params = model.parameters()
    loader = get_dataloader(args, split="train")
    optimizer = torch.optim.AdamW(params, lr=1e-5, weight_decay=1e-5)
    scaler = torch.amp.GradScaler(device=str(args.device), enabled=args.amp)

    _ = train_pirads.train_epoch(model, loader, optimizer, scaler=scaler, epoch=0, args=args)
    _ = train_pirads.val_epoch(model, loader, epoch=0, args=args)

    cspca_model = CSPCAModel(backbone=model).to(args.device)
    optimizer_cspca = torch.optim.AdamW(cspca_model.parameters(), lr=1e-5)
    _ = train_cspca.train_epoch(cspca_model, loader, optimizer_cspca, epoch=0, args=args)
    _ = train_cspca.val_epoch(cspca_model, loader, epoch=0, args=args)