Anirudh Balaraman commited on
Commit
f6b233e
Β·
unverified Β·
1 Parent(s): cc831d6

Update architecture.md

Browse files
Files changed (1) hide show
  1. docs/architecture.md +19 -19
docs/architecture.md CHANGED
@@ -1,5 +1,14 @@
1
  # Architecture
2
 
 
 
 
 
 
 
 
 
 
3
  ## Tensor Shape Convention
4
 
5
  Throughout the pipeline, tensors follow the shape `[B, N, C, D, H, W]`:
@@ -13,26 +22,24 @@ Throughout the pipeline, tensors follow the shape `[B, N, C, D, H, W]`:
13
  | H | Patch height | 64 |
14
  | W | Patch width | 64 |
15
 
16
- ## MILModel_3D
17
 
18
  The core model processes each patch independently through a CNN backbone, then aggregates patch-level features via a transformer encoder and attention pooling.
19
 
20
  ```mermaid
21
  flowchart TD
22
- A["Input [B, N, C, D, H, W]"] --> B["Reshape to [B*N, C, D, H, W]"]
23
- B --> C[ResNet18-3D Backbone]
24
- C --> D["Reshape to [B, N, 512]"]
25
- D --> E[Transformer Encoder\n4 layers, 8 heads]
26
- E --> F[Attention Pooling\n512 β†’ 2048 β†’ 1]
27
- F --> G["Weighted Sum [B, 512]"]
28
- G --> H["FC Head [B, num_classes]"]
29
  ```
30
 
31
  ### Forward Pass
32
 
33
  1. **Backbone**: Input is reshaped from `[B, N, C, D, H, W]` to `[B*N, C, D, H, W]` and passed through a 3D ResNet18 (with 3 input channels). The final FC layer is removed, yielding 512-dimensional features per patch.
34
 
35
- 2. **Transformer**: Features are reshaped to `[B, N, 512]`, permuted to `[N, B, 512]` for the transformer encoder (4 layers, 8 attention heads), then permuted back.
36
 
37
  3. **Attention**: A two-layer attention network (`512 β†’ 2048 β†’ 1` with Tanh) computes a scalar weight per patch, normalized via softmax.
38
 
@@ -56,9 +63,9 @@ Wraps a frozen `MILModel_3D` backbone and replaces the classification head:
56
 
57
  ```mermaid
58
  flowchart TD
59
- A["Input [B, N, C, D, H, W]"] --> B["Frozen Backbone\n(ResNet18 + Transformer)"]
60
  B --> C["Pooled Features [B, 512]"]
61
- C --> D["SimpleNN Head\n512 β†’ 256 β†’ 128 β†’ 1"]
62
  D --> E["Sigmoid β†’ csPCa Probability"]
63
  ```
64
 
@@ -70,7 +77,7 @@ Linear(256, 128) β†’ ReLU β†’ Dropout(0.3)
70
  Linear(128, 1) β†’ Sigmoid
71
  ```
72
 
73
- During csPCa training, the backbone's `net` (ResNet18), `transformer`, and `myfc` parameters are frozen. The `attention` module and `SimpleNN` head remain trainable.
74
 
75
  ## Attention Loss
76
 
@@ -87,11 +94,4 @@ total_loss = class_loss + lambda_att * attention_loss
87
  - `lambda_att` warms up linearly from 0 to 2.0 over the first 25 epochs
88
  - The attention predictions are computed with detached transformer outputs to avoid gradient interference with classification
89
 
90
- ## Patch Extraction
91
-
92
- Patches are extracted using MONAI's `RandWeightedCropd` (when heatmaps are available) or `RandCropByPosNegLabeld` (without heatmaps):
93
-
94
- - **With heatmaps**: The combined DWI/ADC heatmap multiplied by the prostate mask serves as the sampling weight map β€” regions with high DWI and low ADC are sampled more frequently
95
- - **Without heatmaps**: Crops are sampled from positive (prostate) regions based on the binary mask
96
 
97
- Each scan yields `N` patches (default 24) of size `tile_size x tile_size x depth` (default 64x64x3).
 
1
  # Architecture
2
 
3
+ ## Patch Extraction
4
+
5
+ Patches are extracted using MONAI's `RandWeightedCropd` (when heatmaps are available) or `RandCropByPosNegLabeld` (without heatmaps):
6
+
7
+ - **With heatmaps**: The combined DWI/ADC heatmap multiplied by the prostate mask serves as the sampling weight map β€” regions with high DWI and low ADC are sampled more frequently
8
+ - **Without heatmaps**: Crops are sampled from positive (prostate) regions based on the binary mask
9
+
10
+ Each scan yields `N` patches (default 24) of size `tile_size x tile_size x depth` (default 64x64x3).
11
+
12
  ## Tensor Shape Convention
13
 
14
  Throughout the pipeline, tensors follow the shape `[B, N, C, D, H, W]`:
 
22
  | H | Patch height | 64 |
23
  | W | Patch width | 64 |
24
 
25
+ ## MILModel3D
26
 
27
  The core model processes each patch independently through a CNN backbone, then aggregates patch-level features via a transformer encoder and attention pooling.
28
 
29
  ```mermaid
30
  flowchart TD
31
+ A["Input [B, N, C, D, H, W]"] --> B[ResNet18-3D Backbone]
32
+ B --> C[Transformer Encoder<br/>4 layers, 8 heads]
33
+ C --> D[Attention Pooling<br/>512 β†’ 2048 β†’ 1]
34
+ D --> E["Weighted Sum [B, 512]"]
35
+ E --> F["FC Head [B, num_classes]"]
 
 
36
  ```
37
 
38
  ### Forward Pass
39
 
40
  1. **Backbone**: Input is reshaped from `[B, N, C, D, H, W]` to `[B*N, C, D, H, W]` and passed through a 3D ResNet18 (with 3 input channels). The final FC layer is removed, yielding 512-dimensional features per patch.
41
 
42
+ 2. **Transformer**: Output from the ResNet 18 encoder is forwarded to the transformer encoder.
43
 
44
  3. **Attention**: A two-layer attention network (`512 β†’ 2048 β†’ 1` with Tanh) computes a scalar weight per patch, normalized via softmax.
45
 
 
63
 
64
  ```mermaid
65
  flowchart TD
66
+ A["Input [B, N, C, D, H, W]"] --> B["Frozen Backbone<br/>(ResNet18 + Transformer)"]
67
  B --> C["Pooled Features [B, 512]"]
68
+ C --> D["SimpleNN Head<br/>512 β†’ 256 β†’ 128 β†’ 1"]
69
  D --> E["Sigmoid β†’ csPCa Probability"]
70
  ```
71
 
 
77
  Linear(128, 1) β†’ Sigmoid
78
  ```
79
 
80
+ During csPCa training, the backbone's `net` (ResNet18), `transformer` are frozen. The `attention` module and `SimpleNN` head remain trainable.
81
 
82
  ## Attention Loss
83
 
 
94
  - `lambda_att` warms up linearly from 0 to 2.0 over the first 25 epochs
95
  - The attention predictions are computed with detached transformer outputs to avoid gradient interference with classification
96
 
 
 
 
 
 
 
97