File size: 6,971 Bytes
c4a8353 bea5a4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
---
license: mit
language:
- en
---
# MM-DLS
# Multi-task deep learning based on PET/CT images for the diagnosis and prognosis prediction of advanced non-small cell lung cancer
## Overview
**MM-DLS** is a multi-modal, multi-task deep learning framework for the diagnosis, staging, and prognosis prediction of advanced non-small cell lung cancer (NSCLC). It integrates multi-source data including CT images, PET metabolic parameters, and clinical information to provide a unified, non-invasive decision-making tool for personalized treatment planning.
This repository implements the full MM-DLS pipeline, consisting of:
- Lung-lesion segmentation with cross-attention transformer
- Multi-modal feature fusion (CT, PET, Clinical)
- Multi-task learning: Pathological classification, TNM staging, DFS and OS survival prediction
- Cox proportional hazards survival loss
The framework supports both classification (adenocarcinoma vs squamous cell carcinoma) and survival risk prediction tasks, and has been validated on large-scale multi-center clinical datasets.
---
## Key Features
- **Multi-modal fusion:** Combines CT-based imaging features, PET metabolic biomarkers (SUVmax, SUVmean, SUVpeak, TLG, MTV), and structured clinical variables (age, sex, smoking status, smoking duration, smoking cessation history, tumor size).
- **Multi-task learning:** Simultaneous optimization for:
- Histological subtype classification (LUAD vs LUSC)
- TNM stage classification (I-II, III, IV)
- Disease-free survival (DFS) prediction
- Overall survival (OS) prediction
- **Attention-based feature fusion:** Transformer cross-attention module to integrate lung-lesion spatial information.
- **Survival modeling:** Incorporates Cox Proportional Hazards loss for survival time prediction.
- **Flexible data simulation and loading:** Includes utilities for synthetic data generation and multi-slice 2D volume processing.
---
## Architecture
The overall MM-DLS system consists of:





1. **Segmentation Module (LungLesionSegmentor):**
- Shared ResNet encoder to extract features from CT images.
- Dual decoders for lung and lesion segmentation.
- Transformer-based cross-attention module for enhanced spatial feature interaction between lung and lesion regions.
2. **Feature Encoders:**
- `LesionEncoder`: 2D convolutional encoder for lesion patches.
- `SpaceEncoder`: 2D convolutional encoder for lung-space contextual patches.
3. **Attention Fusion Module:**
- `LesionAttentionFusion`: Multi-head attention to fuse lesion and lung features into compact patient-level representations.
4. **Patient-Level Fusion Model (PatientLevelFusionModel):**
- Fully connected network that combines imaging, PET, and clinical features.
- Outputs classification logits, DFS and OS risk scores.
5. **Loss Functions:**
- Binary cross-entropy loss for classification.
- Cox proportional hazards loss (`CoxPHLoss`) for survival prediction.
---
## Code Structure
- `ModelLesionEncoder.py`: Lesion image encoder extracting discriminative features from multi-slice tumor regions.
- `ModelSpaceEncoder.py`: Lung space encoder modeling anatomical and spatial context beyond the lesion.
- `LesionAttentionFusion.py`: Attention-based fusion module for adaptive integration of lesion and spatial features.
- `ClinicalFusionModel.py`: Patient-level fusion network combining imaging features, radiomics, PET signals, and clinical variables.
- `HierMM_DLS.py`:Core hierarchical multimodal deep learning model supporting multi-task learning: (1)Subtype classification; (2)TNM stage prediction; (3)DFS and OS modeling
- `CoxphLoss.py`: Cox proportional hazards loss for survival modeling with censored data.
- `PatientDataset.py`:Patient dataset loader supporting imaging, radiomics, PET, clinical variables, survival outcomes, and treatment labels.
- `LungLesionSegmentation.py`: Lung-lesion segmentation model
- `ImageDataLoader.py`: Image preprocessing and loading utilities for multi-slice inputs.
- `plot_results.py`: Visualization utilities for Kaplan–Meier curves, hazard ratios, and survival analysis results.
---
## Data Format
The input data is organized per patient as follows:
### Imaging Data:
- CT slices (PNG format)
- Lung masks (binary masks, PNG)
- Lesion masks (binary masks, PNG)
- Slices grouped per patient ID
### Tabular Data:
- Radiomics features: 128-dimensional vector (PyRadiomics extracted)
- PET features: [SUVmax, SUVmean, SUVpeak, TLG, MTV]
- Clinical features: [Age, Sex, Smoking Status, Smoking Duration, Smoking Cessation, Tumor Diameter]
- Survival data: DFS time/event, OS time/event
- Classification label: LUAD (0) or LUSC (1)
Simulated data utilities are provided for experimentation and reproducibility.
---
## Installation
```bash
# Clone repository
conda create -n mm_dls python=3.10 -y
conda activate mm_dls
git clone https://github.com/your_username/MM-DLS-NSCLC.git
```
## Install dependencies
```bash
pip install -r requirements.txt
```
## Usage
### 🔽 Download Pretrained Models
Pretrained MM-DLS models are available for direct download:
- **MM-DLS (Full multimodal, best checkpoint)**
[⬇️ Download Pretrained Model](https://drive.google.com/file/d/1IcyCwMgCX8wv0NMp84U4wlzhLoXH7ayx/view?usp=drive_link) Size 1.3 MB
The MM-DLS model is intentionally lightweight (~1.3 MB), as it employs compact CNN encoders and MLP-based multimodal fusion rather than large pretrained backbones, enabling efficient deployment and fast inference.
After downloading, place the model files under the `./MODEL/` directory:
Training:
```bash
python train_patient_model.py
```
Evaluation:
```bash
python test.py
```
Example Forward Pass:
```bash
python run_sample.ipynb
```
## Model Performance (from publication)
### Histological Subtype Classification:
AUC: 0.85 ~ 0.92 across cohorts
AP: 0.81 ~ 0.86
### TNM Stage Prediction:
AUC: Stage I-II (0.86 ~ 0.96), Stage III (0.85 ~ 0.95), Stage IV (0.83 ~ 0.95)
### AP and calibration maintained across internal and external sets
DFS & OS Prognosis:
C-index: up to 0.75
Time-dependent AUC (1/2/3 years): 0.77 ~ 0.91
Brier score: consistently < 0.2 for DFS and < 0.3 for OS
Superior to single modality models (clinical-only or imaging-only)
## Reference
Please cite our original publication when using this work:
License
This project is licensed under the MIT License.
⚠️ **Notice:** The pretrained model is shared solely for research validation purposes and **should not be used, distributed, or cited before the associated study is formally published**.
Contact
For any questions or collaborations, please contact:
Dr. Fang Dai: daifang_cool@163.com |