| | --- |
| | license: mit |
| | language: |
| | - en |
| | metrics: |
| | - accuracy 0.9860 |
| | - roc_auc 0.9979 |
| | pipeline_tag: image-feature-extraction |
| | tags: |
| | - gnn |
| | - link-prediction |
| | - pytorch |
| | - pytorch-geometric |
| | - graph-neural-networks |
| | - biology |
| | - microscopy |
| | - cell-tracking |
| | - lineage-tracking |
| |
|
| | --- |
| | # DuMM Bacteria Tracker Model |
| | ## Overview |
| |
|
| | This repository contains the trained weights and documentation for the **DuMM Bacteria Tracker Model**, a **Graph Neural Network (GNN)** designed for **cell lineage link prediction** in time-lapse microscopy data. |
| |
|
| | The model uses a custom implementation of a **Parameter Decoupled Network (PDN)** variant within an **Edge-Propagation Message Passing Neural Network (EP-MPNN)** architecture, inspired by recent advancements in dynamic graph representation learning. |
| |
|
| | ## Model Architecture and Implementation Details |
| |
|
| | * **Framework:** PyTorch and PyTorch Geometric (PyG). |
| | * **Architecture:** `LineageLinkPredictionGNN` (custom `nn.Module`). |
| | * **Core Layers:** Utilizes custom `EP_MPNN_Block` which incorporates **Distance & Similarity (DS)** features and a **Jumping Knowledge (JK)** network for aggregating features across multiple layers. |
| | * **Task:** Binary classification (Link Prediction) on candidate edges between adjacent time frames (continuation or division links). |
| | * **Loss Function:** `BCEWithLogitsLoss`. |
| | * **Input Features:** Node features are scaled using a `StandardScalerTransform` and dynamically used to compute edge attributes (Absolute Difference + Cosine Similarity) via the `DS_block`. |
| |
|
| | ### Input Features |
| |
|
| | The model uses a **10-dimensional feature vector** for each node (cell) derived from image analysis. These features capture morphological and intensity properties of the bacterial cells across two channels (Phase Contrast and Fluorescence). |
| |
|
| | | Feature Name | Type | Description | |
| | | :--- | :--- | :--- | |
| | | `area` | Morphological | Area of the cell segment. | |
| | | `centroid_y` | Positional | Y-coordinate of the cell's centroid (critical for 1D growth systems). | |
| | | `axis_major_length` | Morphological | Length of the cell's major axis. | |
| | | `axis_minor_length` | Morphological | Length of the cell's minor axis. | |
| | | `intensity_mean/max/min_phase` | Intensity | Mean, max, and min pixel intensity in the **Phase Contrast** channel. | |
| | | `intensity_mean/max/min_fluor` | Intensity | Mean, max, and min pixel intensity in the **Fluorescence** channel. | |
| |
|
| |
|
| | The model was trained on microscopy images of the duplex mother machine developed by the Jun lab (https://jun.ucsd.edu/mother_machine.php) |
| | |
| | ### Data Preprocessing and Splitting |
| | |
| | #### Splitting Strategy |
| | To ensure the model generalizes to future, unseen data, a **time-based temporal split** was employed: |
| | 1. **Training Set:** First 60% of unique time frames (`sorted_time_frames[:train_split_idx]`). |
| | 2. **Validation Set:** Next 20% of unique time frames (`sorted_time_frames[train_split_idx:val_split_idx]`). |
| | 3. **Test Set:** Final 20% of unique time frames (`sorted_time_frames[val_split_idx:]`). |
| | |
| | #### Normalization |
| | * **Method:** Node features were normalized using **Standard Scaling (`sklearn.preprocessing.StandardScaler`)**. |
| | * **Fit:** The scaler was **fitted *only* on the training set features** (`all_train_node_features_df`). |
| | * **Application:** The fitted scaler was then applied to transform the features in the Training, Validation, and Test sets via the `StandardScalerTransform`. This avoids data leakage. |
| | |
| | #### Graph Creation (Candidate Generation) |
| | Candidate edges (links between cells in adjacent time frames) were generated based on a custom set of geometric and morphological heuristics: |
| | * **Distance Constraint:** Max distance between centroids is limited by `max_dist_link` (default 50.0). |
| | * **Area Ratio Constraints:** |
| | * **Continuation (1-to-1):** `min_area_ratio_continuation` (0.8) to `max_area_ratio_continuation` (1.2). |
| | * **Division (1-to-2):** `min_area_ratio_division` (1.8) to `max_area_ratio_division` (2.2). |
| | |
| | ### Training Protocol |
| | |
| | | Hyperparameter | Value | Description | |
| | | :--- | :--- | :--- | |
| | | **GNN Layers (`num_blocks`)** | 2 | Number of sequential EP-MPNN blocks. | |
| | | **Hidden Channels** | 128 | Dimension for node and edge embeddings. | |
| | | **Optimizer** | Adam | Standard optimization algorithm. | |
| | | **Learning Rate** | 0.001 | Base learning rate. | |
| | | **Weight Decay** | 0.0005 | L2 regularization applied to prevent overfitting. | |
| | | **Batch Size** | 32 | Number of graphs processed per iteration. | |
| | | **Evaluation Metric** | Validation Accuracy (`val_acc`) | Used for saving the `best_link_prediction_model.pt`. | |
| | | **Early Stopping** | Yes | Monitors Validation Loss (`val_loss`) with a **patience of 10 epochs**. | |
| | | **Max Epochs** | 500 | Maximum number of training epochs. | |