--- license: mit language: - en --- # MM-DLS # Multi-task deep learning based on PET/CT images for the diagnosis and prognosis prediction of advanced non-small cell lung cancer ## Overview **MM-DLS** is a multi-modal, multi-task deep learning framework for the diagnosis, staging, and prognosis prediction of advanced non-small cell lung cancer (NSCLC). It integrates multi-source data including CT images, PET metabolic parameters, and clinical information to provide a unified, non-invasive decision-making tool for personalized treatment planning. This repository implements the full MM-DLS pipeline, consisting of: - Lung-lesion segmentation with cross-attention transformer - Multi-modal feature fusion (CT, PET, Clinical) - Multi-task learning: Pathological classification, TNM staging, DFS and OS survival prediction - Cox proportional hazards survival loss The framework supports both classification (adenocarcinoma vs squamous cell carcinoma) and survival risk prediction tasks, and has been validated on large-scale multi-center clinical datasets. --- ## Key Features - **Multi-modal fusion:** Combines CT-based imaging features, PET metabolic biomarkers (SUVmax, SUVmean, SUVpeak, TLG, MTV), and structured clinical variables (age, sex, smoking status, smoking duration, smoking cessation history, tumor size). - **Multi-task learning:** Simultaneous optimization for: - Histological subtype classification (LUAD vs LUSC) - TNM stage classification (I-II, III, IV) - Disease-free survival (DFS) prediction - Overall survival (OS) prediction - **Attention-based feature fusion:** Transformer cross-attention module to integrate lung-lesion spatial information. - **Survival modeling:** Incorporates Cox Proportional Hazards loss for survival time prediction. - **Flexible data simulation and loading:** Includes utilities for synthetic data generation and multi-slice 2D volume processing. --- ## Architecture The overall MM-DLS system consists of: ![Python](https://img.shields.io/badge/python-3.9%2B-blue) ![PyTorch](https://img.shields.io/badge/PyTorch-2.x-red) ![CUDA](https://img.shields.io/badge/CUDA-11.8%2B-green) ![License](https://img.shields.io/badge/License-MIT-lightgrey) ![Status](https://img.shields.io/badge/Status-Research-orange) 1. **Segmentation Module (LungLesionSegmentor):** - Shared ResNet encoder to extract features from CT images. - Dual decoders for lung and lesion segmentation. - Transformer-based cross-attention module for enhanced spatial feature interaction between lung and lesion regions. 2. **Feature Encoders:** - `LesionEncoder`: 2D convolutional encoder for lesion patches. - `SpaceEncoder`: 2D convolutional encoder for lung-space contextual patches. 3. **Attention Fusion Module:** - `LesionAttentionFusion`: Multi-head attention to fuse lesion and lung features into compact patient-level representations. 4. **Patient-Level Fusion Model (PatientLevelFusionModel):** - Fully connected network that combines imaging, PET, and clinical features. - Outputs classification logits, DFS and OS risk scores. 5. **Loss Functions:** - Binary cross-entropy loss for classification. - Cox proportional hazards loss (`CoxPHLoss`) for survival prediction. --- ## Code Structure - `ModelLesionEncoder.py`: Lesion image encoder extracting discriminative features from multi-slice tumor regions. - `ModelSpaceEncoder.py`: Lung space encoder modeling anatomical and spatial context beyond the lesion. - `LesionAttentionFusion.py`: Attention-based fusion module for adaptive integration of lesion and spatial features. - `ClinicalFusionModel.py`: Patient-level fusion network combining imaging features, radiomics, PET signals, and clinical variables. - `HierMM_DLS.py`:Core hierarchical multimodal deep learning model supporting multi-task learning: (1)Subtype classification; (2)TNM stage prediction; (3)DFS and OS modeling - `CoxphLoss.py`: Cox proportional hazards loss for survival modeling with censored data. - `PatientDataset.py`:Patient dataset loader supporting imaging, radiomics, PET, clinical variables, survival outcomes, and treatment labels. - `LungLesionSegmentation.py`: Lung-lesion segmentation model - `ImageDataLoader.py`: Image preprocessing and loading utilities for multi-slice inputs. - `plot_results.py`: Visualization utilities for Kaplan–Meier curves, hazard ratios, and survival analysis results. --- ## Data Format The input data is organized per patient as follows: ### Imaging Data: - CT slices (PNG format) - Lung masks (binary masks, PNG) - Lesion masks (binary masks, PNG) - Slices grouped per patient ID ### Tabular Data: - Radiomics features: 128-dimensional vector (PyRadiomics extracted) - PET features: [SUVmax, SUVmean, SUVpeak, TLG, MTV] - Clinical features: [Age, Sex, Smoking Status, Smoking Duration, Smoking Cessation, Tumor Diameter] - Survival data: DFS time/event, OS time/event - Classification label: LUAD (0) or LUSC (1) Simulated data utilities are provided for experimentation and reproducibility. --- ## Installation ```bash # Clone repository conda create -n mm_dls python=3.10 -y conda activate mm_dls git clone https://github.com/your_username/MM-DLS-NSCLC.git ``` ## Install dependencies ```bash pip install -r requirements.txt ``` ## Usage ### 🔽 Download Pretrained Models Pretrained MM-DLS models are available for direct download: - **MM-DLS (Full multimodal, best checkpoint)** [⬇️ Download Pretrained Model](https://drive.google.com/file/d/1IcyCwMgCX8wv0NMp84U4wlzhLoXH7ayx/view?usp=drive_link) Size 1.3 MB The MM-DLS model is intentionally lightweight (~1.3 MB), as it employs compact CNN encoders and MLP-based multimodal fusion rather than large pretrained backbones, enabling efficient deployment and fast inference. After downloading, place the model files under the `./MODEL/` directory: Training: ```bash python train_patient_model.py ``` Evaluation: ```bash python test.py ``` Example Forward Pass: ```bash python run_sample.ipynb ``` ## Model Performance (from publication) ### Histological Subtype Classification: AUC: 0.85 ~ 0.92 across cohorts AP: 0.81 ~ 0.86 ### TNM Stage Prediction: AUC: Stage I-II (0.86 ~ 0.96), Stage III (0.85 ~ 0.95), Stage IV (0.83 ~ 0.95) ### AP and calibration maintained across internal and external sets DFS & OS Prognosis: C-index: up to 0.75 Time-dependent AUC (1/2/3 years): 0.77 ~ 0.91 Brier score: consistently < 0.2 for DFS and < 0.3 for OS Superior to single modality models (clinical-only or imaging-only) ## Reference Please cite our original publication when using this work: License This project is licensed under the MIT License. ⚠️ **Notice:** The pretrained model is shared solely for research validation purposes and **should not be used, distributed, or cited before the associated study is formally published**. Contact For any questions or collaborations, please contact: Dr. Fang Dai: daifang_cool@163.com