Spaces:
Runtime error
Runtime error
| # Architecture | |
| ## Patch Extraction | |
| Patches are extracted using MONAI's `RandWeightedCropd` (when heatmaps are available) or `RandCropByPosNegLabeld` (without heatmaps): | |
| - **With heatmaps**: The combined DWI/ADC heatmap multiplied by the prostate mask serves as the sampling weight map β regions with high DWI and low ADC are sampled more frequently | |
| - **Without heatmaps**: Crops are sampled from positive (prostate) regions based on the binary mask | |
| Each scan yields `N` patches (default 24) of size `tile_size x tile_size x depth` (default 64x64x3). | |
| ## Tensor Shape Convention | |
| Throughout the pipeline, tensors follow the shape `[B, N, C, D, H, W]`: | |
| | Dim | Meaning | Typical Value | | |
| |-----|---------|---------------| | |
| | B | Batch size | 4β8 | | |
| | N | Number of patches (instances) | 24 | | |
| | C | Channels (T2 + DWI + ADC) | 3 | | |
| | D | Depth (slices per patch) | 3 | | |
| | H | Patch height | 64 | | |
| | W | Patch width | 64 | | |
| ## MILModel3D | |
| The core model processes each patch independently through a CNN backbone, then aggregates patch-level features via a transformer encoder and attention pooling. | |
| ```mermaid | |
| flowchart TD | |
| A["Input [B, N, C, D, H, W]"] --> B[ResNet18-3D Backbone] | |
| B --> C[Transformer Encoder<br/>4 layers, 8 heads] | |
| C --> D[Attention Pooling<br/>512 β 2048 β 1] | |
| D --> E["Weighted Sum [B, 512]"] | |
| E --> F["FC Head [B, num_classes]"] | |
| ``` | |
| ### Forward Pass | |
| 1. **Backbone**: Input is reshaped from `[B, N, C, D, H, W]` to `[B*N, C, D, H, W]` and passed through a 3D ResNet18 (with 3 input channels). The final FC layer is removed, yielding 512-dimensional features per patch. | |
| 2. **Transformer**: Output from the ResNet 18 encoder is forwarded to the transformer encoder. | |
| 3. **Attention**: A two-layer attention network (`512 β 2048 β 1` with Tanh) computes a scalar weight per patch, normalized via softmax. | |
| 4. **Classification**: The attention-weighted sum of patch features produces a single `[B, 512]` vector per scan, which is projected to class logits by a linear layer. | |
| ### MIL Modes | |
| | Mode | Aggregation Strategy | | |
| |------|---------------------| | |
| | `mean` | Average logits across patches | | |
| | `max` | Max logits across patches | | |
| | `att` | Attention-weighted feature pooling | | |
| | `att_trans` | Transformer encoder + attention pooling (primary mode) | | |
| | `att_trans_pyramid` | Pyramid transformer on intermediate ResNet layers + attention | | |
| The default and primary mode is `att_trans`. | |
| ## csPCa_Model | |
| Wraps a frozen `MILModel_3D` backbone and replaces the classification head: | |
| ```mermaid | |
| flowchart TD | |
| A["Input [B, N, C, D, H, W]"] --> B["Frozen Backbone<br/>(ResNet18 + Transformer)"] | |
| B --> C["Pooled Features [B, 512]"] | |
| C --> D["SimpleNN Head<br/>512 β 256 β 128 β 1"] | |
| D --> E["Sigmoid β csPCa Probability"] | |
| ``` | |
| ### SimpleNN | |
| ``` | |
| Linear(512, 256) β ReLU | |
| Linear(256, 128) β ReLU β Dropout(0.3) | |
| Linear(128, 1) β Sigmoid | |
| ``` | |
| During csPCa training, the backbone's `net` (ResNet18), `transformer` are frozen. The `attention` module and `SimpleNN` head remain trainable. | |
| ## Attention Loss | |
| During PI-RADS training with heatmaps enabled, the model uses a dual-loss objective: | |
| ``` | |
| total_loss = class_loss + lambda_att * attention_loss | |
| ``` | |
| - **Classification loss**: Standard CrossEntropy on PI-RADS labels | |
| - **Attention loss**: `1 - cosine_similarity(predicted_attention, heatmap_attention)` | |
| - Heatmap-derived attention labels are computed by summing spatial heatmap values per patch, squaring for sharpness, and normalizing | |
| - PI-RADS 2 samples get uniform attention (no expected lesion) | |
| - `lambda_att` warms up linearly from 0 to 2.0 over the first 25 epochs | |
| - The attention predictions are computed with detached transformer outputs to avoid gradient interference with classification | |