Spaces:
Runtime error
Runtime error
Anirudh Balaraman commited on
Update architecture.md
Browse files- docs/architecture.md +19 -19
docs/architecture.md
CHANGED
|
@@ -1,5 +1,14 @@
|
|
| 1 |
# Architecture
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
## Tensor Shape Convention
|
| 4 |
|
| 5 |
Throughout the pipeline, tensors follow the shape `[B, N, C, D, H, W]`:
|
|
@@ -13,26 +22,24 @@ Throughout the pipeline, tensors follow the shape `[B, N, C, D, H, W]`:
|
|
| 13 |
| H | Patch height | 64 |
|
| 14 |
| W | Patch width | 64 |
|
| 15 |
|
| 16 |
-
##
|
| 17 |
|
| 18 |
The core model processes each patch independently through a CNN backbone, then aggregates patch-level features via a transformer encoder and attention pooling.
|
| 19 |
|
| 20 |
```mermaid
|
| 21 |
flowchart TD
|
| 22 |
-
A["Input [B, N, C, D, H, W]"] --> B[
|
| 23 |
-
B --> C[
|
| 24 |
-
C --> D[
|
| 25 |
-
D --> E[
|
| 26 |
-
E --> F[
|
| 27 |
-
F --> G["Weighted Sum [B, 512]"]
|
| 28 |
-
G --> H["FC Head [B, num_classes]"]
|
| 29 |
```
|
| 30 |
|
| 31 |
### Forward Pass
|
| 32 |
|
| 33 |
1. **Backbone**: Input is reshaped from `[B, N, C, D, H, W]` to `[B*N, C, D, H, W]` and passed through a 3D ResNet18 (with 3 input channels). The final FC layer is removed, yielding 512-dimensional features per patch.
|
| 34 |
|
| 35 |
-
2. **Transformer**:
|
| 36 |
|
| 37 |
3. **Attention**: A two-layer attention network (`512 β 2048 β 1` with Tanh) computes a scalar weight per patch, normalized via softmax.
|
| 38 |
|
|
@@ -56,9 +63,9 @@ Wraps a frozen `MILModel_3D` backbone and replaces the classification head:
|
|
| 56 |
|
| 57 |
```mermaid
|
| 58 |
flowchart TD
|
| 59 |
-
A["Input [B, N, C, D, H, W]"] --> B["Frozen Backbone
|
| 60 |
B --> C["Pooled Features [B, 512]"]
|
| 61 |
-
C --> D["SimpleNN Head
|
| 62 |
D --> E["Sigmoid β csPCa Probability"]
|
| 63 |
```
|
| 64 |
|
|
@@ -70,7 +77,7 @@ Linear(256, 128) β ReLU β Dropout(0.3)
|
|
| 70 |
Linear(128, 1) β Sigmoid
|
| 71 |
```
|
| 72 |
|
| 73 |
-
During csPCa training, the backbone's `net` (ResNet18), `transformer`
|
| 74 |
|
| 75 |
## Attention Loss
|
| 76 |
|
|
@@ -87,11 +94,4 @@ total_loss = class_loss + lambda_att * attention_loss
|
|
| 87 |
- `lambda_att` warms up linearly from 0 to 2.0 over the first 25 epochs
|
| 88 |
- The attention predictions are computed with detached transformer outputs to avoid gradient interference with classification
|
| 89 |
|
| 90 |
-
## Patch Extraction
|
| 91 |
-
|
| 92 |
-
Patches are extracted using MONAI's `RandWeightedCropd` (when heatmaps are available) or `RandCropByPosNegLabeld` (without heatmaps):
|
| 93 |
-
|
| 94 |
-
- **With heatmaps**: The combined DWI/ADC heatmap multiplied by the prostate mask serves as the sampling weight map β regions with high DWI and low ADC are sampled more frequently
|
| 95 |
-
- **Without heatmaps**: Crops are sampled from positive (prostate) regions based on the binary mask
|
| 96 |
|
| 97 |
-
Each scan yields `N` patches (default 24) of size `tile_size x tile_size x depth` (default 64x64x3).
|
|
|
|
| 1 |
# Architecture
|
| 2 |
|
| 3 |
+
## Patch Extraction
|
| 4 |
+
|
| 5 |
+
Patches are extracted using MONAI's `RandWeightedCropd` (when heatmaps are available) or `RandCropByPosNegLabeld` (without heatmaps):
|
| 6 |
+
|
| 7 |
+
- **With heatmaps**: The combined DWI/ADC heatmap multiplied by the prostate mask serves as the sampling weight map β regions with high DWI and low ADC are sampled more frequently
|
| 8 |
+
- **Without heatmaps**: Crops are sampled from positive (prostate) regions based on the binary mask
|
| 9 |
+
|
| 10 |
+
Each scan yields `N` patches (default 24) of size `tile_size x tile_size x depth` (default 64x64x3).
|
| 11 |
+
|
| 12 |
## Tensor Shape Convention
|
| 13 |
|
| 14 |
Throughout the pipeline, tensors follow the shape `[B, N, C, D, H, W]`:
|
|
|
|
| 22 |
| H | Patch height | 64 |
|
| 23 |
| W | Patch width | 64 |
|
| 24 |
|
| 25 |
+
## MILModel3D
|
| 26 |
|
| 27 |
The core model processes each patch independently through a CNN backbone, then aggregates patch-level features via a transformer encoder and attention pooling.
|
| 28 |
|
| 29 |
```mermaid
|
| 30 |
flowchart TD
|
| 31 |
+
A["Input [B, N, C, D, H, W]"] --> B[ResNet18-3D Backbone]
|
| 32 |
+
B --> C[Transformer Encoder<br/>4 layers, 8 heads]
|
| 33 |
+
C --> D[Attention Pooling<br/>512 β 2048 β 1]
|
| 34 |
+
D --> E["Weighted Sum [B, 512]"]
|
| 35 |
+
E --> F["FC Head [B, num_classes]"]
|
|
|
|
|
|
|
| 36 |
```
|
| 37 |
|
| 38 |
### Forward Pass
|
| 39 |
|
| 40 |
1. **Backbone**: Input is reshaped from `[B, N, C, D, H, W]` to `[B*N, C, D, H, W]` and passed through a 3D ResNet18 (with 3 input channels). The final FC layer is removed, yielding 512-dimensional features per patch.
|
| 41 |
|
| 42 |
+
2. **Transformer**: Output from the ResNet 18 encoder is forwarded to the transformer encoder.
|
| 43 |
|
| 44 |
3. **Attention**: A two-layer attention network (`512 β 2048 β 1` with Tanh) computes a scalar weight per patch, normalized via softmax.
|
| 45 |
|
|
|
|
| 63 |
|
| 64 |
```mermaid
|
| 65 |
flowchart TD
|
| 66 |
+
A["Input [B, N, C, D, H, W]"] --> B["Frozen Backbone<br/>(ResNet18 + Transformer)"]
|
| 67 |
B --> C["Pooled Features [B, 512]"]
|
| 68 |
+
C --> D["SimpleNN Head<br/>512 β 256 β 128 β 1"]
|
| 69 |
D --> E["Sigmoid β csPCa Probability"]
|
| 70 |
```
|
| 71 |
|
|
|
|
| 77 |
Linear(128, 1) β Sigmoid
|
| 78 |
```
|
| 79 |
|
| 80 |
+
During csPCa training, the backbone's `net` (ResNet18), `transformer` are frozen. The `attention` module and `SimpleNN` head remain trainable.
|
| 81 |
|
| 82 |
## Attention Loss
|
| 83 |
|
|
|
|
| 94 |
- `lambda_att` warms up linearly from 0 to 2.0 over the first 25 epochs
|
| 95 |
- The attention predictions are computed with detached transformer outputs to avoid gradient interference with classification
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
|
|