File size: 3,784 Bytes
cd89698
 
f6b233e
 
 
 
 
 
 
 
 
cd89698
 
 
 
 
 
 
 
 
 
 
 
 
f6b233e
cd89698
 
 
 
 
f6b233e
 
 
 
 
cd89698
 
 
 
 
 
f6b233e
cd89698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6b233e
cd89698
f6b233e
cd89698
 
 
 
 
 
 
 
 
 
 
f6b233e
cd89698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Architecture

## Patch Extraction

Patches are extracted using MONAI's `RandWeightedCropd` (when heatmaps are available) or `RandCropByPosNegLabeld` (without heatmaps):

- **With heatmaps**: The combined DWI/ADC heatmap multiplied by the prostate mask serves as the sampling weight map β€” regions with high DWI and low ADC are sampled more frequently
- **Without heatmaps**: Crops are sampled from positive (prostate) regions based on the binary mask

Each scan yields `N` patches (default 24) of size `tile_size x tile_size x depth` (default 64x64x3).

## Tensor Shape Convention

Throughout the pipeline, tensors follow the shape `[B, N, C, D, H, W]`:

| Dim | Meaning | Typical Value |
|-----|---------|---------------|
| B | Batch size | 4–8 |
| N | Number of patches (instances) | 24 |
| C | Channels (T2 + DWI + ADC) | 3 |
| D | Depth (slices per patch) | 3 |
| H | Patch height | 64 |
| W | Patch width | 64 |

## MILModel3D

The core model processes each patch independently through a CNN backbone, then aggregates patch-level features via a transformer encoder and attention pooling.

```mermaid
flowchart TD
    A["Input [B, N, C, D, H, W]"] --> B[ResNet18-3D Backbone]
    B --> C[Transformer Encoder<br/>4 layers, 8 heads]
    C --> D[Attention Pooling<br/>512 β†’ 2048 β†’ 1]
    D --> E["Weighted Sum [B, 512]"]
    E --> F["FC Head [B, num_classes]"]
```

### Forward Pass

1. **Backbone**: Input is reshaped from `[B, N, C, D, H, W]` to `[B*N, C, D, H, W]` and passed through a 3D ResNet18 (with 3 input channels). The final FC layer is removed, yielding 512-dimensional features per patch.

2. **Transformer**: Output from the ResNet 18 encoder is forwarded to the transformer encoder.

3. **Attention**: A two-layer attention network (`512 β†’ 2048 β†’ 1` with Tanh) computes a scalar weight per patch, normalized via softmax.

4. **Classification**: The attention-weighted sum of patch features produces a single `[B, 512]` vector per scan, which is projected to class logits by a linear layer.

### MIL Modes

| Mode | Aggregation Strategy |
|------|---------------------|
| `mean` | Average logits across patches |
| `max` | Max logits across patches |
| `att` | Attention-weighted feature pooling |
| `att_trans` | Transformer encoder + attention pooling (primary mode) |
| `att_trans_pyramid` | Pyramid transformer on intermediate ResNet layers + attention |

The default and primary mode is `att_trans`.

## csPCa_Model

Wraps a frozen `MILModel_3D` backbone and replaces the classification head:

```mermaid
flowchart TD
    A["Input [B, N, C, D, H, W]"] --> B["Frozen Backbone<br/>(ResNet18 + Transformer)"]
    B --> C["Pooled Features [B, 512]"]
    C --> D["SimpleNN Head<br/>512 β†’ 256 β†’ 128 β†’ 1"]
    D --> E["Sigmoid β†’ csPCa Probability"]
```

### SimpleNN

```
Linear(512, 256) β†’ ReLU
Linear(256, 128) β†’ ReLU β†’ Dropout(0.3)
Linear(128, 1) β†’ Sigmoid
```

During csPCa training, the backbone's `net` (ResNet18), `transformer` are frozen. The `attention` module and `SimpleNN` head remain trainable.

## Attention Loss

During PI-RADS training with heatmaps enabled, the model uses a dual-loss objective:

```
total_loss = class_loss + lambda_att * attention_loss
```

- **Classification loss**: Standard CrossEntropy on PI-RADS labels
- **Attention loss**: `1 - cosine_similarity(predicted_attention, heatmap_attention)`
    - Heatmap-derived attention labels are computed by summing spatial heatmap values per patch, squaring for sharpness, and normalizing
    - PI-RADS 2 samples get uniform attention (no expected lesion)
    - `lambda_att` warms up linearly from 0 to 2.0 over the first 25 epochs
    - The attention predictions are computed with detached transformer outputs to avoid gradient interference with classification