|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
--- |
|
|
# MM-DLS |
|
|
# Multi-task deep learning based on PET/CT images for the diagnosis and prognosis prediction of advanced non-small cell lung cancer |
|
|
|
|
|
## Overview |
|
|
|
|
|
**MM-DLS** is a multi-modal, multi-task deep learning framework for the diagnosis, staging, and prognosis prediction of advanced non-small cell lung cancer (NSCLC). It integrates multi-source data including CT images, PET metabolic parameters, and clinical information to provide a unified, non-invasive decision-making tool for personalized treatment planning. |
|
|
|
|
|
This repository implements the full MM-DLS pipeline, consisting of: |
|
|
- Lung-lesion segmentation with cross-attention transformer |
|
|
- Multi-modal feature fusion (CT, PET, Clinical) |
|
|
- Multi-task learning: Pathological classification, TNM staging, DFS and OS survival prediction |
|
|
- Cox proportional hazards survival loss |
|
|
|
|
|
The framework supports both classification (adenocarcinoma vs squamous cell carcinoma) and survival risk prediction tasks, and has been validated on large-scale multi-center clinical datasets. |
|
|
|
|
|
--- |
|
|
|
|
|
## Key Features |
|
|
|
|
|
- **Multi-modal fusion:** Combines CT-based imaging features, PET metabolic biomarkers (SUVmax, SUVmean, SUVpeak, TLG, MTV), and structured clinical variables (age, sex, smoking status, smoking duration, smoking cessation history, tumor size). |
|
|
- **Multi-task learning:** Simultaneous optimization for: |
|
|
- Histological subtype classification (LUAD vs LUSC) |
|
|
- TNM stage classification (I-II, III, IV) |
|
|
- Disease-free survival (DFS) prediction |
|
|
- Overall survival (OS) prediction |
|
|
- **Attention-based feature fusion:** Transformer cross-attention module to integrate lung-lesion spatial information. |
|
|
- **Survival modeling:** Incorporates Cox Proportional Hazards loss for survival time prediction. |
|
|
- **Flexible data simulation and loading:** Includes utilities for synthetic data generation and multi-slice 2D volume processing. |
|
|
|
|
|
--- |
|
|
|
|
|
## Architecture |
|
|
|
|
|
The overall MM-DLS system consists of: |
|
|
|
|
|
 |
|
|
 |
|
|
 |
|
|
 |
|
|
 |
|
|
|
|
|
1. **Segmentation Module (LungLesionSegmentor):** |
|
|
- Shared ResNet encoder to extract features from CT images. |
|
|
- Dual decoders for lung and lesion segmentation. |
|
|
- Transformer-based cross-attention module for enhanced spatial feature interaction between lung and lesion regions. |
|
|
|
|
|
2. **Feature Encoders:** |
|
|
- `LesionEncoder`: 2D convolutional encoder for lesion patches. |
|
|
- `SpaceEncoder`: 2D convolutional encoder for lung-space contextual patches. |
|
|
|
|
|
3. **Attention Fusion Module:** |
|
|
- `LesionAttentionFusion`: Multi-head attention to fuse lesion and lung features into compact patient-level representations. |
|
|
|
|
|
4. **Patient-Level Fusion Model (PatientLevelFusionModel):** |
|
|
- Fully connected network that combines imaging, PET, and clinical features. |
|
|
- Outputs classification logits, DFS and OS risk scores. |
|
|
|
|
|
5. **Loss Functions:** |
|
|
- Binary cross-entropy loss for classification. |
|
|
- Cox proportional hazards loss (`CoxPHLoss`) for survival prediction. |
|
|
|
|
|
--- |
|
|
|
|
|
## Code Structure |
|
|
|
|
|
- `ModelLesionEncoder.py`: Lesion image encoder extracting discriminative features from multi-slice tumor regions. |
|
|
- `ModelSpaceEncoder.py`: Lung space encoder modeling anatomical and spatial context beyond the lesion. |
|
|
- `LesionAttentionFusion.py`: Attention-based fusion module for adaptive integration of lesion and spatial features. |
|
|
- `ClinicalFusionModel.py`: Patient-level fusion network combining imaging features, radiomics, PET signals, and clinical variables. |
|
|
- `HierMM_DLS.py`:Core hierarchical multimodal deep learning model supporting multi-task learning: (1)Subtype classification; (2)TNM stage prediction; (3)DFS and OS modeling |
|
|
- `CoxphLoss.py`: Cox proportional hazards loss for survival modeling with censored data. |
|
|
- `PatientDataset.py`:Patient dataset loader supporting imaging, radiomics, PET, clinical variables, survival outcomes, and treatment labels. |
|
|
- `LungLesionSegmentation.py`: Lung-lesion segmentation model |
|
|
- `ImageDataLoader.py`: Image preprocessing and loading utilities for multi-slice inputs. |
|
|
- `plot_results.py`: Visualization utilities for Kaplan–Meier curves, hazard ratios, and survival analysis results. |
|
|
|
|
|
--- |
|
|
|
|
|
## Data Format |
|
|
|
|
|
The input data is organized per patient as follows: |
|
|
|
|
|
### Imaging Data: |
|
|
- CT slices (PNG format) |
|
|
- Lung masks (binary masks, PNG) |
|
|
- Lesion masks (binary masks, PNG) |
|
|
- Slices grouped per patient ID |
|
|
|
|
|
### Tabular Data: |
|
|
- Radiomics features: 128-dimensional vector (PyRadiomics extracted) |
|
|
- PET features: [SUVmax, SUVmean, SUVpeak, TLG, MTV] |
|
|
- Clinical features: [Age, Sex, Smoking Status, Smoking Duration, Smoking Cessation, Tumor Diameter] |
|
|
- Survival data: DFS time/event, OS time/event |
|
|
- Classification label: LUAD (0) or LUSC (1) |
|
|
|
|
|
Simulated data utilities are provided for experimentation and reproducibility. |
|
|
|
|
|
--- |
|
|
|
|
|
## Installation |
|
|
|
|
|
```bash |
|
|
# Clone repository |
|
|
conda create -n mm_dls python=3.10 -y |
|
|
conda activate mm_dls |
|
|
git clone https://github.com/your_username/MM-DLS-NSCLC.git |
|
|
``` |
|
|
## Install dependencies |
|
|
```bash |
|
|
pip install -r requirements.txt |
|
|
``` |
|
|
## Usage |
|
|
|
|
|
### 🔽 Download Pretrained Models |
|
|
|
|
|
Pretrained MM-DLS models are available for direct download: |
|
|
|
|
|
- **MM-DLS (Full multimodal, best checkpoint)** |
|
|
[⬇️ Download Pretrained Model](https://drive.google.com/file/d/1IcyCwMgCX8wv0NMp84U4wlzhLoXH7ayx/view?usp=drive_link) Size 1.3 MB |
|
|
The MM-DLS model is intentionally lightweight (~1.3 MB), as it employs compact CNN encoders and MLP-based multimodal fusion rather than large pretrained backbones, enabling efficient deployment and fast inference. |
|
|
|
|
|
|
|
|
|
|
|
After downloading, place the model files under the `./MODEL/` directory: |
|
|
|
|
|
Training: |
|
|
```bash |
|
|
python train_patient_model.py |
|
|
``` |
|
|
Evaluation: |
|
|
```bash |
|
|
python test.py |
|
|
``` |
|
|
Example Forward Pass: |
|
|
```bash |
|
|
python run_sample.ipynb |
|
|
``` |
|
|
## Model Performance (from publication) |
|
|
### Histological Subtype Classification: |
|
|
|
|
|
AUC: 0.85 ~ 0.92 across cohorts |
|
|
|
|
|
AP: 0.81 ~ 0.86 |
|
|
|
|
|
### TNM Stage Prediction: |
|
|
|
|
|
AUC: Stage I-II (0.86 ~ 0.96), Stage III (0.85 ~ 0.95), Stage IV (0.83 ~ 0.95) |
|
|
|
|
|
### AP and calibration maintained across internal and external sets |
|
|
|
|
|
DFS & OS Prognosis: |
|
|
|
|
|
C-index: up to 0.75 |
|
|
|
|
|
Time-dependent AUC (1/2/3 years): 0.77 ~ 0.91 |
|
|
|
|
|
Brier score: consistently < 0.2 for DFS and < 0.3 for OS |
|
|
|
|
|
Superior to single modality models (clinical-only or imaging-only) |
|
|
|
|
|
## Reference |
|
|
Please cite our original publication when using this work: |
|
|
|
|
|
License |
|
|
This project is licensed under the MIT License. |
|
|
|
|
|
⚠️ **Notice:** The pretrained model is shared solely for research validation purposes and **should not be used, distributed, or cited before the associated study is formally published**. |
|
|
|
|
|
Contact |
|
|
For any questions or collaborations, please contact: |
|
|
|
|
|
Dr. Fang Dai: daifang_cool@163.com |