# Architecture
## Patch Extraction
Patches are extracted using MONAI's `RandWeightedCropd` (when heatmaps are available) or `RandCropByPosNegLabeld` (without heatmaps):
- **With heatmaps**: The combined DWI/ADC heatmap multiplied by the prostate mask serves as the sampling weight map — regions with high DWI and low ADC are sampled more frequently
- **Without heatmaps**: Crops are sampled from positive (prostate) regions based on the binary mask
Each scan yields `N` patches (default 24) of size `tile_size x tile_size x depth` (default 64x64x3).
## Tensor Shape Convention
Throughout the pipeline, tensors follow the shape `[B, N, C, D, H, W]`:
| Dim | Meaning | Typical Value |
|-----|---------|---------------|
| B | Batch size | 4–8 |
| N | Number of patches (instances) | 24 |
| C | Channels (T2 + DWI + ADC) | 3 |
| D | Depth (slices per patch) | 3 |
| H | Patch height | 64 |
| W | Patch width | 64 |
## MILModel3D
The core model processes each patch independently through a CNN backbone, then aggregates patch-level features via a transformer encoder and attention pooling.
```mermaid
flowchart TD
A["Input [B, N, C, D, H, W]"] --> B[ResNet18-3D Backbone]
B --> C[Transformer Encoder
4 layers, 8 heads]
C --> D[Attention Pooling
512 → 2048 → 1]
D --> E["Weighted Sum [B, 512]"]
E --> F["FC Head [B, num_classes]"]
```
### Forward Pass
1. **Backbone**: Input is reshaped from `[B, N, C, D, H, W]` to `[B*N, C, D, H, W]` and passed through a 3D ResNet18 (with 3 input channels). The final FC layer is removed, yielding 512-dimensional features per patch.
2. **Transformer**: Output from the ResNet 18 encoder is forwarded to the transformer encoder.
3. **Attention**: A two-layer attention network (`512 → 2048 → 1` with Tanh) computes a scalar weight per patch, normalized via softmax.
4. **Classification**: The attention-weighted sum of patch features produces a single `[B, 512]` vector per scan, which is projected to class logits by a linear layer.
### MIL Modes
| Mode | Aggregation Strategy |
|------|---------------------|
| `mean` | Average logits across patches |
| `max` | Max logits across patches |
| `att` | Attention-weighted feature pooling |
| `att_trans` | Transformer encoder + attention pooling (primary mode) |
| `att_trans_pyramid` | Pyramid transformer on intermediate ResNet layers + attention |
The default and primary mode is `att_trans`.
## csPCa_Model
Wraps a frozen `MILModel_3D` backbone and replaces the classification head:
```mermaid
flowchart TD
A["Input [B, N, C, D, H, W]"] --> B["Frozen Backbone
(ResNet18 + Transformer)"]
B --> C["Pooled Features [B, 512]"]
C --> D["SimpleNN Head
512 → 256 → 128 → 1"]
D --> E["Sigmoid → csPCa Probability"]
```
### SimpleNN
```
Linear(512, 256) → ReLU
Linear(256, 128) → ReLU → Dropout(0.3)
Linear(128, 1) → Sigmoid
```
During csPCa training, the backbone's `net` (ResNet18), `transformer` are frozen. The `attention` module and `SimpleNN` head remain trainable.
## Attention Loss
During PI-RADS training with heatmaps enabled, the model uses a dual-loss objective:
```
total_loss = class_loss + lambda_att * attention_loss
```
- **Classification loss**: Standard CrossEntropy on PI-RADS labels
- **Attention loss**: `1 - cosine_similarity(predicted_attention, heatmap_attention)`
- Heatmap-derived attention labels are computed by summing spatial heatmap values per patch, squaring for sharpness, and normalizing
- PI-RADS 2 samples get uniform attention (no expected lesion)
- `lambda_att` warms up linearly from 0 to 2.0 over the first 25 epochs
- The attention predictions are computed with detached transformer outputs to avoid gradient interference with classification