File size: 3,899 Bytes
cd89698
 
 
 
 
 
cc831d6
cd89698
 
 
 
 
 
cc831d6
cd89698
 
 
cc831d6
cd89698
 
 
 
 
 
 
 
 
 
 
 
 
d60f76e
 
cd89698
 
d60f76e
cd89698
 
 
 
 
 
 
d60f76e
cd89698
 
 
d60f76e
cd89698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62bf25f
cd89698
 
 
 
 
 
 
 
 
62bf25f
cd89698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62bf25f
cc831d6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Pipeline

The full pipeline has three phases: preprocessing, PI-RADS training (Stage 1), and csPCa training (Stage 2).

```mermaid
flowchart TD
    subgraph Preprocessing
        R[register_and_crop] --> S[get_segmentation_mask]
        S --> H[histogram_match]
        H --> G[get_heatmap]
    end

    subgraph Stage 1
        P[PI-RADS Training<br/>CrossEntropy + Attention Loss]
    end

    subgraph Stage 2
        C[csPCa Training<br/>Frozen Backbone + BCE Loss]
    end

    G --> P
    P -->|frozen backbone| C
```

## Preprocessing
```bash
python preprocess_main.py \
    --config config/config_preprocess.yaml \
    --steps register_and_crop get_segmentation_mask histogram_match get_heatmap
```

Run the following steps in sequnce:

### Step 1: Register and Crop

Resamples T2W, DWI, and ADC to a common spacing of `(0.4, 0.4, 3.0)` mm using `picai_prep`, then center-crops with a configurable margin (default 20%).

### Step 2: Prostate Segmentation

Runs a pre-trained segmentation model on T2W images to generate binary prostate masks. Post-processing retains only the top 10 slices by non-zero voxel count.

### Step 3: Histogram Matching

Matches the histogram intensity of each sequnce to a reference image within masked (prostate) regions using `skimage.exposure.match_histograms`.

### Step 4: Heatmap Generation

Creates weak-attention heatmaps from DWI and ADC:

- **DWI heatmap**: `(dwi - min) / (max - min)` β€” higher DWI signal = higher attention
- **ADC heatmap**: `(max - adc) / (max - min)` β€” lower ADC = higher attention (inverted)
- **Combined**: element-wise product, re-normalized to [0, 1]

!!! note "Step Dependencies"
    Steps must run in the order shown above. The pipeline validates dependencies automatically β€” for example, `get_heatmap` requires `get_segmentation_mask` and `histogram_match` to have run first.

## Stage 1: PI-RADS Classification

Trains a 4-class PI-RADS classifier (grades 2–5, mapped to labels 0–3).

```bash
python run_pirads.py --mode train --config config/config_pirads_train.yaml
```

**Training details:**

| Component | Value |
|-----------|-------|
| Loss | CrossEntropy + cosine-similarity attention loss |
| Attention loss weight | Linear warmup over 25 epochs to `lambda=2.0` |
| Optimizer | AdamW (base LR `2e-4`, transformer LR `6e-5`) |
| Scheduler | CosineAnnealingLR |
| Metric | Quadratic Weighted Kappa (QWK) |
| Early stopping | After 40 epochs without validation loss improvement |
| AMP | Disabled by default (enabled in example YAML config) |

**Attention loss**: For each batch, the model's learned attention weights are compared against heatmap-derived attention labels via cosine similarity. PI-RADS 2 samples receive uniform attention (no lesion expected). The loss is weighted by `lambda_att`, which warms up linearly over the first 25 epochs.

## Stage 2: csPCa Risk Prediction

Builds on a frozen PI-RADS backbone to predict binary csPCa risk. The self-attention and classification head are fine-tuned.

```bash
python run_cspca.py --mode train --config config/config_cspca_train.yaml
```

**Training details:**

| Component | Value |
|-----------|-------|
| Loss | Binary Cross-Entropy (BCE) |
| Backbone | Frozen PI-RADS model (ResNet18 + Transformer); attention module is trainable |
| Head | SimpleNN: `512 β†’ 256 β†’ 128 β†’ 1` with ReLU + Dropout(0.3) + Sigmoid |
| Optimizer | AdamW (LR `2e-4`) |
| Seeds | 20 random seeds (default) for 95% CI |
| Metrics | AUC, Sensitivity, Specificity |

The backbone's feature extractor (`net`), transformer, and `myfc` are frozen. The attention module and `SimpleNN` classification head are trained. After training the framework reports mean and 95% confidence intervals for AUC, sensitivity, and specificity by testing across 20 random seeds.

Refer to [Getting Started](getting-started.md) for JSON dataset format to run run_pirads.py and run_cspca.py