# Architecture ## Patch Extraction Patches are extracted using MONAI's `RandWeightedCropd` (when heatmaps are available) or `RandCropByPosNegLabeld` (without heatmaps): - **With heatmaps**: The combined DWI/ADC heatmap multiplied by the prostate mask serves as the sampling weight map — regions with high DWI and low ADC are sampled more frequently - **Without heatmaps**: Crops are sampled from positive (prostate) regions based on the binary mask Each scan yields `N` patches (default 24) of size `tile_size x tile_size x depth` (default 64x64x3). ## Tensor Shape Convention Throughout the pipeline, tensors follow the shape `[B, N, C, D, H, W]`: | Dim | Meaning | Typical Value | |-----|---------|---------------| | B | Batch size | 4–8 | | N | Number of patches (instances) | 24 | | C | Channels (T2 + DWI + ADC) | 3 | | D | Depth (slices per patch) | 3 | | H | Patch height | 64 | | W | Patch width | 64 | ## MILModel3D The core model processes each patch independently through a CNN backbone, then aggregates patch-level features via a transformer encoder and attention pooling. ```mermaid flowchart TD A["Input [B, N, C, D, H, W]"] --> B[ResNet18-3D Backbone] B --> C[Transformer Encoder
4 layers, 8 heads] C --> D[Attention Pooling
512 → 2048 → 1] D --> E["Weighted Sum [B, 512]"] E --> F["FC Head [B, num_classes]"] ``` ### Forward Pass 1. **Backbone**: Input is reshaped from `[B, N, C, D, H, W]` to `[B*N, C, D, H, W]` and passed through a 3D ResNet18 (with 3 input channels). The final FC layer is removed, yielding 512-dimensional features per patch. 2. **Transformer**: Output from the ResNet 18 encoder is forwarded to the transformer encoder. 3. **Attention**: A two-layer attention network (`512 → 2048 → 1` with Tanh) computes a scalar weight per patch, normalized via softmax. 4. **Classification**: The attention-weighted sum of patch features produces a single `[B, 512]` vector per scan, which is projected to class logits by a linear layer. ### MIL Modes | Mode | Aggregation Strategy | |------|---------------------| | `mean` | Average logits across patches | | `max` | Max logits across patches | | `att` | Attention-weighted feature pooling | | `att_trans` | Transformer encoder + attention pooling (primary mode) | | `att_trans_pyramid` | Pyramid transformer on intermediate ResNet layers + attention | The default and primary mode is `att_trans`. ## csPCa_Model Wraps a frozen `MILModel_3D` backbone and replaces the classification head: ```mermaid flowchart TD A["Input [B, N, C, D, H, W]"] --> B["Frozen Backbone
(ResNet18 + Transformer)"] B --> C["Pooled Features [B, 512]"] C --> D["SimpleNN Head
512 → 256 → 128 → 1"] D --> E["Sigmoid → csPCa Probability"] ``` ### SimpleNN ``` Linear(512, 256) → ReLU Linear(256, 128) → ReLU → Dropout(0.3) Linear(128, 1) → Sigmoid ``` During csPCa training, the backbone's `net` (ResNet18), `transformer` are frozen. The `attention` module and `SimpleNN` head remain trainable. ## Attention Loss During PI-RADS training with heatmaps enabled, the model uses a dual-loss objective: ``` total_loss = class_loss + lambda_att * attention_loss ``` - **Classification loss**: Standard CrossEntropy on PI-RADS labels - **Attention loss**: `1 - cosine_similarity(predicted_attention, heatmap_attention)` - Heatmap-derived attention labels are computed by summing spatial heatmap values per patch, squaring for sharpness, and normalizing - PI-RADS 2 samples get uniform attention (no expected lesion) - `lambda_att` warms up linearly from 0 to 2.0 over the first 25 epochs - The attention predictions are computed with detached transformer outputs to avoid gradient interference with classification