Sophia Tang commited on
Commit ·
b55bace
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +61 -0
- .gitignore +15 -0
- LICENSE +21 -0
- README.md +68 -0
- assets/branchsbm.png +3 -0
- assets/branchsbm_anim.gif +3 -0
- assets/clonidine.png +3 -0
- assets/lidar.png +3 -0
- assets/mouse.png +3 -0
- assets/trametinib.png +3 -0
- assets/veres.png +3 -0
- configs/.DS_Store +0 -0
- configs/clonidine_100D.yaml +22 -0
- configs/clonidine_150D.yaml +22 -0
- configs/clonidine_50D.yaml +22 -0
- configs/clonidine_50Dsingle.yaml +22 -0
- configs/lidar.yaml +15 -0
- configs/lidar_single.yaml +15 -0
- configs/mouse.yaml +18 -0
- configs/mouse_single.yaml +18 -0
- configs/trametinib.yaml +22 -0
- configs/trametinib_single.yaml +22 -0
- configs/veres.yaml +25 -0
- dataloaders/.DS_Store +0 -0
- dataloaders/clonidine_single_branch.py +265 -0
- dataloaders/clonidine_v2_data.py +280 -0
- dataloaders/lidar_data.py +529 -0
- dataloaders/lidar_data_single.py +274 -0
- dataloaders/mouse_data.py +453 -0
- dataloaders/three_branch_data.py +306 -0
- dataloaders/trametinib_single.py +268 -0
- dataloaders/veres_leiden_data.py +317 -0
- environment.yml +41 -0
- parsers.py +502 -0
- scripts/README.md +226 -0
- scripts/clonidine100.sh +26 -0
- scripts/clonidine150.sh +26 -0
- scripts/clonidine50.sh +26 -0
- scripts/clonidine50_single.sh +26 -0
- scripts/lidar.sh +26 -0
- scripts/lidar_single.sh +27 -0
- scripts/mouse.sh +25 -0
- scripts/mouse_single.sh +25 -0
- scripts/trametinib.sh +26 -0
- scripts/trametinib_single.sh +26 -0
- scripts/veres.sh +26 -0
- src/.DS_Store +0 -0
- src/branch_flow_net_test.py +1791 -0
- src/branch_flow_net_train.py +375 -0
- src/branch_growth_net_train.py +994 -0
.gitattributes
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
branchsbm.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
branchsbm/branchsbm.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
branchsbm/clonidine.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
branchsbm/lidar.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
branchsbm/mouse.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
branchsbm/trametinib.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
clonidine.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
lidar.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
mouse.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
trametinib.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
data/pca_and_leiden_labels.csv filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
data/mouse_hematopoiesis.csv filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
data/simulation_gene.csv filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
data/Trametinib_5.0uM_pca_and_leidenumap_labels.csv filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
data/Veres_alltime.csv filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
data/Weinreb_alltime.csv filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
data/Weinreb_t2_leiden_clusters.csv filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
data/eb_noscale.csv filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
data/emt.csv filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
data/*.las filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
data/*.las filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
assets/veres.png filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
logs/
|
| 2 |
+
wandb/
|
| 3 |
+
__pycache__/
|
| 4 |
+
checkpoints/
|
| 5 |
+
lightining_logs/
|
| 6 |
+
results/
|
| 7 |
+
*.log
|
| 8 |
+
*.pyc
|
| 9 |
+
lightining_logs/
|
| 10 |
+
figures/
|
| 11 |
+
*.ckpt
|
| 12 |
+
*.csv
|
| 13 |
+
data/
|
| 14 |
+
extra/
|
| 15 |
+
.vscode/
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Sophia Tang
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [Branched Schrödinger Bridge Matching](https://arxiv.org/abs/2506.09007) (ICLR 2026) 🌳🧬
|
| 2 |
+
|
| 3 |
+
[**Sophia Tang**](https://sophtang.github.io/), [**Yinuo Zhang**](https://www.linkedin.com/in/yinuozhang98/), [**Alexander Tong**](https://www.alextong.net/) and [**Pranam Chatterjee**](https://www.chatterjeelab.com/)
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
This is the repository for [**Branched Schrödinger Bridge Matching**](https://arxiv.org/abs/2506.09007) (ICLR 2026) 🌳🧬. It is partially built on the [**Metric Flow Matching repo**](https://github.com/kkapusniak/metric-flow-matching) ([Kapusniak et al., 2024](https://arxiv.org/abs/2405.14780)).
|
| 8 |
+
|
| 9 |
+
Predicting how a population evolves between an initial and final state is central to many problems in generative modeling, from simulating perturbation responses to modelling cell fate decisions 🧫. Existing approaches, such as flow matching and Schrödinger Bridge Matching, effectively learn mappings between two distributions by modelling a single stochastic path. However, these methods are **inherently limited to unimodal transitions and cannot capture branched or divergent evolution from a common origin to multiple distinct outcomes.**
|
| 10 |
+
|
| 11 |
+
A key challenge in trajectory matching is reconstructing multi-modal marginals, particularly when modes diverge along distinct dynamical paths . Existing Schrödinger bridge and flow matching frameworks approximate multi-modal distributions by simulating many *independent* particle trajectories, which are susceptible to mode collapse, with particles concentrating on dominant high-density modes or traversing only low-energy intermediate paths.
|
| 12 |
+
|
| 13 |
+
To address this, we introduce **Branched Schrödinger Bridge Matching (BranchSBM)** 🌳🧬, a novel framework that learns a set of diverging velocity fields to reconstruct multi-modal target distributions while simultaneously learning growth networks that allocate mass across branches. Guided by a time-dependent potential energy function Vt, BranchSBM captures diverging, energy-minimizing dynamics without requiring intermediate-time supervision and can generate the full branched evolution from a single initial sample.
|
| 14 |
+
|
| 15 |
+
🌟 We define the **Branched Generalized Schrödinger Bridge problem** and introduce BranchSBM, a novel matching framework that learns optimal branched trajectories from an initial distribution to multiple target distributions.
|
| 16 |
+
|
| 17 |
+
🌟 We derive the Branched Conditional Stochastic Optimal Control (CondSOC) problem as the sum of Unbalanced CondSOC objectives and leverage a multi-stage training algorithm to learn the optimal branching drift and growth fields that transport mass along a branched trajectory.
|
| 18 |
+
|
| 19 |
+
🌟 We demonstrate the unique capability of BranchSBM to model dynamic branching trajectories across various real-world problems, including 3D navigation over LiDAR manifolds, modelling differentiating single-cell population dynamics, and simulating heterogeneous cellular responses to drug perturbation.
|
| 20 |
+
|
| 21 |
+
# Experiments
|
| 22 |
+
Code and instructions to reproduce our results are provided in `/scripts/README`.
|
| 23 |
+
|
| 24 |
+
## LiDAR Experiment 🗻
|
| 25 |
+
|
| 26 |
+
As a proof of concept, we first evaluate BranchSBM for navigating branched paths along the surface of a three-dimensional LiDAR manifold, from an initial distribution to two distinct target distributions while remaining on low-altitude regions of the manifold.
|
| 27 |
+
|
| 28 |
+

|
| 29 |
+
|
| 30 |
+
## Mouse Hematopoiesis and Pancreatic β-Cell Experiment 🧫
|
| 31 |
+
|
| 32 |
+
BranchSBM is uniquely positioned to model single-cell population dynamics where a homogeneous cell population (e.g., progenitor cells) differentiates into several distinct subpopulation branches, each of which independently undergoes growth dynamics. In this experiment, we demonstrate this capability on mouse hematopoiesis data and pancreatic β-cell differentiation data.
|
| 33 |
+
|
| 34 |
+
We evaluate BranchSBM on a mouse hematopoiesis scRNA-seq dataset containing three developmental time points representing progenitor cells differentiating into two terminal cell fates. Compared to a single-branch SBM, BranchSBM successfully learns distinct branching trajectories and accurately reconstructs intermediate cell states, demonstrating its ability to recover lineage bifurcation dynamics.
|
| 35 |
+
|
| 36 |
+

|
| 37 |
+
|
| 38 |
+
We evaluate BranchSBM on a pancreatic β-cell differentiation dataset ([Veres et al., 2019](https://www.nature.com/articles/s41586-019-1168-5)) containing 51,274 cells collected across eight time points as human pluripotent stem cells differentiate into pancreatic β-like cells. Cells are projected into a 30-dimensional PCA space, and Leiden clustering is used to define 11 terminal cell populations at the final time point.
|
| 39 |
+
|
| 40 |
+
BranchSBM is trained using only samples from the initial and final states, while intermediate distributions are inferred by learning trajectories constrained to the data manifold using an RBF state cost. Compared to baselines, BranchSBM significantly improves reconstruction of both intermediate and terminal distributions, achieving lower Wasserstein distances at validation time points. These results demonstrate that BranchSBM can accurately recover branching differentiation dynamics without intermediate supervision.
|
| 41 |
+
|
| 42 |
+

|
| 43 |
+
|
| 44 |
+
## Cell Perturbation Modelling Experiment 💉
|
| 45 |
+
Predicting the effects of perturbation on cell state dynamics is a crucial problem for therapeutic design. In this experiment, we leverage BranchSBM to model the **trajectories of a single cell line from a single homogeneous state to multiple heterogeneous states after a drug-induced perturbation**. We demonstrate that BranchSBM is capable of modeling high-dimensional gene expression data and learning branched trajectories that accurately reconstruct diverging perturbed cell populations.
|
| 46 |
+
|
| 47 |
+
We extract the data for a single cell line (A-549) under perturbation with Clonidine and Trametinib at 5 µL, selected based on cell abundance and response diversity from the Tahoe-100M dataset.
|
| 48 |
+
|
| 49 |
+
For the Clonidine perturbation data, we show that **BranchSBM reconstructs the ground-truth distributions, capturing the location and spread of the dataset**, whereas single-branch SBM fails to differentiate cells in cluster 1 that differ from cluster 0 in higher-dimensional principal components. We also show that BranchSBM can simulate trajectories in high-dimensional state spaces by *scaling up to 150 PCs*.
|
| 50 |
+
|
| 51 |
+

|
| 52 |
+
|
| 53 |
+
We further show that BranchSBM can **scale beyond two branches by modeling the perturbed cell population of Trametinib-treated cells**, which diverge into *three distinct clusters*. We trained BranchSBM with three endpoints and single-branch SBM with one endpoint containing all three clusters on the top 50 PCs.
|
| 54 |
+
|
| 55 |
+

|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
## Citation
|
| 59 |
+
If you find this repository helpful for your publications, please consider citing our paper:
|
| 60 |
+
```
|
| 61 |
+
@article{tang2026branchsbm,
|
| 62 |
+
title={Branched Schrödinger Bridge Matching},
|
| 63 |
+
author={Tang, Sophia and Zhang, Yinuo and Tong, Alexander and Chatterjee, Pranam},
|
| 64 |
+
journal={14th International Conference on Learning Representations (ICLR 2026)},
|
| 65 |
+
year={2026}
|
| 66 |
+
}
|
| 67 |
+
```
|
| 68 |
+
To use this repository, you agree to abide by the MIT License.
|
assets/branchsbm.png
ADDED
|
Git LFS Details
|
assets/branchsbm_anim.gif
ADDED
|
Git LFS Details
|
assets/clonidine.png
ADDED
|
Git LFS Details
|
assets/lidar.png
ADDED
|
Git LFS Details
|
assets/mouse.png
ADDED
|
Git LFS Details
|
assets/trametinib.png
ADDED
|
Git LFS Details
|
assets/veres.png
ADDED
|
Git LFS Details
|
configs/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
configs/clonidine_100D.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "tahoe"
|
| 2 |
+
data_name: "clonidine100D"
|
| 3 |
+
accelerator: "gpu"
|
| 4 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 6 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 7 |
+
dim: 100
|
| 8 |
+
t_exclude: []
|
| 9 |
+
time_geopath: true
|
| 10 |
+
whiten: false
|
| 11 |
+
velocity_metric: "rbf"
|
| 12 |
+
metric_patience: 25
|
| 13 |
+
patience: 25
|
| 14 |
+
n_centers: 300
|
| 15 |
+
kappa: 2
|
| 16 |
+
rho: -2.75
|
| 17 |
+
alpha_metric: 1
|
| 18 |
+
metric_epochs: 100
|
| 19 |
+
branchsbm: true
|
| 20 |
+
seeds: [42]
|
| 21 |
+
branches: 2
|
| 22 |
+
metric_clusters: 3
|
configs/clonidine_150D.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "tahoe"
|
| 2 |
+
data_name: "clonidine150D"
|
| 3 |
+
accelerator: "gpu"
|
| 4 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 6 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 7 |
+
dim: 150
|
| 8 |
+
t_exclude: []
|
| 9 |
+
time_geopath: true
|
| 10 |
+
whiten: false
|
| 11 |
+
velocity_metric: "rbf"
|
| 12 |
+
metric_patience: 25
|
| 13 |
+
patience: 25
|
| 14 |
+
n_centers: 300
|
| 15 |
+
kappa: 3
|
| 16 |
+
rho: -2.75
|
| 17 |
+
alpha_metric: 1
|
| 18 |
+
metric_epochs: 100
|
| 19 |
+
branchsbm: true
|
| 20 |
+
seeds: [42]
|
| 21 |
+
branches: 2
|
| 22 |
+
metric_clusters: 3
|
configs/clonidine_50D.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "tahoe"
|
| 2 |
+
data_name: "clonidine50D"
|
| 3 |
+
accelerator: "gpu"
|
| 4 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 6 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 7 |
+
dim: 50
|
| 8 |
+
t_exclude: []
|
| 9 |
+
time_geopath: true
|
| 10 |
+
whiten: false
|
| 11 |
+
velocity_metric: "rbf"
|
| 12 |
+
metric_patience: 25
|
| 13 |
+
patience: 25
|
| 14 |
+
n_centers: 150
|
| 15 |
+
kappa: 1.5
|
| 16 |
+
rho: -2.75
|
| 17 |
+
alpha_metric: 1
|
| 18 |
+
metric_epochs: 100
|
| 19 |
+
branchsbm: true
|
| 20 |
+
seeds: [42]
|
| 21 |
+
branches: 2
|
| 22 |
+
metric_clusters: 3
|
configs/clonidine_50Dsingle.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "tahoe"
|
| 2 |
+
data_name: "clonidine50Dsingle"
|
| 3 |
+
accelerator: "gpu"
|
| 4 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 6 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 7 |
+
dim: 50
|
| 8 |
+
t_exclude: []
|
| 9 |
+
time_geopath: true
|
| 10 |
+
whiten: false
|
| 11 |
+
velocity_metric: "rbf"
|
| 12 |
+
metric_patience: 25
|
| 13 |
+
patience: 25
|
| 14 |
+
n_centers: 150
|
| 15 |
+
kappa: 1.5
|
| 16 |
+
rho: -2.75
|
| 17 |
+
alpha_metric: 1
|
| 18 |
+
metric_epochs: 100
|
| 19 |
+
branchsbm: true
|
| 20 |
+
seeds: [42]
|
| 21 |
+
branches: 1
|
| 22 |
+
metric_clusters: 2
|
configs/lidar.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "lidar"
|
| 2 |
+
data_name: "lidar"
|
| 3 |
+
dim: 3
|
| 4 |
+
whiten: true
|
| 5 |
+
t_exclude: []
|
| 6 |
+
velocity_metric: "land"
|
| 7 |
+
gammas: [0.125]
|
| 8 |
+
rho: 0.001
|
| 9 |
+
branchsbm: true
|
| 10 |
+
seeds: [42]
|
| 11 |
+
patience_geopath: 50
|
| 12 |
+
metric_epochs: 100
|
| 13 |
+
time_geopath: true
|
| 14 |
+
branches: 2
|
| 15 |
+
metric_clusters: 3
|
configs/lidar_single.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "lidar"
|
| 2 |
+
data_name: "lidarsingle"
|
| 3 |
+
dim: 3
|
| 4 |
+
whiten: true
|
| 5 |
+
t_exclude: []
|
| 6 |
+
velocity_metric: "land"
|
| 7 |
+
gammas: [0.125]
|
| 8 |
+
rho: 0.001
|
| 9 |
+
branchsbm: true
|
| 10 |
+
seeds: [42]
|
| 11 |
+
patience_geopath: 50
|
| 12 |
+
metric_epochs: 100
|
| 13 |
+
time_geopath: true
|
| 14 |
+
branches: 1
|
| 15 |
+
metric_clusters: 2
|
configs/mouse.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "scrna"
|
| 2 |
+
data_name: "mouse"
|
| 3 |
+
hidden_dims_geopath: [64, 64, 64]
|
| 4 |
+
hidden_dims_flow: [64, 64, 64]
|
| 5 |
+
hidden_dims_growth: [64, 64, 64]
|
| 6 |
+
dim: 2
|
| 7 |
+
whiten: false
|
| 8 |
+
t_exclude: []
|
| 9 |
+
velocity_metric: "land"
|
| 10 |
+
gammas: [0.125]
|
| 11 |
+
rho: 0.001
|
| 12 |
+
branchsbm: true
|
| 13 |
+
seeds: [42]
|
| 14 |
+
patience_geopath: 50
|
| 15 |
+
metric_epochs: 100
|
| 16 |
+
time_geopath: false
|
| 17 |
+
branches: 2
|
| 18 |
+
metric_clusters: 2
|
configs/mouse_single.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "scrna"
|
| 2 |
+
data_name: "mousesingle"
|
| 3 |
+
hidden_dims_geopath: [64, 64, 64]
|
| 4 |
+
hidden_dims_flow: [64, 64, 64]
|
| 5 |
+
hidden_dims_growth: [64, 64, 64]
|
| 6 |
+
dim: 2
|
| 7 |
+
whiten: false
|
| 8 |
+
t_exclude: []
|
| 9 |
+
velocity_metric: "land"
|
| 10 |
+
gammas: [0.125]
|
| 11 |
+
rho: 0.001
|
| 12 |
+
branchsbm: true
|
| 13 |
+
seeds: [42]
|
| 14 |
+
patience_geopath: 50
|
| 15 |
+
metric_epochs: 100
|
| 16 |
+
time_geopath: true
|
| 17 |
+
branches: 1
|
| 18 |
+
metric_clusters: 2
|
configs/trametinib.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "tahoe"
|
| 2 |
+
data_name: "trametinib"
|
| 3 |
+
accelerator: "gpu"
|
| 4 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 6 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 7 |
+
dim: 50
|
| 8 |
+
t_exclude: []
|
| 9 |
+
time_geopath: true
|
| 10 |
+
whiten: false
|
| 11 |
+
velocity_metric: "rbf"
|
| 12 |
+
metric_patience: 25
|
| 13 |
+
patience: 25
|
| 14 |
+
n_centers: 150
|
| 15 |
+
kappa: 1.5
|
| 16 |
+
rho: -2.75
|
| 17 |
+
alpha_metric: 1
|
| 18 |
+
metric_epochs: 100
|
| 19 |
+
branchsbm: true
|
| 20 |
+
seeds: [42]
|
| 21 |
+
branches: 3
|
| 22 |
+
metric_clusters: 4
|
configs/trametinib_single.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "tahoe"
|
| 2 |
+
data_name: "trametinibsingle"
|
| 3 |
+
accelerator: "gpu"
|
| 4 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 6 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 7 |
+
dim: 50
|
| 8 |
+
t_exclude: []
|
| 9 |
+
time_geopath: true
|
| 10 |
+
whiten: false
|
| 11 |
+
velocity_metric: "rbf"
|
| 12 |
+
metric_patience: 25
|
| 13 |
+
patience: 25
|
| 14 |
+
n_centers: 150
|
| 15 |
+
kappa: 1.5
|
| 16 |
+
rho: -2.75
|
| 17 |
+
alpha_metric: 1
|
| 18 |
+
metric_epochs: 100
|
| 19 |
+
branchsbm: true
|
| 20 |
+
seeds: [42]
|
| 21 |
+
branches: 1
|
| 22 |
+
metric_clusters: 2
|
configs/veres.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "scrna"
|
| 2 |
+
data_name: "veres"
|
| 3 |
+
data_path: "data/Veres_alltime.csv"
|
| 4 |
+
accelerator: "gpu"
|
| 5 |
+
hidden_dims_geopath: [512, 512, 512]
|
| 6 |
+
hidden_dims_flow: [512, 512, 512]
|
| 7 |
+
hidden_dims_growth: [512, 512, 512]
|
| 8 |
+
dim: 30
|
| 9 |
+
t_exclude: []
|
| 10 |
+
time_geopath: true
|
| 11 |
+
whiten: false
|
| 12 |
+
velocity_metric: "rbf"
|
| 13 |
+
metric_patience: 25
|
| 14 |
+
patience: 25
|
| 15 |
+
patience_geopath: 50
|
| 16 |
+
n_centers: 300
|
| 17 |
+
kappa: 2
|
| 18 |
+
rho: 0.001
|
| 19 |
+
alpha_metric: 1.0
|
| 20 |
+
metric_epochs: 100
|
| 21 |
+
branchsbm: true
|
| 22 |
+
seeds: [42]
|
| 23 |
+
branches: 5
|
| 24 |
+
metric_clusters: 2
|
| 25 |
+
batch_size: 256
|
dataloaders/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
dataloaders/clonidine_single_branch.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
from sklearn.preprocessing import StandardScaler
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
from functools import partial
|
| 10 |
+
from scipy.spatial import cKDTree
|
| 11 |
+
from sklearn.cluster import KMeans
|
| 12 |
+
from torch.utils.data import TensorDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ClonidineSingleBranchDataModule(pl.LightningDataModule):
|
| 16 |
+
def __init__(self, args):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.save_hyperparameters()
|
| 19 |
+
|
| 20 |
+
self.batch_size = args.batch_size
|
| 21 |
+
self.max_dim = args.dim
|
| 22 |
+
self.whiten = args.whiten
|
| 23 |
+
self.split_ratios = args.split_ratios
|
| 24 |
+
|
| 25 |
+
self.dim = args.dim
|
| 26 |
+
print("dimension")
|
| 27 |
+
print(self.dim)
|
| 28 |
+
# Path to your combined data
|
| 29 |
+
self.data_path = "./data/pca_and_leiden_labels.csv"
|
| 30 |
+
self.num_timesteps = 2
|
| 31 |
+
self.args = args
|
| 32 |
+
self._prepare_data()
|
| 33 |
+
|
| 34 |
+
def _prepare_data(self):
|
| 35 |
+
df = pd.read_csv(self.data_path, comment='#')
|
| 36 |
+
df = df.iloc[:, 1:]
|
| 37 |
+
df = df.replace('', np.nan)
|
| 38 |
+
pc_cols = df.columns[:self.dim]
|
| 39 |
+
for col in pc_cols:
|
| 40 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 41 |
+
leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
|
| 42 |
+
leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM'
|
| 43 |
+
|
| 44 |
+
dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
|
| 45 |
+
clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
|
| 46 |
+
|
| 47 |
+
dmso_data = df[dmso_mask].copy()
|
| 48 |
+
clonidine_data = df[clonidine_mask].copy()
|
| 49 |
+
|
| 50 |
+
top_clonidine_clusters = ['0.0', '4.0']
|
| 51 |
+
|
| 52 |
+
x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
|
| 53 |
+
x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
|
| 54 |
+
|
| 55 |
+
x1_1_coords = x1_1_data[pc_cols].values
|
| 56 |
+
x1_2_coords = x1_2_data[pc_cols].values
|
| 57 |
+
|
| 58 |
+
x1_1_coords = x1_1_coords.astype(float)
|
| 59 |
+
x1_2_coords = x1_2_coords.astype(float)
|
| 60 |
+
|
| 61 |
+
# Target size is now the minimum across all three endpoint clusters
|
| 62 |
+
target_size = min(len(x1_1_coords), len(x1_2_coords),)
|
| 63 |
+
|
| 64 |
+
# Helper function to select points closest to centroid
|
| 65 |
+
def select_closest_to_centroid(coords, target_size):
|
| 66 |
+
if len(coords) <= target_size:
|
| 67 |
+
return coords
|
| 68 |
+
|
| 69 |
+
# Calculate centroid
|
| 70 |
+
centroid = np.mean(coords, axis=0)
|
| 71 |
+
|
| 72 |
+
# Calculate distances to centroid
|
| 73 |
+
distances = np.linalg.norm(coords - centroid, axis=1)
|
| 74 |
+
|
| 75 |
+
# Get indices of closest points
|
| 76 |
+
closest_indices = np.argsort(distances)[:target_size]
|
| 77 |
+
|
| 78 |
+
return coords[closest_indices]
|
| 79 |
+
|
| 80 |
+
# Sample all endpoint clusters to target size using centroid-based selection
|
| 81 |
+
x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
|
| 82 |
+
x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
|
| 83 |
+
|
| 84 |
+
dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
|
| 85 |
+
|
| 86 |
+
# DMSO (unchanged)
|
| 87 |
+
largest_dmso_cluster = dmso_cluster_counts.index[0]
|
| 88 |
+
dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
|
| 89 |
+
|
| 90 |
+
dmso_coords = dmso_cluster_data[pc_cols].values
|
| 91 |
+
|
| 92 |
+
# Random sampling from largest DMSO cluster to match target size
|
| 93 |
+
# For DMSO, we'll also use centroid-based selection for consistency
|
| 94 |
+
if len(dmso_coords) >= target_size:
|
| 95 |
+
x0_coords = select_closest_to_centroid(dmso_coords, target_size)
|
| 96 |
+
else:
|
| 97 |
+
# If largest cluster is smaller than target, use all of it and pad with other DMSO cells
|
| 98 |
+
remaining_needed = target_size - len(dmso_coords)
|
| 99 |
+
other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
|
| 100 |
+
other_dmso_coords = other_dmso_data[pc_cols].values
|
| 101 |
+
|
| 102 |
+
if len(other_dmso_coords) >= remaining_needed:
|
| 103 |
+
# Select closest to centroid from other DMSO cells
|
| 104 |
+
other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
|
| 105 |
+
x0_coords = np.vstack([dmso_coords, other_selected])
|
| 106 |
+
else:
|
| 107 |
+
# Use all available DMSO cells and reduce target size
|
| 108 |
+
all_dmso_coords = dmso_data[pc_cols].values
|
| 109 |
+
target_size = min(target_size, len(all_dmso_coords))
|
| 110 |
+
x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
|
| 111 |
+
|
| 112 |
+
# Re-select endpoint clusters with updated target size
|
| 113 |
+
x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
|
| 114 |
+
x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
|
| 115 |
+
|
| 116 |
+
# No need to resample since we already selected the right number
|
| 117 |
+
# The endpoint clusters are already at target_size from centroid-based selection
|
| 118 |
+
|
| 119 |
+
self.n_samples = target_size
|
| 120 |
+
|
| 121 |
+
x0 = torch.tensor(x0_coords, dtype=torch.float32)
|
| 122 |
+
x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
|
| 123 |
+
x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
|
| 124 |
+
x1 = torch.cat([x1_1, x1_2], dim=0)
|
| 125 |
+
|
| 126 |
+
self.coords_t0 = x0
|
| 127 |
+
self.coords_t1 = x1
|
| 128 |
+
self.time_labels = np.concatenate([
|
| 129 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 130 |
+
np.ones(len(self.coords_t1)), # t=1
|
| 131 |
+
])
|
| 132 |
+
|
| 133 |
+
split_index = int(target_size * self.split_ratios[0])
|
| 134 |
+
|
| 135 |
+
if target_size - split_index < self.batch_size:
|
| 136 |
+
split_index = target_size - self.batch_size
|
| 137 |
+
print('total count is:', target_size)
|
| 138 |
+
|
| 139 |
+
train_x0 = x0[:split_index]
|
| 140 |
+
val_x0 = x0[split_index:]
|
| 141 |
+
train_x1 = x1[:split_index]
|
| 142 |
+
val_x1 = x1[split_index:]
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
self.val_x0 = val_x0
|
| 146 |
+
|
| 147 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 148 |
+
train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0)
|
| 149 |
+
|
| 150 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 151 |
+
val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0)
|
| 152 |
+
|
| 153 |
+
# Updated train dataloaders to include x1_3
|
| 154 |
+
self.train_dataloaders = {
|
| 155 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 156 |
+
"x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
self.val_dataloaders = {
|
| 160 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 161 |
+
"x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
all_coords = df[pc_cols].dropna().values.astype(float)
|
| 165 |
+
self.dataset = torch.tensor(all_coords, dtype=torch.float32)
|
| 166 |
+
self.tree = cKDTree(all_coords)
|
| 167 |
+
|
| 168 |
+
self.test_dataloaders = {
|
| 169 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 170 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
# Updated metric samples - now using 4 clusters instead of 3
|
| 174 |
+
#km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy())
|
| 175 |
+
km_all = KMeans(n_clusters=2, random_state=0).fit(self.dataset.numpy())
|
| 176 |
+
|
| 177 |
+
cluster_labels = km_all.labels_
|
| 178 |
+
|
| 179 |
+
cluster_0_mask = cluster_labels == 0
|
| 180 |
+
cluster_1_mask = cluster_labels == 1
|
| 181 |
+
|
| 182 |
+
samples = self.dataset.cpu().numpy()
|
| 183 |
+
|
| 184 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 185 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 186 |
+
|
| 187 |
+
self.metric_samples_dataloaders = [
|
| 188 |
+
DataLoader(
|
| 189 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 190 |
+
batch_size=cluster_1_data.shape[0],
|
| 191 |
+
shuffle=False,
|
| 192 |
+
drop_last=False,
|
| 193 |
+
),
|
| 194 |
+
DataLoader(
|
| 195 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 196 |
+
batch_size=cluster_0_data.shape[0],
|
| 197 |
+
shuffle=False,
|
| 198 |
+
drop_last=False,
|
| 199 |
+
),
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
def train_dataloader(self):
|
| 203 |
+
combined_loaders = {
|
| 204 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 205 |
+
"metric_samples": CombinedLoader(
|
| 206 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 207 |
+
),
|
| 208 |
+
}
|
| 209 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 210 |
+
|
| 211 |
+
def val_dataloader(self):
|
| 212 |
+
combined_loaders = {
|
| 213 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 214 |
+
"metric_samples": CombinedLoader(
|
| 215 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 216 |
+
),
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def test_dataloader(self):
|
| 224 |
+
combined_loaders = {
|
| 225 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 226 |
+
"metric_samples": CombinedLoader(
|
| 227 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 228 |
+
),
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 232 |
+
|
| 233 |
+
def get_manifold_proj(self, points):
|
| 234 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 235 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 236 |
+
|
| 237 |
+
@staticmethod
|
| 238 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 239 |
+
"""
|
| 240 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 241 |
+
This replaces the plane projection for 2D manifold regularization
|
| 242 |
+
"""
|
| 243 |
+
points_np = x.detach().cpu().numpy()
|
| 244 |
+
_, idx = tree.query(points_np, k=k)
|
| 245 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 246 |
+
|
| 247 |
+
# Compute weighted average of neighbors
|
| 248 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 249 |
+
weights = torch.exp(-dists / temp)
|
| 250 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 251 |
+
|
| 252 |
+
# Weighted average of neighbors
|
| 253 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 254 |
+
|
| 255 |
+
# Blend original point with smoothed version
|
| 256 |
+
alpha = 0.3 # How much smoothing to apply
|
| 257 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 258 |
+
|
| 259 |
+
def get_timepoint_data(self):
|
| 260 |
+
"""Return data organized by timepoints for visualization"""
|
| 261 |
+
return {
|
| 262 |
+
't0': self.coords_t0,
|
| 263 |
+
't1': self.coords_t1,
|
| 264 |
+
'time_labels': self.time_labels
|
| 265 |
+
}
|
dataloaders/clonidine_v2_data.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
from sklearn.preprocessing import StandardScaler
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
from functools import partial
|
| 10 |
+
from scipy.spatial import cKDTree
|
| 11 |
+
from sklearn.cluster import KMeans
|
| 12 |
+
from torch.utils.data import TensorDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ClonidineV2DataModule(pl.LightningDataModule):
|
| 16 |
+
def __init__(self, args):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.save_hyperparameters()
|
| 19 |
+
|
| 20 |
+
self.batch_size = args.batch_size
|
| 21 |
+
self.max_dim = args.dim
|
| 22 |
+
self.whiten = args.whiten
|
| 23 |
+
self.split_ratios = args.split_ratios
|
| 24 |
+
|
| 25 |
+
self.dim = args.dim
|
| 26 |
+
print("dimension")
|
| 27 |
+
print(self.dim)
|
| 28 |
+
# Path to your combined data
|
| 29 |
+
self.data_path = "./data/pca_and_leiden_labels.csv"
|
| 30 |
+
self.num_timesteps = 2
|
| 31 |
+
self.args = args
|
| 32 |
+
self._prepare_data()
|
| 33 |
+
|
| 34 |
+
def _prepare_data(self):
|
| 35 |
+
df = pd.read_csv(self.data_path, comment='#')
|
| 36 |
+
df = df.iloc[:, 1:]
|
| 37 |
+
df = df.replace('', np.nan)
|
| 38 |
+
pc_cols = df.columns[:self.dim]
|
| 39 |
+
for col in pc_cols:
|
| 40 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 41 |
+
leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
|
| 42 |
+
leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM'
|
| 43 |
+
|
| 44 |
+
dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
|
| 45 |
+
clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
|
| 46 |
+
|
| 47 |
+
dmso_data = df[dmso_mask].copy()
|
| 48 |
+
clonidine_data = df[clonidine_mask].copy()
|
| 49 |
+
|
| 50 |
+
top_clonidine_clusters = ['0.0', '4.0']
|
| 51 |
+
|
| 52 |
+
x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
|
| 53 |
+
x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
|
| 54 |
+
|
| 55 |
+
x1_1_coords = x1_1_data[pc_cols].values
|
| 56 |
+
x1_2_coords = x1_2_data[pc_cols].values
|
| 57 |
+
|
| 58 |
+
x1_1_coords = x1_1_coords.astype(float)
|
| 59 |
+
x1_2_coords = x1_2_coords.astype(float)
|
| 60 |
+
|
| 61 |
+
# Target size is now the minimum across all three endpoint clusters
|
| 62 |
+
target_size = min(len(x1_1_coords), len(x1_2_coords),)
|
| 63 |
+
|
| 64 |
+
# Helper function to select points closest to centroid
|
| 65 |
+
def select_closest_to_centroid(coords, target_size):
|
| 66 |
+
if len(coords) <= target_size:
|
| 67 |
+
return coords
|
| 68 |
+
|
| 69 |
+
# Calculate centroid
|
| 70 |
+
centroid = np.mean(coords, axis=0)
|
| 71 |
+
|
| 72 |
+
# Calculate distances to centroid
|
| 73 |
+
distances = np.linalg.norm(coords - centroid, axis=1)
|
| 74 |
+
|
| 75 |
+
# Get indices of closest points
|
| 76 |
+
closest_indices = np.argsort(distances)[:target_size]
|
| 77 |
+
|
| 78 |
+
return coords[closest_indices]
|
| 79 |
+
|
| 80 |
+
# Sample all endpoint clusters to target size using centroid-based selection
|
| 81 |
+
x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
|
| 82 |
+
x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
|
| 83 |
+
|
| 84 |
+
dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
|
| 85 |
+
|
| 86 |
+
# DMSO (unchanged)
|
| 87 |
+
largest_dmso_cluster = dmso_cluster_counts.index[0]
|
| 88 |
+
dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
|
| 89 |
+
|
| 90 |
+
dmso_coords = dmso_cluster_data[pc_cols].values
|
| 91 |
+
|
| 92 |
+
# Random sampling from largest DMSO cluster to match target size
|
| 93 |
+
# For DMSO, we'll also use centroid-based selection for consistency
|
| 94 |
+
if len(dmso_coords) >= target_size:
|
| 95 |
+
x0_coords = select_closest_to_centroid(dmso_coords, target_size)
|
| 96 |
+
else:
|
| 97 |
+
# If largest cluster is smaller than target, use all of it and pad with other DMSO cells
|
| 98 |
+
remaining_needed = target_size - len(dmso_coords)
|
| 99 |
+
other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
|
| 100 |
+
other_dmso_coords = other_dmso_data[pc_cols].values
|
| 101 |
+
|
| 102 |
+
if len(other_dmso_coords) >= remaining_needed:
|
| 103 |
+
# Select closest to centroid from other DMSO cells
|
| 104 |
+
other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
|
| 105 |
+
x0_coords = np.vstack([dmso_coords, other_selected])
|
| 106 |
+
else:
|
| 107 |
+
# Use all available DMSO cells and reduce target size
|
| 108 |
+
all_dmso_coords = dmso_data[pc_cols].values
|
| 109 |
+
target_size = min(target_size, len(all_dmso_coords))
|
| 110 |
+
x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
|
| 111 |
+
|
| 112 |
+
# Re-select endpoint clusters with updated target size
|
| 113 |
+
x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
|
| 114 |
+
x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
|
| 115 |
+
|
| 116 |
+
# No need to resample since we already selected the right number
|
| 117 |
+
# The endpoint clusters are already at target_size from centroid-based selection
|
| 118 |
+
|
| 119 |
+
self.n_samples = target_size
|
| 120 |
+
|
| 121 |
+
x0 = torch.tensor(x0_coords, dtype=torch.float32)
|
| 122 |
+
x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
|
| 123 |
+
x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
|
| 124 |
+
|
| 125 |
+
self.coords_t0 = x0
|
| 126 |
+
self.coords_t1_1 = x1_1
|
| 127 |
+
self.coords_t1_2 = x1_2
|
| 128 |
+
self.time_labels = np.concatenate([
|
| 129 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 130 |
+
np.ones(len(self.coords_t1_1)), # t=1
|
| 131 |
+
np.ones(len(self.coords_t1_2)),
|
| 132 |
+
])
|
| 133 |
+
|
| 134 |
+
split_index = int(target_size * self.split_ratios[0])
|
| 135 |
+
|
| 136 |
+
if target_size - split_index < self.batch_size:
|
| 137 |
+
split_index = target_size - self.batch_size
|
| 138 |
+
print('total count is:', target_size)
|
| 139 |
+
|
| 140 |
+
train_x0 = x0[:split_index]
|
| 141 |
+
val_x0 = x0[split_index:]
|
| 142 |
+
train_x1_1 = x1_1[:split_index]
|
| 143 |
+
val_x1_1 = x1_1[split_index:]
|
| 144 |
+
train_x1_2 = x1_2[:split_index]
|
| 145 |
+
val_x1_2 = x1_2[split_index:]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
self.val_x0 = val_x0
|
| 149 |
+
|
| 150 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 151 |
+
train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
|
| 152 |
+
train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
|
| 153 |
+
|
| 154 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 155 |
+
val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
|
| 156 |
+
val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
|
| 157 |
+
|
| 158 |
+
# Updated train dataloaders to include x1_3
|
| 159 |
+
self.train_dataloaders = {
|
| 160 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 161 |
+
"x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 162 |
+
"x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
self.val_dataloaders = {
|
| 166 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 167 |
+
"x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 168 |
+
"x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
all_coords = df[pc_cols].dropna().values.astype(float)
|
| 172 |
+
self.dataset = torch.tensor(all_coords, dtype=torch.float32)
|
| 173 |
+
self.tree = cKDTree(all_coords)
|
| 174 |
+
|
| 175 |
+
self.test_dataloaders = {
|
| 176 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 177 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
km_all = KMeans(n_clusters=3, random_state=0).fit(self.dataset.numpy())
|
| 181 |
+
|
| 182 |
+
cluster_labels = km_all.labels_
|
| 183 |
+
|
| 184 |
+
cluster_0_mask = cluster_labels == 0
|
| 185 |
+
cluster_1_mask = cluster_labels == 1
|
| 186 |
+
cluster_2_mask = cluster_labels == 2
|
| 187 |
+
|
| 188 |
+
samples = self.dataset.cpu().numpy()
|
| 189 |
+
|
| 190 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 191 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 192 |
+
cluster_2_data = samples[cluster_2_mask]
|
| 193 |
+
|
| 194 |
+
self.metric_samples_dataloaders = [
|
| 195 |
+
DataLoader(
|
| 196 |
+
torch.tensor(cluster_2_data, dtype=torch.float32),
|
| 197 |
+
batch_size=cluster_2_data.shape[0],
|
| 198 |
+
shuffle=False,
|
| 199 |
+
drop_last=False,
|
| 200 |
+
),
|
| 201 |
+
DataLoader(
|
| 202 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 203 |
+
batch_size=cluster_0_data.shape[0],
|
| 204 |
+
shuffle=False,
|
| 205 |
+
drop_last=False,
|
| 206 |
+
),
|
| 207 |
+
|
| 208 |
+
DataLoader(
|
| 209 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 210 |
+
batch_size=cluster_1_data.shape[0],
|
| 211 |
+
shuffle=False,
|
| 212 |
+
drop_last=False,
|
| 213 |
+
),
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
def train_dataloader(self):
|
| 217 |
+
combined_loaders = {
|
| 218 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 219 |
+
"metric_samples": CombinedLoader(
|
| 220 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 221 |
+
),
|
| 222 |
+
}
|
| 223 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 224 |
+
|
| 225 |
+
def val_dataloader(self):
|
| 226 |
+
combined_loaders = {
|
| 227 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 228 |
+
"metric_samples": CombinedLoader(
|
| 229 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 230 |
+
),
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def test_dataloader(self):
|
| 237 |
+
combined_loaders = {
|
| 238 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 239 |
+
"metric_samples": CombinedLoader(
|
| 240 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 241 |
+
),
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 245 |
+
|
| 246 |
+
def get_manifold_proj(self, points):
|
| 247 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 248 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 249 |
+
|
| 250 |
+
@staticmethod
|
| 251 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 252 |
+
"""
|
| 253 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 254 |
+
This replaces the plane projection for 2D manifold regularization
|
| 255 |
+
"""
|
| 256 |
+
points_np = x.detach().cpu().numpy()
|
| 257 |
+
_, idx = tree.query(points_np, k=k)
|
| 258 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 259 |
+
|
| 260 |
+
# Compute weighted average of neighbors
|
| 261 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 262 |
+
weights = torch.exp(-dists / temp)
|
| 263 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 264 |
+
|
| 265 |
+
# Weighted average of neighbors
|
| 266 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 267 |
+
|
| 268 |
+
# Blend original point with smoothed version
|
| 269 |
+
alpha = 0.3 # How much smoothing to apply
|
| 270 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 271 |
+
|
| 272 |
+
def get_timepoint_data(self):
|
| 273 |
+
"""Return data organized by timepoints for visualization"""
|
| 274 |
+
return {
|
| 275 |
+
't0': self.coords_t0,
|
| 276 |
+
't1_1': self.coords_t1_1,
|
| 277 |
+
't1_2': self.coords_t1_2,
|
| 278 |
+
'time_labels': self.time_labels
|
| 279 |
+
}
|
| 280 |
+
|
dataloaders/lidar_data.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
from sklearn.preprocessing import StandardScaler
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from pytorch_lightning.utilities.combined_loader import CombinedLoader
|
| 7 |
+
import laspy
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy.spatial import cKDTree
|
| 10 |
+
import math
|
| 11 |
+
from functools import partial
|
| 12 |
+
from torch.utils.data import TensorDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GaussianMM:
|
| 16 |
+
def __init__(self, mu, var):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.centers = torch.tensor(mu)
|
| 19 |
+
self.logstd = torch.tensor(var).log() / 2.0
|
| 20 |
+
self.K = self.centers.shape[0]
|
| 21 |
+
|
| 22 |
+
def logprob(self, x):
|
| 23 |
+
logprobs = self.normal_logprob(
|
| 24 |
+
x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd
|
| 25 |
+
)
|
| 26 |
+
logprobs = torch.sum(logprobs, dim=2)
|
| 27 |
+
return torch.logsumexp(logprobs, dim=1) - math.log(self.K)
|
| 28 |
+
|
| 29 |
+
def normal_logprob(self, z, mean, log_std):
|
| 30 |
+
mean = mean + torch.tensor(0.0)
|
| 31 |
+
log_std = log_std + torch.tensor(0.0)
|
| 32 |
+
c = torch.tensor([math.log(2 * math.pi)]).to(z)
|
| 33 |
+
inv_sigma = torch.exp(-log_std)
|
| 34 |
+
tmp = (z - mean) * inv_sigma
|
| 35 |
+
return -0.5 * (tmp * tmp + 2 * log_std + c)
|
| 36 |
+
|
| 37 |
+
def __call__(self, n_samples):
|
| 38 |
+
idx = torch.randint(self.K, (n_samples,)).to(self.centers.device)
|
| 39 |
+
mean = self.centers[idx]
|
| 40 |
+
return torch.randn(*mean.shape).to(mean) * torch.exp(self.logstd) + mean
|
| 41 |
+
|
| 42 |
+
class BranchedLidarDataModule(pl.LightningDataModule):
|
| 43 |
+
def __init__(self, args):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.save_hyperparameters()
|
| 46 |
+
|
| 47 |
+
self.data_path = args.data_path
|
| 48 |
+
self.batch_size = args.batch_size
|
| 49 |
+
self.max_dim = args.dim
|
| 50 |
+
self.whiten = args.whiten
|
| 51 |
+
self.p0_mu = [
|
| 52 |
+
[-4.5, -4.0, 0.5],
|
| 53 |
+
[-4.2, -3.5, 0.5],
|
| 54 |
+
[-4.0, -3.0, 0.5],
|
| 55 |
+
[-3.75, -2.5, 0.5],
|
| 56 |
+
]
|
| 57 |
+
self.p0_var = 0.02
|
| 58 |
+
|
| 59 |
+
self.p1_1_mu = [
|
| 60 |
+
[-2.5, -0.25, 0.5],
|
| 61 |
+
[-2.25, 0.675, 0.5],
|
| 62 |
+
[-2, 1.5, 0.5],
|
| 63 |
+
]
|
| 64 |
+
self.p1_2_mu = [
|
| 65 |
+
[2, -2, 0.5],
|
| 66 |
+
[2.6, -1.25, 0.5],
|
| 67 |
+
[3.2, -0.5, 0.5]
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
self.p1_var = 0.03
|
| 71 |
+
self.k = 20
|
| 72 |
+
self.n_samples = 5000
|
| 73 |
+
self.num_timesteps = 2
|
| 74 |
+
self.split_ratios = args.split_ratios
|
| 75 |
+
self._prepare_data()
|
| 76 |
+
|
| 77 |
+
def assign_region(self):
|
| 78 |
+
all_centers = {
|
| 79 |
+
0: torch.tensor(self.p0_mu), # Region 0: p0
|
| 80 |
+
1: torch.tensor(self.p1_1_mu), # Region 1: p1_1
|
| 81 |
+
2: torch.tensor(self.p1_2_mu), # Region 2: p1_2
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
dataset = self.dataset.to(torch.float32)
|
| 85 |
+
N = dataset.shape[0]
|
| 86 |
+
assignments = torch.zeros(N, dtype=torch.long)
|
| 87 |
+
|
| 88 |
+
# For each point, compute min distance to each region's centers
|
| 89 |
+
for i in range(N):
|
| 90 |
+
point = dataset[i]
|
| 91 |
+
min_dist = float("inf")
|
| 92 |
+
best_region = 0
|
| 93 |
+
for region, centers in all_centers.items():
|
| 94 |
+
dists = ((centers - point)**2).sum(dim=1)
|
| 95 |
+
region_min = dists.min()
|
| 96 |
+
if region_min < min_dist:
|
| 97 |
+
min_dist = region_min
|
| 98 |
+
best_region = region
|
| 99 |
+
assignments[i] = best_region
|
| 100 |
+
return assignments
|
| 101 |
+
|
| 102 |
+
def _prepare_data(self):
|
| 103 |
+
las = laspy.read(self.data_path)
|
| 104 |
+
# Extract only "ground" points.
|
| 105 |
+
self.mask = las.classification == 2
|
| 106 |
+
# Original Preprocessing
|
| 107 |
+
x_offset, x_scale = las.header.offsets[0], las.header.scales[0]
|
| 108 |
+
y_offset, y_scale = las.header.offsets[1], las.header.scales[1]
|
| 109 |
+
z_offset, z_scale = las.header.offsets[2], las.header.scales[2]
|
| 110 |
+
dataset = np.vstack(
|
| 111 |
+
(
|
| 112 |
+
las.X[self.mask] * x_scale + x_offset,
|
| 113 |
+
las.Y[self.mask] * y_scale + y_offset,
|
| 114 |
+
las.Z[self.mask] * z_scale + z_offset,
|
| 115 |
+
)
|
| 116 |
+
).transpose()
|
| 117 |
+
mi = dataset.min(axis=0, keepdims=True)
|
| 118 |
+
ma = dataset.max(axis=0, keepdims=True)
|
| 119 |
+
dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0]
|
| 120 |
+
|
| 121 |
+
self.dataset = torch.tensor(dataset, dtype=torch.float32)
|
| 122 |
+
self.tree = cKDTree(dataset)
|
| 123 |
+
|
| 124 |
+
x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples)
|
| 125 |
+
x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples)
|
| 126 |
+
x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples)
|
| 127 |
+
|
| 128 |
+
x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian)
|
| 129 |
+
x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian)
|
| 130 |
+
x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian)
|
| 131 |
+
|
| 132 |
+
split_index = int(self.n_samples * self.split_ratios[0])
|
| 133 |
+
|
| 134 |
+
self.scaler = StandardScaler()
|
| 135 |
+
if self.whiten:
|
| 136 |
+
self.dataset = torch.tensor(
|
| 137 |
+
self.scaler.fit_transform(dataset), dtype=torch.float32
|
| 138 |
+
)
|
| 139 |
+
x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32)
|
| 140 |
+
x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32)
|
| 141 |
+
x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32)
|
| 142 |
+
|
| 143 |
+
train_x0 = x0[:split_index]
|
| 144 |
+
val_x0 = x0[split_index:]
|
| 145 |
+
|
| 146 |
+
# branches
|
| 147 |
+
train_x1_1 = x1_1[:split_index]
|
| 148 |
+
print("train_x1_1")
|
| 149 |
+
print(train_x1_1.shape)
|
| 150 |
+
val_x1_1 = x1_1[split_index:]
|
| 151 |
+
train_x1_2 = x1_2[:split_index]
|
| 152 |
+
val_x1_2 = x1_2[split_index:]
|
| 153 |
+
|
| 154 |
+
self.val_x0 = val_x0
|
| 155 |
+
|
| 156 |
+
# Adjust split_index to ensure minimum validation samples
|
| 157 |
+
if self.n_samples - split_index < self.batch_size:
|
| 158 |
+
split_index = self.n_samples - self.batch_size
|
| 159 |
+
|
| 160 |
+
self.train_dataloaders = {
|
| 161 |
+
"x0": DataLoader(train_x0, batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 162 |
+
"x1_1": DataLoader(train_x1_1, batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 163 |
+
"x1_2": DataLoader(train_x1_2, batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 164 |
+
}
|
| 165 |
+
self.val_dataloaders = {
|
| 166 |
+
"x0": DataLoader(val_x0, batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 167 |
+
"x1_1": DataLoader(val_x1_1, batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 168 |
+
"x1_2": DataLoader(val_x1_2, batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 169 |
+
}
|
| 170 |
+
# to edit?
|
| 171 |
+
self.test_dataloaders = [
|
| 172 |
+
DataLoader(
|
| 173 |
+
self.val_x0,
|
| 174 |
+
batch_size=self.val_x0.shape[0],
|
| 175 |
+
shuffle=False,
|
| 176 |
+
drop_last=False,
|
| 177 |
+
),
|
| 178 |
+
DataLoader(
|
| 179 |
+
self.dataset,
|
| 180 |
+
batch_size=self.dataset.shape[0],
|
| 181 |
+
shuffle=False,
|
| 182 |
+
drop_last=False,
|
| 183 |
+
),
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
points = self.dataset.cpu().numpy()
|
| 187 |
+
x, y = points[:, 0], points[:, 1]
|
| 188 |
+
# Diagonal-based coordinates (rotated 45°)
|
| 189 |
+
u = (x + y) / np.sqrt(2) # along x=y
|
| 190 |
+
# start region (A) using u
|
| 191 |
+
u_thresh = np.percentile(u, 30) # tweak this threshold to control size
|
| 192 |
+
mask_A = u <= u_thresh
|
| 193 |
+
|
| 194 |
+
# among the rest, split by x=y diagonal
|
| 195 |
+
remaining = ~mask_A
|
| 196 |
+
mask_B = remaining & (x < y) # left of diagonal
|
| 197 |
+
mask_C = remaining & (x >= y) # right of diagonal
|
| 198 |
+
|
| 199 |
+
# Assign dataloaders
|
| 200 |
+
self.metric_samples_dataloaders = [
|
| 201 |
+
DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False),
|
| 202 |
+
DataLoader(torch.tensor(points[mask_B], dtype=torch.float32), batch_size=points[mask_B].shape[0], shuffle=False),
|
| 203 |
+
DataLoader(torch.tensor(points[mask_C], dtype=torch.float32), batch_size=points[mask_C].shape[0], shuffle=False),
|
| 204 |
+
]
|
| 205 |
+
|
| 206 |
+
def train_dataloader(self):
|
| 207 |
+
combined_loaders = {
|
| 208 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 209 |
+
"metric_samples": CombinedLoader(
|
| 210 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 211 |
+
),
|
| 212 |
+
}
|
| 213 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 214 |
+
|
| 215 |
+
def val_dataloader(self):
|
| 216 |
+
combined_loaders = {
|
| 217 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 218 |
+
"metric_samples": CombinedLoader(
|
| 219 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 220 |
+
),
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 224 |
+
|
| 225 |
+
def test_dataloader(self):
|
| 226 |
+
return CombinedLoader(self.test_dataloaders)
|
| 227 |
+
|
| 228 |
+
def get_tangent_proj(self, points):
|
| 229 |
+
w = self.get_tangent_plane(points)
|
| 230 |
+
return partial(BranchedLidarDataModule.projection_op, w=w)
|
| 231 |
+
|
| 232 |
+
def get_tangent_plane(self, points, temp=1e-3):
|
| 233 |
+
points_np = points.detach().cpu().numpy()
|
| 234 |
+
_, idx = self.tree.query(points_np, k=self.k)
|
| 235 |
+
nearest_pts = self.dataset[idx]
|
| 236 |
+
nearest_pts = torch.tensor(nearest_pts).to(points)
|
| 237 |
+
|
| 238 |
+
dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 239 |
+
weights = torch.exp(-dists / temp)
|
| 240 |
+
|
| 241 |
+
# Fits plane with least vertical distance.
|
| 242 |
+
w = BranchedLidarDataModule.fit_plane(nearest_pts, weights)
|
| 243 |
+
return w
|
| 244 |
+
|
| 245 |
+
@staticmethod
|
| 246 |
+
def fit_plane(points, weights=None):
|
| 247 |
+
"""Expects points to be of shape (..., 3).
|
| 248 |
+
Returns [a, b, c] such that the plane is defined as
|
| 249 |
+
ax + by + c = z
|
| 250 |
+
"""
|
| 251 |
+
D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1)
|
| 252 |
+
z = points[..., 2]
|
| 253 |
+
if weights is not None:
|
| 254 |
+
Dtrans = D.transpose(-1, -2)
|
| 255 |
+
else:
|
| 256 |
+
DW = D * weights
|
| 257 |
+
Dtrans = DW.transpose(-1, -2)
|
| 258 |
+
w = torch.linalg.solve(
|
| 259 |
+
torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1))
|
| 260 |
+
).squeeze(-1)
|
| 261 |
+
return w
|
| 262 |
+
|
| 263 |
+
@staticmethod
|
| 264 |
+
def projection_op(x, w):
|
| 265 |
+
"""Projects points to a plane defined by w."""
|
| 266 |
+
# Normal vector to the tangent plane.
|
| 267 |
+
n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1)
|
| 268 |
+
|
| 269 |
+
pn = torch.sum(x * n, dim=-1, keepdim=True)
|
| 270 |
+
nn = torch.sum(n * n, dim=-1, keepdim=True)
|
| 271 |
+
|
| 272 |
+
# Offset.
|
| 273 |
+
d = w[..., 2:3]
|
| 274 |
+
|
| 275 |
+
# Projection of x onto n.
|
| 276 |
+
projn_x = ((pn + d) / nn) * n
|
| 277 |
+
|
| 278 |
+
# Remove component in the normal direction.
|
| 279 |
+
return x - projn_x
|
| 280 |
+
|
| 281 |
+
class WeightedBranchedLidarDataModule(pl.LightningDataModule):
|
| 282 |
+
def __init__(self, args):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.save_hyperparameters()
|
| 285 |
+
|
| 286 |
+
self.data_path = args.data_path
|
| 287 |
+
self.batch_size = args.batch_size
|
| 288 |
+
self.max_dim = args.dim
|
| 289 |
+
self.whiten = args.whiten
|
| 290 |
+
self.p0_mu = [
|
| 291 |
+
[-4.5, -4.0, 0.5],
|
| 292 |
+
[-4.2, -3.5, 0.5],
|
| 293 |
+
[-4.0, -3.0, 0.5],
|
| 294 |
+
[-3.75, -2.5, 0.5],
|
| 295 |
+
]
|
| 296 |
+
self.p0_var = 0.02
|
| 297 |
+
# multiple p1 for each branch
|
| 298 |
+
#changed
|
| 299 |
+
self.p1_1_mu = [
|
| 300 |
+
[-2.5, -0.25, 0.5],
|
| 301 |
+
[-2.25, 0.675, 0.5],
|
| 302 |
+
[-2, 1.5, 0.5],
|
| 303 |
+
]
|
| 304 |
+
self.p1_2_mu = [
|
| 305 |
+
[2, -2, 0.5],
|
| 306 |
+
[2.6, -1.25, 0.5],
|
| 307 |
+
[3.2, -0.5, 0.5]
|
| 308 |
+
]
|
| 309 |
+
|
| 310 |
+
self.p1_var = 0.03
|
| 311 |
+
self.k = 20
|
| 312 |
+
self.n_samples = 5000
|
| 313 |
+
self.num_timesteps = 2
|
| 314 |
+
self.split_ratios = args.split_ratios
|
| 315 |
+
|
| 316 |
+
self.num_timesteps = 2
|
| 317 |
+
self.metric_clusters = 3
|
| 318 |
+
self.args = args
|
| 319 |
+
self._prepare_data()
|
| 320 |
+
|
| 321 |
+
def _prepare_data(self):
|
| 322 |
+
las = laspy.read(self.data_path)
|
| 323 |
+
# Extract only "ground" points.
|
| 324 |
+
self.mask = las.classification == 2
|
| 325 |
+
# Original Preprocessing
|
| 326 |
+
x_offset, x_scale = las.header.offsets[0], las.header.scales[0]
|
| 327 |
+
y_offset, y_scale = las.header.offsets[1], las.header.scales[1]
|
| 328 |
+
z_offset, z_scale = las.header.offsets[2], las.header.scales[2]
|
| 329 |
+
dataset = np.vstack(
|
| 330 |
+
(
|
| 331 |
+
las.X[self.mask] * x_scale + x_offset,
|
| 332 |
+
las.Y[self.mask] * y_scale + y_offset,
|
| 333 |
+
las.Z[self.mask] * z_scale + z_offset,
|
| 334 |
+
)
|
| 335 |
+
).transpose()
|
| 336 |
+
mi = dataset.min(axis=0, keepdims=True)
|
| 337 |
+
ma = dataset.max(axis=0, keepdims=True)
|
| 338 |
+
dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0]
|
| 339 |
+
|
| 340 |
+
self.dataset = torch.tensor(dataset, dtype=torch.float32)
|
| 341 |
+
self.tree = cKDTree(dataset)
|
| 342 |
+
|
| 343 |
+
x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples)
|
| 344 |
+
x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples)
|
| 345 |
+
x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples)
|
| 346 |
+
|
| 347 |
+
x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian)
|
| 348 |
+
x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian)
|
| 349 |
+
x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian)
|
| 350 |
+
|
| 351 |
+
split_index = int(self.n_samples * self.split_ratios[0])
|
| 352 |
+
|
| 353 |
+
self.scaler = StandardScaler()
|
| 354 |
+
if self.whiten:
|
| 355 |
+
self.dataset = torch.tensor(
|
| 356 |
+
self.scaler.fit_transform(dataset), dtype=torch.float32
|
| 357 |
+
)
|
| 358 |
+
x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32)
|
| 359 |
+
x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32)
|
| 360 |
+
x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32)
|
| 361 |
+
|
| 362 |
+
self.coords_t0 = x0
|
| 363 |
+
self.coords_t1_1 = x1_1
|
| 364 |
+
self.coords_t1_2 = x1_2
|
| 365 |
+
self.time_labels = np.concatenate([
|
| 366 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 367 |
+
np.ones(len(self.coords_t1_1)), # t=1
|
| 368 |
+
np.ones(len(self.coords_t1_2)), # t=1
|
| 369 |
+
])
|
| 370 |
+
|
| 371 |
+
train_x0 = x0[:split_index]
|
| 372 |
+
val_x0 = x0[split_index:]
|
| 373 |
+
|
| 374 |
+
# branches
|
| 375 |
+
train_x1_1 = x1_1[:split_index]
|
| 376 |
+
|
| 377 |
+
val_x1_1 = x1_1[split_index:]
|
| 378 |
+
train_x1_2 = x1_2[:split_index]
|
| 379 |
+
val_x1_2 = x1_2[split_index:]
|
| 380 |
+
|
| 381 |
+
self.val_x0 = val_x0
|
| 382 |
+
|
| 383 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 384 |
+
train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
|
| 385 |
+
train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
|
| 386 |
+
|
| 387 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 388 |
+
val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
|
| 389 |
+
val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
|
| 390 |
+
|
| 391 |
+
# Adjust split_index to ensure minimum validation samples
|
| 392 |
+
if self.n_samples - split_index < self.batch_size:
|
| 393 |
+
split_index = self.n_samples - self.batch_size
|
| 394 |
+
|
| 395 |
+
self.train_dataloaders = {
|
| 396 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 397 |
+
"x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 398 |
+
"x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
self.val_dataloaders = {
|
| 402 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 403 |
+
"x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 404 |
+
"x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
# to edit?
|
| 408 |
+
self.test_dataloaders = {
|
| 409 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 410 |
+
"x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True),
|
| 411 |
+
"x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True),
|
| 412 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
points = self.dataset.cpu().numpy()
|
| 416 |
+
x, y = points[:, 0], points[:, 1]
|
| 417 |
+
# Diagonal-based coordinates (rotated 45°)
|
| 418 |
+
u = (x + y) / np.sqrt(2) # along x=y
|
| 419 |
+
# start region (A) using u
|
| 420 |
+
u_thresh = np.percentile(u, 30) # tweak this threshold to control size
|
| 421 |
+
mask_A = u <= u_thresh
|
| 422 |
+
|
| 423 |
+
# among the rest, split by x=y diagonal
|
| 424 |
+
remaining = ~mask_A
|
| 425 |
+
mask_B = remaining & (x < y) # left of diagonal
|
| 426 |
+
mask_C = remaining & (x >= y) # right of diagonal
|
| 427 |
+
|
| 428 |
+
# Assign dataloaders
|
| 429 |
+
self.metric_samples_dataloaders = [
|
| 430 |
+
DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False),
|
| 431 |
+
DataLoader(torch.tensor(points[mask_B], dtype=torch.float32), batch_size=points[mask_B].shape[0], shuffle=False),
|
| 432 |
+
DataLoader(torch.tensor(points[mask_C], dtype=torch.float32), batch_size=points[mask_C].shape[0], shuffle=False),
|
| 433 |
+
]
|
| 434 |
+
|
| 435 |
+
def train_dataloader(self):
|
| 436 |
+
combined_loaders = {
|
| 437 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 438 |
+
"metric_samples": CombinedLoader(
|
| 439 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 440 |
+
),
|
| 441 |
+
}
|
| 442 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 443 |
+
|
| 444 |
+
def val_dataloader(self):
|
| 445 |
+
combined_loaders = {
|
| 446 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 447 |
+
"metric_samples": CombinedLoader(
|
| 448 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 449 |
+
),
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 453 |
+
|
| 454 |
+
def test_dataloader(self):
|
| 455 |
+
combined_loaders = {
|
| 456 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 457 |
+
"metric_samples": CombinedLoader(
|
| 458 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 459 |
+
),
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 463 |
+
|
| 464 |
+
def get_tangent_proj(self, points):
|
| 465 |
+
w = self.get_tangent_plane(points)
|
| 466 |
+
return partial(BranchedLidarDataModule.projection_op, w=w)
|
| 467 |
+
|
| 468 |
+
def get_tangent_plane(self, points, temp=1e-3):
|
| 469 |
+
points_np = points.detach().cpu().numpy()
|
| 470 |
+
_, idx = self.tree.query(points_np, k=self.k)
|
| 471 |
+
nearest_pts = self.dataset[idx]
|
| 472 |
+
nearest_pts = torch.tensor(nearest_pts).to(points)
|
| 473 |
+
|
| 474 |
+
dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 475 |
+
weights = torch.exp(-dists / temp)
|
| 476 |
+
|
| 477 |
+
# Fits plane with least vertical distance.
|
| 478 |
+
w = BranchedLidarDataModule.fit_plane(nearest_pts, weights)
|
| 479 |
+
return w
|
| 480 |
+
|
| 481 |
+
@staticmethod
|
| 482 |
+
def fit_plane(points, weights=None):
|
| 483 |
+
"""Expects points to be of shape (..., 3).
|
| 484 |
+
Returns [a, b, c] such that the plane is defined as
|
| 485 |
+
ax + by + c = z
|
| 486 |
+
"""
|
| 487 |
+
D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1)
|
| 488 |
+
z = points[..., 2]
|
| 489 |
+
if weights is not None:
|
| 490 |
+
Dtrans = D.transpose(-1, -2)
|
| 491 |
+
else:
|
| 492 |
+
DW = D * weights
|
| 493 |
+
Dtrans = DW.transpose(-1, -2)
|
| 494 |
+
w = torch.linalg.solve(
|
| 495 |
+
torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1))
|
| 496 |
+
).squeeze(-1)
|
| 497 |
+
return w
|
| 498 |
+
|
| 499 |
+
@staticmethod
|
| 500 |
+
def projection_op(x, w):
|
| 501 |
+
"""Projects points to a plane defined by w."""
|
| 502 |
+
# Normal vector to the tangent plane.
|
| 503 |
+
n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1)
|
| 504 |
+
|
| 505 |
+
pn = torch.sum(x * n, dim=-1, keepdim=True)
|
| 506 |
+
nn = torch.sum(n * n, dim=-1, keepdim=True)
|
| 507 |
+
|
| 508 |
+
# Offset.
|
| 509 |
+
d = w[..., 2:3]
|
| 510 |
+
|
| 511 |
+
# Projection of x onto n.
|
| 512 |
+
projn_x = ((pn + d) / nn) * n
|
| 513 |
+
|
| 514 |
+
# Remove component in the normal direction.
|
| 515 |
+
return x - projn_x
|
| 516 |
+
|
| 517 |
+
def get_timepoint_data(self):
|
| 518 |
+
"""Return data organized by timepoints for visualization"""
|
| 519 |
+
return {
|
| 520 |
+
't0': self.coords_t0,
|
| 521 |
+
't1_1': self.coords_t1_1,
|
| 522 |
+
't1_2': self.coords_t1_2,
|
| 523 |
+
'time_labels': self.time_labels
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
def get_datamodule():
|
| 527 |
+
datamodule = WeightedBranchedLidarDataModule(args)
|
| 528 |
+
datamodule.setup(stage="fit")
|
| 529 |
+
return datamodule
|
dataloaders/lidar_data_single.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
from sklearn.preprocessing import StandardScaler
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from pytorch_lightning.utilities.combined_loader import CombinedLoader
|
| 7 |
+
import laspy
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy.spatial import cKDTree
|
| 10 |
+
import math
|
| 11 |
+
from functools import partial
|
| 12 |
+
from torch.utils.data import TensorDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GaussianMM:
|
| 16 |
+
def __init__(self, mu, var):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.centers = torch.tensor(mu)
|
| 19 |
+
self.logstd = torch.tensor(var).log() / 2.0
|
| 20 |
+
self.K = self.centers.shape[0]
|
| 21 |
+
|
| 22 |
+
def logprob(self, x):
|
| 23 |
+
logprobs = self.normal_logprob(
|
| 24 |
+
x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd
|
| 25 |
+
)
|
| 26 |
+
logprobs = torch.sum(logprobs, dim=2)
|
| 27 |
+
return torch.logsumexp(logprobs, dim=1) - math.log(self.K)
|
| 28 |
+
|
| 29 |
+
def normal_logprob(self, z, mean, log_std):
|
| 30 |
+
mean = mean + torch.tensor(0.0)
|
| 31 |
+
log_std = log_std + torch.tensor(0.0)
|
| 32 |
+
c = torch.tensor([math.log(2 * math.pi)]).to(z)
|
| 33 |
+
inv_sigma = torch.exp(-log_std)
|
| 34 |
+
tmp = (z - mean) * inv_sigma
|
| 35 |
+
return -0.5 * (tmp * tmp + 2 * log_std + c)
|
| 36 |
+
|
| 37 |
+
def __call__(self, n_samples):
|
| 38 |
+
idx = torch.randint(self.K, (n_samples,)).to(self.centers.device)
|
| 39 |
+
mean = self.centers[idx]
|
| 40 |
+
return torch.randn(*mean.shape).to(mean) * torch.exp(self.logstd) + mean
|
| 41 |
+
|
| 42 |
+
class LidarSingleDataModule(pl.LightningDataModule):
|
| 43 |
+
def __init__(self, args):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.save_hyperparameters()
|
| 46 |
+
|
| 47 |
+
self.data_path = args.data_path
|
| 48 |
+
self.batch_size = args.batch_size
|
| 49 |
+
self.max_dim = args.dim
|
| 50 |
+
self.whiten = args.whiten
|
| 51 |
+
self.p0_mu = [
|
| 52 |
+
[-4.5, -4.0, 0.5],
|
| 53 |
+
[-4.2, -3.5, 0.5],
|
| 54 |
+
[-4.0, -3.0, 0.5],
|
| 55 |
+
[-3.75, -2.5, 0.5],
|
| 56 |
+
]
|
| 57 |
+
self.p0_var = 0.02
|
| 58 |
+
# multiple p1 for each branch
|
| 59 |
+
#changed
|
| 60 |
+
self.p1_1_mu = [
|
| 61 |
+
[-2.5, -0.25, 0.5],
|
| 62 |
+
[-2.25, 0.675, 0.5],
|
| 63 |
+
[-2, 1.5, 0.5],
|
| 64 |
+
]
|
| 65 |
+
self.p1_2_mu = [
|
| 66 |
+
[2, -2, 0.5],
|
| 67 |
+
[2.6, -1.25, 0.5],
|
| 68 |
+
[3.2, -0.5, 0.5]
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
self.p1_var = 0.03
|
| 72 |
+
self.k = 20
|
| 73 |
+
self.n_samples = 5000
|
| 74 |
+
self.num_timesteps = 2
|
| 75 |
+
self.split_ratios = args.split_ratios
|
| 76 |
+
|
| 77 |
+
self.num_timesteps = 2
|
| 78 |
+
self.metric_clusters = 3
|
| 79 |
+
self.args = args
|
| 80 |
+
self._prepare_data()
|
| 81 |
+
|
| 82 |
+
def _prepare_data(self):
|
| 83 |
+
las = laspy.read(self.data_path)
|
| 84 |
+
# Extract only "ground" points.
|
| 85 |
+
self.mask = las.classification == 2
|
| 86 |
+
# Original Preprocessing
|
| 87 |
+
x_offset, x_scale = las.header.offsets[0], las.header.scales[0]
|
| 88 |
+
y_offset, y_scale = las.header.offsets[1], las.header.scales[1]
|
| 89 |
+
z_offset, z_scale = las.header.offsets[2], las.header.scales[2]
|
| 90 |
+
dataset = np.vstack(
|
| 91 |
+
(
|
| 92 |
+
las.X[self.mask] * x_scale + x_offset,
|
| 93 |
+
las.Y[self.mask] * y_scale + y_offset,
|
| 94 |
+
las.Z[self.mask] * z_scale + z_offset,
|
| 95 |
+
)
|
| 96 |
+
).transpose()
|
| 97 |
+
mi = dataset.min(axis=0, keepdims=True)
|
| 98 |
+
ma = dataset.max(axis=0, keepdims=True)
|
| 99 |
+
dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0]
|
| 100 |
+
|
| 101 |
+
self.dataset = torch.tensor(dataset, dtype=torch.float32)
|
| 102 |
+
self.tree = cKDTree(dataset)
|
| 103 |
+
|
| 104 |
+
x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples)
|
| 105 |
+
x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples)
|
| 106 |
+
x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples)
|
| 107 |
+
|
| 108 |
+
x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian)
|
| 109 |
+
x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian)
|
| 110 |
+
x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian)
|
| 111 |
+
|
| 112 |
+
split_index = int(self.n_samples * self.split_ratios[0])
|
| 113 |
+
|
| 114 |
+
self.scaler = StandardScaler()
|
| 115 |
+
if self.whiten:
|
| 116 |
+
self.dataset = torch.tensor(
|
| 117 |
+
self.scaler.fit_transform(dataset), dtype=torch.float32
|
| 118 |
+
)
|
| 119 |
+
x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32)
|
| 120 |
+
x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32)
|
| 121 |
+
x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32)
|
| 122 |
+
x1 = torch.cat([x1_1, x1_2], dim=0)
|
| 123 |
+
|
| 124 |
+
self.coords_t0 = x0
|
| 125 |
+
self.coords_t1 = x1
|
| 126 |
+
self.time_labels = np.concatenate([
|
| 127 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 128 |
+
np.ones(len(self.coords_t1)), # t=1
|
| 129 |
+
])
|
| 130 |
+
|
| 131 |
+
train_x0 = x0[:split_index]
|
| 132 |
+
val_x0 = x0[split_index:]
|
| 133 |
+
|
| 134 |
+
# branches
|
| 135 |
+
train_x1 = x1[:split_index]
|
| 136 |
+
val_x1 = x1[split_index:]
|
| 137 |
+
|
| 138 |
+
self.val_x0 = val_x0
|
| 139 |
+
|
| 140 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 141 |
+
train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0)
|
| 142 |
+
|
| 143 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 144 |
+
val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0)
|
| 145 |
+
|
| 146 |
+
# Adjust split_index to ensure minimum validation samples
|
| 147 |
+
if self.n_samples - split_index < self.batch_size:
|
| 148 |
+
split_index = self.n_samples - self.batch_size
|
| 149 |
+
|
| 150 |
+
self.train_dataloaders = {
|
| 151 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 152 |
+
"x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
self.val_dataloaders = {
|
| 156 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 157 |
+
"x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
# to edit?
|
| 161 |
+
self.test_dataloaders = {
|
| 162 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=False),
|
| 163 |
+
"x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True),
|
| 164 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=True, drop_last=False),
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
points = self.dataset.cpu().numpy()
|
| 168 |
+
x, y = points[:, 0], points[:, 1]
|
| 169 |
+
# Diagonal-based coordinates (rotated 45°)
|
| 170 |
+
u = (x + y) / np.sqrt(2) # along x=y
|
| 171 |
+
# start region (A) using u
|
| 172 |
+
u_thresh = np.percentile(u, 30) # tweak this threshold to control size
|
| 173 |
+
mask_A = u <= u_thresh
|
| 174 |
+
|
| 175 |
+
# among the rest, split by x=y diagonal
|
| 176 |
+
remaining = ~mask_A
|
| 177 |
+
mask_B = remaining & (x < y) # left of diagonal
|
| 178 |
+
mask_C = remaining & (x >= y) # right of diagonal
|
| 179 |
+
|
| 180 |
+
# Assign dataloaders
|
| 181 |
+
self.metric_samples_dataloaders = [
|
| 182 |
+
DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False),
|
| 183 |
+
DataLoader(torch.tensor(points[remaining], dtype=torch.float32), batch_size=points[remaining].shape[0], shuffle=False),
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
def train_dataloader(self):
|
| 187 |
+
combined_loaders = {
|
| 188 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 189 |
+
"metric_samples": CombinedLoader(
|
| 190 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 191 |
+
),
|
| 192 |
+
}
|
| 193 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 194 |
+
|
| 195 |
+
def val_dataloader(self):
|
| 196 |
+
combined_loaders = {
|
| 197 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 198 |
+
"metric_samples": CombinedLoader(
|
| 199 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 200 |
+
),
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 204 |
+
|
| 205 |
+
def test_dataloader(self):
|
| 206 |
+
combined_loaders = {
|
| 207 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 208 |
+
"metric_samples": CombinedLoader(
|
| 209 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 210 |
+
),
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 214 |
+
|
| 215 |
+
def get_tangent_proj(self, points):
|
| 216 |
+
w = self.get_tangent_plane(points)
|
| 217 |
+
return partial(LidarSingleDataModule.projection_op, w=w)
|
| 218 |
+
|
| 219 |
+
def get_tangent_plane(self, points, temp=1e-3):
|
| 220 |
+
points_np = points.detach().cpu().numpy()
|
| 221 |
+
_, idx = self.tree.query(points_np, k=self.k)
|
| 222 |
+
nearest_pts = self.dataset[idx]
|
| 223 |
+
nearest_pts = torch.tensor(nearest_pts).to(points)
|
| 224 |
+
|
| 225 |
+
dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 226 |
+
weights = torch.exp(-dists / temp)
|
| 227 |
+
|
| 228 |
+
# Fits plane with least vertical distance.
|
| 229 |
+
w = LidarSingleDataModule.fit_plane(nearest_pts, weights)
|
| 230 |
+
return w
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def fit_plane(points, weights=None):
|
| 234 |
+
"""Expects points to be of shape (..., 3).
|
| 235 |
+
Returns [a, b, c] such that the plane is defined as
|
| 236 |
+
ax + by + c = z
|
| 237 |
+
"""
|
| 238 |
+
D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1)
|
| 239 |
+
z = points[..., 2]
|
| 240 |
+
if weights is not None:
|
| 241 |
+
Dtrans = D.transpose(-1, -2)
|
| 242 |
+
else:
|
| 243 |
+
DW = D * weights
|
| 244 |
+
Dtrans = DW.transpose(-1, -2)
|
| 245 |
+
w = torch.linalg.solve(
|
| 246 |
+
torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1))
|
| 247 |
+
).squeeze(-1)
|
| 248 |
+
return w
|
| 249 |
+
|
| 250 |
+
@staticmethod
|
| 251 |
+
def projection_op(x, w):
|
| 252 |
+
"""Projects points to a plane defined by w."""
|
| 253 |
+
# Normal vector to the tangent plane.
|
| 254 |
+
n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1)
|
| 255 |
+
|
| 256 |
+
pn = torch.sum(x * n, dim=-1, keepdim=True)
|
| 257 |
+
nn = torch.sum(n * n, dim=-1, keepdim=True)
|
| 258 |
+
|
| 259 |
+
# Offset.
|
| 260 |
+
d = w[..., 2:3]
|
| 261 |
+
|
| 262 |
+
# Projection of x onto n.
|
| 263 |
+
projn_x = ((pn + d) / nn) * n
|
| 264 |
+
|
| 265 |
+
# Remove component in the normal direction.
|
| 266 |
+
return x - projn_x
|
| 267 |
+
|
| 268 |
+
def get_timepoint_data(self):
|
| 269 |
+
"""Return data organized by timepoints for visualization"""
|
| 270 |
+
return {
|
| 271 |
+
't0': self.coords_t0,
|
| 272 |
+
't1': self.coords_t1,
|
| 273 |
+
'time_labels': self.time_labels
|
| 274 |
+
}
|
dataloaders/mouse_data.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
from sklearn.preprocessing import StandardScaler
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
| 7 |
+
import numpy as np
|
| 8 |
+
from scipy.spatial import cKDTree
|
| 9 |
+
import math
|
| 10 |
+
from functools import partial
|
| 11 |
+
from sklearn.cluster import KMeans, DBSCAN
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from torch.utils.data import TensorDataset
|
| 15 |
+
|
| 16 |
+
class WeightedBranchedCellDataModule(pl.LightningDataModule):
|
| 17 |
+
def __init__(self, args):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.save_hyperparameters()
|
| 20 |
+
|
| 21 |
+
self.data_path = args.data_path
|
| 22 |
+
self.batch_size = args.batch_size
|
| 23 |
+
self.max_dim = args.dim
|
| 24 |
+
self.whiten = args.whiten
|
| 25 |
+
self.k = 20
|
| 26 |
+
self.n_samples = 1429
|
| 27 |
+
self.num_timesteps = 2 # t=0, t=1, t=2
|
| 28 |
+
self.split_ratios = args.split_ratios
|
| 29 |
+
self.metric_clusters = args.metric_clusters
|
| 30 |
+
self.args = args
|
| 31 |
+
self._prepare_data()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _prepare_data(self):
|
| 35 |
+
print("Preparing cell data in BranchedCellDataModule")
|
| 36 |
+
|
| 37 |
+
df = pd.read_csv(self.data_path)
|
| 38 |
+
|
| 39 |
+
# Build dictionary of coordinates by time
|
| 40 |
+
coords_by_t = {
|
| 41 |
+
t: df[df["samples"] == t][["x1","x2"]].values
|
| 42 |
+
for t in sorted(df["samples"].unique())
|
| 43 |
+
}
|
| 44 |
+
n0 = coords_by_t[0].shape[0] # Number of T=0 points
|
| 45 |
+
self.n_samples = n0 # Update n_samples to match actual data if changes
|
| 46 |
+
|
| 47 |
+
# Cluster the t=2 cells into two branches
|
| 48 |
+
km = KMeans(n_clusters=2, random_state=42).fit(coords_by_t[2])
|
| 49 |
+
df2 = df[df["samples"] == 2].copy()
|
| 50 |
+
df2["branch"] = km.labels_
|
| 51 |
+
|
| 52 |
+
cluster_counts = df2["branch"].value_counts().sort_index()
|
| 53 |
+
print(cluster_counts)
|
| 54 |
+
|
| 55 |
+
# Sample n0 points from each branch
|
| 56 |
+
endpoints = {}
|
| 57 |
+
for b in (0, 1):
|
| 58 |
+
endpoints[b] = (
|
| 59 |
+
df2[df2["branch"] == b]
|
| 60 |
+
.sample(n=n0, random_state=42)[["x1","x2"]]
|
| 61 |
+
.values
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) # T=0 coordinates index
|
| 65 |
+
x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32)
|
| 66 |
+
x1_1 = torch.tensor(endpoints[0], dtype=torch.float32) # Branch index
|
| 67 |
+
x1_2 = torch.tensor(endpoints[1], dtype=torch.float32) # Branch index
|
| 68 |
+
|
| 69 |
+
self.coords_t0 = x0
|
| 70 |
+
self.coords_t1 = x_inter
|
| 71 |
+
self.coords_t2_1 = x1_1
|
| 72 |
+
self.coords_t2_2 = x1_2
|
| 73 |
+
self.time_labels = np.concatenate([
|
| 74 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 75 |
+
np.ones(len(self.coords_t1)), # t=1
|
| 76 |
+
np.ones(len(self.coords_t2_1)) * 2, # t=1
|
| 77 |
+
np.ones(len(self.coords_t2_2)) * 2,
|
| 78 |
+
])
|
| 79 |
+
|
| 80 |
+
split_index = int(n0 * self.split_ratios[0])
|
| 81 |
+
|
| 82 |
+
if n0 - split_index < self.batch_size:
|
| 83 |
+
split_index = n0 - self.batch_size
|
| 84 |
+
|
| 85 |
+
train_x0 = x0[:split_index]
|
| 86 |
+
val_x0 = x0[split_index:]
|
| 87 |
+
train_x1_1 = x1_1[:split_index]
|
| 88 |
+
val_x1_1 = x1_1[split_index:]
|
| 89 |
+
train_x1_2 = x1_2[:split_index]
|
| 90 |
+
val_x1_2 = x1_2[split_index:]
|
| 91 |
+
|
| 92 |
+
self.val_x0 = val_x0
|
| 93 |
+
|
| 94 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 95 |
+
train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
|
| 96 |
+
train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
|
| 97 |
+
|
| 98 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 99 |
+
val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
|
| 100 |
+
val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
|
| 101 |
+
|
| 102 |
+
if self.n_samples - split_index < self.batch_size:
|
| 103 |
+
split_index = self.n_samples - self.batch_size
|
| 104 |
+
|
| 105 |
+
self.train_dataloaders = {
|
| 106 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 107 |
+
"x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 108 |
+
"x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
self.val_dataloaders = {
|
| 112 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 113 |
+
"x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 114 |
+
"x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())])
|
| 118 |
+
self.dataset = torch.tensor(all_data, dtype=torch.float32)
|
| 119 |
+
self.tree = cKDTree(all_data)
|
| 120 |
+
|
| 121 |
+
# if whitening is enabled, need to apply this to the full dataset
|
| 122 |
+
#if self.whiten:
|
| 123 |
+
#self.scaler = StandardScaler()
|
| 124 |
+
#self.dataset = torch.tensor(
|
| 125 |
+
#self.scaler.fit_transform(all_data), dtype=torch.float32
|
| 126 |
+
#)
|
| 127 |
+
|
| 128 |
+
self.test_dataloaders = {
|
| 129 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 130 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
# Metric Dataloader
|
| 134 |
+
# K-means clustering of ALL points into 2 groups
|
| 135 |
+
if self.metric_clusters == 3:
|
| 136 |
+
km_all = KMeans(n_clusters=3, random_state=45).fit(self.dataset.numpy())
|
| 137 |
+
cluster_labels = km_all.labels_
|
| 138 |
+
|
| 139 |
+
cluster_0_mask = cluster_labels == 0
|
| 140 |
+
cluster_1_mask = cluster_labels == 1
|
| 141 |
+
cluster_2_mask = cluster_labels == 2
|
| 142 |
+
|
| 143 |
+
samples = self.dataset.cpu().numpy()
|
| 144 |
+
|
| 145 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 146 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 147 |
+
cluster_2_data = samples[cluster_2_mask]
|
| 148 |
+
|
| 149 |
+
self.metric_samples_dataloaders = [
|
| 150 |
+
DataLoader(
|
| 151 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 152 |
+
batch_size=cluster_1_data.shape[0],
|
| 153 |
+
shuffle=False,
|
| 154 |
+
drop_last=False,
|
| 155 |
+
),
|
| 156 |
+
DataLoader(
|
| 157 |
+
torch.tensor(cluster_2_data, dtype=torch.float32),
|
| 158 |
+
batch_size=cluster_2_data.shape[0],
|
| 159 |
+
shuffle=False,
|
| 160 |
+
drop_last=False,
|
| 161 |
+
),
|
| 162 |
+
|
| 163 |
+
DataLoader(
|
| 164 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 165 |
+
batch_size=cluster_0_data.shape[0],
|
| 166 |
+
shuffle=False,
|
| 167 |
+
drop_last=False,
|
| 168 |
+
),
|
| 169 |
+
]
|
| 170 |
+
else:
|
| 171 |
+
km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy())
|
| 172 |
+
cluster_labels = km_all.labels_
|
| 173 |
+
|
| 174 |
+
cluster_0_mask = cluster_labels == 0
|
| 175 |
+
cluster_1_mask = cluster_labels == 1
|
| 176 |
+
|
| 177 |
+
samples = self.dataset.cpu().numpy()
|
| 178 |
+
|
| 179 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 180 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 181 |
+
|
| 182 |
+
self.metric_samples_dataloaders = [
|
| 183 |
+
DataLoader(
|
| 184 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 185 |
+
batch_size=cluster_1_data.shape[0],
|
| 186 |
+
shuffle=False,
|
| 187 |
+
drop_last=False,
|
| 188 |
+
),
|
| 189 |
+
DataLoader(
|
| 190 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 191 |
+
batch_size=cluster_0_data.shape[0],
|
| 192 |
+
shuffle=False,
|
| 193 |
+
drop_last=False,
|
| 194 |
+
),
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def train_dataloader(self):
|
| 199 |
+
combined_loaders = {
|
| 200 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 201 |
+
"metric_samples": CombinedLoader(
|
| 202 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 203 |
+
),
|
| 204 |
+
}
|
| 205 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 206 |
+
|
| 207 |
+
def val_dataloader(self):
|
| 208 |
+
combined_loaders = {
|
| 209 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 210 |
+
"metric_samples": CombinedLoader(
|
| 211 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 212 |
+
),
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 216 |
+
|
| 217 |
+
def test_dataloader(self):
|
| 218 |
+
combined_loaders = {
|
| 219 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 220 |
+
"metric_samples": CombinedLoader(
|
| 221 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 222 |
+
),
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 226 |
+
|
| 227 |
+
def get_manifold_proj(self, points):
|
| 228 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 229 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 230 |
+
|
| 231 |
+
@staticmethod
|
| 232 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 233 |
+
"""
|
| 234 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 235 |
+
This replaces the plane projection for 2D manifold regularization
|
| 236 |
+
"""
|
| 237 |
+
points_np = x.detach().cpu().numpy()
|
| 238 |
+
_, idx = tree.query(points_np, k=k)
|
| 239 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 240 |
+
|
| 241 |
+
# Compute weighted average of neighbors
|
| 242 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 243 |
+
weights = torch.exp(-dists / temp)
|
| 244 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 245 |
+
|
| 246 |
+
# Weighted average of neighbors
|
| 247 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 248 |
+
|
| 249 |
+
# Blend original point with smoothed version
|
| 250 |
+
alpha = 0.3 # How much smoothing to apply
|
| 251 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 252 |
+
|
| 253 |
+
def get_timepoint_data(self):
|
| 254 |
+
"""Return data organized by timepoints for visualization"""
|
| 255 |
+
return {
|
| 256 |
+
't0': self.coords_t0,
|
| 257 |
+
't1': self.coords_t1,
|
| 258 |
+
't2_1': self.coords_t2_1,
|
| 259 |
+
't2_2': self.coords_t2_2,
|
| 260 |
+
'time_labels': self.time_labels
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class SingleBranchCellDataModule(pl.LightningDataModule):
|
| 266 |
+
def __init__(self, args):
|
| 267 |
+
super().__init__()
|
| 268 |
+
self.save_hyperparameters()
|
| 269 |
+
|
| 270 |
+
self.data_path = args.data_path
|
| 271 |
+
self.batch_size = args.batch_size
|
| 272 |
+
self.max_dim = args.dim
|
| 273 |
+
self.whiten = args.whiten
|
| 274 |
+
self.k = 20
|
| 275 |
+
self.n_samples = 1429
|
| 276 |
+
self.num_timesteps = 3 # t=0, t=1, t=2
|
| 277 |
+
self.split_ratios = args.split_ratios
|
| 278 |
+
self.metric_clusters = 3
|
| 279 |
+
self.args = args
|
| 280 |
+
self._prepare_data()
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def _prepare_data(self):
|
| 284 |
+
print("Preparing cell data in BranchedCellDataModule")
|
| 285 |
+
|
| 286 |
+
df = pd.read_csv(self.data_path)
|
| 287 |
+
|
| 288 |
+
# Build dictionary of coordinates by time
|
| 289 |
+
coords_by_t = {
|
| 290 |
+
t: df[df["samples"] == t][["x1","x2"]].values
|
| 291 |
+
for t in sorted(df["samples"].unique())
|
| 292 |
+
}
|
| 293 |
+
n0 = coords_by_t[0].shape[0] # Number of T=0 points
|
| 294 |
+
self.n_samples = n0 # Update n_samples to match actual data if changes
|
| 295 |
+
|
| 296 |
+
x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) # T=0 coordinates index
|
| 297 |
+
x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32)
|
| 298 |
+
x1 = torch.tensor(coords_by_t[2], dtype=torch.float32) # Branch index
|
| 299 |
+
|
| 300 |
+
# Store for get_timepoint_data()
|
| 301 |
+
self.coords_t0 = x0
|
| 302 |
+
self.coords_t1 = x_inter
|
| 303 |
+
self.coords_t2 = x1
|
| 304 |
+
self.time_labels = np.concatenate([
|
| 305 |
+
np.zeros(len(x0)),
|
| 306 |
+
np.ones(len(x_inter)),
|
| 307 |
+
np.ones(len(x1)) * 2,
|
| 308 |
+
])
|
| 309 |
+
|
| 310 |
+
split_index = int(n0 * self.split_ratios[0])
|
| 311 |
+
|
| 312 |
+
if n0 - split_index < self.batch_size:
|
| 313 |
+
split_index = n0 - self.batch_size
|
| 314 |
+
|
| 315 |
+
train_x0 = x0[:split_index]
|
| 316 |
+
val_x0 = x0[split_index:]
|
| 317 |
+
train_x1 = x1[:split_index]
|
| 318 |
+
val_x1 = x1[split_index:]
|
| 319 |
+
|
| 320 |
+
self.val_x0 = val_x0
|
| 321 |
+
|
| 322 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 323 |
+
train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=0.5)
|
| 324 |
+
|
| 325 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 326 |
+
val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=0.5)
|
| 327 |
+
|
| 328 |
+
if self.n_samples - split_index < self.batch_size:
|
| 329 |
+
split_index = self.n_samples - self.batch_size
|
| 330 |
+
|
| 331 |
+
self.train_dataloaders = {
|
| 332 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 333 |
+
"x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
self.val_dataloaders = {
|
| 337 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 338 |
+
"x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())])
|
| 342 |
+
self.dataset = torch.tensor(all_data, dtype=torch.float32)
|
| 343 |
+
self.tree = cKDTree(all_data)
|
| 344 |
+
|
| 345 |
+
# if whitening is enabled, need to apply this to the full dataset
|
| 346 |
+
if self.whiten:
|
| 347 |
+
self.scaler = StandardScaler()
|
| 348 |
+
self.dataset = torch.tensor(
|
| 349 |
+
self.scaler.fit_transform(all_data), dtype=torch.float32
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
self.test_dataloaders = {
|
| 353 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 354 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
# Metric Dataloader
|
| 358 |
+
# K-means clustering of ALL points into 2 groups
|
| 359 |
+
km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy())
|
| 360 |
+
cluster_labels = km_all.labels_
|
| 361 |
+
|
| 362 |
+
cluster_0_mask = cluster_labels == 0
|
| 363 |
+
cluster_1_mask = cluster_labels == 1
|
| 364 |
+
|
| 365 |
+
samples = self.dataset.cpu().numpy()
|
| 366 |
+
|
| 367 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 368 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 369 |
+
|
| 370 |
+
self.metric_samples_dataloaders = [
|
| 371 |
+
DataLoader(
|
| 372 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 373 |
+
batch_size=cluster_1_data.shape[0],
|
| 374 |
+
shuffle=False,
|
| 375 |
+
drop_last=False,
|
| 376 |
+
),
|
| 377 |
+
DataLoader(
|
| 378 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 379 |
+
batch_size=cluster_0_data.shape[0],
|
| 380 |
+
shuffle=False,
|
| 381 |
+
drop_last=False,
|
| 382 |
+
),
|
| 383 |
+
]
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def train_dataloader(self):
|
| 387 |
+
combined_loaders = {
|
| 388 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 389 |
+
"metric_samples": CombinedLoader(
|
| 390 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 391 |
+
),
|
| 392 |
+
}
|
| 393 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 394 |
+
|
| 395 |
+
def val_dataloader(self):
|
| 396 |
+
combined_loaders = {
|
| 397 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 398 |
+
"metric_samples": CombinedLoader(
|
| 399 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 400 |
+
),
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 404 |
+
|
| 405 |
+
def test_dataloader(self):
|
| 406 |
+
combined_loaders = {
|
| 407 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 408 |
+
"metric_samples": CombinedLoader(
|
| 409 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 410 |
+
),
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 414 |
+
|
| 415 |
+
def get_manifold_proj(self, points):
|
| 416 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 417 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 418 |
+
|
| 419 |
+
@staticmethod
|
| 420 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 421 |
+
"""
|
| 422 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 423 |
+
This replaces the plane projection for 2D manifold regularization
|
| 424 |
+
"""
|
| 425 |
+
points_np = x.detach().cpu().numpy()
|
| 426 |
+
_, idx = tree.query(points_np, k=k)
|
| 427 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 428 |
+
|
| 429 |
+
# Compute weighted average of neighbors
|
| 430 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 431 |
+
weights = torch.exp(-dists / temp)
|
| 432 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 433 |
+
|
| 434 |
+
# Weighted average of neighbors
|
| 435 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 436 |
+
|
| 437 |
+
# Blend original point with smoothed version
|
| 438 |
+
alpha = 0.3 # How much smoothing to apply
|
| 439 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 440 |
+
|
| 441 |
+
def get_timepoint_data(self):
|
| 442 |
+
"""Return data organized by timepoints for visualization"""
|
| 443 |
+
return {
|
| 444 |
+
't0': self.coords_t0,
|
| 445 |
+
't1': self.coords_t1,
|
| 446 |
+
't2': self.coords_t2,
|
| 447 |
+
'time_labels': self.time_labels
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
"""def get_datamodule():
|
| 451 |
+
datamodule = WeightedBranchedCellDataModule(args)
|
| 452 |
+
datamodule.setup(stage="fit")
|
| 453 |
+
return datamodule"""
|
dataloaders/three_branch_data.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
from sklearn.preprocessing import StandardScaler
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
from functools import partial
|
| 10 |
+
from scipy.spatial import cKDTree
|
| 11 |
+
from sklearn.cluster import KMeans
|
| 12 |
+
from torch.utils.data import TensorDataset
|
| 13 |
+
|
| 14 |
+
class ThreeBranchTahoeDataModule(pl.LightningDataModule):
|
| 15 |
+
def __init__(self, args):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.save_hyperparameters()
|
| 18 |
+
|
| 19 |
+
self.batch_size = args.batch_size
|
| 20 |
+
self.max_dim = args.dim
|
| 21 |
+
self.whiten = args.whiten
|
| 22 |
+
self.split_ratios = args.split_ratios
|
| 23 |
+
self.num_timesteps = 2
|
| 24 |
+
self.data_path = f"{args.working_dir}/data/Trametinib_5.0uM_pca_and_leidenumap_labels.csv"
|
| 25 |
+
self.args = args
|
| 26 |
+
|
| 27 |
+
self._prepare_data()
|
| 28 |
+
|
| 29 |
+
def _prepare_data(self):
|
| 30 |
+
df = pd.read_csv(self.data_path, comment='#')
|
| 31 |
+
df = df.iloc[:, 1:]
|
| 32 |
+
df = df.replace('', np.nan)
|
| 33 |
+
pc_cols = df.columns[:50]
|
| 34 |
+
for col in pc_cols:
|
| 35 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 36 |
+
leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
|
| 37 |
+
leiden_clonidine_col = 'leiden_Trametinib_5.0uM'
|
| 38 |
+
|
| 39 |
+
dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
|
| 40 |
+
clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
|
| 41 |
+
|
| 42 |
+
dmso_data = df[dmso_mask].copy()
|
| 43 |
+
clonidine_data = df[clonidine_mask].copy()
|
| 44 |
+
|
| 45 |
+
# Updated to include all three clusters: 0, 4, and 6
|
| 46 |
+
top_clonidine_clusters = ['1.0', '3.0', '5.0']
|
| 47 |
+
|
| 48 |
+
x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
|
| 49 |
+
x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
|
| 50 |
+
x1_3_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[2]]
|
| 51 |
+
|
| 52 |
+
x1_1_coords = x1_1_data[pc_cols].values
|
| 53 |
+
x1_2_coords = x1_2_data[pc_cols].values
|
| 54 |
+
x1_3_coords = x1_3_data[pc_cols].values
|
| 55 |
+
|
| 56 |
+
x1_1_coords = x1_1_coords.astype(float)
|
| 57 |
+
x1_2_coords = x1_2_coords.astype(float)
|
| 58 |
+
x1_3_coords = x1_3_coords.astype(float)
|
| 59 |
+
|
| 60 |
+
# Target size is now the minimum across all three endpoint clusters
|
| 61 |
+
target_size = min(len(x1_1_coords), len(x1_2_coords), len(x1_3_coords))
|
| 62 |
+
|
| 63 |
+
# Helper function to select points closest to centroid
|
| 64 |
+
def select_closest_to_centroid(coords, target_size):
|
| 65 |
+
if len(coords) <= target_size:
|
| 66 |
+
return coords
|
| 67 |
+
|
| 68 |
+
# Calculate centroid
|
| 69 |
+
centroid = np.mean(coords, axis=0)
|
| 70 |
+
|
| 71 |
+
# Calculate distances to centroid
|
| 72 |
+
distances = np.linalg.norm(coords - centroid, axis=1)
|
| 73 |
+
|
| 74 |
+
# Get indices of closest points
|
| 75 |
+
closest_indices = np.argsort(distances)[:target_size]
|
| 76 |
+
|
| 77 |
+
return coords[closest_indices]
|
| 78 |
+
|
| 79 |
+
# Sample all endpoint clusters to target size using centroid-based selection
|
| 80 |
+
x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
|
| 81 |
+
x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
|
| 82 |
+
x1_3_coords = select_closest_to_centroid(x1_3_coords, target_size)
|
| 83 |
+
|
| 84 |
+
dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
|
| 85 |
+
|
| 86 |
+
# DMSO (unchanged)
|
| 87 |
+
largest_dmso_cluster = dmso_cluster_counts.index[0]
|
| 88 |
+
dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
|
| 89 |
+
|
| 90 |
+
dmso_coords = dmso_cluster_data[pc_cols].values
|
| 91 |
+
|
| 92 |
+
# Random sampling from largest DMSO cluster to match target size
|
| 93 |
+
# For DMSO, we'll also use centroid-based selection for consistency
|
| 94 |
+
if len(dmso_coords) >= target_size:
|
| 95 |
+
x0_coords = select_closest_to_centroid(dmso_coords, target_size)
|
| 96 |
+
else:
|
| 97 |
+
# If largest cluster is smaller than target, use all of it and pad with other DMSO cells
|
| 98 |
+
remaining_needed = target_size - len(dmso_coords)
|
| 99 |
+
other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
|
| 100 |
+
other_dmso_coords = other_dmso_data[pc_cols].values
|
| 101 |
+
|
| 102 |
+
if len(other_dmso_coords) >= remaining_needed:
|
| 103 |
+
# Select closest to centroid from other DMSO cells
|
| 104 |
+
other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
|
| 105 |
+
x0_coords = np.vstack([dmso_coords, other_selected])
|
| 106 |
+
else:
|
| 107 |
+
# Use all available DMSO cells and reduce target size
|
| 108 |
+
all_dmso_coords = dmso_data[pc_cols].values
|
| 109 |
+
target_size = min(target_size, len(all_dmso_coords))
|
| 110 |
+
x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
|
| 111 |
+
|
| 112 |
+
# Re-select endpoint clusters with updated target size
|
| 113 |
+
x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
|
| 114 |
+
x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
|
| 115 |
+
x1_3_coords = select_closest_to_centroid(x1_3_data[pc_cols].values.astype(float), target_size)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
self.n_samples = target_size
|
| 119 |
+
|
| 120 |
+
# for plotting
|
| 121 |
+
self.coords_t0 = torch.tensor(x0_coords, dtype=torch.float32)
|
| 122 |
+
self.coords_t1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
|
| 123 |
+
self.coords_t1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
|
| 124 |
+
self.coords_t1_3 = torch.tensor(x1_3_coords, dtype=torch.float32)
|
| 125 |
+
|
| 126 |
+
self.time_labels = np.concatenate([
|
| 127 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 128 |
+
np.ones(len(self.coords_t1_1)), # t=1
|
| 129 |
+
np.ones(len(self.coords_t1_2)), # t=1
|
| 130 |
+
np.ones(len(self.coords_t1_3)), # t=1
|
| 131 |
+
])
|
| 132 |
+
|
| 133 |
+
x0 = torch.tensor(x0_coords, dtype=torch.float32)
|
| 134 |
+
x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
|
| 135 |
+
x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
|
| 136 |
+
x1_3 = torch.tensor(x1_3_coords, dtype=torch.float32)
|
| 137 |
+
|
| 138 |
+
split_index = int(target_size * self.split_ratios[0])
|
| 139 |
+
|
| 140 |
+
if target_size - split_index < self.batch_size:
|
| 141 |
+
split_index = target_size - self.batch_size
|
| 142 |
+
|
| 143 |
+
train_x0 = x0[:split_index]
|
| 144 |
+
val_x0 = x0[split_index:]
|
| 145 |
+
train_x1_1 = x1_1[:split_index]
|
| 146 |
+
val_x1_1 = x1_1[split_index:]
|
| 147 |
+
train_x1_2 = x1_2[:split_index]
|
| 148 |
+
val_x1_2 = x1_2[split_index:]
|
| 149 |
+
train_x1_3 = x1_3[:split_index]
|
| 150 |
+
val_x1_3 = x1_3[split_index:]
|
| 151 |
+
|
| 152 |
+
self.val_x0 = val_x0
|
| 153 |
+
|
| 154 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 155 |
+
train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.603)
|
| 156 |
+
train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.255)
|
| 157 |
+
train_x1_3_weights = torch.full((train_x1_3.shape[0], 1), fill_value=0.142)
|
| 158 |
+
|
| 159 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 160 |
+
val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.603)
|
| 161 |
+
val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.255)
|
| 162 |
+
val_x1_3_weights = torch.full((val_x1_3.shape[0], 1), fill_value=0.142)
|
| 163 |
+
|
| 164 |
+
# Updated train dataloaders to include x1_3
|
| 165 |
+
self.train_dataloaders = {
|
| 166 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 167 |
+
"x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 168 |
+
"x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 169 |
+
"x1_3": DataLoader(TensorDataset(train_x1_3, train_x1_3_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
# Updated val dataloaders to include x1_3
|
| 173 |
+
self.val_dataloaders = {
|
| 174 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 175 |
+
"x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 176 |
+
"x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 177 |
+
"x1_3": DataLoader(TensorDataset(val_x1_3, val_x1_3_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
all_coords = df[pc_cols].dropna().values.astype(float)
|
| 181 |
+
self.dataset = torch.tensor(all_coords, dtype=torch.float32)
|
| 182 |
+
self.tree = cKDTree(all_coords)
|
| 183 |
+
|
| 184 |
+
self.test_dataloaders = {
|
| 185 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 186 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
# Updated metric samples - now using 4 clusters instead of 3
|
| 190 |
+
#km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy())
|
| 191 |
+
km_all = KMeans(n_clusters=4, random_state=0).fit(self.dataset[:, :3].numpy())
|
| 192 |
+
|
| 193 |
+
cluster_labels = km_all.labels_
|
| 194 |
+
|
| 195 |
+
cluster_0_mask = cluster_labels == 0
|
| 196 |
+
cluster_1_mask = cluster_labels == 1
|
| 197 |
+
cluster_2_mask = cluster_labels == 2
|
| 198 |
+
cluster_3_mask = cluster_labels == 3
|
| 199 |
+
|
| 200 |
+
samples = self.dataset.cpu().numpy()
|
| 201 |
+
|
| 202 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 203 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 204 |
+
cluster_2_data = samples[cluster_2_mask]
|
| 205 |
+
cluster_3_data = samples[cluster_3_mask]
|
| 206 |
+
|
| 207 |
+
self.metric_samples_dataloaders = [
|
| 208 |
+
DataLoader(
|
| 209 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 210 |
+
batch_size=cluster_1_data.shape[0],
|
| 211 |
+
shuffle=False,
|
| 212 |
+
drop_last=False,
|
| 213 |
+
),
|
| 214 |
+
DataLoader(
|
| 215 |
+
torch.tensor(cluster_3_data, dtype=torch.float32),
|
| 216 |
+
batch_size=cluster_3_data.shape[0],
|
| 217 |
+
shuffle=False,
|
| 218 |
+
drop_last=False,
|
| 219 |
+
),
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
DataLoader(
|
| 223 |
+
torch.tensor(cluster_2_data, dtype=torch.float32),
|
| 224 |
+
batch_size=cluster_2_data.shape[0],
|
| 225 |
+
shuffle=False,
|
| 226 |
+
drop_last=False,
|
| 227 |
+
),
|
| 228 |
+
DataLoader(
|
| 229 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 230 |
+
batch_size=cluster_0_data.shape[0],
|
| 231 |
+
shuffle=False,
|
| 232 |
+
drop_last=False,
|
| 233 |
+
),
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
def train_dataloader(self):
|
| 237 |
+
combined_loaders = {
|
| 238 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 239 |
+
"metric_samples": CombinedLoader(
|
| 240 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 241 |
+
),
|
| 242 |
+
}
|
| 243 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 244 |
+
|
| 245 |
+
def val_dataloader(self):
|
| 246 |
+
combined_loaders = {
|
| 247 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 248 |
+
"metric_samples": CombinedLoader(
|
| 249 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 250 |
+
),
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 254 |
+
|
| 255 |
+
def test_dataloader(self):
|
| 256 |
+
combined_loaders = {
|
| 257 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 258 |
+
"metric_samples": CombinedLoader(
|
| 259 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 260 |
+
),
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 264 |
+
|
| 265 |
+
def get_manifold_proj(self, points):
|
| 266 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 267 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 268 |
+
|
| 269 |
+
@staticmethod
|
| 270 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 271 |
+
"""
|
| 272 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 273 |
+
This replaces the plane projection for 2D manifold regularization
|
| 274 |
+
"""
|
| 275 |
+
points_np = x.detach().cpu().numpy()
|
| 276 |
+
_, idx = tree.query(points_np, k=k)
|
| 277 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 278 |
+
|
| 279 |
+
# Compute weighted average of neighbors
|
| 280 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 281 |
+
weights = torch.exp(-dists / temp)
|
| 282 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 283 |
+
|
| 284 |
+
# Weighted average of neighbors
|
| 285 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 286 |
+
|
| 287 |
+
# Blend original point with smoothed version
|
| 288 |
+
alpha = 0.3 # How much smoothing to apply
|
| 289 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 290 |
+
|
| 291 |
+
def get_timepoint_data(self):
|
| 292 |
+
"""Return data organized by timepoints for visualization"""
|
| 293 |
+
return {
|
| 294 |
+
't0': self.coords_t0,
|
| 295 |
+
't1_1': self.coords_t1_1,
|
| 296 |
+
't1_2': self.coords_t1_2,
|
| 297 |
+
't1_3': self.coords_t1_3,
|
| 298 |
+
'time_labels': self.time_labels
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
def get_datamodule():
|
| 302 |
+
from plot.parsers_tahoe import parse_args
|
| 303 |
+
args = parse_args()
|
| 304 |
+
datamodule = ThreeBranchTahoeDataModule(args)
|
| 305 |
+
datamodule.setup(stage="fit")
|
| 306 |
+
return datamodule
|
dataloaders/trametinib_single.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
from sklearn.preprocessing import StandardScaler
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
from functools import partial
|
| 10 |
+
from scipy.spatial import cKDTree
|
| 11 |
+
from sklearn.cluster import KMeans
|
| 12 |
+
from torch.utils.data import TensorDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TrametinibSingleBranchDataModule(pl.LightningDataModule):
|
| 16 |
+
def __init__(self, args):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.save_hyperparameters()
|
| 19 |
+
|
| 20 |
+
self.batch_size = args.batch_size
|
| 21 |
+
self.max_dim = args.dim
|
| 22 |
+
self.whiten = args.whiten
|
| 23 |
+
self.split_ratios = args.split_ratios
|
| 24 |
+
self.num_timesteps = 2
|
| 25 |
+
self.data_path = args.data_path
|
| 26 |
+
self.args = args
|
| 27 |
+
|
| 28 |
+
self._prepare_data()
|
| 29 |
+
|
| 30 |
+
def _prepare_data(self):
|
| 31 |
+
df = pd.read_csv(self.data_path, comment='#')
|
| 32 |
+
df = df.iloc[:, 1:]
|
| 33 |
+
df = df.replace('', np.nan)
|
| 34 |
+
pc_cols = df.columns[:50]
|
| 35 |
+
for col in pc_cols:
|
| 36 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 37 |
+
leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
|
| 38 |
+
leiden_clonidine_col = 'leiden_Trametinib_5.0uM'
|
| 39 |
+
|
| 40 |
+
dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
|
| 41 |
+
clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
|
| 42 |
+
|
| 43 |
+
dmso_data = df[dmso_mask].copy()
|
| 44 |
+
clonidine_data = df[clonidine_mask].copy()
|
| 45 |
+
|
| 46 |
+
# Updated to include all three clusters: 0, 4, and 6
|
| 47 |
+
top_clonidine_clusters = ['1.0', '3.0', '5.0']
|
| 48 |
+
|
| 49 |
+
x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
|
| 50 |
+
x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
|
| 51 |
+
x1_3_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[2]]
|
| 52 |
+
|
| 53 |
+
x1_1_coords = x1_1_data[pc_cols].values
|
| 54 |
+
x1_2_coords = x1_2_data[pc_cols].values
|
| 55 |
+
x1_3_coords = x1_3_data[pc_cols].values
|
| 56 |
+
|
| 57 |
+
x1_1_coords = x1_1_coords.astype(float)
|
| 58 |
+
x1_2_coords = x1_2_coords.astype(float)
|
| 59 |
+
x1_3_coords = x1_3_coords.astype(float)
|
| 60 |
+
|
| 61 |
+
# Target size is now the minimum across all three endpoint clusters
|
| 62 |
+
target_size = min(len(x1_1_coords), len(x1_2_coords), len(x1_3_coords))
|
| 63 |
+
|
| 64 |
+
# Helper function to select points closest to centroid
|
| 65 |
+
def select_closest_to_centroid(coords, target_size):
|
| 66 |
+
if len(coords) <= target_size:
|
| 67 |
+
return coords
|
| 68 |
+
|
| 69 |
+
# Calculate centroid
|
| 70 |
+
centroid = np.mean(coords, axis=0)
|
| 71 |
+
|
| 72 |
+
# Calculate distances to centroid
|
| 73 |
+
distances = np.linalg.norm(coords - centroid, axis=1)
|
| 74 |
+
|
| 75 |
+
# Get indices of closest points
|
| 76 |
+
closest_indices = np.argsort(distances)[:target_size]
|
| 77 |
+
|
| 78 |
+
return coords[closest_indices]
|
| 79 |
+
|
| 80 |
+
# Sample all endpoint clusters to target size using centroid-based selection
|
| 81 |
+
x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
|
| 82 |
+
x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
|
| 83 |
+
x1_3_coords = select_closest_to_centroid(x1_3_coords, target_size)
|
| 84 |
+
|
| 85 |
+
dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
|
| 86 |
+
|
| 87 |
+
# DMSO (unchanged)
|
| 88 |
+
largest_dmso_cluster = dmso_cluster_counts.index[0]
|
| 89 |
+
dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
|
| 90 |
+
|
| 91 |
+
dmso_coords = dmso_cluster_data[pc_cols].values
|
| 92 |
+
|
| 93 |
+
# Random sampling from largest DMSO cluster to match target size
|
| 94 |
+
# For DMSO, we'll also use centroid-based selection for consistency
|
| 95 |
+
if len(dmso_coords) >= target_size:
|
| 96 |
+
x0_coords = select_closest_to_centroid(dmso_coords, target_size)
|
| 97 |
+
else:
|
| 98 |
+
# If largest cluster is smaller than target, use all of it and pad with other DMSO cells
|
| 99 |
+
remaining_needed = target_size - len(dmso_coords)
|
| 100 |
+
other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
|
| 101 |
+
other_dmso_coords = other_dmso_data[pc_cols].values
|
| 102 |
+
|
| 103 |
+
if len(other_dmso_coords) >= remaining_needed:
|
| 104 |
+
# Select closest to centroid from other DMSO cells
|
| 105 |
+
other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
|
| 106 |
+
x0_coords = np.vstack([dmso_coords, other_selected])
|
| 107 |
+
else:
|
| 108 |
+
# Use all available DMSO cells and reduce target size
|
| 109 |
+
all_dmso_coords = dmso_data[pc_cols].values
|
| 110 |
+
target_size = min(target_size, len(all_dmso_coords))
|
| 111 |
+
x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
|
| 112 |
+
|
| 113 |
+
# Re-select endpoint clusters with updated target size
|
| 114 |
+
x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
|
| 115 |
+
x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
|
| 116 |
+
x1_3_coords = select_closest_to_centroid(x1_3_data[pc_cols].values.astype(float), target_size)
|
| 117 |
+
|
| 118 |
+
self.n_samples = target_size
|
| 119 |
+
|
| 120 |
+
# for plotting
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
x0 = torch.tensor(x0_coords, dtype=torch.float32)
|
| 124 |
+
x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
|
| 125 |
+
x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
|
| 126 |
+
x1_3 = torch.tensor(x1_3_coords, dtype=torch.float32)
|
| 127 |
+
x1 = torch.cat([x1_1, x1_2, x1_3], dim=0)
|
| 128 |
+
|
| 129 |
+
self.coords_t0 = x0
|
| 130 |
+
self.coords_t1 = x1
|
| 131 |
+
|
| 132 |
+
self.time_labels = np.concatenate([
|
| 133 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 134 |
+
np.ones(len(self.coords_t1)), # t=1
|
| 135 |
+
])
|
| 136 |
+
|
| 137 |
+
split_index = int(target_size * self.split_ratios[0])
|
| 138 |
+
|
| 139 |
+
if target_size - split_index < self.batch_size:
|
| 140 |
+
split_index = target_size - self.batch_size
|
| 141 |
+
|
| 142 |
+
train_x0 = x0[:split_index]
|
| 143 |
+
val_x0 = x0[split_index:]
|
| 144 |
+
train_x1 = x1_1[:split_index]
|
| 145 |
+
val_x1 = x1_1[split_index:]
|
| 146 |
+
|
| 147 |
+
self.val_x0 = val_x0
|
| 148 |
+
|
| 149 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 150 |
+
train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0)
|
| 151 |
+
|
| 152 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 153 |
+
val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0)
|
| 154 |
+
|
| 155 |
+
# Updated train dataloaders to include x1_3
|
| 156 |
+
self.train_dataloaders = {
|
| 157 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 158 |
+
"x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
# Updated val dataloaders to include x1_3
|
| 162 |
+
self.val_dataloaders = {
|
| 163 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 164 |
+
"x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
all_coords = df[pc_cols].dropna().values.astype(float)
|
| 168 |
+
self.dataset = torch.tensor(all_coords, dtype=torch.float32)
|
| 169 |
+
self.tree = cKDTree(all_coords)
|
| 170 |
+
|
| 171 |
+
self.test_dataloaders = {
|
| 172 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 173 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
# Updated metric samples - now using 4 clusters instead of 3
|
| 177 |
+
#km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy())
|
| 178 |
+
km_all = KMeans(n_clusters=2, random_state=0).fit(self.dataset[:, :3].numpy())
|
| 179 |
+
|
| 180 |
+
cluster_labels = km_all.labels_
|
| 181 |
+
|
| 182 |
+
cluster_0_mask = cluster_labels == 0
|
| 183 |
+
cluster_1_mask = cluster_labels == 1
|
| 184 |
+
|
| 185 |
+
samples = self.dataset.cpu().numpy()
|
| 186 |
+
|
| 187 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 188 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 189 |
+
|
| 190 |
+
self.metric_samples_dataloaders = [
|
| 191 |
+
DataLoader(
|
| 192 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 193 |
+
batch_size=cluster_1_data.shape[0],
|
| 194 |
+
shuffle=False,
|
| 195 |
+
drop_last=False,
|
| 196 |
+
),
|
| 197 |
+
DataLoader(
|
| 198 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 199 |
+
batch_size=cluster_0_data.shape[0],
|
| 200 |
+
shuffle=False,
|
| 201 |
+
drop_last=False,
|
| 202 |
+
),
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
def train_dataloader(self):
|
| 206 |
+
combined_loaders = {
|
| 207 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 208 |
+
"metric_samples": CombinedLoader(
|
| 209 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 210 |
+
),
|
| 211 |
+
}
|
| 212 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 213 |
+
|
| 214 |
+
def val_dataloader(self):
|
| 215 |
+
combined_loaders = {
|
| 216 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 217 |
+
"metric_samples": CombinedLoader(
|
| 218 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 219 |
+
),
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def test_dataloader(self):
|
| 227 |
+
combined_loaders = {
|
| 228 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 229 |
+
"metric_samples": CombinedLoader(
|
| 230 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 231 |
+
),
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 235 |
+
|
| 236 |
+
def get_manifold_proj(self, points):
|
| 237 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 238 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 239 |
+
|
| 240 |
+
@staticmethod
|
| 241 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 242 |
+
"""
|
| 243 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 244 |
+
This replaces the plane projection for 2D manifold regularization
|
| 245 |
+
"""
|
| 246 |
+
points_np = x.detach().cpu().numpy()
|
| 247 |
+
_, idx = tree.query(points_np, k=k)
|
| 248 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 249 |
+
|
| 250 |
+
# Compute weighted average of neighbors
|
| 251 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 252 |
+
weights = torch.exp(-dists / temp)
|
| 253 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 254 |
+
|
| 255 |
+
# Weighted average of neighbors
|
| 256 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 257 |
+
|
| 258 |
+
# Blend original point with smoothed version
|
| 259 |
+
alpha = 0.3 # How much smoothing to apply
|
| 260 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 261 |
+
|
| 262 |
+
def get_timepoint_data(self):
|
| 263 |
+
"""Return data organized by timepoints for visualization"""
|
| 264 |
+
return {
|
| 265 |
+
't0': self.coords_t0,
|
| 266 |
+
't1': self.coords_t1,
|
| 267 |
+
'time_labels': self.time_labels
|
| 268 |
+
}
|
dataloaders/veres_leiden_data.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
from sklearn.preprocessing import StandardScaler
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
| 7 |
+
import numpy as np
|
| 8 |
+
from scipy.spatial import cKDTree
|
| 9 |
+
import math
|
| 10 |
+
from functools import partial
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from torch.utils.data import TensorDataset
|
| 14 |
+
from sklearn.neighbors import kneighbors_graph
|
| 15 |
+
import igraph as ig
|
| 16 |
+
from leidenalg import find_partition, ModularityVertexPartition
|
| 17 |
+
|
| 18 |
+
class WeightedBranchedVeresDataModule(pl.LightningDataModule):
|
| 19 |
+
|
| 20 |
+
def __init__(self, args):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.save_hyperparameters()
|
| 23 |
+
|
| 24 |
+
self.data_path = args.data_path
|
| 25 |
+
self.batch_size = args.batch_size
|
| 26 |
+
self.max_dim = args.dim
|
| 27 |
+
self.whiten = args.whiten
|
| 28 |
+
self.k = 20
|
| 29 |
+
self.num_timesteps = 8
|
| 30 |
+
# initial placeholder, will be set by clustering result
|
| 31 |
+
self.num_branches = args.branches if hasattr(args, 'branches') else None
|
| 32 |
+
self.split_ratios = args.split_ratios
|
| 33 |
+
self.metric_clusters = args.metric_clusters
|
| 34 |
+
self.discard_small = args.discard if hasattr(args, 'discard') else False
|
| 35 |
+
self.args = args
|
| 36 |
+
self._prepare_data()
|
| 37 |
+
|
| 38 |
+
def _prepare_data(self):
|
| 39 |
+
print("Preparing Veres cell data with Leiden clustering in WeightedBranchedVeresLeidenDataModule")
|
| 40 |
+
df = pd.read_csv(self.data_path)
|
| 41 |
+
|
| 42 |
+
# Build dictionary of coordinates by time
|
| 43 |
+
coords_by_t = {
|
| 44 |
+
t: df[df["samples"] == t].iloc[:, 1:].values # Skip 'samples' column
|
| 45 |
+
for t in sorted(df["samples"].unique())
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
n0 = coords_by_t[0].shape[0]
|
| 49 |
+
self.n_samples = n0
|
| 50 |
+
|
| 51 |
+
print("Timepoint distribution:")
|
| 52 |
+
for t in sorted(coords_by_t.keys()):
|
| 53 |
+
print(f" t={t}: {coords_by_t[t].shape[0]} points")
|
| 54 |
+
|
| 55 |
+
# Leiden clustering on final timepoint
|
| 56 |
+
final_t = max(coords_by_t.keys())
|
| 57 |
+
coords_final = coords_by_t[final_t]
|
| 58 |
+
k = 20
|
| 59 |
+
knn_graph = kneighbors_graph(coords_final, k, mode='connectivity', include_self=False)
|
| 60 |
+
sources, targets = knn_graph.nonzero()
|
| 61 |
+
edgelist = list(zip(sources.tolist(), targets.tolist()))
|
| 62 |
+
graph = ig.Graph(edgelist, directed=False)
|
| 63 |
+
partition = find_partition(graph, ModularityVertexPartition)
|
| 64 |
+
leiden_labels = np.array(partition.membership)
|
| 65 |
+
n_leiden = len(np.unique(leiden_labels))
|
| 66 |
+
print(f"Leiden found {n_leiden} clusters at t={final_t}")
|
| 67 |
+
|
| 68 |
+
df_final = df[df["samples"] == final_t].copy()
|
| 69 |
+
df_final["branch"] = leiden_labels
|
| 70 |
+
|
| 71 |
+
cluster_counts = df_final["branch"].value_counts().sort_index()
|
| 72 |
+
print(f"Branch distribution at t={final_t} (pre-merge):")
|
| 73 |
+
print(cluster_counts)
|
| 74 |
+
|
| 75 |
+
# Merge small clusters to nearest large cluster (by centroid)
|
| 76 |
+
min_cells = 100 # threshold; adjust if needed
|
| 77 |
+
cluster_data_dict = {}
|
| 78 |
+
cluster_sizes = []
|
| 79 |
+
for b in range(n_leiden):
|
| 80 |
+
branch_data = df_final[df_final["branch"] == b].iloc[:, 1:-1].values
|
| 81 |
+
cluster_data_dict[b] = branch_data
|
| 82 |
+
cluster_sizes.append(branch_data.shape[0])
|
| 83 |
+
|
| 84 |
+
large_clusters = [b for b, size in enumerate(cluster_sizes) if size >= min_cells]
|
| 85 |
+
small_clusters = [b for b, size in enumerate(cluster_sizes) if size < min_cells]
|
| 86 |
+
|
| 87 |
+
# If no large cluster exists (all small), treat all clusters as large
|
| 88 |
+
if len(large_clusters) == 0:
|
| 89 |
+
large_clusters = list(range(n_leiden))
|
| 90 |
+
small_clusters = []
|
| 91 |
+
|
| 92 |
+
if self.discard_small:
|
| 93 |
+
# Discard small clusters instead of merging
|
| 94 |
+
print(f"Discarding {len(small_clusters)} small clusters (< {min_cells} cells)")
|
| 95 |
+
# Keep only cells from large clusters
|
| 96 |
+
mask = np.isin(leiden_labels, large_clusters)
|
| 97 |
+
df_final = df_final[mask].copy()
|
| 98 |
+
merged_labels = leiden_labels[mask]
|
| 99 |
+
|
| 100 |
+
# Remap to contiguous ids
|
| 101 |
+
new_ids = np.unique(merged_labels)
|
| 102 |
+
id_map = {old: new for new, old in enumerate(new_ids)}
|
| 103 |
+
merged_labels = np.array([id_map[x] for x in merged_labels])
|
| 104 |
+
n_merged = len(np.unique(merged_labels))
|
| 105 |
+
|
| 106 |
+
df_final["branch"] = merged_labels
|
| 107 |
+
print(f"Kept {n_merged} large clusters")
|
| 108 |
+
else:
|
| 109 |
+
centroids = {b: np.mean(cluster_data_dict[b], axis=0) for b in range(n_leiden) if cluster_data_dict[b].shape[0] > 0}
|
| 110 |
+
|
| 111 |
+
merged_labels = leiden_labels.copy()
|
| 112 |
+
for b in small_clusters:
|
| 113 |
+
if cluster_data_dict[b].shape[0] == 0:
|
| 114 |
+
continue
|
| 115 |
+
# find nearest large cluster
|
| 116 |
+
dists = [np.linalg.norm(centroids[b] - centroids[bl]) for bl in large_clusters]
|
| 117 |
+
nearest_large = large_clusters[int(np.argmin(dists))]
|
| 118 |
+
merged_labels[leiden_labels == b] = nearest_large
|
| 119 |
+
|
| 120 |
+
# remap to contiguous ids
|
| 121 |
+
new_ids = np.unique(merged_labels)
|
| 122 |
+
id_map = {old: new for new, old in enumerate(new_ids)}
|
| 123 |
+
merged_labels = np.array([id_map[x] for x in merged_labels])
|
| 124 |
+
n_merged = len(np.unique(merged_labels))
|
| 125 |
+
|
| 126 |
+
df_final["branch"] = merged_labels
|
| 127 |
+
print(f"Merged into {n_merged} clusters")
|
| 128 |
+
|
| 129 |
+
cluster_counts_merged = df_final["branch"].value_counts().sort_index()
|
| 130 |
+
print(f"Branch distribution at t={final_t} (post-merge):")
|
| 131 |
+
print(cluster_counts_merged)
|
| 132 |
+
|
| 133 |
+
endpoints = {}
|
| 134 |
+
cluster_sizes = []
|
| 135 |
+
for b in range(n_merged):
|
| 136 |
+
branch_data = df_final[df_final["branch"] == b].iloc[:, 1:-1].values
|
| 137 |
+
cluster_sizes.append(branch_data.shape[0])
|
| 138 |
+
replace = branch_data.shape[0] < n0
|
| 139 |
+
sampled_indices = np.random.choice(branch_data.shape[0], size=n0, replace=replace)
|
| 140 |
+
endpoints[b] = branch_data[sampled_indices]
|
| 141 |
+
total_t_final = sum(cluster_sizes)
|
| 142 |
+
|
| 143 |
+
x0 = torch.tensor(coords_by_t[0], dtype=torch.float32)
|
| 144 |
+
self.coords_t0 = x0
|
| 145 |
+
# intermediate timepoints
|
| 146 |
+
self.coords_intermediate = {t: torch.tensor(coords_by_t[t], dtype=torch.float32)
|
| 147 |
+
for t in coords_by_t.keys() if t != 0 and t != final_t}
|
| 148 |
+
|
| 149 |
+
self.branch_endpoints = {b: torch.tensor(endpoints[b], dtype=torch.float32) for b in range(n_merged)}
|
| 150 |
+
self.num_branches = n_merged
|
| 151 |
+
|
| 152 |
+
# time labels (for visualization)
|
| 153 |
+
time_labels_list = [np.zeros(len(self.coords_t0))]
|
| 154 |
+
for t in sorted(self.coords_intermediate.keys()):
|
| 155 |
+
time_labels_list.append(np.ones(len(self.coords_intermediate[t])) * t)
|
| 156 |
+
for b in range(self.num_branches):
|
| 157 |
+
time_labels_list.append(np.ones(len(self.branch_endpoints[b])) * final_t)
|
| 158 |
+
self.time_labels = np.concatenate(time_labels_list)
|
| 159 |
+
|
| 160 |
+
# splits
|
| 161 |
+
split_index = int(n0 * self.split_ratios[0])
|
| 162 |
+
if n0 - split_index < self.batch_size:
|
| 163 |
+
split_index = n0 - self.batch_size
|
| 164 |
+
|
| 165 |
+
train_x0 = x0[:split_index]
|
| 166 |
+
val_x0 = x0[split_index:]
|
| 167 |
+
self.val_x0 = val_x0
|
| 168 |
+
|
| 169 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 170 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 171 |
+
|
| 172 |
+
# branch weights proportional to cluster sizes
|
| 173 |
+
branch_weights = [size / total_t_final for size in cluster_sizes]
|
| 174 |
+
|
| 175 |
+
# Split intermediate timepoints for sequential training support
|
| 176 |
+
train_intermediate = {}
|
| 177 |
+
val_intermediate = {}
|
| 178 |
+
self.train_coords_intermediate = {} # Store training-only intermediate data for MMD
|
| 179 |
+
for t in sorted(self.coords_intermediate.keys()):
|
| 180 |
+
coords_t = self.coords_intermediate[t]
|
| 181 |
+
train_coords_t = coords_t[:split_index]
|
| 182 |
+
val_coords_t = coords_t[split_index:]
|
| 183 |
+
train_weights_t = torch.full((train_coords_t.shape[0], 1), fill_value=1.0)
|
| 184 |
+
val_weights_t = torch.full((val_coords_t.shape[0], 1), fill_value=1.0)
|
| 185 |
+
train_intermediate[f"x{t}"] = (train_coords_t, train_weights_t)
|
| 186 |
+
val_intermediate[f"x{t}"] = (val_coords_t, val_weights_t)
|
| 187 |
+
self.train_coords_intermediate[t] = train_coords_t # Store training data by int key
|
| 188 |
+
|
| 189 |
+
train_loaders = {
|
| 190 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 191 |
+
}
|
| 192 |
+
val_loaders = {
|
| 193 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
# Add all intermediate timepoints to loaders
|
| 197 |
+
for t_key in sorted(train_intermediate.keys()):
|
| 198 |
+
train_coords_t, train_weights_t = train_intermediate[t_key]
|
| 199 |
+
val_coords_t, val_weights_t = val_intermediate[t_key]
|
| 200 |
+
train_loaders[t_key] = DataLoader(
|
| 201 |
+
TensorDataset(train_coords_t, train_weights_t),
|
| 202 |
+
batch_size=self.batch_size,
|
| 203 |
+
shuffle=True,
|
| 204 |
+
drop_last=True
|
| 205 |
+
)
|
| 206 |
+
val_loaders[t_key] = DataLoader(
|
| 207 |
+
TensorDataset(val_coords_t, val_weights_t),
|
| 208 |
+
batch_size=self.batch_size,
|
| 209 |
+
shuffle=False,
|
| 210 |
+
drop_last=True
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
for b in range(self.num_branches):
|
| 214 |
+
# Calculate split based on this branch's size, not t=0 size
|
| 215 |
+
branch_size = self.branch_endpoints[b].shape[0]
|
| 216 |
+
branch_split_index = int(branch_size * self.split_ratios[0])
|
| 217 |
+
if branch_size - branch_split_index < self.batch_size:
|
| 218 |
+
branch_split_index = max(0, branch_size - self.batch_size)
|
| 219 |
+
|
| 220 |
+
train_branch = self.branch_endpoints[b][:branch_split_index]
|
| 221 |
+
val_branch = self.branch_endpoints[b][branch_split_index:]
|
| 222 |
+
train_branch_weights = torch.full((train_branch.shape[0], 1), fill_value=branch_weights[b])
|
| 223 |
+
val_branch_weights = torch.full((val_branch.shape[0], 1), fill_value=branch_weights[b])
|
| 224 |
+
train_loaders[f"x1_{b+1}"] = DataLoader(
|
| 225 |
+
TensorDataset(train_branch, train_branch_weights),
|
| 226 |
+
batch_size=self.batch_size,
|
| 227 |
+
shuffle=True,
|
| 228 |
+
drop_last=True
|
| 229 |
+
)
|
| 230 |
+
val_loaders[f"x1_{b+1}"] = DataLoader(
|
| 231 |
+
TensorDataset(val_branch, val_branch_weights),
|
| 232 |
+
batch_size=self.batch_size,
|
| 233 |
+
shuffle=True,
|
| 234 |
+
drop_last=True
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
self.train_dataloaders = train_loaders
|
| 238 |
+
self.val_dataloaders = val_loaders
|
| 239 |
+
|
| 240 |
+
# full dataset
|
| 241 |
+
all_data_list = [coords_by_t[t] for t in sorted(coords_by_t.keys())]
|
| 242 |
+
all_data = np.vstack(all_data_list)
|
| 243 |
+
self.dataset = torch.tensor(all_data, dtype=torch.float32)
|
| 244 |
+
self.tree = cKDTree(all_data)
|
| 245 |
+
|
| 246 |
+
self.test_dataloaders = {
|
| 247 |
+
"x0": DataLoader(TensorDataset(self.val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 248 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
# Metric dataloaders: t0 vs (t1..t_final + endpoints)
|
| 252 |
+
cluster_0_data = self.coords_t0.cpu().numpy()
|
| 253 |
+
cluster_1_list = [self.coords_intermediate[t].cpu().numpy() for t in sorted(self.coords_intermediate.keys())]
|
| 254 |
+
cluster_1_list.extend([self.branch_endpoints[b].cpu().numpy() for b in range(self.num_branches)])
|
| 255 |
+
cluster_1_data = np.vstack(cluster_1_list)
|
| 256 |
+
|
| 257 |
+
self.metric_samples_dataloaders = [
|
| 258 |
+
DataLoader(torch.tensor(cluster_0_data, dtype=torch.float32), batch_size=cluster_0_data.shape[0], shuffle=False, drop_last=False),
|
| 259 |
+
DataLoader(torch.tensor(cluster_1_data, dtype=torch.float32), batch_size=cluster_1_data.shape[0], shuffle=False, drop_last=False),
|
| 260 |
+
]
|
| 261 |
+
|
| 262 |
+
def train_dataloader(self):
|
| 263 |
+
combined_loaders = {
|
| 264 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 265 |
+
"metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"),
|
| 266 |
+
}
|
| 267 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 268 |
+
|
| 269 |
+
def val_dataloader(self):
|
| 270 |
+
combined_loaders = {
|
| 271 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 272 |
+
"metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"),
|
| 273 |
+
}
|
| 274 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 275 |
+
|
| 276 |
+
def test_dataloader(self):
|
| 277 |
+
combined_loaders = {
|
| 278 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 279 |
+
"metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"),
|
| 280 |
+
}
|
| 281 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 282 |
+
|
| 283 |
+
def get_manifold_proj(self, points):
|
| 284 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 285 |
+
|
| 286 |
+
@staticmethod
|
| 287 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 288 |
+
points_np = x.detach().cpu().numpy()
|
| 289 |
+
_, idx = tree.query(points_np, k=k)
|
| 290 |
+
nearest_pts = dataset[idx]
|
| 291 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 292 |
+
weights = torch.exp(-dists / temp)
|
| 293 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 294 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 295 |
+
alpha = 0.3
|
| 296 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 297 |
+
|
| 298 |
+
def get_timepoint_data(self):
|
| 299 |
+
result = {
|
| 300 |
+
't0': self.coords_t0,
|
| 301 |
+
'time_labels': self.time_labels
|
| 302 |
+
}
|
| 303 |
+
# intermediate timepoints
|
| 304 |
+
for t in sorted(self.coords_intermediate.keys()):
|
| 305 |
+
result[f't{t}'] = self.coords_intermediate[t]
|
| 306 |
+
final_t = max([0] + list(self.coords_intermediate.keys())) + 1
|
| 307 |
+
for b in range(self.num_branches):
|
| 308 |
+
result[f't{final_t}_{b}'] = self.branch_endpoints[b]
|
| 309 |
+
return result
|
| 310 |
+
|
| 311 |
+
def get_train_intermediate_data(self):
|
| 312 |
+
if hasattr(self, 'train_coords_intermediate'):
|
| 313 |
+
return self.train_coords_intermediate
|
| 314 |
+
else:
|
| 315 |
+
# Fallback to full intermediate data if train split not available
|
| 316 |
+
print("Warning: train_coords_intermediate not found, returning full intermediate data.")
|
| 317 |
+
return self.coords_intermediate
|
environment.yml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: branchsbm
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- pytorch
|
| 5 |
+
- defaults
|
| 6 |
+
dependencies:
|
| 7 |
+
- conda-forge::python=3.10
|
| 8 |
+
- conda-forge::openssl
|
| 9 |
+
- ca-certificates
|
| 10 |
+
- certifi
|
| 11 |
+
- pytorch::pytorch
|
| 12 |
+
- matplotlib
|
| 13 |
+
- pandas
|
| 14 |
+
- seaborn
|
| 15 |
+
- torchmetrics
|
| 16 |
+
- numpy>=1.26.0,<2.0.0
|
| 17 |
+
- scikit-learn
|
| 18 |
+
- pyyaml
|
| 19 |
+
- jupyter
|
| 20 |
+
- ipykernel
|
| 21 |
+
- notebook
|
| 22 |
+
- tqdm
|
| 23 |
+
- pytorch-lightning>=2.0.0
|
| 24 |
+
- lightning>=2.0.0
|
| 25 |
+
- python-igraph
|
| 26 |
+
- leidenalg
|
| 27 |
+
- pip
|
| 28 |
+
- pip:
|
| 29 |
+
- scipy==1.13.1
|
| 30 |
+
- wandb==0.22.1
|
| 31 |
+
- torchcfm==1.0.7
|
| 32 |
+
- torchdyn==1.0.6
|
| 33 |
+
- torchdiffeq
|
| 34 |
+
- pot
|
| 35 |
+
- hydra-core
|
| 36 |
+
- omegaconf
|
| 37 |
+
- laspy
|
| 38 |
+
- umap-learn
|
| 39 |
+
- scanpy
|
| 40 |
+
- lpips
|
| 41 |
+
- geomloss
|
parsers.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
def parse_args():
|
| 4 |
+
parser = argparse.ArgumentParser(description="Train BranchSBM")
|
| 5 |
+
|
| 6 |
+
parser.add_argument("--seed", default=2, type=int)
|
| 7 |
+
|
| 8 |
+
parser.add_argument(
|
| 9 |
+
"--config_path", type=str,
|
| 10 |
+
default='',
|
| 11 |
+
help="Path to config file"
|
| 12 |
+
)
|
| 13 |
+
####### ITERATES IN THE CODE #######
|
| 14 |
+
parser.add_argument(
|
| 15 |
+
"--seeds",
|
| 16 |
+
nargs="+",
|
| 17 |
+
type=int,
|
| 18 |
+
default=[42, 43, 44, 45, 46],
|
| 19 |
+
help="Random seeds to iterate over",
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--t_exclude",
|
| 23 |
+
nargs="+",
|
| 24 |
+
type=int,
|
| 25 |
+
default=None,
|
| 26 |
+
help="Time points to exclude (iterating over)",
|
| 27 |
+
)
|
| 28 |
+
####################################
|
| 29 |
+
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--working_dir",
|
| 32 |
+
type=str,
|
| 33 |
+
default="path/to/your/home/BranchSBM",
|
| 34 |
+
help="Working directory",
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--resume_flow_model_ckpt",
|
| 38 |
+
type=str,
|
| 39 |
+
default=None,
|
| 40 |
+
help="Path to the flow model to resume training",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--resume_growth_model_ckpt",
|
| 44 |
+
type=str,
|
| 45 |
+
default=None,
|
| 46 |
+
help="Path to the flow model to resume training",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--load_geopath_model_ckpt",
|
| 50 |
+
type=str,
|
| 51 |
+
default=None,
|
| 52 |
+
help="Path to the geopath model to resume training",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--sequential",
|
| 56 |
+
action=argparse.BooleanOptionalAction,
|
| 57 |
+
default=False,
|
| 58 |
+
help="Use sequential training for multi-timepoint data",
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--discard",
|
| 62 |
+
action=argparse.BooleanOptionalAction,
|
| 63 |
+
default=False,
|
| 64 |
+
help="Discard small clusters instead of merging them in Leiden clustering",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--pseudo",
|
| 68 |
+
action=argparse.BooleanOptionalAction,
|
| 69 |
+
default=False,
|
| 70 |
+
help="Use pseudotime-based clustering for Weinreb data instead of Leiden on t=2",
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--branches",
|
| 74 |
+
type=int,
|
| 75 |
+
default=2,
|
| 76 |
+
help="Number of branches",
|
| 77 |
+
)
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--metric_clusters",
|
| 80 |
+
type=int,
|
| 81 |
+
default=3,
|
| 82 |
+
help="Number of metric clusters",
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--resolution",
|
| 86 |
+
type=float,
|
| 87 |
+
default=1.0,
|
| 88 |
+
help="Resolution parameter for Leiden clustering",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
######### DATASETS #################
|
| 92 |
+
parser = datasets_parser(parser)
|
| 93 |
+
####################################
|
| 94 |
+
|
| 95 |
+
######### IMAGE DATASETS ###########
|
| 96 |
+
parser = image_datasets_parser(parser)
|
| 97 |
+
####################################
|
| 98 |
+
|
| 99 |
+
######### METRICS ##################
|
| 100 |
+
parser = metric_parser(parser)
|
| 101 |
+
####################################
|
| 102 |
+
|
| 103 |
+
######### General Training #########
|
| 104 |
+
parser = general_training_parser(parser)
|
| 105 |
+
####################################
|
| 106 |
+
|
| 107 |
+
######### Training GeoPath Network ####
|
| 108 |
+
parser = geopath_network_parser(parser)
|
| 109 |
+
####################################
|
| 110 |
+
|
| 111 |
+
######### Training Flow Network ####
|
| 112 |
+
parser = flow_network_parser(parser)
|
| 113 |
+
####################################
|
| 114 |
+
|
| 115 |
+
parser = growth_network_parser(parser)
|
| 116 |
+
|
| 117 |
+
return parser.parse_args()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def datasets_parser(parser):
|
| 121 |
+
parser.add_argument("--dim", type=int, default=3, help="Dimension of data")
|
| 122 |
+
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--data_type",
|
| 125 |
+
type=str,
|
| 126 |
+
default="lidar",
|
| 127 |
+
help="Type of data, now wither scrna or one of toys",
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--data_path",
|
| 131 |
+
type=str,
|
| 132 |
+
default="",
|
| 133 |
+
help="lidar data path",
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--data_name",
|
| 137 |
+
type=str,
|
| 138 |
+
default="lidar",
|
| 139 |
+
help="Path to the dataset",
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--whiten",
|
| 143 |
+
action=argparse.BooleanOptionalAction,
|
| 144 |
+
default=True,
|
| 145 |
+
help="Whiten the data",
|
| 146 |
+
)
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--min_cells",
|
| 149 |
+
type=int,
|
| 150 |
+
default=500,
|
| 151 |
+
help="Minimum cells per cluster for Leiden clustering",
|
| 152 |
+
)
|
| 153 |
+
parser.add_argument(
|
| 154 |
+
"--k",
|
| 155 |
+
type=int,
|
| 156 |
+
default=20,
|
| 157 |
+
help="Number of neighbors for KNN graph in Leiden clustering",
|
| 158 |
+
)
|
| 159 |
+
parser.add_argument(
|
| 160 |
+
"--pseudotime_threshold",
|
| 161 |
+
type=float,
|
| 162 |
+
default=0.6,
|
| 163 |
+
help="Pseudotime threshold for terminal cells (only used when --pseudo is True)",
|
| 164 |
+
)
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--terminal_neighbors",
|
| 167 |
+
type=int,
|
| 168 |
+
default=20,
|
| 169 |
+
help="Number of neighbors for terminal cell clustering (only used when --pseudo is True)",
|
| 170 |
+
)
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--terminal_resolution",
|
| 173 |
+
type=float,
|
| 174 |
+
default=0.2,
|
| 175 |
+
help="Resolution for terminal cell Leiden clustering (only used when --pseudo is True)",
|
| 176 |
+
)
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
"--n_dcs",
|
| 179 |
+
type=int,
|
| 180 |
+
default=10,
|
| 181 |
+
help="Number of diffusion components for DPT (only used when --pseudo is True)",
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--initial_neighbors",
|
| 185 |
+
type=int,
|
| 186 |
+
default=30,
|
| 187 |
+
help="Number of neighbors for initial kNN graph (only used when --pseudo is True)",
|
| 188 |
+
)
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--initial_resolution",
|
| 191 |
+
type=float,
|
| 192 |
+
default=1.0,
|
| 193 |
+
help="Resolution for initial Leiden clustering (only used when --pseudo is True)",
|
| 194 |
+
)
|
| 195 |
+
return parser
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def image_datasets_parser(parser):
|
| 199 |
+
parser.add_argument(
|
| 200 |
+
"--image_size",
|
| 201 |
+
type=int,
|
| 202 |
+
default=128,
|
| 203 |
+
help="Size of the image",
|
| 204 |
+
)
|
| 205 |
+
parser.add_argument(
|
| 206 |
+
"--x0_label",
|
| 207 |
+
type=str,
|
| 208 |
+
default="dog",
|
| 209 |
+
help="Label for x0",
|
| 210 |
+
)
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--x1_label",
|
| 213 |
+
type=str,
|
| 214 |
+
default="cat",
|
| 215 |
+
help="Label for x1",
|
| 216 |
+
)
|
| 217 |
+
return parser
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def metric_parser(parser):
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--branchsbm",
|
| 223 |
+
action=argparse.BooleanOptionalAction,
|
| 224 |
+
default=True,
|
| 225 |
+
help="If branched SBM",
|
| 226 |
+
)
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
"--n_centers",
|
| 229 |
+
type=int,
|
| 230 |
+
default=100,
|
| 231 |
+
help="Number of centers for RBF network",
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--kappa",
|
| 235 |
+
type=float,
|
| 236 |
+
default=1.0,
|
| 237 |
+
help="Kappa parameter for RBF network",
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--rho",
|
| 241 |
+
type=float,
|
| 242 |
+
default=0.001,
|
| 243 |
+
help="Rho parameter in Riemanian Velocity Calculation",
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--velocity_metric",
|
| 247 |
+
type=str,
|
| 248 |
+
default="rbf",
|
| 249 |
+
help="Metric for velocity calculation",
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--gammas",
|
| 253 |
+
nargs="+",
|
| 254 |
+
type=float,
|
| 255 |
+
default=[0.2, 0.2],
|
| 256 |
+
help="Gamma parameter in Riemanian Velocity Calculation",
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
parser.add_argument(
|
| 260 |
+
"--metric_epochs",
|
| 261 |
+
type=int,
|
| 262 |
+
default=100,
|
| 263 |
+
help="Number of epochs for metric learning",
|
| 264 |
+
)
|
| 265 |
+
parser.add_argument(
|
| 266 |
+
"--metric_patience",
|
| 267 |
+
type=int,
|
| 268 |
+
default=20,
|
| 269 |
+
help="Patience for metric learning",
|
| 270 |
+
)
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
"--metric_lr",
|
| 273 |
+
type=float,
|
| 274 |
+
default=1e-2,
|
| 275 |
+
help="Learning rate for metric learning",
|
| 276 |
+
)
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
"--alpha_metric",
|
| 279 |
+
type=float,
|
| 280 |
+
default=1.0,
|
| 281 |
+
help="Alpha parameter for metric learning",
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
return parser
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def general_training_parser(parser):
|
| 288 |
+
parser.add_argument(
|
| 289 |
+
"--batch_size", type=int, default=128, help="Batch size for training"
|
| 290 |
+
)
|
| 291 |
+
parser.add_argument(
|
| 292 |
+
"--optimal_transport_method",
|
| 293 |
+
type=str,
|
| 294 |
+
default="exact",
|
| 295 |
+
help="Use optimal transport in CFM training",
|
| 296 |
+
)
|
| 297 |
+
parser.add_argument(
|
| 298 |
+
"--ema_decay",
|
| 299 |
+
type=float,
|
| 300 |
+
default=None,
|
| 301 |
+
help="Decay for EMA",
|
| 302 |
+
)
|
| 303 |
+
parser.add_argument(
|
| 304 |
+
"--split_ratios",
|
| 305 |
+
nargs=2,
|
| 306 |
+
type=float,
|
| 307 |
+
default=[0.9, 0.1],
|
| 308 |
+
help="Split ratios for training/validation data in CFM training",
|
| 309 |
+
)
|
| 310 |
+
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
"--accelerator", type=str, default="gpu", help="Training accelerator"
|
| 313 |
+
)
|
| 314 |
+
parser.add_argument(
|
| 315 |
+
"--run_name", type=str, default=None, help="Name for the wandb run"
|
| 316 |
+
)
|
| 317 |
+
parser.add_argument(
|
| 318 |
+
"--sim_num_steps",
|
| 319 |
+
type=int,
|
| 320 |
+
default=1000,
|
| 321 |
+
help="Number of steps in simulation",
|
| 322 |
+
)
|
| 323 |
+
return parser
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def geopath_network_parser(parser):
|
| 327 |
+
parser.add_argument(
|
| 328 |
+
"--manifold",
|
| 329 |
+
action=argparse.BooleanOptionalAction,
|
| 330 |
+
default=True,
|
| 331 |
+
help="If use data manifold metric",
|
| 332 |
+
)
|
| 333 |
+
parser.add_argument(
|
| 334 |
+
"--patience_geopath",
|
| 335 |
+
type=int,
|
| 336 |
+
default=50,
|
| 337 |
+
help="Patience for training geopath model",
|
| 338 |
+
)
|
| 339 |
+
parser.add_argument(
|
| 340 |
+
"--hidden_dims_geopath",
|
| 341 |
+
nargs="+",
|
| 342 |
+
type=int,
|
| 343 |
+
default=[64, 64, 64],
|
| 344 |
+
help="Dimensions of hidden layers for GeoPath model training",
|
| 345 |
+
)
|
| 346 |
+
parser.add_argument(
|
| 347 |
+
"--time_geopath",
|
| 348 |
+
action=argparse.BooleanOptionalAction,
|
| 349 |
+
default=False,
|
| 350 |
+
help="Use time in GeoPath model",
|
| 351 |
+
)
|
| 352 |
+
parser.add_argument(
|
| 353 |
+
"--activation_geopath",
|
| 354 |
+
type=str,
|
| 355 |
+
default="selu",
|
| 356 |
+
help="Activation function for GeoPath",
|
| 357 |
+
)
|
| 358 |
+
parser.add_argument(
|
| 359 |
+
"--geopath_optimizer",
|
| 360 |
+
type=str,
|
| 361 |
+
default="adam",
|
| 362 |
+
help="Optimizer for GeoPath training",
|
| 363 |
+
)
|
| 364 |
+
parser.add_argument(
|
| 365 |
+
"--geopath_lr",
|
| 366 |
+
type=float,
|
| 367 |
+
default=1e-4,
|
| 368 |
+
help="Learning rate for GeoPath training",
|
| 369 |
+
)
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
"--geopath_weight_decay",
|
| 372 |
+
type=float,
|
| 373 |
+
default=1e-5,
|
| 374 |
+
help="Weight decay for GeoPath training",
|
| 375 |
+
)
|
| 376 |
+
parser.add_argument(
|
| 377 |
+
"--mmd_weight",
|
| 378 |
+
type=float,
|
| 379 |
+
default=0.1,
|
| 380 |
+
help="Weight for MMD loss at intermediate timepoints (only used when >2 timepoints)",
|
| 381 |
+
)
|
| 382 |
+
return parser
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def flow_network_parser(parser):
|
| 386 |
+
parser.add_argument(
|
| 387 |
+
"--sigma", type=float, default=0.1, help="Sigma parameter for CFM (variance)"
|
| 388 |
+
)
|
| 389 |
+
parser.add_argument(
|
| 390 |
+
"--patience",
|
| 391 |
+
type=int,
|
| 392 |
+
default=5,
|
| 393 |
+
help="Patience for early stopping in CFM training",
|
| 394 |
+
)
|
| 395 |
+
parser.add_argument(
|
| 396 |
+
"--hidden_dims_flow",
|
| 397 |
+
nargs="+",
|
| 398 |
+
type=int,
|
| 399 |
+
default=[64, 64, 64],
|
| 400 |
+
help="Dimensions of hidden layers for CFM training",
|
| 401 |
+
)
|
| 402 |
+
parser.add_argument(
|
| 403 |
+
"--check_val_every_n_epoch",
|
| 404 |
+
type=int,
|
| 405 |
+
default=10,
|
| 406 |
+
help="Check validation every N epochs during CFM training",
|
| 407 |
+
)
|
| 408 |
+
parser.add_argument(
|
| 409 |
+
"--activation_flow",
|
| 410 |
+
type=str,
|
| 411 |
+
default="selu",
|
| 412 |
+
help="Activation function for CFM",
|
| 413 |
+
)
|
| 414 |
+
parser.add_argument(
|
| 415 |
+
"--flow_optimizer",
|
| 416 |
+
type=str,
|
| 417 |
+
default="adamw",
|
| 418 |
+
help="Optimizer for GeoPath training",
|
| 419 |
+
)
|
| 420 |
+
parser.add_argument(
|
| 421 |
+
"--flow_lr",
|
| 422 |
+
type=float,
|
| 423 |
+
default=1e-3,
|
| 424 |
+
help="Learning rate for GeoPath training",
|
| 425 |
+
)
|
| 426 |
+
parser.add_argument(
|
| 427 |
+
"--flow_weight_decay",
|
| 428 |
+
type=float,
|
| 429 |
+
default=1e-5,
|
| 430 |
+
help="Weight decay for GeoPath training",
|
| 431 |
+
)
|
| 432 |
+
return parser
|
| 433 |
+
|
| 434 |
+
def growth_network_parser(parser):
|
| 435 |
+
parser.add_argument(
|
| 436 |
+
"--patience_growth",
|
| 437 |
+
type=int,
|
| 438 |
+
default=5,
|
| 439 |
+
help="Patience for early stopping in CFM training",
|
| 440 |
+
)
|
| 441 |
+
parser.add_argument(
|
| 442 |
+
"--time_growth",
|
| 443 |
+
action=argparse.BooleanOptionalAction,
|
| 444 |
+
default=False,
|
| 445 |
+
help="Use time in GeoPath model",
|
| 446 |
+
)
|
| 447 |
+
parser.add_argument(
|
| 448 |
+
"--hidden_dims_growth",
|
| 449 |
+
nargs="+",
|
| 450 |
+
type=int,
|
| 451 |
+
default=[64, 64, 64],
|
| 452 |
+
help="Dimensions of hidden layers for growth net training",
|
| 453 |
+
)
|
| 454 |
+
parser.add_argument(
|
| 455 |
+
"--activation_growth",
|
| 456 |
+
type=str,
|
| 457 |
+
default="tanh",
|
| 458 |
+
help="Activation function for CFM",
|
| 459 |
+
)
|
| 460 |
+
parser.add_argument(
|
| 461 |
+
"--growth_optimizer",
|
| 462 |
+
type=str,
|
| 463 |
+
default="adamw",
|
| 464 |
+
help="Optimizer for GeoPath training",
|
| 465 |
+
)
|
| 466 |
+
parser.add_argument(
|
| 467 |
+
"--growth_lr",
|
| 468 |
+
type=float,
|
| 469 |
+
default=1e-3,
|
| 470 |
+
help="Learning rate for GeoPath training",
|
| 471 |
+
)
|
| 472 |
+
parser.add_argument(
|
| 473 |
+
"--growth_weight_decay",
|
| 474 |
+
type=float,
|
| 475 |
+
default=1e-5,
|
| 476 |
+
help="Weight decay for GeoPath training",
|
| 477 |
+
)
|
| 478 |
+
parser.add_argument(
|
| 479 |
+
"--lambda_energy",
|
| 480 |
+
type=float,
|
| 481 |
+
default=1.0,
|
| 482 |
+
help="Weight for energy loss",
|
| 483 |
+
)
|
| 484 |
+
parser.add_argument(
|
| 485 |
+
"--lambda_mass",
|
| 486 |
+
type=float,
|
| 487 |
+
default=100.0,
|
| 488 |
+
help="Weight for mass loss",
|
| 489 |
+
)
|
| 490 |
+
parser.add_argument(
|
| 491 |
+
"--lambda_match",
|
| 492 |
+
type=float,
|
| 493 |
+
default=1000.0,
|
| 494 |
+
help="Weight for matching loss",
|
| 495 |
+
)
|
| 496 |
+
parser.add_argument(
|
| 497 |
+
"--lambda_recons",
|
| 498 |
+
type=float,
|
| 499 |
+
default=1.0,
|
| 500 |
+
help="Weight for reconstruction loss",
|
| 501 |
+
)
|
| 502 |
+
return parser
|
scripts/README.md
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Running Experiments with BranchSBM 🌳🧬
|
| 2 |
+
|
| 3 |
+
This directory contains training scripts for all experiments with BranchSBM, including LiDAR navigation 🗻, simulating cell differentiation 🧫, and cell state perturbation modelling 🧬. This codebase contains code from the [Metric Flow Matching repo](https://github.com/kkapusniak/metric-flow-matching) ([Kapusniak et al. 2024](https://arxiv.org/abs/2405.14780)).
|
| 4 |
+
|
| 5 |
+
## Environment Installation
|
| 6 |
+
```
|
| 7 |
+
conda env create -f environment.yml
|
| 8 |
+
|
| 9 |
+
conda activate branchsbm
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
## Data
|
| 13 |
+
LiDAR data is taken from the [Generalized Schrödinger Bridge Matching repo](https://github.com/facebookresearch/generalized-schrodinger-bridge-matching) and Mouse Hematopoesis is taken from the [DeepRUOT repo](https://github.com/zhenyiizhang/DeepRUOT)
|
| 14 |
+
|
| 15 |
+
We use perturbation data from the [Tahoe-100M dataset](https://huggingface.co/datasets/tahoebio/Tahoe-100M) containing control DMSO-treated cell data and perturbed cell data.
|
| 16 |
+
|
| 17 |
+
The raw data contains a total of 60K genes. We select the top 2000 highly variable genes (HVGs) and perform principal component analysis (PCA), to maximally capture the variance in the data via the top principal components (38% in the top-50 PCs). **Our goal is to learn the dynamic trajectories that map control cell clusters to the perturbd cell clusters.**
|
| 18 |
+
|
| 19 |
+
**Specifically, we model the following perturbations**:
|
| 20 |
+
|
| 21 |
+
1. **Clonidine**: Cell states under 5uM Clonidine perturbation at various PC dimensions (50D, 100D, 150D) with 1 unseen population.
|
| 22 |
+
2. **Trametinib**: Cell states under 5uM Trametinib perturbation (50D) with 2 unseen populations.
|
| 23 |
+
|
| 24 |
+
All data files are stored in:
|
| 25 |
+
```
|
| 26 |
+
BranchSBMl/data/
|
| 27 |
+
├── rainier2-thin.las # LiDAR data
|
| 28 |
+
├── mouse_hematopoiesis.csv # Mouse Hematopoiesis data
|
| 29 |
+
├── pca_and_leiden_labels.csv # Clonidine data
|
| 30 |
+
├── Trametinib_5.0uM_pca_and_leidenumap_labels.csv # Trametinib data
|
| 31 |
+
└── Veres_alltime.csv # Pancreatic β-Cell data
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Running Experiments
|
| 35 |
+
|
| 36 |
+
All training scripts are located in `BranchSBM/scripts/`. Each script is pre-configured for a specific experiment.
|
| 37 |
+
|
| 38 |
+
The scripts for BranchSBM experiments include:
|
| 39 |
+
|
| 40 |
+
- **`lidar.sh`** - LiDAR trajectory data with 2 branches
|
| 41 |
+
- **`mouse.sh`** - Mouse cell differentiation with 2 branches
|
| 42 |
+
- **`clonidine.sh`** - Clonidine perturbation with 2 branches
|
| 43 |
+
- **`trametinib.sh`** - Trametinib perturbation with 3 branches
|
| 44 |
+
- **`veres.sh`** - Pancreatic beta-cell differentiation with 11 branches
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
The scripts for the baseline single-branch SBM experiments include:
|
| 48 |
+
|
| 49 |
+
- **`mouse_single.sh`** - Mouse single branch
|
| 50 |
+
- **`clonidine_single.sh`** - Clonidine single branch
|
| 51 |
+
- **`trametinib_single.sh`** - Trametinib single branch
|
| 52 |
+
- **`lidar_single.sh`** - LiDAR single branch
|
| 53 |
+
|
| 54 |
+
**Before running experiments:**
|
| 55 |
+
|
| 56 |
+
1. Set `HOME_LOC` to the base path where BranchSBM is located and `ENV_PATH` to the directory where your environment is downloaded in the `.sh` files in `scripts/`
|
| 57 |
+
2. Create a path `BranchSBM/results` where the simulated trajectory figures and metrics will be saved. Also, create `BranchSBM/logs` where the training logs will be saved.
|
| 58 |
+
3. Activate the conda environment:
|
| 59 |
+
```
|
| 60 |
+
conda activate branchsbm
|
| 61 |
+
```
|
| 62 |
+
4. Login to wandb using `wandb login`
|
| 63 |
+
|
| 64 |
+
**Run experiment using `nohup` with the following commands:**
|
| 65 |
+
|
| 66 |
+
```
|
| 67 |
+
cd scripts
|
| 68 |
+
|
| 69 |
+
chmod lidar.sh
|
| 70 |
+
|
| 71 |
+
nohup ./lidar.sh > lidar.log 2>&1 &
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
Evaluation will run automatically after the specified number of rollouts `--num_rollouts` is finished. To see metrics, go to `results/<experiment>/metrics/` or the end of `logs/<experiment>.log`.
|
| 75 |
+
|
| 76 |
+
For Clonidine, `x1_1` indicates the cell cluster that is sampled from for training and `x1_2` is the held-out cell cluster. For Trametinib `x1_1` indicates the cell cluster that is sampled from for training and `x1_2` and `x1_3` are the held-out cell clusters.
|
| 77 |
+
|
| 78 |
+
We report the following metrics for each of the clusters in our paper:
|
| 79 |
+
1. Maximum Mean Discrepancy (RBF-MMD) of simualted cell cluster with target cell cluster (same cell count).
|
| 80 |
+
2. 1-Wasserstein and 2-Wasserstein distances against full cell population in the cluster.
|
| 81 |
+
|
| 82 |
+
## Overview of Outputs
|
| 83 |
+
|
| 84 |
+
**Training outputs are saved to experiment-specific directories:**
|
| 85 |
+
|
| 86 |
+
```
|
| 87 |
+
BranchSBM/results/
|
| 88 |
+
├── <DATE>_clonidine50D_branched/
|
| 89 |
+
│ └── figures/ # Figures of simulated
|
| 90 |
+
│ └── metrics.csv # JSON of metrics
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
**PyTorch Lightning automatically saves model checkpoints to:**
|
| 94 |
+
|
| 95 |
+
```
|
| 96 |
+
BranchSBM/scripts/lightning_logs/
|
| 97 |
+
├── <wandb-run-id>/
|
| 98 |
+
│ ├── checkpoints/
|
| 99 |
+
│ │ ├── epoch=N-step=M.ckpt # Checkpoint
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
**Training logs are saved in:**
|
| 103 |
+
```
|
| 104 |
+
entangled-cell/logs/
|
| 105 |
+
├── <DATE>_lidar_single_train.log
|
| 106 |
+
├── <DATE>_lidar_train.log
|
| 107 |
+
├── <DATE>_mouse_single_train.log
|
| 108 |
+
├── <DATE>_mouse_train.log
|
| 109 |
+
├── <DATE>_clonidine_single_train.log
|
| 110 |
+
├── <DATE>_clonidine50D_train.log
|
| 111 |
+
├── <DATE>_clonidine100D_train.log
|
| 112 |
+
├── <DATE>_clonidine150D_train.log
|
| 113 |
+
├── <DATE>_trametinib_single_train.log
|
| 114 |
+
├── <DATE>_trametinib_train.log
|
| 115 |
+
└── <DATE>_veres_train.log
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
## Available Experiments
|
| 119 |
+
|
| 120 |
+
### Branched Experiments (Multi-branch trajectories)
|
| 121 |
+
|
| 122 |
+
These experiments model cell differentiation or perturbation with multiple branches:
|
| 123 |
+
|
| 124 |
+
- **`mouse.sh`** - Mouse cell differentiation with 2 branches (GPU 0)
|
| 125 |
+
- **`trametinib.sh`** - Trametinib perturbation with 3 branches (GPU 1)
|
| 126 |
+
- **`lidar.sh`** - LiDAR trajectory data with 2 branches (GPU 2)
|
| 127 |
+
- **`clonidine.sh`** - Clonidine perturbation with 2 branches (GPU 3)
|
| 128 |
+
|
| 129 |
+
### Single-Branch Experiments (Control/baseline)
|
| 130 |
+
|
| 131 |
+
These are baseline experiments with single trajectories:
|
| 132 |
+
|
| 133 |
+
- **`mouse_single.sh`** - Mouse single trajectory (GPU 4)
|
| 134 |
+
- **`clonidine_single.sh`** - Clonidine single trajectory (GPU 5)
|
| 135 |
+
- **`trametinib_single.sh`** - Trametinib single trajectory (GPU 6)
|
| 136 |
+
- **`lidar_single.sh`** - LiDAR single trajectory (GPU 7)
|
| 137 |
+
|
| 138 |
+
## Running Scripts
|
| 139 |
+
|
| 140 |
+
### Run a single experiment
|
| 141 |
+
|
| 142 |
+
From the `scripts/` directory:
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
cd scripts
|
| 146 |
+
chmod +x mouse.sh
|
| 147 |
+
nohup ./mouse.sh > mouse.log 2>&1 &
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### Run all branched experiments in parallel
|
| 151 |
+
|
| 152 |
+
```bash
|
| 153 |
+
nohup ./mouse.sh > mouse.log 2>&1 &
|
| 154 |
+
nohup ./trametinib.sh > trametinib.log 2>&1 &
|
| 155 |
+
nohup ./lidar.sh > lidar.log 2>&1 &
|
| 156 |
+
nohup ./clonidine.sh > clonidine.log 2>&1 &
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
### Run all single-branch experiments in parallel
|
| 160 |
+
|
| 161 |
+
```bash
|
| 162 |
+
nohup ./mouse_single.sh > mouse_single.log 2>&1 &
|
| 163 |
+
nohup ./clonidine_single.sh > clonidine_single.log 2>&1 &
|
| 164 |
+
nohup ./trametinib_single.sh > trametinib_single.log 2>&1 &
|
| 165 |
+
nohup ./lidar_single.sh > lidar_single.log 2>&1 &
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
### Run all experiments simultaneously
|
| 169 |
+
|
| 170 |
+
Each script is assigned to a different GPU, so you can run all 8 in parallel:
|
| 171 |
+
|
| 172 |
+
```bash
|
| 173 |
+
nohup ./mouse.sh > mouse.log 2>&1 &
|
| 174 |
+
nohup ./trametinib.sh > trametinib.log 2>&1 &
|
| 175 |
+
nohup ./lidar.sh > lidar.log 2>&1 &
|
| 176 |
+
nohup ./clonidine.sh > clonidine.log 2>&1 &
|
| 177 |
+
nohup ./mouse_single.sh > mouse_single.log 2>&1 &
|
| 178 |
+
nohup ./clonidine_single.sh > clonidine_single.log 2>&1 &
|
| 179 |
+
nohup ./trametinib_single.sh > trametinib_single.log 2>&1 &
|
| 180 |
+
nohup ./lidar_single.sh > lidar_single.log 2>&1 &
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
## Monitoring Training
|
| 184 |
+
|
| 185 |
+
Logs are saved in `./BranchSBM/logs/` with format `MM_DD_<experiment>_train.log`.
|
| 186 |
+
|
| 187 |
+
Each experiment logs to wandb with a unique run name:
|
| 188 |
+
- Branched experiments: `<dataset>_branched` (e.g., `mouse_branched`)
|
| 189 |
+
- Single experiments: `<dataset>_single` (e.g., `mouse_single`)
|
| 190 |
+
|
| 191 |
+
Visit your wandb dashboard to view training progress in real-time.
|
| 192 |
+
|
| 193 |
+
## Training Parameters
|
| 194 |
+
|
| 195 |
+
Default training parameters for each experiment:
|
| 196 |
+
|
| 197 |
+
| Parameter | LiDAR | Mouse Hematopoiesis scRNA | Clonidine (50 PCs) | Clonidine (100 PCs) | Clonidine (150 PCs) | Trametinib | Pancreatic β-Cell |
|
| 198 |
+
|---|---|---|---|---|---|---|---|
|
| 199 |
+
| branches | 2 | 2 | 2 | 2 | 2 | 3 | 11 |
|
| 200 |
+
| data dimension | 3 | 2 | 50 | 100 | 150 | 50 | 30 |
|
| 201 |
+
| batch size | 128 | 128 | 32 | 32 | 32 | 32 | 256 |
|
| 202 |
+
| λ_energy | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 |
|
| 203 |
+
| λ_mass | 100 | 100 | 100 | 100 | 100 | 100 | 100 |
|
| 204 |
+
| λ_match | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ |
|
| 205 |
+
| λ_recons | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 |
|
| 206 |
+
| λ_growth | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 |
|
| 207 |
+
| V_t | LAND | LAND | RBF | RBF | RBF | RBF | RBF |
|
| 208 |
+
| RBF N_c | - | - | 150 | 300 | 300 | 150 | 300 |
|
| 209 |
+
| RBF κ | - | - | 1.5 | 2.0 | 3.0 | 1.5 | 3.0 |
|
| 210 |
+
| hidden dimension | 64 | 64 | 1024 | 1024 | 1024 | 1024 | 1024 |
|
| 211 |
+
| lr interpolant | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ |
|
| 212 |
+
| lr velocity | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ |
|
| 213 |
+
| lr growth | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ |
|
| 214 |
+
|
| 215 |
+
To modify parameters, edit the corresponding `.sh` file.
|
| 216 |
+
|
| 217 |
+
## Training Pipeline
|
| 218 |
+
|
| 219 |
+
Each experiment runs through 4 stages:
|
| 220 |
+
|
| 221 |
+
1. **Stage 1: Geopath** - Train geodesic path interpolants
|
| 222 |
+
2. **Stage 2: Flow Matching** - Train continuous normalizing flows
|
| 223 |
+
3. **Stage 3: Growth** - Train growth networks for branches
|
| 224 |
+
4. **Stage 4: Joint** - Joint training of all components
|
| 225 |
+
|
| 226 |
+
Checkpoints are saved automatically and loaded between stages.
|
scripts/clonidine100.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/BranchSBM
|
| 4 |
+
ENV_LOC=/path/to/your/envs/branchsbm
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC
|
| 6 |
+
LOG_LOC=$HOME_LOC/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='clonidine100D_branched'
|
| 9 |
+
PYTHON_EXECUTABLE=$ENV_LOC/bin/python
|
| 10 |
+
|
| 11 |
+
# Set GPU device
|
| 12 |
+
export CUDA_VISIBLE_DEVICES=4
|
| 13 |
+
|
| 14 |
+
# ===================================================================
|
| 15 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 16 |
+
conda activate $ENV_LOC
|
| 17 |
+
|
| 18 |
+
cd $HOME_LOC
|
| 19 |
+
|
| 20 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
|
| 21 |
+
--epochs 100 \
|
| 22 |
+
--run_name "${DATE}_${SPECIAL_PREFIX}" \
|
| 23 |
+
--config_path "$SCRIPT_LOC/configs/clonidine_100D.yaml" \
|
| 24 |
+
--batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
|
| 25 |
+
|
| 26 |
+
conda deactivate
|
scripts/clonidine150.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/BranchSBM
|
| 4 |
+
ENV_LOC=/path/to/your/envs/branchsbm
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC
|
| 6 |
+
LOG_LOC=$HOME_LOC/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='clonidine150D_branched'
|
| 9 |
+
PYTHON_EXECUTABLE=$ENV_LOC/bin/python
|
| 10 |
+
|
| 11 |
+
# Set GPU device
|
| 12 |
+
export CUDA_VISIBLE_DEVICES=5
|
| 13 |
+
|
| 14 |
+
# ===================================================================
|
| 15 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 16 |
+
conda activate $ENV_LOC
|
| 17 |
+
|
| 18 |
+
cd $HOME_LOC
|
| 19 |
+
|
| 20 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
|
| 21 |
+
--epochs 100 \
|
| 22 |
+
--run_name "${DATE}_${SPECIAL_PREFIX}" \
|
| 23 |
+
--config_path "$SCRIPT_LOC/configs/clonidine_150D.yaml" \
|
| 24 |
+
--batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
|
| 25 |
+
|
| 26 |
+
conda deactivate
|
scripts/clonidine50.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/BranchSBM
|
| 4 |
+
ENV_LOC=/path/to/your/envs/branchsbm
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC
|
| 6 |
+
LOG_LOC=$HOME_LOC/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='clonidine50D_branched'
|
| 9 |
+
PYTHON_EXECUTABLE=$ENV_LOC/bin/python
|
| 10 |
+
|
| 11 |
+
# Set GPU device
|
| 12 |
+
export CUDA_VISIBLE_DEVICES=3
|
| 13 |
+
|
| 14 |
+
# ===================================================================
|
| 15 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 16 |
+
conda activate $ENV_LOC
|
| 17 |
+
|
| 18 |
+
cd $HOME_LOC
|
| 19 |
+
|
| 20 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
|
| 21 |
+
--epochs 100 \
|
| 22 |
+
--run_name ${DATE}_${SPECIAL_PREFIX} \
|
| 23 |
+
--config_path "$SCRIPT_LOC/configs/clonidine_50D.yaml" \
|
| 24 |
+
--batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
|
| 25 |
+
|
| 26 |
+
conda deactivate
|
scripts/clonidine50_single.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/BranchSBM
|
| 4 |
+
ENV_LOC=/path/to/your/envs/branchsbm
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC
|
| 6 |
+
LOG_LOC=$HOME_LOC/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='clonidine_single'
|
| 9 |
+
PYTHON_EXECUTABLE=$ENV_LOC/bin/python
|
| 10 |
+
|
| 11 |
+
# Set GPU device
|
| 12 |
+
export CUDA_VISIBLE_DEVICES=3
|
| 13 |
+
|
| 14 |
+
# ===================================================================
|
| 15 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 16 |
+
conda activate $ENV_LOC
|
| 17 |
+
|
| 18 |
+
cd $HOME_LOC
|
| 19 |
+
|
| 20 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
|
| 21 |
+
--epochs 100 \
|
| 22 |
+
--run_name "clonidine50D_single" \
|
| 23 |
+
--config_path "$SCRIPT_LOC/configs/clonidine_50Dsingle.yaml" \
|
| 24 |
+
--batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
|
| 25 |
+
|
| 26 |
+
conda deactivate
|
scripts/lidar.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/BranchSBM
|
| 4 |
+
ENV_LOC=/path/to/your/envs/branchsbm
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC
|
| 6 |
+
LOG_LOC=$HOME_LOC/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='lidar_branched'
|
| 9 |
+
PYTHON_EXECUTABLE=$ENV_LOC/bin/python
|
| 10 |
+
|
| 11 |
+
# Set GPU device
|
| 12 |
+
export CUDA_VISIBLE_DEVICES=2
|
| 13 |
+
|
| 14 |
+
# ===================================================================
|
| 15 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 16 |
+
conda activate $ENV_LOC
|
| 17 |
+
|
| 18 |
+
cd $HOME_LOC
|
| 19 |
+
|
| 20 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
|
| 21 |
+
--config_path "$SCRIPT_LOC/configs/lidar.yaml" \
|
| 22 |
+
--epochs 10 \
|
| 23 |
+
--run_name "${DATE}_${SPECIAL_PREFIX}" \
|
| 24 |
+
--batch_size 128 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
|
| 25 |
+
|
| 26 |
+
conda deactivate
|
scripts/lidar_single.sh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/BranchSBM
|
| 4 |
+
ENV_LOC=/path/to/your/envs/branchsbm
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC
|
| 6 |
+
LOG_LOC=$HOME_LOC/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='lidar_single'
|
| 9 |
+
# set 3 have skip connection
|
| 10 |
+
PYTHON_EXECUTABLE=$ENV_LOC/bin/python
|
| 11 |
+
|
| 12 |
+
# Set GPU device
|
| 13 |
+
export CUDA_VISIBLE_DEVICES=2
|
| 14 |
+
|
| 15 |
+
# ===================================================================
|
| 16 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 17 |
+
conda activate $ENV_LOC
|
| 18 |
+
|
| 19 |
+
cd $HOME_LOC
|
| 20 |
+
|
| 21 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
|
| 22 |
+
--config_path "$SCRIPT_LOC/configs/lidar_single.yaml" \
|
| 23 |
+
--run_name "${DATE}_${SPECIAL_PREFIX}" \
|
| 24 |
+
--epochs 100 \
|
| 25 |
+
--batch_size 128 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
|
| 26 |
+
|
| 27 |
+
conda deactivate
|
scripts/mouse.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/BranchSBM
|
| 4 |
+
ENV_LOC=/path/to/your/envs/branchsbm
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC
|
| 6 |
+
LOG_LOC=$HOME_LOC/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='mouse_branched'
|
| 9 |
+
PYTHON_EXECUTABLE=$ENV_LOC/bin/python
|
| 10 |
+
|
| 11 |
+
# Set GPU device
|
| 12 |
+
export CUDA_VISIBLE_DEVICES=1
|
| 13 |
+
|
| 14 |
+
# ===================================================================
|
| 15 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 16 |
+
conda activate $ENV_LOC
|
| 17 |
+
|
| 18 |
+
cd $HOME_LOC
|
| 19 |
+
|
| 20 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
|
| 21 |
+
--config_path "$SCRIPT_LOC/configs/mouse.yaml" \
|
| 22 |
+
--epochs 100 \
|
| 23 |
+
--run_name "${DATE}_${SPECIAL_PREFIX}" >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
|
| 24 |
+
|
| 25 |
+
conda deactivate
|
scripts/mouse_single.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/BranchSBM
|
| 4 |
+
ENV_LOC=/path/to/your/envs/branchsbm
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC
|
| 6 |
+
LOG_LOC=$HOME_LOC/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='mouse_single'
|
| 9 |
+
PYTHON_EXECUTABLE=$ENV_LOC/bin/python
|
| 10 |
+
|
| 11 |
+
# Set GPU device
|
| 12 |
+
export CUDA_VISIBLE_DEVICES=1
|
| 13 |
+
|
| 14 |
+
# ===================================================================
|
| 15 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 16 |
+
conda activate $ENV_LOC
|
| 17 |
+
|
| 18 |
+
cd $HOME_LOC
|
| 19 |
+
|
| 20 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
|
| 21 |
+
--epochs 100 \
|
| 22 |
+
--run_name "${DATE}_${SPECIAL_PREFIX}" \
|
| 23 |
+
--config_path "$SCRIPT_LOC/configs/mouse_single.yaml" >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
|
| 24 |
+
|
| 25 |
+
conda deactivate
|
scripts/trametinib.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/BranchSBM
|
| 4 |
+
ENV_LOC=/path/to/your/envs/branchsbm
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC
|
| 6 |
+
LOG_LOC=$HOME_LOC/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='trametinib_branched'
|
| 9 |
+
PYTHON_EXECUTABLE=$ENV_LOC/bin/python
|
| 10 |
+
|
| 11 |
+
# Set GPU device
|
| 12 |
+
export CUDA_VISIBLE_DEVICES=6
|
| 13 |
+
|
| 14 |
+
# ===================================================================
|
| 15 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 16 |
+
conda activate $ENV_LOC
|
| 17 |
+
|
| 18 |
+
cd $HOME_LOC
|
| 19 |
+
|
| 20 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
|
| 21 |
+
--epochs 100 \
|
| 22 |
+
--run_name "${DATE}_${SPECIAL_PREFIX}" \
|
| 23 |
+
--config_path "$SCRIPT_LOC/configs/trametinib.yaml" \
|
| 24 |
+
--batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
|
| 25 |
+
|
| 26 |
+
conda deactivate
|
scripts/trametinib_single.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/BranchSBM
|
| 4 |
+
ENV_LOC=/path/to/your/envs/branchsbm
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC
|
| 6 |
+
LOG_LOC=$HOME_LOC/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='trametinib_single'
|
| 9 |
+
PYTHON_EXECUTABLE=$ENV_LOC/bin/python
|
| 10 |
+
|
| 11 |
+
# Set GPU device
|
| 12 |
+
export CUDA_VISIBLE_DEVICES=6
|
| 13 |
+
|
| 14 |
+
# ===================================================================
|
| 15 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 16 |
+
conda activate $ENV_LOC
|
| 17 |
+
|
| 18 |
+
cd $HOME_LOC
|
| 19 |
+
|
| 20 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
|
| 21 |
+
--epochs 100 \
|
| 22 |
+
--run_name "${DATE}_${SPECIAL_PREFIX}" \
|
| 23 |
+
--config_path "$SCRIPT_LOC/configs/trametinib_single.yaml" \
|
| 24 |
+
--batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
|
| 25 |
+
|
| 26 |
+
conda deactivate
|
scripts/veres.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/BranchSBM
|
| 4 |
+
ENV_LOC=/path/to/your/envs/branchsbm
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC
|
| 6 |
+
LOG_LOC=$HOME_LOC/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='veres'
|
| 9 |
+
PYTHON_EXECUTABLE=$ENV_LOC/bin/python
|
| 10 |
+
|
| 11 |
+
# Set GPU device
|
| 12 |
+
export CUDA_VISIBLE_DEVICES=7
|
| 13 |
+
|
| 14 |
+
# ===================================================================
|
| 15 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 16 |
+
conda activate $ENV_LOC
|
| 17 |
+
|
| 18 |
+
cd $HOME_LOC
|
| 19 |
+
|
| 20 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
|
| 21 |
+
--epochs 100 \
|
| 22 |
+
--run_name ${DATE}_${SPECIAL_PREFIX} \
|
| 23 |
+
--min_cells 100 \
|
| 24 |
+
--config $SCRIPT_LOC/configs/veres.yaml >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
|
| 25 |
+
|
| 26 |
+
conda deactivate
|
src/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
src/branch_flow_net_test.py
ADDED
|
@@ -0,0 +1,1791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Separate test classes for each BranchSBM experiment with specific plotting styles.
|
| 3 |
+
Each class handles testing and visualization for: LiDAR, Mouse, Clonidine, Trametinib, Veres.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import csv
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import pytorch_lightning as pl
|
| 13 |
+
import random
|
| 14 |
+
import ot
|
| 15 |
+
from torchdyn.core import NeuralODE
|
| 16 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 17 |
+
from matplotlib.collections import LineCollection
|
| 18 |
+
from .networks.utils import flow_model_torch_wrapper
|
| 19 |
+
from .branch_flow_net_train import BranchFlowNetTrainBase
|
| 20 |
+
from .branch_growth_net_train import GrowthNetTrain
|
| 21 |
+
from .utils import wasserstein, mix_rbf_mmd2, plot_lidar
|
| 22 |
+
import json
|
| 23 |
+
|
| 24 |
+
def evaluate_model(gt_data, model_data, a, b):
|
| 25 |
+
# ensure inputs are tensors
|
| 26 |
+
if not isinstance(gt_data, torch.Tensor):
|
| 27 |
+
gt_data = torch.tensor(gt_data, dtype=torch.float32)
|
| 28 |
+
if not isinstance(model_data, torch.Tensor):
|
| 29 |
+
model_data = torch.tensor(model_data, dtype=torch.float32)
|
| 30 |
+
|
| 31 |
+
# choose device: prefer model_data's device if it's not CPU, otherwise use gt_data's device
|
| 32 |
+
try:
|
| 33 |
+
model_dev = model_data.device
|
| 34 |
+
except Exception:
|
| 35 |
+
model_dev = torch.device('cpu')
|
| 36 |
+
try:
|
| 37 |
+
gt_dev = gt_data.device
|
| 38 |
+
except Exception:
|
| 39 |
+
gt_dev = torch.device('cpu')
|
| 40 |
+
|
| 41 |
+
device = model_dev if model_dev.type != 'cpu' else gt_dev
|
| 42 |
+
|
| 43 |
+
gt = gt_data.to(device=device, dtype=torch.float32)
|
| 44 |
+
md = model_data.to(device=device, dtype=torch.float32)
|
| 45 |
+
|
| 46 |
+
M = torch.cdist(gt, md, p=2).cpu().numpy()
|
| 47 |
+
if np.isnan(M).any() or np.isinf(M).any():
|
| 48 |
+
return np.nan
|
| 49 |
+
return ot.emd2(a, b, M, numItermax=1e7)
|
| 50 |
+
|
| 51 |
+
def compute_distribution_distances(pred, true, pred_full=None, true_full=None):
|
| 52 |
+
w1 = wasserstein(pred, true, power=1)
|
| 53 |
+
w2 = wasserstein(pred, true, power=2)
|
| 54 |
+
|
| 55 |
+
# Use full dimensions for MMD if provided, otherwise use same as W1/W2
|
| 56 |
+
mmd_pred = pred_full if pred_full is not None else pred
|
| 57 |
+
mmd_true = true_full if true_full is not None else true
|
| 58 |
+
|
| 59 |
+
# MMD requires same number of samples — randomly subsample the larger set
|
| 60 |
+
n_pred, n_true = mmd_pred.shape[0], mmd_true.shape[0]
|
| 61 |
+
if n_pred > n_true:
|
| 62 |
+
perm = torch.randperm(n_pred)[:n_true]
|
| 63 |
+
mmd_pred = mmd_pred[perm]
|
| 64 |
+
elif n_true > n_pred:
|
| 65 |
+
perm = torch.randperm(n_true)[:n_pred]
|
| 66 |
+
mmd_true = mmd_true[perm]
|
| 67 |
+
mmd = mix_rbf_mmd2(mmd_pred, mmd_true, sigma_list=[0.01, 0.1, 1, 10, 100]).item()
|
| 68 |
+
|
| 69 |
+
return {"W1": w1, "W2": w2, "MMD": mmd}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def compute_tmv_from_mass_over_time(mass_over_time, all_endpoints, time_points=None, timepoint_data=None, time_index=None, target_time=None, gt_key_template='t1_{}', weights_over_time=None):
|
| 73 |
+
|
| 74 |
+
if weights_over_time is not None or mass_over_time is not None:
|
| 75 |
+
if time_index is None:
|
| 76 |
+
if target_time is not None and time_points is not None:
|
| 77 |
+
arr = np.array(time_points)
|
| 78 |
+
time_index = int(np.argmin(np.abs(arr - float(target_time))))
|
| 79 |
+
else:
|
| 80 |
+
# default to last index
|
| 81 |
+
ref_list = weights_over_time if weights_over_time is not None else mass_over_time
|
| 82 |
+
time_index = len(ref_list[0]) - 1
|
| 83 |
+
else:
|
| 84 |
+
# neither available; time_index not used
|
| 85 |
+
if time_index is None:
|
| 86 |
+
time_index = -1
|
| 87 |
+
|
| 88 |
+
n_branches = len(all_endpoints)
|
| 89 |
+
|
| 90 |
+
# initial total cells for normalization
|
| 91 |
+
n_initial = None
|
| 92 |
+
if timepoint_data is not None and 't0' in timepoint_data:
|
| 93 |
+
try:
|
| 94 |
+
n_initial = int(timepoint_data['t0'].shape[0])
|
| 95 |
+
except Exception:
|
| 96 |
+
n_initial = None
|
| 97 |
+
|
| 98 |
+
pred_masses = []
|
| 99 |
+
for i in range(n_branches):
|
| 100 |
+
# Use sum of actual particle weights if available, otherwise mean_weight * num_particles
|
| 101 |
+
if weights_over_time is not None:
|
| 102 |
+
try:
|
| 103 |
+
weights_tensor = weights_over_time[i][time_index]
|
| 104 |
+
# Sum all particle weights to get total mass for this branch
|
| 105 |
+
total_mass = float(weights_tensor.sum().item())
|
| 106 |
+
pred_masses.append(total_mass)
|
| 107 |
+
continue
|
| 108 |
+
except Exception:
|
| 109 |
+
pass # Fall through to mean weight calculation
|
| 110 |
+
|
| 111 |
+
# Fallback: mean weight from mass_over_time if available, otherwise assume weight=1
|
| 112 |
+
mean_w = 1.0
|
| 113 |
+
if mass_over_time is not None:
|
| 114 |
+
try:
|
| 115 |
+
mean_w = float(mass_over_time[i][time_index])
|
| 116 |
+
except Exception:
|
| 117 |
+
mean_w = 1.0
|
| 118 |
+
|
| 119 |
+
# determine number of particles for this branch
|
| 120 |
+
num_particles = 0
|
| 121 |
+
try:
|
| 122 |
+
if hasattr(all_endpoints[i], 'shape'):
|
| 123 |
+
num_particles = int(all_endpoints[i].shape[0])
|
| 124 |
+
else:
|
| 125 |
+
num_particles = int(len(all_endpoints[i]))
|
| 126 |
+
except Exception:
|
| 127 |
+
num_particles = 0
|
| 128 |
+
|
| 129 |
+
pred_masses.append(mean_w * float(num_particles))
|
| 130 |
+
|
| 131 |
+
# ground-truth masses per branch
|
| 132 |
+
gt_masses = []
|
| 133 |
+
if timepoint_data is not None:
|
| 134 |
+
for i in range(n_branches):
|
| 135 |
+
key1 = gt_key_template.format(i)
|
| 136 |
+
if key1 in timepoint_data:
|
| 137 |
+
gt_masses.append(float(timepoint_data[key1].shape[0]))
|
| 138 |
+
else:
|
| 139 |
+
base_key = gt_key_template.split("_")[0] if '_' in gt_key_template else gt_key_template
|
| 140 |
+
if base_key in timepoint_data:
|
| 141 |
+
gt_masses.append(float(timepoint_data[base_key].shape[0]))
|
| 142 |
+
else:
|
| 143 |
+
gt_masses.append(0.0)
|
| 144 |
+
else:
|
| 145 |
+
gt_masses = [0.0 for _ in range(n_branches)]
|
| 146 |
+
|
| 147 |
+
# determine normalization denominator
|
| 148 |
+
if n_initial is None:
|
| 149 |
+
s = float(sum(gt_masses))
|
| 150 |
+
if s > 0:
|
| 151 |
+
n_initial = s
|
| 152 |
+
else:
|
| 153 |
+
n_initial = float(sum(pred_masses)) if sum(pred_masses) > 0 else 1.0
|
| 154 |
+
|
| 155 |
+
pred_fracs = [m / float(n_initial) for m in pred_masses]
|
| 156 |
+
gt_fracs = [m / float(n_initial) for m in gt_masses]
|
| 157 |
+
|
| 158 |
+
tmv = 0.5 * float(np.sum(np.abs(np.array(pred_fracs) - np.array(gt_fracs))))
|
| 159 |
+
|
| 160 |
+
return {
|
| 161 |
+
'time_index': time_index,
|
| 162 |
+
'pred_masses': pred_masses,
|
| 163 |
+
'gt_masses': gt_masses,
|
| 164 |
+
'pred_fracs': pred_fracs,
|
| 165 |
+
'gt_fracs': gt_fracs,
|
| 166 |
+
'tmv': tmv,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class FlowNetTestLidar(GrowthNetTrain):
|
| 171 |
+
|
| 172 |
+
def test_step(self, batch, batch_idx):
|
| 173 |
+
# Unwrap CombinedLoader outer tuple if needed
|
| 174 |
+
if isinstance(batch, (list, tuple)) and len(batch) == 1:
|
| 175 |
+
batch = batch[0]
|
| 176 |
+
|
| 177 |
+
if isinstance(batch, dict) and "test_samples" in batch:
|
| 178 |
+
test_samples = batch["test_samples"]
|
| 179 |
+
metric_samples = batch["metric_samples"]
|
| 180 |
+
|
| 181 |
+
if isinstance(test_samples, (list, tuple)) and len(test_samples) >= 2 and isinstance(test_samples[-1], int):
|
| 182 |
+
test_samples = test_samples[0]
|
| 183 |
+
if isinstance(metric_samples, (list, tuple)) and len(metric_samples) >= 2 and isinstance(metric_samples[-1], int):
|
| 184 |
+
metric_samples = metric_samples[0]
|
| 185 |
+
|
| 186 |
+
if isinstance(test_samples, (list, tuple)) and len(test_samples) == 1:
|
| 187 |
+
test_samples = test_samples[0]
|
| 188 |
+
main_batch = test_samples
|
| 189 |
+
|
| 190 |
+
if isinstance(metric_samples, dict):
|
| 191 |
+
metric_batch = list(metric_samples.values())
|
| 192 |
+
elif isinstance(metric_samples, (list, tuple)):
|
| 193 |
+
metric_batch = [m[0] if isinstance(m, (list, tuple)) and len(m) == 1 else m for m in metric_samples]
|
| 194 |
+
else:
|
| 195 |
+
metric_batch = [metric_samples]
|
| 196 |
+
elif isinstance(batch, (list, tuple)) and len(batch) == 2:
|
| 197 |
+
# Old tuple format: (test_samples, metric_samples)
|
| 198 |
+
# Each could be dict or list
|
| 199 |
+
test_samples = batch[0]
|
| 200 |
+
metric_samples = batch[1]
|
| 201 |
+
|
| 202 |
+
if isinstance(test_samples, dict):
|
| 203 |
+
main_batch = test_samples
|
| 204 |
+
elif isinstance(test_samples, (list, tuple)):
|
| 205 |
+
main_batch = test_samples[0]
|
| 206 |
+
else:
|
| 207 |
+
main_batch = test_samples
|
| 208 |
+
|
| 209 |
+
if isinstance(metric_samples, dict):
|
| 210 |
+
metric_batch = list(metric_samples.values())
|
| 211 |
+
elif isinstance(metric_samples, (list, tuple)):
|
| 212 |
+
metric_batch = [m[0] if isinstance(m, (list, tuple)) and len(m) == 1 else m for m in metric_samples]
|
| 213 |
+
else:
|
| 214 |
+
metric_batch = [metric_samples]
|
| 215 |
+
else:
|
| 216 |
+
# Fallback
|
| 217 |
+
main_batch = batch
|
| 218 |
+
metric_batch = []
|
| 219 |
+
|
| 220 |
+
timepoint_data = self.trainer.datamodule.get_timepoint_data()
|
| 221 |
+
# main_batch is a dict like {"x0": (tensor, weights), ...}
|
| 222 |
+
if isinstance(main_batch, dict):
|
| 223 |
+
device = main_batch["x0"][0].device
|
| 224 |
+
else:
|
| 225 |
+
device = main_batch[0]["x0"][0].device
|
| 226 |
+
|
| 227 |
+
x0_all = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
|
| 228 |
+
w0_all = torch.ones(x0_all.shape[0], 1, dtype=torch.float32).to(device)
|
| 229 |
+
full_batch = {"x0": (x0_all, w0_all)}
|
| 230 |
+
|
| 231 |
+
time_points, all_endpoints, all_trajs, mass_over_time, energy_over_time, weights_over_time = self.get_mass_and_position(full_batch, metric_batch)
|
| 232 |
+
|
| 233 |
+
cloud_points = main_batch["dataset"][0] # [N, 3]
|
| 234 |
+
|
| 235 |
+
# Run 5 trials with random subsampling for robust metrics
|
| 236 |
+
n_trials = 5
|
| 237 |
+
|
| 238 |
+
# Compute per-branch metrics
|
| 239 |
+
metrics_dict = {}
|
| 240 |
+
for i, endpoints in enumerate(all_endpoints):
|
| 241 |
+
true_data_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
|
| 242 |
+
true_data = torch.tensor(timepoint_data[true_data_key], dtype=torch.float32).to(endpoints.device)
|
| 243 |
+
|
| 244 |
+
w1_br, w2_br, mmd_br = [], [], []
|
| 245 |
+
for trial in range(n_trials):
|
| 246 |
+
n_min = min(endpoints.shape[0], true_data.shape[0])
|
| 247 |
+
perm_pred = torch.randperm(endpoints.shape[0])[:n_min]
|
| 248 |
+
perm_gt = torch.randperm(true_data.shape[0])[:n_min]
|
| 249 |
+
m = compute_distribution_distances(
|
| 250 |
+
endpoints[perm_pred, :2], true_data[perm_gt, :2],
|
| 251 |
+
pred_full=endpoints[perm_pred], true_full=true_data[perm_gt]
|
| 252 |
+
)
|
| 253 |
+
w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"])
|
| 254 |
+
|
| 255 |
+
metrics_dict[f"branch_{i+1}"] = {
|
| 256 |
+
"W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
|
| 257 |
+
"W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
|
| 258 |
+
"MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
|
| 259 |
+
}
|
| 260 |
+
self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True)
|
| 261 |
+
print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
|
| 262 |
+
f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
|
| 263 |
+
f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
|
| 264 |
+
|
| 265 |
+
# Compute combined metrics across all branches (5 trials)
|
| 266 |
+
all_pred_combined = torch.cat(list(all_endpoints), dim=0)
|
| 267 |
+
all_true_list = []
|
| 268 |
+
for i in range(len(all_endpoints)):
|
| 269 |
+
true_data_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
|
| 270 |
+
all_true_list.append(torch.tensor(timepoint_data[true_data_key], dtype=torch.float32).to(all_pred_combined.device))
|
| 271 |
+
all_true_combined = torch.cat(all_true_list, dim=0)
|
| 272 |
+
|
| 273 |
+
w1_trials, w2_trials, mmd_trials = [], [], []
|
| 274 |
+
for trial in range(n_trials):
|
| 275 |
+
n_min = min(all_pred_combined.shape[0], all_true_combined.shape[0])
|
| 276 |
+
perm_pred = torch.randperm(all_pred_combined.shape[0])[:n_min]
|
| 277 |
+
perm_gt = torch.randperm(all_true_combined.shape[0])[:n_min]
|
| 278 |
+
m = compute_distribution_distances(
|
| 279 |
+
all_pred_combined[perm_pred, :2], all_true_combined[perm_gt, :2],
|
| 280 |
+
pred_full=all_pred_combined[perm_pred], true_full=all_true_combined[perm_gt]
|
| 281 |
+
)
|
| 282 |
+
w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"])
|
| 283 |
+
|
| 284 |
+
w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1)
|
| 285 |
+
w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1)
|
| 286 |
+
mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1)
|
| 287 |
+
self.log("test/W1_combined", w1_mean, on_epoch=True)
|
| 288 |
+
self.log("test/W2_combined", w2_mean, on_epoch=True)
|
| 289 |
+
self.log("test/MMD_combined", mmd_mean, on_epoch=True)
|
| 290 |
+
|
| 291 |
+
metrics_dict["combined"] = {
|
| 292 |
+
"W1_mean": float(w1_mean), "W1_std": float(w1_std),
|
| 293 |
+
"W2_mean": float(w2_mean), "W2_std": float(w2_std),
|
| 294 |
+
"MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std),
|
| 295 |
+
"n_trials": n_trials,
|
| 296 |
+
}
|
| 297 |
+
print(f"\n=== Combined ===")
|
| 298 |
+
print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}")
|
| 299 |
+
print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}")
|
| 300 |
+
print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}")
|
| 301 |
+
|
| 302 |
+
# Inverse-transform cloud points for visualization
|
| 303 |
+
if self.whiten:
|
| 304 |
+
cloud_points = torch.tensor(
|
| 305 |
+
self.trainer.datamodule.scaler.inverse_transform(
|
| 306 |
+
cloud_points.cpu().detach().numpy()
|
| 307 |
+
)
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Create results directory structure
|
| 311 |
+
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
|
| 312 |
+
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
|
| 313 |
+
figures_dir = f'{results_dir}/figures'
|
| 314 |
+
os.makedirs(figures_dir, exist_ok=True)
|
| 315 |
+
|
| 316 |
+
# Save metrics to JSON
|
| 317 |
+
metrics_path = f'{results_dir}/metrics.json'
|
| 318 |
+
with open(metrics_path, 'w') as f:
|
| 319 |
+
json.dump(metrics_dict, f, indent=2)
|
| 320 |
+
print(f"Metrics saved to {metrics_path}")
|
| 321 |
+
|
| 322 |
+
# Save detailed per-branch metrics to CSV
|
| 323 |
+
detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
|
| 324 |
+
with open(detailed_csv_path, 'w', newline='') as csvfile:
|
| 325 |
+
writer = csv.writer(csvfile)
|
| 326 |
+
writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std'])
|
| 327 |
+
for key in sorted(metrics_dict.keys()):
|
| 328 |
+
m = metrics_dict[key]
|
| 329 |
+
writer.writerow([key,
|
| 330 |
+
f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}',
|
| 331 |
+
f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}',
|
| 332 |
+
f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}'])
|
| 333 |
+
print(f"Detailed metrics CSV saved to {detailed_csv_path}")
|
| 334 |
+
|
| 335 |
+
# Convert all_trajs from list of lists to stacked tensors for plotting
|
| 336 |
+
# all_trajs[i] is a list of T tensors of shape [B, D]
|
| 337 |
+
# Stack to get shape [B, T, D]
|
| 338 |
+
stacked_trajs = []
|
| 339 |
+
for traj_list in all_trajs:
|
| 340 |
+
# Stack along time dimension (dim=1) to get [B, T, D]
|
| 341 |
+
stacked_traj = torch.stack(traj_list, dim=1)
|
| 342 |
+
stacked_trajs.append(stacked_traj)
|
| 343 |
+
|
| 344 |
+
# Inverse-transform trajectories to match cloud_points coordinates
|
| 345 |
+
if self.whiten:
|
| 346 |
+
stacked_trajs_original = []
|
| 347 |
+
for traj in stacked_trajs:
|
| 348 |
+
B, T, D = traj.shape
|
| 349 |
+
# Reshape to [B*T, D] for inverse transform
|
| 350 |
+
traj_flat = traj.reshape(-1, D).cpu().detach().numpy()
|
| 351 |
+
traj_inv = self.trainer.datamodule.scaler.inverse_transform(traj_flat)
|
| 352 |
+
# Reshape back to [B, T, D]
|
| 353 |
+
traj_inv = torch.tensor(traj_inv).reshape(B, T, D)
|
| 354 |
+
stacked_trajs_original.append(traj_inv)
|
| 355 |
+
stacked_trajs = stacked_trajs_original
|
| 356 |
+
|
| 357 |
+
# ===== Plot all branches together =====
|
| 358 |
+
fig = plt.figure(figsize=(10, 8))
|
| 359 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 360 |
+
ax.view_init(elev=30, azim=-115, roll=0)
|
| 361 |
+
for i, traj in enumerate(stacked_trajs):
|
| 362 |
+
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
|
| 363 |
+
plt.savefig(f'{figures_dir}/{self.args.data_name}_all_branches.png', dpi=300)
|
| 364 |
+
plt.close()
|
| 365 |
+
|
| 366 |
+
# ===== Plot each branch separately =====
|
| 367 |
+
for i, traj in enumerate(stacked_trajs):
|
| 368 |
+
fig = plt.figure(figsize=(10, 8))
|
| 369 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 370 |
+
ax.view_init(elev=30, azim=-115, roll=0)
|
| 371 |
+
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
|
| 372 |
+
plt.savefig(f'{figures_dir}/{self.args.data_name}_branch_{i + 1}.png', dpi=300)
|
| 373 |
+
plt.close()
|
| 374 |
+
|
| 375 |
+
print(f"LiDAR figures saved to {figures_dir}")
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class FlowNetTestMouse(GrowthNetTrain):
|
| 379 |
+
|
| 380 |
+
def test_step(self, batch, batch_idx):
|
| 381 |
+
# Handle both tuple and dict batch formats from CombinedLoader
|
| 382 |
+
if isinstance(batch, dict):
|
| 383 |
+
main_batch = batch.get("test_samples", batch)
|
| 384 |
+
if isinstance(main_batch, tuple):
|
| 385 |
+
main_batch = main_batch[0]
|
| 386 |
+
elif isinstance(batch, (list, tuple)) and len(batch) >= 1:
|
| 387 |
+
if isinstance(batch[0], dict):
|
| 388 |
+
main_batch = batch[0].get("test_samples", batch[0])
|
| 389 |
+
if isinstance(main_batch, tuple):
|
| 390 |
+
main_batch = main_batch[0]
|
| 391 |
+
else:
|
| 392 |
+
main_batch = batch[0][0]
|
| 393 |
+
else:
|
| 394 |
+
main_batch = batch
|
| 395 |
+
|
| 396 |
+
device = main_batch["x0"][0].device
|
| 397 |
+
|
| 398 |
+
# Use val x0 as initial conditions
|
| 399 |
+
x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
|
| 400 |
+
|
| 401 |
+
# Get timepoint data for ground truth
|
| 402 |
+
timepoint_data = self.trainer.datamodule.get_timepoint_data()
|
| 403 |
+
|
| 404 |
+
# Ground truth at t1 (intermediate timepoint)
|
| 405 |
+
data_t1 = torch.tensor(timepoint_data['t1'], dtype=torch.float32)
|
| 406 |
+
|
| 407 |
+
# Define color schemes for mouse (2 branches)
|
| 408 |
+
custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"]
|
| 409 |
+
custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
|
| 410 |
+
custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1)
|
| 411 |
+
custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2)
|
| 412 |
+
|
| 413 |
+
t_span_full = torch.linspace(0, 1.0, 100).to(device)
|
| 414 |
+
all_trajs = []
|
| 415 |
+
|
| 416 |
+
for i, flow_net in enumerate(self.flow_nets):
|
| 417 |
+
node = NeuralODE(
|
| 418 |
+
flow_model_torch_wrapper(flow_net),
|
| 419 |
+
solver="euler",
|
| 420 |
+
sensitivity="adjoint",
|
| 421 |
+
).to(device)
|
| 422 |
+
|
| 423 |
+
with torch.no_grad():
|
| 424 |
+
traj = node.trajectory(x0, t_span_full).cpu() # [T, B, D]
|
| 425 |
+
|
| 426 |
+
traj = torch.transpose(traj, 0, 1) # [B, T, D]
|
| 427 |
+
all_trajs.append(traj)
|
| 428 |
+
|
| 429 |
+
t_span_metric_t1 = torch.linspace(0, 0.5, 50).to(device)
|
| 430 |
+
t_span_metric_t2 = torch.linspace(0, 1.0, 100).to(device)
|
| 431 |
+
n_trials = 5
|
| 432 |
+
|
| 433 |
+
# Gather t2 branch ground truth
|
| 434 |
+
data_t2_branches = []
|
| 435 |
+
for i in range(len(self.flow_nets)):
|
| 436 |
+
key = f't2_{i+1}'
|
| 437 |
+
if key in timepoint_data:
|
| 438 |
+
data_t2_branches.append(torch.tensor(timepoint_data[key], dtype=torch.float32))
|
| 439 |
+
elif i == 0 and 't2' in timepoint_data:
|
| 440 |
+
data_t2_branches.append(torch.tensor(timepoint_data['t2'], dtype=torch.float32))
|
| 441 |
+
else:
|
| 442 |
+
data_t2_branches.append(None)
|
| 443 |
+
|
| 444 |
+
# Combined t2 ground truth (all branches merged)
|
| 445 |
+
data_t2_all_list = [d for d in data_t2_branches if d is not None]
|
| 446 |
+
data_t2_combined = torch.cat(data_t2_all_list, dim=0) if data_t2_all_list else None
|
| 447 |
+
|
| 448 |
+
# ---- t1 combined metrics (all branches pooled, compared to t1) ----
|
| 449 |
+
w1_t1_trials, w2_t1_trials, mmd_t1_trials = [], [], []
|
| 450 |
+
|
| 451 |
+
for trial in range(n_trials):
|
| 452 |
+
all_preds = []
|
| 453 |
+
for i, flow_net in enumerate(self.flow_nets):
|
| 454 |
+
node = NeuralODE(
|
| 455 |
+
flow_model_torch_wrapper(flow_net),
|
| 456 |
+
solver="euler",
|
| 457 |
+
sensitivity="adjoint",
|
| 458 |
+
).to(device)
|
| 459 |
+
|
| 460 |
+
with torch.no_grad():
|
| 461 |
+
traj = node.trajectory(x0, t_span_metric_t1) # [T, B, D]
|
| 462 |
+
|
| 463 |
+
x_final = traj[-1].cpu() # [B, D]
|
| 464 |
+
all_preds.append(x_final)
|
| 465 |
+
|
| 466 |
+
preds = torch.cat(all_preds, dim=0)
|
| 467 |
+
target_size = preds.shape[0]
|
| 468 |
+
perm = torch.randperm(data_t1.shape[0])[:target_size]
|
| 469 |
+
data_t1_reduced = data_t1[perm]
|
| 470 |
+
|
| 471 |
+
metrics = compute_distribution_distances(
|
| 472 |
+
preds[:, :2], data_t1_reduced[:, :2]
|
| 473 |
+
)
|
| 474 |
+
w1_t1_trials.append(metrics["W1"])
|
| 475 |
+
w2_t1_trials.append(metrics["W2"])
|
| 476 |
+
mmd_t1_trials.append(metrics["MMD"])
|
| 477 |
+
|
| 478 |
+
# ---- t2 per-branch metrics (each branch endpoint vs its own t2 cluster) ----
|
| 479 |
+
branch_t2_metrics = {}
|
| 480 |
+
for i, flow_net in enumerate(self.flow_nets):
|
| 481 |
+
if data_t2_branches[i] is None:
|
| 482 |
+
continue
|
| 483 |
+
w1_br, w2_br, mmd_br = [], [], []
|
| 484 |
+
for trial in range(n_trials):
|
| 485 |
+
node = NeuralODE(
|
| 486 |
+
flow_model_torch_wrapper(flow_net),
|
| 487 |
+
solver="euler",
|
| 488 |
+
sensitivity="adjoint",
|
| 489 |
+
).to(device)
|
| 490 |
+
with torch.no_grad():
|
| 491 |
+
traj = node.trajectory(x0, t_span_metric_t2)
|
| 492 |
+
x_final = traj[-1].cpu()
|
| 493 |
+
gt = data_t2_branches[i]
|
| 494 |
+
n_min = min(x_final.shape[0], gt.shape[0])
|
| 495 |
+
perm_pred = torch.randperm(x_final.shape[0])[:n_min]
|
| 496 |
+
perm_gt = torch.randperm(gt.shape[0])[:n_min]
|
| 497 |
+
m = compute_distribution_distances(
|
| 498 |
+
x_final[perm_pred, :2], gt[perm_gt, :2]
|
| 499 |
+
)
|
| 500 |
+
w1_br.append(m["W1"])
|
| 501 |
+
w2_br.append(m["W2"])
|
| 502 |
+
mmd_br.append(m["MMD"])
|
| 503 |
+
branch_t2_metrics[f"branch_{i+1}_t2"] = {
|
| 504 |
+
"W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
|
| 505 |
+
"W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
|
| 506 |
+
"MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
|
| 507 |
+
}
|
| 508 |
+
print(f"Branch {i+1} @ t2 — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
|
| 509 |
+
f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
|
| 510 |
+
f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
|
| 511 |
+
|
| 512 |
+
# ---- t2 combined metrics (all branches pooled, compared to all t2) ----
|
| 513 |
+
w1_t2_trials, w2_t2_trials, mmd_t2_trials = [], [], []
|
| 514 |
+
if data_t2_combined is not None:
|
| 515 |
+
for trial in range(n_trials):
|
| 516 |
+
all_preds = []
|
| 517 |
+
for i, flow_net in enumerate(self.flow_nets):
|
| 518 |
+
node = NeuralODE(
|
| 519 |
+
flow_model_torch_wrapper(flow_net),
|
| 520 |
+
solver="euler",
|
| 521 |
+
sensitivity="adjoint",
|
| 522 |
+
).to(device)
|
| 523 |
+
with torch.no_grad():
|
| 524 |
+
traj = node.trajectory(x0, t_span_metric_t2)
|
| 525 |
+
all_preds.append(traj[-1].cpu())
|
| 526 |
+
preds = torch.cat(all_preds, dim=0)
|
| 527 |
+
n_min = min(preds.shape[0], data_t2_combined.shape[0])
|
| 528 |
+
perm_pred = torch.randperm(preds.shape[0])[:n_min]
|
| 529 |
+
perm_gt = torch.randperm(data_t2_combined.shape[0])[:n_min]
|
| 530 |
+
m = compute_distribution_distances(
|
| 531 |
+
preds[perm_pred, :2], data_t2_combined[perm_gt, :2]
|
| 532 |
+
)
|
| 533 |
+
w1_t2_trials.append(m["W1"])
|
| 534 |
+
w2_t2_trials.append(m["W2"])
|
| 535 |
+
mmd_t2_trials.append(m["MMD"])
|
| 536 |
+
|
| 537 |
+
# Compute mean and std
|
| 538 |
+
w1_t1_mean, w1_t1_std = np.mean(w1_t1_trials), np.std(w1_t1_trials, ddof=1)
|
| 539 |
+
w2_t1_mean, w2_t1_std = np.mean(w2_t1_trials), np.std(w2_t1_trials, ddof=1)
|
| 540 |
+
mmd_t1_mean, mmd_t1_std = np.mean(mmd_t1_trials), np.std(mmd_t1_trials, ddof=1)
|
| 541 |
+
|
| 542 |
+
# Log metrics
|
| 543 |
+
self.log("test/W1_combined_t1", w1_t1_mean, on_epoch=True)
|
| 544 |
+
self.log("test/W2_combined_t1", w2_t1_mean, on_epoch=True)
|
| 545 |
+
self.log("test/MMD_combined_t1", mmd_t1_mean, on_epoch=True)
|
| 546 |
+
|
| 547 |
+
metrics_dict = {
|
| 548 |
+
"combined_t1": {
|
| 549 |
+
"W1_mean": float(w1_t1_mean), "W1_std": float(w1_t1_std),
|
| 550 |
+
"W2_mean": float(w2_t1_mean), "W2_std": float(w2_t1_std),
|
| 551 |
+
"MMD_mean": float(mmd_t1_mean), "MMD_std": float(mmd_t1_std),
|
| 552 |
+
"n_trials": n_trials,
|
| 553 |
+
}
|
| 554 |
+
}
|
| 555 |
+
metrics_dict.update(branch_t2_metrics)
|
| 556 |
+
|
| 557 |
+
if w1_t2_trials:
|
| 558 |
+
w1_t2_mean, w1_t2_std = np.mean(w1_t2_trials), np.std(w1_t2_trials, ddof=1)
|
| 559 |
+
w2_t2_mean, w2_t2_std = np.mean(w2_t2_trials), np.std(w2_t2_trials, ddof=1)
|
| 560 |
+
mmd_t2_mean, mmd_t2_std = np.mean(mmd_t2_trials), np.std(mmd_t2_trials, ddof=1)
|
| 561 |
+
self.log("test/W1_combined_t2", w1_t2_mean, on_epoch=True)
|
| 562 |
+
self.log("test/W2_combined_t2", w2_t2_mean, on_epoch=True)
|
| 563 |
+
self.log("test/MMD_combined_t2", mmd_t2_mean, on_epoch=True)
|
| 564 |
+
metrics_dict["combined_t2"] = {
|
| 565 |
+
"W1_mean": float(w1_t2_mean), "W1_std": float(w1_t2_std),
|
| 566 |
+
"W2_mean": float(w2_t2_mean), "W2_std": float(w2_t2_std),
|
| 567 |
+
"MMD_mean": float(mmd_t2_mean), "MMD_std": float(mmd_t2_std),
|
| 568 |
+
"n_trials": n_trials,
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
print(f"\n=== Combined @ t1 ===")
|
| 572 |
+
print(f"W1: {w1_t1_mean:.6f} ± {w1_t1_std:.6f}")
|
| 573 |
+
print(f"W2: {w2_t1_mean:.6f} ± {w2_t1_std:.6f}")
|
| 574 |
+
print(f"MMD: {mmd_t1_mean:.6f} ± {mmd_t1_std:.6f}")
|
| 575 |
+
if w1_t2_trials:
|
| 576 |
+
print(f"\n=== Combined @ t2 ===")
|
| 577 |
+
print(f"W1: {w1_t2_mean:.6f} ± {w1_t2_std:.6f}")
|
| 578 |
+
print(f"W2: {w2_t2_mean:.6f} ± {w2_t2_std:.6f}")
|
| 579 |
+
print(f"MMD: {mmd_t2_mean:.6f} ± {mmd_t2_std:.6f}")
|
| 580 |
+
|
| 581 |
+
# Create results directory structure
|
| 582 |
+
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
|
| 583 |
+
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
|
| 584 |
+
figures_dir = f'{results_dir}/figures'
|
| 585 |
+
os.makedirs(figures_dir, exist_ok=True)
|
| 586 |
+
|
| 587 |
+
# Save metrics to JSON
|
| 588 |
+
metrics_path = f'{results_dir}/metrics.json'
|
| 589 |
+
with open(metrics_path, 'w') as f:
|
| 590 |
+
json.dump(metrics_dict, f, indent=2)
|
| 591 |
+
print(f"Metrics saved to {metrics_path}")
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
# Save detailed metrics to CSV
|
| 595 |
+
detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
|
| 596 |
+
with open(detailed_csv_path, 'w', newline='') as csvfile:
|
| 597 |
+
writer = csv.writer(csvfile)
|
| 598 |
+
writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std'])
|
| 599 |
+
for key in sorted(metrics_dict.keys()):
|
| 600 |
+
m = metrics_dict[key]
|
| 601 |
+
writer.writerow([key,
|
| 602 |
+
f'{m.get("W1_mean", 0):.6f}', f'{m.get("W1_std", 0):.6f}',
|
| 603 |
+
f'{m.get("W2_mean", 0):.6f}', f'{m.get("W2_std", 0):.6f}',
|
| 604 |
+
f'{m.get("MMD_mean", 0):.6f}', f'{m.get("MMD_std", 0):.6f}'])
|
| 605 |
+
print(f"Detailed metrics CSV saved to {detailed_csv_path}")
|
| 606 |
+
|
| 607 |
+
# ===== Plot individual branches (using full t_span trajectories) =====
|
| 608 |
+
self._plot_mouse_branches(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2)
|
| 609 |
+
|
| 610 |
+
# ===== Plot all branches together =====
|
| 611 |
+
self._plot_mouse_combined(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2)
|
| 612 |
+
|
| 613 |
+
print(f"Mouse figures saved to {figures_dir}")
|
| 614 |
+
|
| 615 |
+
def _plot_mouse_branches(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2):
|
| 616 |
+
"""Plot each branch separately with timepoint background."""
|
| 617 |
+
n_branches = len(all_trajs)
|
| 618 |
+
branch_names = [f'Branch {i+1}' for i in range(n_branches)]
|
| 619 |
+
branch_colors = ['#B83CFF', '#50B2D7'][:n_branches]
|
| 620 |
+
cmaps = [cmap1, cmap2][:n_branches]
|
| 621 |
+
|
| 622 |
+
# Stack list-of-tensors into [B, T, D] numpy arrays
|
| 623 |
+
all_trajs_np = []
|
| 624 |
+
for traj in all_trajs:
|
| 625 |
+
if isinstance(traj, list):
|
| 626 |
+
traj = torch.stack(traj, dim=1) # list of [B,D] -> [B,T,D]
|
| 627 |
+
all_trajs_np.append(traj.cpu().detach().numpy())
|
| 628 |
+
all_trajs = all_trajs_np
|
| 629 |
+
|
| 630 |
+
# Move timepoint data to numpy
|
| 631 |
+
for key in list(timepoint_data.keys()):
|
| 632 |
+
if torch.is_tensor(timepoint_data[key]):
|
| 633 |
+
timepoint_data[key] = timepoint_data[key].cpu().numpy()
|
| 634 |
+
|
| 635 |
+
# Compute global axis limits
|
| 636 |
+
all_coords = []
|
| 637 |
+
for key in ['t0', 't1', 't2', 't2_1', 't2_2']:
|
| 638 |
+
if key in timepoint_data:
|
| 639 |
+
all_coords.append(timepoint_data[key][:, :2])
|
| 640 |
+
for traj_np in all_trajs:
|
| 641 |
+
all_coords.append(traj_np.reshape(-1, traj_np.shape[-1])[:, :2])
|
| 642 |
+
|
| 643 |
+
all_coords = np.concatenate(all_coords, axis=0)
|
| 644 |
+
x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
|
| 645 |
+
y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()
|
| 646 |
+
|
| 647 |
+
# Add margin
|
| 648 |
+
x_margin = 0.05 * (x_max - x_min)
|
| 649 |
+
y_margin = 0.05 * (y_max - y_min)
|
| 650 |
+
x_min -= x_margin
|
| 651 |
+
x_max += x_margin
|
| 652 |
+
y_min -= y_margin
|
| 653 |
+
y_max += y_margin
|
| 654 |
+
|
| 655 |
+
for i, traj in enumerate(all_trajs):
|
| 656 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 657 |
+
cmap = cmaps[i]
|
| 658 |
+
c_end = branch_colors[i]
|
| 659 |
+
|
| 660 |
+
# Plot timepoint background
|
| 661 |
+
t2_key = f't2_{i+1}' if f't2_{i+1}' in timepoint_data else 't2'
|
| 662 |
+
coords_list = [timepoint_data['t0'], timepoint_data['t1'], timepoint_data[t2_key]]
|
| 663 |
+
tp_colors = ['#05009E', '#A19EFF', c_end]
|
| 664 |
+
tp_labels = ["t=0", "t=1", f"t=2 (branch {i+1})"]
|
| 665 |
+
|
| 666 |
+
for coords, color, label in zip(coords_list, tp_colors, tp_labels):
|
| 667 |
+
alpha = 0.8 if color == '#05009E' else 0.6
|
| 668 |
+
ax.scatter(coords[:, 0], coords[:, 1],
|
| 669 |
+
c=color, s=80, alpha=alpha, marker='x',
|
| 670 |
+
label=f'{label} cells', linewidth=1.5)
|
| 671 |
+
|
| 672 |
+
# Plot continuous trajectories with LineCollection for speed
|
| 673 |
+
traj_2d = traj[:, :, :2]
|
| 674 |
+
n_time = traj_2d.shape[1]
|
| 675 |
+
color_vals = cmap(np.linspace(0, 1, n_time))
|
| 676 |
+
segments = []
|
| 677 |
+
seg_colors = []
|
| 678 |
+
for j in range(traj_2d.shape[0]):
|
| 679 |
+
pts = traj_2d[j] # [T, 2]
|
| 680 |
+
segs = np.stack([pts[:-1], pts[1:]], axis=1)
|
| 681 |
+
segments.append(segs)
|
| 682 |
+
seg_colors.append(color_vals[:-1])
|
| 683 |
+
segments = np.concatenate(segments, axis=0)
|
| 684 |
+
seg_colors = np.concatenate(seg_colors, axis=0)
|
| 685 |
+
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
|
| 686 |
+
ax.add_collection(lc)
|
| 687 |
+
|
| 688 |
+
# Start and end points
|
| 689 |
+
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
|
| 690 |
+
c='#05009E', s=30, marker='o', label='Trajectory Start',
|
| 691 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 692 |
+
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
|
| 693 |
+
c=c_end, s=30, marker='o', label='Trajectory End',
|
| 694 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 695 |
+
|
| 696 |
+
ax.set_xlim(x_min, x_max)
|
| 697 |
+
ax.set_ylim(y_min, y_max)
|
| 698 |
+
ax.set_xlabel("PC1", fontsize=12)
|
| 699 |
+
ax.set_ylabel("PC2", fontsize=12)
|
| 700 |
+
ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14)
|
| 701 |
+
ax.grid(True, alpha=0.3)
|
| 702 |
+
ax.legend(loc='upper right', fontsize=12, frameon=False)
|
| 703 |
+
|
| 704 |
+
plt.tight_layout()
|
| 705 |
+
plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300)
|
| 706 |
+
plt.close()
|
| 707 |
+
|
| 708 |
+
def _plot_mouse_combined(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2):
|
| 709 |
+
"""Plot all branches together."""
|
| 710 |
+
n_branches = len(all_trajs)
|
| 711 |
+
branch_names = [f'Branch {i+1}' for i in range(n_branches)]
|
| 712 |
+
branch_colors = ['#B83CFF', '#50B2D7'][:n_branches]
|
| 713 |
+
|
| 714 |
+
# Build timepoint key/color/label lists depending on branching
|
| 715 |
+
if 't2_1' in timepoint_data:
|
| 716 |
+
tp_keys = ['t0', 't1', 't2_1', 't2_2']
|
| 717 |
+
tp_colors = ['#05009E', '#A19EFF', '#B83CFF', '#50B2D7']
|
| 718 |
+
tp_labels = ['t=0', 't=1', 't=2 (branch 1)', 't=2 (branch 2)']
|
| 719 |
+
else:
|
| 720 |
+
tp_keys = ['t0', 't1', 't2']
|
| 721 |
+
tp_colors = ['#05009E', '#A19EFF', '#B83CFF']
|
| 722 |
+
tp_labels = ['t=0', 't=1', 't=2']
|
| 723 |
+
|
| 724 |
+
# Stack list-of-tensors into [B, T, D] numpy arrays
|
| 725 |
+
all_trajs_np = []
|
| 726 |
+
for traj in all_trajs:
|
| 727 |
+
if isinstance(traj, list):
|
| 728 |
+
traj = torch.stack(traj, dim=1)
|
| 729 |
+
if torch.is_tensor(traj):
|
| 730 |
+
traj = traj.cpu().detach().numpy()
|
| 731 |
+
all_trajs_np.append(traj)
|
| 732 |
+
all_trajs = all_trajs_np
|
| 733 |
+
|
| 734 |
+
# Move timepoint data to numpy
|
| 735 |
+
for key in list(timepoint_data.keys()):
|
| 736 |
+
if torch.is_tensor(timepoint_data[key]):
|
| 737 |
+
timepoint_data[key] = timepoint_data[key].cpu().numpy()
|
| 738 |
+
|
| 739 |
+
fig, ax = plt.subplots(figsize=(12, 10))
|
| 740 |
+
|
| 741 |
+
# Plot timepoint background
|
| 742 |
+
for idx, (t_key, color, label) in enumerate(zip(
|
| 743 |
+
tp_keys,
|
| 744 |
+
tp_colors,
|
| 745 |
+
tp_labels
|
| 746 |
+
)):
|
| 747 |
+
if t_key in timepoint_data:
|
| 748 |
+
coords = timepoint_data[t_key]
|
| 749 |
+
ax.scatter(coords[:, 0], coords[:, 1],
|
| 750 |
+
c=color, s=80, alpha=0.4, marker='x',
|
| 751 |
+
label=f'{label} cells', linewidth=1.5)
|
| 752 |
+
|
| 753 |
+
# Plot trajectories with color gradients
|
| 754 |
+
cmaps = [cmap1, cmap2]
|
| 755 |
+
for i, traj in enumerate(all_trajs):
|
| 756 |
+
traj_2d = traj[:, :, :2]
|
| 757 |
+
c_end = branch_colors[i]
|
| 758 |
+
cmap = cmaps[i]
|
| 759 |
+
n_time = traj_2d.shape[1]
|
| 760 |
+
color_vals = cmap(np.linspace(0, 1, n_time))
|
| 761 |
+
segments = []
|
| 762 |
+
seg_colors = []
|
| 763 |
+
for j in range(traj_2d.shape[0]):
|
| 764 |
+
pts = traj_2d[j]
|
| 765 |
+
segs = np.stack([pts[:-1], pts[1:]], axis=1)
|
| 766 |
+
segments.append(segs)
|
| 767 |
+
seg_colors.append(color_vals[:-1])
|
| 768 |
+
segments = np.concatenate(segments, axis=0)
|
| 769 |
+
seg_colors = np.concatenate(seg_colors, axis=0)
|
| 770 |
+
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
|
| 771 |
+
ax.add_collection(lc)
|
| 772 |
+
|
| 773 |
+
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
|
| 774 |
+
c='#05009E', s=30, marker='o',
|
| 775 |
+
label=f'{branch_names[i]} Start',
|
| 776 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 777 |
+
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
|
| 778 |
+
c=c_end, s=30, marker='o',
|
| 779 |
+
label=f'{branch_names[i]} End',
|
| 780 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 781 |
+
|
| 782 |
+
ax.set_xlabel("PC1", fontsize=14)
|
| 783 |
+
ax.set_ylabel("PC2", fontsize=14)
|
| 784 |
+
ax.set_title("All Branch Trajectories with Timepoint Background",
|
| 785 |
+
fontsize=16, weight='bold')
|
| 786 |
+
ax.grid(True, alpha=0.3)
|
| 787 |
+
ax.legend(loc='upper right', fontsize=12, frameon=False)
|
| 788 |
+
|
| 789 |
+
plt.tight_layout()
|
| 790 |
+
plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300)
|
| 791 |
+
plt.close()
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
class FlowNetTestClonidine(BranchFlowNetTrainBase):
|
| 795 |
+
"""Test class for Clonidine perturbation experiment (1 or 2 branches)."""
|
| 796 |
+
|
| 797 |
+
def test_step(self, batch, batch_idx):
|
| 798 |
+
# Handle both dict and tuple batch formats from CombinedLoader
|
| 799 |
+
if isinstance(batch, dict) and "test_samples" in batch:
|
| 800 |
+
# New format: {"test_samples": {...}, "metric_samples": {...}}
|
| 801 |
+
main_batch = batch["test_samples"]
|
| 802 |
+
elif isinstance(batch, (list, tuple)) and len(batch) >= 1:
|
| 803 |
+
# Old format with nested structure
|
| 804 |
+
test_samples = batch[0]
|
| 805 |
+
if isinstance(test_samples, dict) and "test_samples" in test_samples:
|
| 806 |
+
main_batch = test_samples["test_samples"][0]
|
| 807 |
+
else:
|
| 808 |
+
main_batch = test_samples
|
| 809 |
+
else:
|
| 810 |
+
# Fallback
|
| 811 |
+
main_batch = batch
|
| 812 |
+
|
| 813 |
+
# Get timepoint data
|
| 814 |
+
timepoint_data = self.trainer.datamodule.get_timepoint_data()
|
| 815 |
+
device = main_batch["x0"][0].device
|
| 816 |
+
|
| 817 |
+
# Use val x0 as initial conditions
|
| 818 |
+
x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
|
| 819 |
+
t_span = torch.linspace(0, 1, 100).to(device)
|
| 820 |
+
|
| 821 |
+
# Define color schemes for clonidine (2 branches)
|
| 822 |
+
custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"]
|
| 823 |
+
custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
|
| 824 |
+
custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1)
|
| 825 |
+
custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2)
|
| 826 |
+
|
| 827 |
+
all_trajs = []
|
| 828 |
+
all_endpoints = []
|
| 829 |
+
|
| 830 |
+
for i, flow_net in enumerate(self.flow_nets):
|
| 831 |
+
node = NeuralODE(
|
| 832 |
+
flow_model_torch_wrapper(flow_net),
|
| 833 |
+
solver="euler",
|
| 834 |
+
sensitivity="adjoint",
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
with torch.no_grad():
|
| 838 |
+
traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
|
| 839 |
+
|
| 840 |
+
traj = torch.transpose(traj, 0, 1) # [B, T, D]
|
| 841 |
+
all_trajs.append(traj)
|
| 842 |
+
all_endpoints.append(traj[:, -1, :])
|
| 843 |
+
|
| 844 |
+
# Run 5 trials with random subsampling for robust metrics
|
| 845 |
+
n_trials = 5
|
| 846 |
+
n_branches = len(self.flow_nets)
|
| 847 |
+
|
| 848 |
+
# Gather per-branch ground truth
|
| 849 |
+
gt_data_per_branch = []
|
| 850 |
+
for i in range(n_branches):
|
| 851 |
+
if n_branches == 1:
|
| 852 |
+
key = 't1'
|
| 853 |
+
else:
|
| 854 |
+
key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
|
| 855 |
+
gt_data_per_branch.append(torch.tensor(timepoint_data[key], dtype=torch.float32))
|
| 856 |
+
gt_all = torch.cat(gt_data_per_branch, dim=0)
|
| 857 |
+
|
| 858 |
+
# Per-branch metrics (5 trials)
|
| 859 |
+
metrics_dict = {}
|
| 860 |
+
for i in range(n_branches):
|
| 861 |
+
w1_br, w2_br, mmd_br = [], [], []
|
| 862 |
+
pred = all_endpoints[i]
|
| 863 |
+
gt = gt_data_per_branch[i]
|
| 864 |
+
for trial in range(n_trials):
|
| 865 |
+
n_min = min(pred.shape[0], gt.shape[0])
|
| 866 |
+
perm_pred = torch.randperm(pred.shape[0])[:n_min]
|
| 867 |
+
perm_gt = torch.randperm(gt.shape[0])[:n_min]
|
| 868 |
+
m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2])
|
| 869 |
+
w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"])
|
| 870 |
+
metrics_dict[f"branch_{i+1}"] = {
|
| 871 |
+
"W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
|
| 872 |
+
"W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
|
| 873 |
+
"MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
|
| 874 |
+
}
|
| 875 |
+
self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True)
|
| 876 |
+
print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
|
| 877 |
+
f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
|
| 878 |
+
f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
|
| 879 |
+
|
| 880 |
+
# Combined metrics (5 trials)
|
| 881 |
+
pred_all = torch.cat(all_endpoints, dim=0)
|
| 882 |
+
w1_trials, w2_trials, mmd_trials = [], [], []
|
| 883 |
+
for trial in range(n_trials):
|
| 884 |
+
n_min = min(pred_all.shape[0], gt_all.shape[0])
|
| 885 |
+
perm_pred = torch.randperm(pred_all.shape[0])[:n_min]
|
| 886 |
+
perm_gt = torch.randperm(gt_all.shape[0])[:n_min]
|
| 887 |
+
m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2])
|
| 888 |
+
w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"])
|
| 889 |
+
|
| 890 |
+
w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1)
|
| 891 |
+
w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1)
|
| 892 |
+
mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1)
|
| 893 |
+
self.log("test/W1_t1_combined", w1_mean, on_epoch=True)
|
| 894 |
+
self.log("test/W2_t1_combined", w2_mean, on_epoch=True)
|
| 895 |
+
self.log("test/MMD_t1_combined", mmd_mean, on_epoch=True)
|
| 896 |
+
metrics_dict['t1_combined'] = {
|
| 897 |
+
"W1_mean": float(w1_mean), "W1_std": float(w1_std),
|
| 898 |
+
"W2_mean": float(w2_mean), "W2_std": float(w2_std),
|
| 899 |
+
"MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std),
|
| 900 |
+
"n_trials": n_trials,
|
| 901 |
+
}
|
| 902 |
+
print(f"\n=== Combined @ t1 ===")
|
| 903 |
+
print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}")
|
| 904 |
+
print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}")
|
| 905 |
+
print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}")
|
| 906 |
+
|
| 907 |
+
# Create results directory structure
|
| 908 |
+
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
|
| 909 |
+
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
|
| 910 |
+
figures_dir = f'{results_dir}/figures'
|
| 911 |
+
os.makedirs(figures_dir, exist_ok=True)
|
| 912 |
+
|
| 913 |
+
# Save metrics to JSON
|
| 914 |
+
metrics_path = f'{results_dir}/metrics.json'
|
| 915 |
+
with open(metrics_path, 'w') as f:
|
| 916 |
+
json.dump(metrics_dict, f, indent=2)
|
| 917 |
+
print(f"Metrics saved to {metrics_path}")
|
| 918 |
+
|
| 919 |
+
# Save detailed metrics to CSV
|
| 920 |
+
detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
|
| 921 |
+
with open(detailed_csv_path, 'w', newline='') as csvfile:
|
| 922 |
+
writer = csv.writer(csvfile)
|
| 923 |
+
writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std'])
|
| 924 |
+
for key in sorted(metrics_dict.keys()):
|
| 925 |
+
m = metrics_dict[key]
|
| 926 |
+
writer.writerow([key,
|
| 927 |
+
f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}',
|
| 928 |
+
f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}',
|
| 929 |
+
f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}'])
|
| 930 |
+
print(f"Detailed metrics CSV saved to {detailed_csv_path}")
|
| 931 |
+
|
| 932 |
+
# ===== Plot branches =====
|
| 933 |
+
self._plot_clonidine_branches(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2)
|
| 934 |
+
self._plot_clonidine_combined(all_trajs, timepoint_data, figures_dir)
|
| 935 |
+
|
| 936 |
+
print(f"Clonidine figures saved to {figures_dir}")
|
| 937 |
+
|
| 938 |
+
def _plot_clonidine_branches(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2):
|
| 939 |
+
"""Plot each branch separately."""
|
| 940 |
+
branch_names = ['Branch 1', 'Branch 2']
|
| 941 |
+
branch_colors = ['#B83CFF', '#50B2D7']
|
| 942 |
+
cmaps = [cmap1, cmap2]
|
| 943 |
+
|
| 944 |
+
# Compute global axis limits – handle single vs multi branch keys
|
| 945 |
+
all_coords = []
|
| 946 |
+
if 't1_1' in timepoint_data:
|
| 947 |
+
tp_keys = ['t0'] + [f't1_{i+1}' for i in range(len(all_trajs))]
|
| 948 |
+
else:
|
| 949 |
+
tp_keys = ['t0', 't1']
|
| 950 |
+
for key in tp_keys:
|
| 951 |
+
all_coords.append(timepoint_data[key][:, :2])
|
| 952 |
+
for traj in all_trajs:
|
| 953 |
+
all_coords.append(traj.reshape(-1, traj.shape[-1])[:, :2])
|
| 954 |
+
|
| 955 |
+
all_coords = np.concatenate(all_coords, axis=0)
|
| 956 |
+
x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
|
| 957 |
+
y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()
|
| 958 |
+
|
| 959 |
+
x_margin = 0.05 * (x_max - x_min)
|
| 960 |
+
y_margin = 0.05 * (y_max - y_min)
|
| 961 |
+
x_min -= x_margin
|
| 962 |
+
x_max += x_margin
|
| 963 |
+
y_min -= y_margin
|
| 964 |
+
y_max += y_margin
|
| 965 |
+
|
| 966 |
+
for i, traj in enumerate(all_trajs):
|
| 967 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 968 |
+
c_end = branch_colors[i]
|
| 969 |
+
|
| 970 |
+
# Plot timepoint background
|
| 971 |
+
t1_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
|
| 972 |
+
coords_list = [timepoint_data['t0'], timepoint_data[t1_key]]
|
| 973 |
+
tp_colors = ['#05009E', c_end]
|
| 974 |
+
t1_label = f"t=1 (branch {i+1})" if len(all_trajs) > 1 else "t=1"
|
| 975 |
+
tp_labels = ["t=0", t1_label]
|
| 976 |
+
|
| 977 |
+
for coords, color, label in zip(coords_list, tp_colors, tp_labels):
|
| 978 |
+
ax.scatter(coords[:, 0], coords[:, 1],
|
| 979 |
+
c=color, s=80, alpha=0.4, marker='x',
|
| 980 |
+
label=f'{label} cells', linewidth=1.5)
|
| 981 |
+
|
| 982 |
+
# Plot continuous trajectories with LineCollection for speed
|
| 983 |
+
traj_2d = traj[:, :, :2]
|
| 984 |
+
n_time = traj_2d.shape[1]
|
| 985 |
+
color_vals = cmaps[i](np.linspace(0, 1, n_time))
|
| 986 |
+
segments = []
|
| 987 |
+
seg_colors = []
|
| 988 |
+
for j in range(traj_2d.shape[0]):
|
| 989 |
+
pts = traj_2d[j]
|
| 990 |
+
segs = np.stack([pts[:-1], pts[1:]], axis=1)
|
| 991 |
+
segments.append(segs)
|
| 992 |
+
seg_colors.append(color_vals[:-1])
|
| 993 |
+
segments = np.concatenate(segments, axis=0)
|
| 994 |
+
seg_colors = np.concatenate(seg_colors, axis=0)
|
| 995 |
+
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
|
| 996 |
+
ax.add_collection(lc)
|
| 997 |
+
|
| 998 |
+
# Start and end points
|
| 999 |
+
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
|
| 1000 |
+
c='#05009E', s=30, marker='o', label='Trajectory Start',
|
| 1001 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 1002 |
+
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
|
| 1003 |
+
c=c_end, s=30, marker='o', label='Trajectory End',
|
| 1004 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 1005 |
+
|
| 1006 |
+
ax.set_xlim(x_min, x_max)
|
| 1007 |
+
ax.set_ylim(y_min, y_max)
|
| 1008 |
+
ax.set_xlabel("PC1", fontsize=12)
|
| 1009 |
+
ax.set_ylabel("PC2", fontsize=12)
|
| 1010 |
+
ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14)
|
| 1011 |
+
ax.grid(True, alpha=0.3)
|
| 1012 |
+
ax.legend(loc='upper right', fontsize=16, frameon=False)
|
| 1013 |
+
|
| 1014 |
+
plt.tight_layout()
|
| 1015 |
+
plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300)
|
| 1016 |
+
plt.close()
|
| 1017 |
+
|
| 1018 |
+
def _plot_clonidine_combined(self, all_trajs, timepoint_data, save_dir):
|
| 1019 |
+
"""Plot all branches together."""
|
| 1020 |
+
branch_names = ['Branch 1', 'Branch 2']
|
| 1021 |
+
branch_colors = ['#B83CFF', '#50B2D7']
|
| 1022 |
+
|
| 1023 |
+
fig, ax = plt.subplots(figsize=(12, 10))
|
| 1024 |
+
|
| 1025 |
+
# Build timepoint keys/colors/labels depending on single vs multi branch
|
| 1026 |
+
if 't1_1' in timepoint_data:
|
| 1027 |
+
tp_keys = ['t0'] + [f't1_{j+1}' for j in range(len(all_trajs))]
|
| 1028 |
+
tp_labels_list = ['t=0'] + [f't=1 (branch {j+1})' for j in range(len(all_trajs))]
|
| 1029 |
+
else:
|
| 1030 |
+
tp_keys = ['t0', 't1']
|
| 1031 |
+
tp_labels_list = ['t=0', 't=1']
|
| 1032 |
+
tp_colors = ['#05009E', '#B83CFF', '#50B2D7'][:len(tp_keys)]
|
| 1033 |
+
|
| 1034 |
+
# Plot timepoint background
|
| 1035 |
+
for t_key, color, label in zip(tp_keys, tp_colors, tp_labels_list):
|
| 1036 |
+
coords = timepoint_data[t_key]
|
| 1037 |
+
ax.scatter(coords[:, 0], coords[:, 1],
|
| 1038 |
+
c=color, s=80, alpha=0.4, marker='x',
|
| 1039 |
+
label=f'{label} cells', linewidth=1.5)
|
| 1040 |
+
|
| 1041 |
+
# Plot trajectories with color gradients
|
| 1042 |
+
custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"]
|
| 1043 |
+
custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
|
| 1044 |
+
cmaps = [
|
| 1045 |
+
LinearSegmentedColormap.from_list("clon_cmap1", custom_colors_1),
|
| 1046 |
+
LinearSegmentedColormap.from_list("clon_cmap2", custom_colors_2),
|
| 1047 |
+
]
|
| 1048 |
+
for i, traj in enumerate(all_trajs):
|
| 1049 |
+
traj_2d = traj[:, :, :2]
|
| 1050 |
+
c_end = branch_colors[i]
|
| 1051 |
+
cmap = cmaps[i]
|
| 1052 |
+
n_time = traj_2d.shape[1]
|
| 1053 |
+
color_vals = cmap(np.linspace(0, 1, n_time))
|
| 1054 |
+
segments = []
|
| 1055 |
+
seg_colors = []
|
| 1056 |
+
for j in range(traj_2d.shape[0]):
|
| 1057 |
+
pts = traj_2d[j]
|
| 1058 |
+
segs = np.stack([pts[:-1], pts[1:]], axis=1)
|
| 1059 |
+
segments.append(segs)
|
| 1060 |
+
seg_colors.append(color_vals[:-1])
|
| 1061 |
+
segments = np.concatenate(segments, axis=0)
|
| 1062 |
+
seg_colors = np.concatenate(seg_colors, axis=0)
|
| 1063 |
+
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
|
| 1064 |
+
ax.add_collection(lc)
|
| 1065 |
+
|
| 1066 |
+
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
|
| 1067 |
+
c='#05009E', s=30, marker='o',
|
| 1068 |
+
label=f'{branch_names[i]} Start',
|
| 1069 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 1070 |
+
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
|
| 1071 |
+
c=c_end, s=30, marker='o',
|
| 1072 |
+
label=f'{branch_names[i]} End',
|
| 1073 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 1074 |
+
|
| 1075 |
+
ax.set_xlabel("PC1", fontsize=14)
|
| 1076 |
+
ax.set_ylabel("PC2", fontsize=14)
|
| 1077 |
+
ax.set_title("All Branch Trajectories with Timepoint Background",
|
| 1078 |
+
fontsize=16, weight='bold')
|
| 1079 |
+
ax.grid(True, alpha=0.3)
|
| 1080 |
+
ax.legend(loc='upper right', fontsize=12, frameon=False)
|
| 1081 |
+
|
| 1082 |
+
plt.tight_layout()
|
| 1083 |
+
plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300)
|
| 1084 |
+
plt.close()
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
class FlowNetTestTrametinib(BranchFlowNetTrainBase):
|
| 1088 |
+
"""Test class for Trametinib perturbation experiment (1 or 3 branches)."""
|
| 1089 |
+
|
| 1090 |
+
def test_step(self, batch, batch_idx):
|
| 1091 |
+
# Handle both dict and tuple batch formats from CombinedLoader
|
| 1092 |
+
if isinstance(batch, dict) and "test_samples" in batch:
|
| 1093 |
+
# New format: {"test_samples": {...}, "metric_samples": {...}}
|
| 1094 |
+
main_batch = batch["test_samples"]
|
| 1095 |
+
elif isinstance(batch, (list, tuple)) and len(batch) >= 1:
|
| 1096 |
+
# Old format with nested structure
|
| 1097 |
+
test_samples = batch[0]
|
| 1098 |
+
if isinstance(test_samples, dict) and "test_samples" in test_samples:
|
| 1099 |
+
main_batch = test_samples["test_samples"][0]
|
| 1100 |
+
else:
|
| 1101 |
+
main_batch = test_samples
|
| 1102 |
+
else:
|
| 1103 |
+
# Fallback
|
| 1104 |
+
main_batch = batch
|
| 1105 |
+
|
| 1106 |
+
# Get timepoint data
|
| 1107 |
+
timepoint_data = self.trainer.datamodule.get_timepoint_data()
|
| 1108 |
+
device = main_batch["x0"][0].device
|
| 1109 |
+
|
| 1110 |
+
# Use val x0 as initial conditions
|
| 1111 |
+
x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
|
| 1112 |
+
t_span = torch.linspace(0, 1, 100).to(device)
|
| 1113 |
+
|
| 1114 |
+
# Define color schemes for trametinib (3 branches)
|
| 1115 |
+
custom_colors_1 = ["#05009E", "#A19EFF", "#9793F8"]
|
| 1116 |
+
custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
|
| 1117 |
+
custom_colors_3 = ["#05009E", "#A19EFF", "#B83CFF"]
|
| 1118 |
+
custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1)
|
| 1119 |
+
custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2)
|
| 1120 |
+
custom_cmap_3 = LinearSegmentedColormap.from_list("cmap3", custom_colors_3)
|
| 1121 |
+
|
| 1122 |
+
all_trajs = []
|
| 1123 |
+
all_endpoints = []
|
| 1124 |
+
|
| 1125 |
+
for i, flow_net in enumerate(self.flow_nets):
|
| 1126 |
+
node = NeuralODE(
|
| 1127 |
+
flow_model_torch_wrapper(flow_net),
|
| 1128 |
+
solver="euler",
|
| 1129 |
+
sensitivity="adjoint",
|
| 1130 |
+
)
|
| 1131 |
+
|
| 1132 |
+
with torch.no_grad():
|
| 1133 |
+
traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
|
| 1134 |
+
|
| 1135 |
+
traj = torch.transpose(traj, 0, 1) # [B, T, D]
|
| 1136 |
+
all_trajs.append(traj)
|
| 1137 |
+
all_endpoints.append(traj[:, -1, :])
|
| 1138 |
+
|
| 1139 |
+
# Run 5 trials with random subsampling for robust metrics
|
| 1140 |
+
n_trials = 5
|
| 1141 |
+
n_branches = len(self.flow_nets)
|
| 1142 |
+
|
| 1143 |
+
# Gather per-branch ground truth
|
| 1144 |
+
gt_data_per_branch = []
|
| 1145 |
+
for i in range(n_branches):
|
| 1146 |
+
if n_branches == 1:
|
| 1147 |
+
key = 't1'
|
| 1148 |
+
else:
|
| 1149 |
+
key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
|
| 1150 |
+
gt_data_per_branch.append(torch.tensor(timepoint_data[key], dtype=torch.float32))
|
| 1151 |
+
gt_all = torch.cat(gt_data_per_branch, dim=0)
|
| 1152 |
+
|
| 1153 |
+
# Per-branch metrics (5 trials)
|
| 1154 |
+
metrics_dict = {}
|
| 1155 |
+
for i in range(n_branches):
|
| 1156 |
+
w1_br, w2_br, mmd_br = [], [], []
|
| 1157 |
+
pred = all_endpoints[i]
|
| 1158 |
+
gt = gt_data_per_branch[i]
|
| 1159 |
+
for trial in range(n_trials):
|
| 1160 |
+
n_min = min(pred.shape[0], gt.shape[0])
|
| 1161 |
+
perm_pred = torch.randperm(pred.shape[0])[:n_min]
|
| 1162 |
+
perm_gt = torch.randperm(gt.shape[0])[:n_min]
|
| 1163 |
+
m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2])
|
| 1164 |
+
w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"])
|
| 1165 |
+
metrics_dict[f"branch_{i+1}"] = {
|
| 1166 |
+
"W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
|
| 1167 |
+
"W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
|
| 1168 |
+
"MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
|
| 1169 |
+
}
|
| 1170 |
+
self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True)
|
| 1171 |
+
print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
|
| 1172 |
+
f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
|
| 1173 |
+
f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
|
| 1174 |
+
|
| 1175 |
+
# Combined metrics (5 trials)
|
| 1176 |
+
pred_all = torch.cat(all_endpoints, dim=0)
|
| 1177 |
+
w1_trials, w2_trials, mmd_trials = [], [], []
|
| 1178 |
+
for trial in range(n_trials):
|
| 1179 |
+
n_min = min(pred_all.shape[0], gt_all.shape[0])
|
| 1180 |
+
perm_pred = torch.randperm(pred_all.shape[0])[:n_min]
|
| 1181 |
+
perm_gt = torch.randperm(gt_all.shape[0])[:n_min]
|
| 1182 |
+
m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2])
|
| 1183 |
+
w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"])
|
| 1184 |
+
|
| 1185 |
+
w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1)
|
| 1186 |
+
w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1)
|
| 1187 |
+
mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1)
|
| 1188 |
+
self.log("test/W1_t1_combined", w1_mean, on_epoch=True)
|
| 1189 |
+
self.log("test/W2_t1_combined", w2_mean, on_epoch=True)
|
| 1190 |
+
self.log("test/MMD_t1_combined", mmd_mean, on_epoch=True)
|
| 1191 |
+
metrics_dict['t1_combined'] = {
|
| 1192 |
+
"W1_mean": float(w1_mean), "W1_std": float(w1_std),
|
| 1193 |
+
"W2_mean": float(w2_mean), "W2_std": float(w2_std),
|
| 1194 |
+
"MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std),
|
| 1195 |
+
"n_trials": n_trials,
|
| 1196 |
+
}
|
| 1197 |
+
print(f"\n=== Combined @ t1 ===")
|
| 1198 |
+
print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}")
|
| 1199 |
+
print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}")
|
| 1200 |
+
print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}")
|
| 1201 |
+
|
| 1202 |
+
# Create results directory structure
|
| 1203 |
+
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
|
| 1204 |
+
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
|
| 1205 |
+
figures_dir = f'{results_dir}/figures'
|
| 1206 |
+
os.makedirs(figures_dir, exist_ok=True)
|
| 1207 |
+
|
| 1208 |
+
# Save metrics to JSON
|
| 1209 |
+
metrics_path = f'{results_dir}/metrics.json'
|
| 1210 |
+
with open(metrics_path, 'w') as f:
|
| 1211 |
+
json.dump(metrics_dict, f, indent=2)
|
| 1212 |
+
print(f"Metrics saved to {metrics_path}")
|
| 1213 |
+
|
| 1214 |
+
# Save detailed metrics to CSV
|
| 1215 |
+
detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
|
| 1216 |
+
with open(detailed_csv_path, 'w', newline='') as csvfile:
|
| 1217 |
+
writer = csv.writer(csvfile)
|
| 1218 |
+
writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std'])
|
| 1219 |
+
for key in sorted(metrics_dict.keys()):
|
| 1220 |
+
m = metrics_dict[key]
|
| 1221 |
+
writer.writerow([key,
|
| 1222 |
+
f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}',
|
| 1223 |
+
f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}',
|
| 1224 |
+
f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}'])
|
| 1225 |
+
print(f"Detailed metrics CSV saved to {detailed_csv_path}")
|
| 1226 |
+
|
| 1227 |
+
# ===== Plot branches =====
|
| 1228 |
+
self._plot_trametinib_branches(all_trajs, timepoint_data, figures_dir,
|
| 1229 |
+
custom_cmap_1, custom_cmap_2, custom_cmap_3)
|
| 1230 |
+
self._plot_trametinib_combined(all_trajs, timepoint_data, figures_dir)
|
| 1231 |
+
|
| 1232 |
+
print(f"Trametinib figures saved to {figures_dir}")
|
| 1233 |
+
|
| 1234 |
+
def _plot_trametinib_branches(self, all_trajs, timepoint_data, save_dir,
|
| 1235 |
+
cmap1, cmap2, cmap3):
|
| 1236 |
+
"""Plot each branch separately."""
|
| 1237 |
+
branch_names = ['Branch 1', 'Branch 2', 'Branch 3']
|
| 1238 |
+
branch_colors = ['#9793F8', '#50B2D7', '#B83CFF']
|
| 1239 |
+
cmaps = [cmap1, cmap2, cmap3]
|
| 1240 |
+
|
| 1241 |
+
# Compute global axis limits – handle single vs multi branch keys
|
| 1242 |
+
all_coords = []
|
| 1243 |
+
if 't1_1' in timepoint_data:
|
| 1244 |
+
tp_keys = ['t0'] + [f't1_{i+1}' for i in range(len(all_trajs))]
|
| 1245 |
+
else:
|
| 1246 |
+
tp_keys = ['t0', 't1']
|
| 1247 |
+
for key in tp_keys:
|
| 1248 |
+
all_coords.append(timepoint_data[key][:, :2])
|
| 1249 |
+
for traj in all_trajs:
|
| 1250 |
+
all_coords.append(traj.reshape(-1, traj.shape[-1])[:, :2])
|
| 1251 |
+
|
| 1252 |
+
all_coords = np.concatenate(all_coords, axis=0)
|
| 1253 |
+
x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
|
| 1254 |
+
y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()
|
| 1255 |
+
|
| 1256 |
+
x_margin = 0.05 * (x_max - x_min)
|
| 1257 |
+
y_margin = 0.05 * (y_max - y_min)
|
| 1258 |
+
x_min -= x_margin
|
| 1259 |
+
x_max += x_margin
|
| 1260 |
+
y_min -= y_margin
|
| 1261 |
+
y_max += y_margin
|
| 1262 |
+
|
| 1263 |
+
for i, traj in enumerate(all_trajs):
|
| 1264 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 1265 |
+
c_end = branch_colors[i]
|
| 1266 |
+
|
| 1267 |
+
# Plot timepoint background
|
| 1268 |
+
t1_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
|
| 1269 |
+
coords_list = [timepoint_data['t0'], timepoint_data[t1_key]]
|
| 1270 |
+
tp_colors = ['#05009E', c_end]
|
| 1271 |
+
t1_label = f"t=1 (branch {i+1})" if len(all_trajs) > 1 else "t=1"
|
| 1272 |
+
tp_labels = ["t=0", t1_label]
|
| 1273 |
+
|
| 1274 |
+
for coords, color, label in zip(coords_list, tp_colors, tp_labels):
|
| 1275 |
+
ax.scatter(coords[:, 0], coords[:, 1],
|
| 1276 |
+
c=color, s=80, alpha=0.4, marker='x',
|
| 1277 |
+
label=f'{label} cells', linewidth=1.5)
|
| 1278 |
+
|
| 1279 |
+
# Plot continuous trajectories with LineCollection for speed
|
| 1280 |
+
traj_2d = traj[:, :, :2]
|
| 1281 |
+
n_time = traj_2d.shape[1]
|
| 1282 |
+
color_vals = cmaps[i](np.linspace(0, 1, n_time))
|
| 1283 |
+
segments = []
|
| 1284 |
+
seg_colors = []
|
| 1285 |
+
for j in range(traj_2d.shape[0]):
|
| 1286 |
+
pts = traj_2d[j]
|
| 1287 |
+
segs = np.stack([pts[:-1], pts[1:]], axis=1)
|
| 1288 |
+
segments.append(segs)
|
| 1289 |
+
seg_colors.append(color_vals[:-1])
|
| 1290 |
+
segments = np.concatenate(segments, axis=0)
|
| 1291 |
+
seg_colors = np.concatenate(seg_colors, axis=0)
|
| 1292 |
+
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
|
| 1293 |
+
ax.add_collection(lc)
|
| 1294 |
+
|
| 1295 |
+
# Start and end points
|
| 1296 |
+
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
|
| 1297 |
+
c='#05009E', s=30, marker='o', label='Trajectory Start',
|
| 1298 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 1299 |
+
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
|
| 1300 |
+
c=c_end, s=30, marker='o', label='Trajectory End',
|
| 1301 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 1302 |
+
|
| 1303 |
+
ax.set_xlim(x_min, x_max)
|
| 1304 |
+
ax.set_ylim(y_min, y_max)
|
| 1305 |
+
ax.set_xlabel("PC1", fontsize=12)
|
| 1306 |
+
ax.set_ylabel("PC2", fontsize=12)
|
| 1307 |
+
ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14)
|
| 1308 |
+
ax.grid(True, alpha=0.3)
|
| 1309 |
+
ax.legend(loc='upper right', fontsize=16, frameon=False)
|
| 1310 |
+
|
| 1311 |
+
plt.tight_layout()
|
| 1312 |
+
plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300)
|
| 1313 |
+
plt.close()
|
| 1314 |
+
|
| 1315 |
+
def _plot_trametinib_combined(self, all_trajs, timepoint_data, save_dir):
|
| 1316 |
+
"""Plot all 3 branches together."""
|
| 1317 |
+
branch_names = ['Branch 1', 'Branch 2', 'Branch 3']
|
| 1318 |
+
branch_colors = ['#9793F8', '#50B2D7', '#B83CFF']
|
| 1319 |
+
|
| 1320 |
+
fig, ax = plt.subplots(figsize=(12, 10))
|
| 1321 |
+
|
| 1322 |
+
# Build timepoint keys/colors/labels depending on single vs multi branch
|
| 1323 |
+
if 't1_1' in timepoint_data:
|
| 1324 |
+
tp_keys = ['t0'] + [f't1_{j+1}' for j in range(len(all_trajs))]
|
| 1325 |
+
tp_labels_list = ['t=0'] + [f't=1 (branch {j+1})' for j in range(len(all_trajs))]
|
| 1326 |
+
else:
|
| 1327 |
+
tp_keys = ['t0', 't1']
|
| 1328 |
+
tp_labels_list = ['t=0', 't=1']
|
| 1329 |
+
tp_colors = ['#05009E', '#9793F8', '#50B2D7', '#B83CFF'][:len(tp_keys)]
|
| 1330 |
+
|
| 1331 |
+
# Plot timepoint background
|
| 1332 |
+
for t_key, color, label in zip(tp_keys, tp_colors, tp_labels_list):
|
| 1333 |
+
coords = timepoint_data[t_key]
|
| 1334 |
+
ax.scatter(coords[:, 0], coords[:, 1],
|
| 1335 |
+
c=color, s=80, alpha=0.4, marker='x',
|
| 1336 |
+
label=f'{label} cells', linewidth=1.5)
|
| 1337 |
+
|
| 1338 |
+
# Plot trajectories with color gradients
|
| 1339 |
+
custom_colors_1 = ["#05009E", "#A19EFF", "#9793F8"]
|
| 1340 |
+
custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
|
| 1341 |
+
custom_colors_3 = ["#05009E", "#A19EFF", "#D577FF"]
|
| 1342 |
+
cmaps = [
|
| 1343 |
+
LinearSegmentedColormap.from_list("tram_cmap1", custom_colors_1),
|
| 1344 |
+
LinearSegmentedColormap.from_list("tram_cmap2", custom_colors_2),
|
| 1345 |
+
LinearSegmentedColormap.from_list("tram_cmap3", custom_colors_3),
|
| 1346 |
+
]
|
| 1347 |
+
for i, traj in enumerate(all_trajs):
|
| 1348 |
+
traj_2d = traj[:, :, :2]
|
| 1349 |
+
c_end = branch_colors[i]
|
| 1350 |
+
cmap = cmaps[i]
|
| 1351 |
+
n_time = traj_2d.shape[1]
|
| 1352 |
+
color_vals = cmap(np.linspace(0, 1, n_time))
|
| 1353 |
+
segments = []
|
| 1354 |
+
seg_colors = []
|
| 1355 |
+
for j in range(traj_2d.shape[0]):
|
| 1356 |
+
pts = traj_2d[j]
|
| 1357 |
+
segs = np.stack([pts[:-1], pts[1:]], axis=1)
|
| 1358 |
+
segments.append(segs)
|
| 1359 |
+
seg_colors.append(color_vals[:-1])
|
| 1360 |
+
segments = np.concatenate(segments, axis=0)
|
| 1361 |
+
seg_colors = np.concatenate(seg_colors, axis=0)
|
| 1362 |
+
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
|
| 1363 |
+
ax.add_collection(lc)
|
| 1364 |
+
|
| 1365 |
+
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
|
| 1366 |
+
c='#05009E', s=30, marker='o',
|
| 1367 |
+
label=f'{branch_names[i]} Start',
|
| 1368 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 1369 |
+
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
|
| 1370 |
+
c=c_end, s=30, marker='o',
|
| 1371 |
+
label=f'{branch_names[i]} End',
|
| 1372 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 1373 |
+
|
| 1374 |
+
ax.set_xlabel("PC1", fontsize=14)
|
| 1375 |
+
ax.set_ylabel("PC2", fontsize=14)
|
| 1376 |
+
ax.set_title("All Branch Trajectories with Timepoint Background",
|
| 1377 |
+
fontsize=16, weight='bold')
|
| 1378 |
+
ax.grid(True, alpha=0.3)
|
| 1379 |
+
ax.legend(loc='upper right', fontsize=12, frameon=False)
|
| 1380 |
+
|
| 1381 |
+
plt.tight_layout()
|
| 1382 |
+
plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300)
|
| 1383 |
+
plt.close()
|
| 1384 |
+
|
| 1385 |
+
class FlowNetTestVeres(GrowthNetTrain):
|
| 1386 |
+
"""Test class for Veres pancreatic endocrinogenesis experiment (3 or 5 branches)."""
|
| 1387 |
+
|
| 1388 |
+
def test_step(self, batch, batch_idx):
|
| 1389 |
+
# Handle both tuple and dict batch formats from CombinedLoader
|
| 1390 |
+
if isinstance(batch, dict):
|
| 1391 |
+
main_batch = batch["test_samples"][0]
|
| 1392 |
+
metric_batch = batch["metric_samples"][0]
|
| 1393 |
+
else:
|
| 1394 |
+
# batch is a list/tuple
|
| 1395 |
+
if isinstance(batch[0], dict):
|
| 1396 |
+
# batch[0] contains the dict with test_samples and metric_samples
|
| 1397 |
+
main_batch = batch[0]["test_samples"][0]
|
| 1398 |
+
metric_batch = batch[0]["metric_samples"][0]
|
| 1399 |
+
else:
|
| 1400 |
+
# batch is a tuple: (test_samples, metric_samples)
|
| 1401 |
+
main_batch = batch[0][0]
|
| 1402 |
+
metric_batch = batch[1][0]
|
| 1403 |
+
|
| 1404 |
+
# Get timepoint data (full datasets, not just val split)
|
| 1405 |
+
timepoint_data = self.trainer.datamodule.get_timepoint_data()
|
| 1406 |
+
device = main_batch["x0"][0].device
|
| 1407 |
+
|
| 1408 |
+
# Use val x0 as initial conditions
|
| 1409 |
+
x0_all = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
|
| 1410 |
+
w0_all = torch.ones(x0_all.shape[0], 1, dtype=torch.float32).to(device)
|
| 1411 |
+
full_batch = {"x0": (x0_all, w0_all)}
|
| 1412 |
+
|
| 1413 |
+
time_points, all_endpoints, all_trajs, mass_over_time, energy_over_time, weights_over_time = self.get_mass_and_position(full_batch, metric_batch)
|
| 1414 |
+
|
| 1415 |
+
n_branches = len(self.flow_nets)
|
| 1416 |
+
|
| 1417 |
+
# trajectory time grid
|
| 1418 |
+
t_span = torch.linspace(0, 1, 101).to(device)
|
| 1419 |
+
|
| 1420 |
+
# `all_trajs` returned from `get_mass_and_position` is expected to be a list where each
|
| 1421 |
+
# element is a sequence of per-timepoint tensors for that branch (shape [B, D] each).
|
| 1422 |
+
# Convert each branch to [T, B, D] then to [B, T, D] for downstream processing.
|
| 1423 |
+
trajs_TBD = [torch.stack(branch_list, dim=0) for branch_list in all_trajs] # each is [T, B, D]
|
| 1424 |
+
trajs_BTD = [t.permute(1, 0, 2) for t in trajs_TBD] # each -> [B, T, D]
|
| 1425 |
+
|
| 1426 |
+
all_trajs = []
|
| 1427 |
+
all_endpoints = []
|
| 1428 |
+
# will store per-branch intermediate frames: each entry -> tensor [B, n_intermediate, D]
|
| 1429 |
+
all_intermediates = []
|
| 1430 |
+
|
| 1431 |
+
for traj in trajs_BTD:
|
| 1432 |
+
# traj is [B, T, D]
|
| 1433 |
+
# optionally inverse-transform if whitened
|
| 1434 |
+
if self.whiten:
|
| 1435 |
+
traj_np = traj.detach().cpu().numpy()
|
| 1436 |
+
n_samples, n_time, n_dims = traj_np.shape
|
| 1437 |
+
traj_flat = traj_np.reshape(-1, n_dims)
|
| 1438 |
+
traj_inv_flat = self.trainer.datamodule.scaler.inverse_transform(traj_flat)
|
| 1439 |
+
traj_inv = traj_inv_flat.reshape(n_samples, n_time, n_dims)
|
| 1440 |
+
traj = torch.tensor(traj_inv, dtype=torch.float32)
|
| 1441 |
+
|
| 1442 |
+
all_trajs.append(traj)
|
| 1443 |
+
|
| 1444 |
+
# Collect six evenly spaced intermediate frames between t=0 and t=1 (exclude endpoints)
|
| 1445 |
+
n_T = traj.shape[1]
|
| 1446 |
+
# choose 8 points including endpoints -> take inner 6 as intermediates
|
| 1447 |
+
inter_times = np.linspace(0.0, 1.0, 8)[1:-1] # 6 values
|
| 1448 |
+
inter_indices = [int(round(t * (n_T - 1))) for t in inter_times]
|
| 1449 |
+
# stack per-branch intermediate frames -> [B, 6, D]
|
| 1450 |
+
intermediates = torch.stack([traj[:, idx, :] for idx in inter_indices], dim=1)
|
| 1451 |
+
all_intermediates.append(intermediates)
|
| 1452 |
+
|
| 1453 |
+
# Final endpoints (t=1)
|
| 1454 |
+
all_endpoints.append(traj[:, -1, :])
|
| 1455 |
+
|
| 1456 |
+
# Run 5 trials with random subsampling for robust metrics
|
| 1457 |
+
n_trials = 5
|
| 1458 |
+
metrics_dict = {}
|
| 1459 |
+
|
| 1460 |
+
# --- Intermediate timepoints (t1-t6) combined metrics ---
|
| 1461 |
+
intermediate_keys = sorted([k for k in timepoint_data.keys()
|
| 1462 |
+
if k.startswith('t') and '_' not in k and k != 't0'])
|
| 1463 |
+
|
| 1464 |
+
if intermediate_keys:
|
| 1465 |
+
n_evals = min(6, len(intermediate_keys))
|
| 1466 |
+
for j in range(n_evals):
|
| 1467 |
+
intermediate_key = intermediate_keys[j]
|
| 1468 |
+
true_data_intermediate = torch.tensor(timepoint_data[intermediate_key], dtype=torch.float32)
|
| 1469 |
+
|
| 1470 |
+
# Gather predicted intermediates across all branches
|
| 1471 |
+
raw_intermediates = [branch[:, j, :] for branch in all_intermediates]
|
| 1472 |
+
all_raw_concat = torch.cat(raw_intermediates, dim=0).cpu() # [n_branches*B, D]
|
| 1473 |
+
|
| 1474 |
+
w1_t, w2_t, mmd_t = [], [], []
|
| 1475 |
+
w1_t_full, w2_t_full, mmd_t_full = [], [], []
|
| 1476 |
+
for trial in range(n_trials):
|
| 1477 |
+
n_min = min(all_raw_concat.shape[0], true_data_intermediate.shape[0])
|
| 1478 |
+
perm_pred = torch.randperm(all_raw_concat.shape[0])[:n_min]
|
| 1479 |
+
perm_gt = torch.randperm(true_data_intermediate.shape[0])[:n_min]
|
| 1480 |
+
# 2D metrics (PC1-PC2)
|
| 1481 |
+
m = compute_distribution_distances(
|
| 1482 |
+
all_raw_concat[perm_pred, :2], true_data_intermediate[perm_gt, :2])
|
| 1483 |
+
w1_t.append(m["W1"]); w2_t.append(m["W2"]); mmd_t.append(m["MMD"])
|
| 1484 |
+
# Full-dimensional metrics (all PCs)
|
| 1485 |
+
m_full = compute_distribution_distances(
|
| 1486 |
+
all_raw_concat[perm_pred], true_data_intermediate[perm_gt])
|
| 1487 |
+
w1_t_full.append(m_full["W1"]); w2_t_full.append(m_full["W2"]); mmd_t_full.append(m_full["MMD"])
|
| 1488 |
+
|
| 1489 |
+
metrics_dict[f'{intermediate_key}_combined'] = {
|
| 1490 |
+
"W1_mean": float(np.mean(w1_t)), "W1_std": float(np.std(w1_t, ddof=1)),
|
| 1491 |
+
"W2_mean": float(np.mean(w2_t)), "W2_std": float(np.std(w2_t, ddof=1)),
|
| 1492 |
+
"MMD_mean": float(np.mean(mmd_t)), "MMD_std": float(np.std(mmd_t, ddof=1)),
|
| 1493 |
+
"W1_full_mean": float(np.mean(w1_t_full)), "W1_full_std": float(np.std(w1_t_full, ddof=1)),
|
| 1494 |
+
"W2_full_mean": float(np.mean(w2_t_full)), "W2_full_std": float(np.std(w2_t_full, ddof=1)),
|
| 1495 |
+
"MMD_full_mean": float(np.mean(mmd_t_full)), "MMD_full_std": float(np.std(mmd_t_full, ddof=1)),
|
| 1496 |
+
}
|
| 1497 |
+
self.log(f"test/W1_{intermediate_key}_combined", np.mean(w1_t), on_epoch=True)
|
| 1498 |
+
self.log(f"test/W1_full_{intermediate_key}_combined", np.mean(w1_t_full), on_epoch=True)
|
| 1499 |
+
print(f"{intermediate_key} combined — W1: {np.mean(w1_t):.6f}±{np.std(w1_t, ddof=1):.6f}, "
|
| 1500 |
+
f"W2: {np.mean(w2_t):.6f}±{np.std(w2_t, ddof=1):.6f}, "
|
| 1501 |
+
f"MMD: {np.mean(mmd_t):.6f}±{np.std(mmd_t, ddof=1):.6f}")
|
| 1502 |
+
print(f"{intermediate_key} combined (full) — W1: {np.mean(w1_t_full):.6f}±{np.std(w1_t_full, ddof=1):.6f}, "
|
| 1503 |
+
f"W2: {np.mean(w2_t_full):.6f}±{np.std(w2_t_full, ddof=1):.6f}, "
|
| 1504 |
+
f"MMD: {np.mean(mmd_t_full):.6f}±{np.std(mmd_t_full, ddof=1):.6f}")
|
| 1505 |
+
|
| 1506 |
+
# --- Final timepoint per-branch metrics ---
|
| 1507 |
+
gt_keys = sorted([k for k in timepoint_data.keys() if k.startswith('t7_')])
|
| 1508 |
+
for i, endpoints in enumerate(all_endpoints):
|
| 1509 |
+
true_data_key = f"t7_{i}"
|
| 1510 |
+
if true_data_key not in timepoint_data:
|
| 1511 |
+
print(f"Warning: {true_data_key} not found in timepoint_data")
|
| 1512 |
+
continue
|
| 1513 |
+
gt = torch.tensor(timepoint_data[true_data_key], dtype=torch.float32)
|
| 1514 |
+
pred = endpoints.cpu()
|
| 1515 |
+
|
| 1516 |
+
w1_br, w2_br, mmd_br = [], [], []
|
| 1517 |
+
w1_br_full, w2_br_full, mmd_br_full = [], [], []
|
| 1518 |
+
for trial in range(n_trials):
|
| 1519 |
+
n_min = min(pred.shape[0], gt.shape[0])
|
| 1520 |
+
perm_pred = torch.randperm(pred.shape[0])[:n_min]
|
| 1521 |
+
perm_gt = torch.randperm(gt.shape[0])[:n_min]
|
| 1522 |
+
# 2D metrics (PC1-PC2)
|
| 1523 |
+
m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2])
|
| 1524 |
+
w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"])
|
| 1525 |
+
# Full-dimensional metrics (all PCs)
|
| 1526 |
+
m_full = compute_distribution_distances(pred[perm_pred], gt[perm_gt])
|
| 1527 |
+
w1_br_full.append(m_full["W1"]); w2_br_full.append(m_full["W2"]); mmd_br_full.append(m_full["MMD"])
|
| 1528 |
+
|
| 1529 |
+
metrics_dict[f"branch_{i}"] = {
|
| 1530 |
+
"W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
|
| 1531 |
+
"W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
|
| 1532 |
+
"MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
|
| 1533 |
+
"W1_full_mean": float(np.mean(w1_br_full)), "W1_full_std": float(np.std(w1_br_full, ddof=1)),
|
| 1534 |
+
"W2_full_mean": float(np.mean(w2_br_full)), "W2_full_std": float(np.std(w2_br_full, ddof=1)),
|
| 1535 |
+
"MMD_full_mean": float(np.mean(mmd_br_full)), "MMD_full_std": float(np.std(mmd_br_full, ddof=1)),
|
| 1536 |
+
}
|
| 1537 |
+
self.log(f"test/W1_branch{i}", np.mean(w1_br), on_epoch=True)
|
| 1538 |
+
self.log(f"test/W1_full_branch{i}", np.mean(w1_br_full), on_epoch=True)
|
| 1539 |
+
print(f"Branch {i} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
|
| 1540 |
+
f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
|
| 1541 |
+
f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
|
| 1542 |
+
print(f"Branch {i} (full) — W1: {np.mean(w1_br_full):.6f}±{np.std(w1_br_full, ddof=1):.6f}, "
|
| 1543 |
+
f"W2: {np.mean(w2_br_full):.6f}±{np.std(w2_br_full, ddof=1):.6f}, "
|
| 1544 |
+
f"MMD: {np.mean(mmd_br_full):.6f}±{np.std(mmd_br_full, ddof=1):.6f}")
|
| 1545 |
+
|
| 1546 |
+
# --- Final timepoint combined metrics ---
|
| 1547 |
+
gt_list = [torch.tensor(timepoint_data[k], dtype=torch.float32) for k in gt_keys]
|
| 1548 |
+
if len(gt_list) > 0 and len(all_endpoints) > 0:
|
| 1549 |
+
gt_all = torch.cat(gt_list, dim=0)
|
| 1550 |
+
pred_all = torch.cat([e.cpu() for e in all_endpoints], dim=0)
|
| 1551 |
+
|
| 1552 |
+
w1_trials, w2_trials, mmd_trials = [], [], []
|
| 1553 |
+
w1_trials_full, w2_trials_full, mmd_trials_full = [], [], []
|
| 1554 |
+
for trial in range(n_trials):
|
| 1555 |
+
n_min = min(pred_all.shape[0], gt_all.shape[0])
|
| 1556 |
+
perm_pred = torch.randperm(pred_all.shape[0])[:n_min]
|
| 1557 |
+
perm_gt = torch.randperm(gt_all.shape[0])[:n_min]
|
| 1558 |
+
# 2D metrics (PC1-PC2)
|
| 1559 |
+
m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2])
|
| 1560 |
+
w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"])
|
| 1561 |
+
# Full-dimensional metrics (all PCs)
|
| 1562 |
+
m_full = compute_distribution_distances(pred_all[perm_pred], gt_all[perm_gt])
|
| 1563 |
+
w1_trials_full.append(m_full["W1"]); w2_trials_full.append(m_full["W2"]); mmd_trials_full.append(m_full["MMD"])
|
| 1564 |
+
|
| 1565 |
+
w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1)
|
| 1566 |
+
w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1)
|
| 1567 |
+
mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1)
|
| 1568 |
+
w1_mean_f, w1_std_f = np.mean(w1_trials_full), np.std(w1_trials_full, ddof=1)
|
| 1569 |
+
w2_mean_f, w2_std_f = np.mean(w2_trials_full), np.std(w2_trials_full, ddof=1)
|
| 1570 |
+
mmd_mean_f, mmd_std_f = np.mean(mmd_trials_full), np.std(mmd_trials_full, ddof=1)
|
| 1571 |
+
self.log("test/W1_t7_combined", w1_mean, on_epoch=True)
|
| 1572 |
+
self.log("test/W2_t7_combined", w2_mean, on_epoch=True)
|
| 1573 |
+
self.log("test/MMD_t7_combined", mmd_mean, on_epoch=True)
|
| 1574 |
+
self.log("test/W1_full_t7_combined", w1_mean_f, on_epoch=True)
|
| 1575 |
+
self.log("test/W2_full_t7_combined", w2_mean_f, on_epoch=True)
|
| 1576 |
+
self.log("test/MMD_full_t7_combined", mmd_mean_f, on_epoch=True)
|
| 1577 |
+
metrics_dict['t7_combined'] = {
|
| 1578 |
+
"W1_mean": float(w1_mean), "W1_std": float(w1_std),
|
| 1579 |
+
"W2_mean": float(w2_mean), "W2_std": float(w2_std),
|
| 1580 |
+
"MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std),
|
| 1581 |
+
"W1_full_mean": float(w1_mean_f), "W1_full_std": float(w1_std_f),
|
| 1582 |
+
"W2_full_mean": float(w2_mean_f), "W2_full_std": float(w2_std_f),
|
| 1583 |
+
"MMD_full_mean": float(mmd_mean_f), "MMD_full_std": float(mmd_std_f),
|
| 1584 |
+
"n_trials": n_trials,
|
| 1585 |
+
}
|
| 1586 |
+
print(f"\n=== Combined @ t7 ===")
|
| 1587 |
+
print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}")
|
| 1588 |
+
print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}")
|
| 1589 |
+
print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}")
|
| 1590 |
+
print(f"W1 (full): {w1_mean_f:.6f} ± {w1_std_f:.6f}")
|
| 1591 |
+
print(f"W2 (full): {w2_mean_f:.6f} ± {w2_std_f:.6f}")
|
| 1592 |
+
print(f"MMD (full): {mmd_mean_f:.6f} ± {mmd_std_f:.6f}")
|
| 1593 |
+
|
| 1594 |
+
# Create results directory structure
|
| 1595 |
+
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
|
| 1596 |
+
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
|
| 1597 |
+
figures_dir = f'{results_dir}/figures'
|
| 1598 |
+
os.makedirs(figures_dir, exist_ok=True)
|
| 1599 |
+
|
| 1600 |
+
# Save metrics to JSON
|
| 1601 |
+
metrics_path = f'{results_dir}/metrics.json'
|
| 1602 |
+
with open(metrics_path, 'w') as f:
|
| 1603 |
+
json.dump(metrics_dict, f, indent=2)
|
| 1604 |
+
print(f"Metrics saved to {metrics_path}")
|
| 1605 |
+
|
| 1606 |
+
# Save detailed metrics to CSV
|
| 1607 |
+
detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
|
| 1608 |
+
with open(detailed_csv_path, 'w', newline='') as csvfile:
|
| 1609 |
+
writer = csv.writer(csvfile)
|
| 1610 |
+
writer.writerow(['Metric_Group',
|
| 1611 |
+
'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std',
|
| 1612 |
+
'W1_Full_Mean', 'W1_Full_Std', 'W2_Full_Mean', 'W2_Full_Std', 'MMD_Full_Mean', 'MMD_Full_Std'])
|
| 1613 |
+
for key in sorted(metrics_dict.keys()):
|
| 1614 |
+
m = metrics_dict[key]
|
| 1615 |
+
writer.writerow([key,
|
| 1616 |
+
f'{m.get("W1_mean", 0):.6f}', f'{m.get("W1_std", 0):.6f}',
|
| 1617 |
+
f'{m.get("W2_mean", 0):.6f}', f'{m.get("W2_std", 0):.6f}',
|
| 1618 |
+
f'{m.get("MMD_mean", 0):.6f}', f'{m.get("MMD_std", 0):.6f}',
|
| 1619 |
+
f'{m.get("W1_full_mean", 0):.6f}', f'{m.get("W1_full_std", 0):.6f}',
|
| 1620 |
+
f'{m.get("W2_full_mean", 0):.6f}', f'{m.get("W2_full_std", 0):.6f}',
|
| 1621 |
+
f'{m.get("MMD_full_mean", 0):.6f}', f'{m.get("MMD_full_std", 0):.6f}'])
|
| 1622 |
+
print(f"Detailed metrics CSV saved to {detailed_csv_path}")
|
| 1623 |
+
|
| 1624 |
+
# ===== Plot branches =====
|
| 1625 |
+
self._plot_veres_branches(all_trajs, timepoint_data, figures_dir, n_branches)
|
| 1626 |
+
self._plot_veres_combined(all_trajs, timepoint_data, figures_dir, n_branches)
|
| 1627 |
+
|
| 1628 |
+
print(f"Veres figures saved to {figures_dir}")
|
| 1629 |
+
|
| 1630 |
+
def _plot_veres_branches(self, all_trajs, timepoint_data, save_dir, n_branches):
|
| 1631 |
+
"""Plot each branch separately in PCA space (PC1 vs PC2)."""
|
| 1632 |
+
branch_colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DFE6E9',
|
| 1633 |
+
'#74B9FF', '#A29BFE', '#FFB74D', '#AED581', '#F06292', '#BA68C8',
|
| 1634 |
+
'#4DB6AC', '#81C784', '#FFD54F', '#90A4AE', '#F48FB1', '#CE93D8',
|
| 1635 |
+
'#64B5F6', '#C5E1A5']
|
| 1636 |
+
|
| 1637 |
+
# Project to first 2 PCs (data is already in PCA space)
|
| 1638 |
+
t0_2d = timepoint_data['t0'].cpu().numpy()[:, :2]
|
| 1639 |
+
t7_2d = [timepoint_data[f't7_{i}'].cpu().numpy()[:, :2] for i in range(n_branches)]
|
| 1640 |
+
|
| 1641 |
+
# Slice trajectories to first 2 PCs
|
| 1642 |
+
trajs_2d = []
|
| 1643 |
+
for traj in all_trajs:
|
| 1644 |
+
trajs_2d.append(traj.cpu().numpy()[:, :, :2]) # [n_samples, n_time, 2]
|
| 1645 |
+
|
| 1646 |
+
# Compute global axis limits
|
| 1647 |
+
all_coords = [t0_2d] + t7_2d
|
| 1648 |
+
for traj_2d in trajs_2d:
|
| 1649 |
+
all_coords.append(traj_2d.reshape(-1, 2))
|
| 1650 |
+
|
| 1651 |
+
all_coords = np.concatenate(all_coords, axis=0)
|
| 1652 |
+
x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
|
| 1653 |
+
y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()
|
| 1654 |
+
|
| 1655 |
+
x_margin = 0.05 * (x_max - x_min)
|
| 1656 |
+
y_margin = 0.05 * (y_max - y_min)
|
| 1657 |
+
x_min -= x_margin
|
| 1658 |
+
x_max += x_margin
|
| 1659 |
+
y_min -= y_margin
|
| 1660 |
+
y_max += y_margin
|
| 1661 |
+
|
| 1662 |
+
for i, traj_2d in enumerate(trajs_2d):
|
| 1663 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 1664 |
+
c_end = branch_colors[i % len(branch_colors)]
|
| 1665 |
+
|
| 1666 |
+
# Plot timepoint background
|
| 1667 |
+
ax.scatter(t0_2d[:, 0], t0_2d[:, 1],
|
| 1668 |
+
c='#05009E', s=80, alpha=0.4, marker='x',
|
| 1669 |
+
label='t=0 cells', linewidth=1.5)
|
| 1670 |
+
ax.scatter(t7_2d[i][:, 0], t7_2d[i][:, 1],
|
| 1671 |
+
c=c_end, s=80, alpha=0.4, marker='x',
|
| 1672 |
+
label=f't=7 (branch {i+1}) cells', linewidth=1.5)
|
| 1673 |
+
|
| 1674 |
+
# Plot continuous trajectories with LineCollection for speed
|
| 1675 |
+
cmap_colors = ["#05009E", "#A19EFF", c_end]
|
| 1676 |
+
cmap = LinearSegmentedColormap.from_list(f"veres_cmap_{i}", cmap_colors)
|
| 1677 |
+
n_time = traj_2d.shape[1]
|
| 1678 |
+
segments = []
|
| 1679 |
+
seg_colors = []
|
| 1680 |
+
color_vals = cmap(np.linspace(0, 1, n_time))
|
| 1681 |
+
for j in range(traj_2d.shape[0]):
|
| 1682 |
+
pts = traj_2d[j] # [T, 2]
|
| 1683 |
+
segs = np.stack([pts[:-1], pts[1:]], axis=1) # [T-1, 2, 2]
|
| 1684 |
+
segments.append(segs)
|
| 1685 |
+
seg_colors.append(color_vals[:-1])
|
| 1686 |
+
segments = np.concatenate(segments, axis=0)
|
| 1687 |
+
seg_colors = np.concatenate(seg_colors, axis=0)
|
| 1688 |
+
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
|
| 1689 |
+
ax.add_collection(lc)
|
| 1690 |
+
|
| 1691 |
+
# Start and end points
|
| 1692 |
+
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
|
| 1693 |
+
c='#05009E', s=30, marker='o', label='Trajectory start (t=0)',
|
| 1694 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 1695 |
+
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
|
| 1696 |
+
c=c_end, s=30, marker='o', label='Trajectory end (t=1)',
|
| 1697 |
+
zorder=5, edgecolors='white', linewidth=1)
|
| 1698 |
+
|
| 1699 |
+
ax.set_xlim(x_min, x_max)
|
| 1700 |
+
ax.set_ylim(y_min, y_max)
|
| 1701 |
+
ax.set_xlabel("PC 1", fontsize=12)
|
| 1702 |
+
ax.set_ylabel("PC 2", fontsize=12)
|
| 1703 |
+
ax.set_title(f"Branch {i+1}: Trajectories (PCA)", fontsize=14)
|
| 1704 |
+
ax.grid(True, alpha=0.3)
|
| 1705 |
+
ax.legend(loc='upper right', fontsize=9, frameon=False)
|
| 1706 |
+
|
| 1707 |
+
plt.tight_layout()
|
| 1708 |
+
plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300)
|
| 1709 |
+
plt.close()
|
| 1710 |
+
|
| 1711 |
+
def _plot_veres_combined(self, all_trajs, timepoint_data, save_dir, n_branches):
|
| 1712 |
+
"""Plot all branches together in PCA space (PC1 vs PC2)."""
|
| 1713 |
+
branch_colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DFE6E9',
|
| 1714 |
+
'#74B9FF', '#A29BFE', '#FFB74D', '#AED581', '#F06292', '#BA68C8',
|
| 1715 |
+
'#4DB6AC', '#81C784', '#FFD54F', '#90A4AE', '#F48FB1', '#CE93D8',
|
| 1716 |
+
'#64B5F6', '#C5E1A5']
|
| 1717 |
+
|
| 1718 |
+
# Project to first 2 PCs (data is already in PCA space)
|
| 1719 |
+
t0_2d = timepoint_data['t0'].cpu().numpy()[:, :2]
|
| 1720 |
+
t7_2d = [timepoint_data[f't7_{i}'].cpu().numpy()[:, :2] for i in range(n_branches)]
|
| 1721 |
+
|
| 1722 |
+
# Slice trajectories to first 2 PCs
|
| 1723 |
+
trajs_2d = []
|
| 1724 |
+
for traj in all_trajs:
|
| 1725 |
+
trajs_2d.append(traj.cpu().numpy()[:, :, :2]) # [n_samples, n_time, 2]
|
| 1726 |
+
|
| 1727 |
+
# Compute axis limits from REAL CELLS ONLY
|
| 1728 |
+
all_coords_real = [t0_2d] + t7_2d
|
| 1729 |
+
all_coords_real = np.concatenate(all_coords_real, axis=0)
|
| 1730 |
+
x_min, x_max = all_coords_real[:, 0].min(), all_coords_real[:, 0].max()
|
| 1731 |
+
y_min, y_max = all_coords_real[:, 1].min(), all_coords_real[:, 1].max()
|
| 1732 |
+
x_margin = 0.05 * (x_max - x_min)
|
| 1733 |
+
y_margin = 0.05 * (y_max - y_min)
|
| 1734 |
+
x_min -= x_margin
|
| 1735 |
+
x_max += x_margin
|
| 1736 |
+
y_min -= y_margin
|
| 1737 |
+
y_max += y_margin
|
| 1738 |
+
|
| 1739 |
+
fig, ax = plt.subplots(figsize=(14, 12))
|
| 1740 |
+
ax.set_xlim(x_min, x_max)
|
| 1741 |
+
ax.set_ylim(y_min, y_max)
|
| 1742 |
+
|
| 1743 |
+
# Plot t=0 cells
|
| 1744 |
+
ax.scatter(t0_2d[:, 0], t0_2d[:, 1],
|
| 1745 |
+
c='#05009E', s=60, alpha=0.3, marker='x',
|
| 1746 |
+
label='t=0 cells', linewidth=1.5)
|
| 1747 |
+
|
| 1748 |
+
# Plot each branch's cells and trajectories
|
| 1749 |
+
for i, traj_2d in enumerate(trajs_2d):
|
| 1750 |
+
c_end = branch_colors[i % len(branch_colors)]
|
| 1751 |
+
|
| 1752 |
+
# Plot t=7 cells for this branch
|
| 1753 |
+
ax.scatter(t7_2d[i][:, 0], t7_2d[i][:, 1],
|
| 1754 |
+
c=c_end, s=60, alpha=0.3, marker='x',
|
| 1755 |
+
label=f't=7 (branch {i+1})', linewidth=1.5)
|
| 1756 |
+
|
| 1757 |
+
# Plot continuous trajectories with LineCollection for speed
|
| 1758 |
+
cmap_colors = ["#05009E", "#A19EFF", c_end]
|
| 1759 |
+
cmap = LinearSegmentedColormap.from_list(f"veres_combined_cmap_{i}", cmap_colors)
|
| 1760 |
+
n_time = traj_2d.shape[1]
|
| 1761 |
+
segments = []
|
| 1762 |
+
seg_colors = []
|
| 1763 |
+
color_vals = cmap(np.linspace(0, 1, n_time))
|
| 1764 |
+
for j in range(traj_2d.shape[0]):
|
| 1765 |
+
pts = traj_2d[j] # [T, 2]
|
| 1766 |
+
segs = np.stack([pts[:-1], pts[1:]], axis=1) # [T-1, 2, 2]
|
| 1767 |
+
segments.append(segs)
|
| 1768 |
+
seg_colors.append(color_vals[:-1])
|
| 1769 |
+
segments = np.concatenate(segments, axis=0)
|
| 1770 |
+
seg_colors = np.concatenate(seg_colors, axis=0)
|
| 1771 |
+
lc = LineCollection(segments, colors=seg_colors, linewidths=1.5, alpha=0.6)
|
| 1772 |
+
ax.add_collection(lc)
|
| 1773 |
+
|
| 1774 |
+
# Start and end points
|
| 1775 |
+
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
|
| 1776 |
+
c='#05009E', s=20, marker='o',
|
| 1777 |
+
zorder=5, edgecolors='white', linewidth=0.5, alpha=0.7)
|
| 1778 |
+
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
|
| 1779 |
+
c=c_end, s=20, marker='o',
|
| 1780 |
+
zorder=5, edgecolors='white', linewidth=0.5, alpha=0.7)
|
| 1781 |
+
|
| 1782 |
+
ax.set_xlabel("PC 1", fontsize=14)
|
| 1783 |
+
ax.set_ylabel("PC 2", fontsize=14)
|
| 1784 |
+
ax.set_title(f"All {n_branches} Branch Trajectories (Veres) - PCA Projection",
|
| 1785 |
+
fontsize=16, weight='bold')
|
| 1786 |
+
ax.grid(True, alpha=0.3)
|
| 1787 |
+
ax.legend(loc='upper right', fontsize=10, frameon=False, ncol=2)
|
| 1788 |
+
|
| 1789 |
+
plt.tight_layout()
|
| 1790 |
+
plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300)
|
| 1791 |
+
plt.close()
|
src/branch_flow_net_train.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import wandb
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
from torch.optim import AdamW
|
| 8 |
+
from torchmetrics.functional import mean_squared_error
|
| 9 |
+
from torchdyn.core import NeuralODE
|
| 10 |
+
from .networks.utils import flow_model_torch_wrapper
|
| 11 |
+
from .utils import wasserstein, plot_lidar
|
| 12 |
+
from .ema import EMA
|
| 13 |
+
|
| 14 |
+
class BranchFlowNetTrainBase(pl.LightningModule):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
flow_matcher,
|
| 18 |
+
flow_nets,
|
| 19 |
+
skipped_time_points=None,
|
| 20 |
+
ot_sampler=None,
|
| 21 |
+
args=None,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.args = args
|
| 25 |
+
|
| 26 |
+
self.flow_matcher = flow_matcher
|
| 27 |
+
self.flow_nets = flow_nets # list of flow networks for each branch
|
| 28 |
+
self.ot_sampler = ot_sampler
|
| 29 |
+
self.skipped_time_points = skipped_time_points
|
| 30 |
+
|
| 31 |
+
self.optimizer_name = args.flow_optimizer
|
| 32 |
+
self.lr = args.flow_lr
|
| 33 |
+
self.weight_decay = args.flow_weight_decay
|
| 34 |
+
self.whiten = args.whiten
|
| 35 |
+
self.working_dir = args.working_dir
|
| 36 |
+
|
| 37 |
+
#branching
|
| 38 |
+
self.branches = len(flow_nets)
|
| 39 |
+
|
| 40 |
+
def forward(self, t, xt, branch_idx):
|
| 41 |
+
# output velocity given branch_idx
|
| 42 |
+
return self.flow_nets[branch_idx](t, xt)
|
| 43 |
+
|
| 44 |
+
def _compute_loss(self, main_batch):
|
| 45 |
+
|
| 46 |
+
x0s = [main_batch["x0"][0]]
|
| 47 |
+
w0s = [main_batch["x0"][1]]
|
| 48 |
+
|
| 49 |
+
x1s_list = []
|
| 50 |
+
w1s_list = []
|
| 51 |
+
|
| 52 |
+
if self.branches > 1:
|
| 53 |
+
for i in range(self.branches):
|
| 54 |
+
x1s_list.append([main_batch[f"x1_{i+1}"][0]])
|
| 55 |
+
w1s_list.append([main_batch[f"x1_{i+1}"][1]])
|
| 56 |
+
else:
|
| 57 |
+
x1s_list.append([main_batch["x1"][0]])
|
| 58 |
+
w1s_list.append([main_batch["x1"][1]])
|
| 59 |
+
|
| 60 |
+
assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches"
|
| 61 |
+
|
| 62 |
+
loss = 0
|
| 63 |
+
for branch_idx in range(self.branches):
|
| 64 |
+
ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx)
|
| 65 |
+
|
| 66 |
+
t = torch.cat(ts)
|
| 67 |
+
xt = torch.cat(xts)
|
| 68 |
+
ut = torch.cat(uts)
|
| 69 |
+
vt = self(t[:, None], xt, branch_idx)
|
| 70 |
+
|
| 71 |
+
loss += mean_squared_error(vt, ut)
|
| 72 |
+
|
| 73 |
+
return loss
|
| 74 |
+
|
| 75 |
+
def _process_flow(self, x0s, x1s, branch_idx):
|
| 76 |
+
ts, xts, uts = [], [], []
|
| 77 |
+
t_start = self.timesteps[0]
|
| 78 |
+
|
| 79 |
+
for i, (x0, x1) in enumerate(zip(x0s, x1s)):
|
| 80 |
+
|
| 81 |
+
x0, x1 = torch.squeeze(x0), torch.squeeze(x1)
|
| 82 |
+
|
| 83 |
+
if self.ot_sampler is not None:
|
| 84 |
+
x0, x1 = self.ot_sampler.sample_plan(
|
| 85 |
+
x0,
|
| 86 |
+
x1,
|
| 87 |
+
replace=True,
|
| 88 |
+
)
|
| 89 |
+
if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]:
|
| 90 |
+
t_start_next = self.timesteps[i + 2]
|
| 91 |
+
else:
|
| 92 |
+
t_start_next = self.timesteps[i + 1]
|
| 93 |
+
|
| 94 |
+
# edit to sample from correct flow matcher
|
| 95 |
+
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(
|
| 96 |
+
x0, x1, t_start, t_start_next, branch_idx
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
ts.append(t)
|
| 100 |
+
|
| 101 |
+
xts.append(xt)
|
| 102 |
+
uts.append(ut)
|
| 103 |
+
t_start = t_start_next
|
| 104 |
+
return ts, xts, uts
|
| 105 |
+
|
| 106 |
+
def training_step(self, batch, batch_idx):
|
| 107 |
+
# Handle both dict and tuple batch formats from CombinedLoader
|
| 108 |
+
if isinstance(batch, (list, tuple)):
|
| 109 |
+
batch = batch[0]
|
| 110 |
+
if isinstance(batch, dict) and "train_samples" in batch:
|
| 111 |
+
main_batch = batch["train_samples"]
|
| 112 |
+
if isinstance(main_batch, tuple):
|
| 113 |
+
main_batch = main_batch[0]
|
| 114 |
+
else:
|
| 115 |
+
# Fallback
|
| 116 |
+
main_batch = batch.get("train_samples", batch)
|
| 117 |
+
|
| 118 |
+
print("Main batch length")
|
| 119 |
+
print(len(main_batch["x0"]))
|
| 120 |
+
|
| 121 |
+
# edited to simulate 100 steps
|
| 122 |
+
self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
|
| 123 |
+
loss = self._compute_loss(main_batch)
|
| 124 |
+
if self.flow_matcher.alpha != 0:
|
| 125 |
+
self.log(
|
| 126 |
+
"FlowNet/mean_geopath_cfm",
|
| 127 |
+
(self.flow_matcher.geopath_net_output.abs().mean()),
|
| 128 |
+
on_step=False,
|
| 129 |
+
on_epoch=True,
|
| 130 |
+
prog_bar=True,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.log(
|
| 134 |
+
"FlowNet/train_loss_cfm",
|
| 135 |
+
loss,
|
| 136 |
+
on_step=False,
|
| 137 |
+
on_epoch=True,
|
| 138 |
+
prog_bar=True,
|
| 139 |
+
logger=True,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
return loss
|
| 144 |
+
|
| 145 |
+
def validation_step(self, batch, batch_idx):
|
| 146 |
+
# Handle both dict and tuple batch formats from CombinedLoader
|
| 147 |
+
if isinstance(batch, (list, tuple)):
|
| 148 |
+
batch = batch[0]
|
| 149 |
+
if isinstance(batch, dict) and "val_samples" in batch:
|
| 150 |
+
main_batch = batch["val_samples"]
|
| 151 |
+
if isinstance(main_batch, tuple):
|
| 152 |
+
main_batch = main_batch[0]
|
| 153 |
+
else:
|
| 154 |
+
# Fallback
|
| 155 |
+
main_batch = batch.get("val_samples", batch)
|
| 156 |
+
|
| 157 |
+
self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
|
| 158 |
+
val_loss = self._compute_loss(main_batch)
|
| 159 |
+
self.log(
|
| 160 |
+
"FlowNet/val_loss_cfm",
|
| 161 |
+
val_loss,
|
| 162 |
+
on_step=False,
|
| 163 |
+
on_epoch=True,
|
| 164 |
+
prog_bar=True,
|
| 165 |
+
logger=True,
|
| 166 |
+
)
|
| 167 |
+
return val_loss
|
| 168 |
+
|
| 169 |
+
def optimizer_step(self, *args, **kwargs):
|
| 170 |
+
super().optimizer_step(*args, **kwargs)
|
| 171 |
+
|
| 172 |
+
for net in self.flow_nets:
|
| 173 |
+
if isinstance(net, EMA):
|
| 174 |
+
net.update_ema()
|
| 175 |
+
|
| 176 |
+
def configure_optimizers(self):
|
| 177 |
+
if self.optimizer_name == "adamw":
|
| 178 |
+
optimizer = AdamW(
|
| 179 |
+
self.parameters(),
|
| 180 |
+
lr=self.lr,
|
| 181 |
+
weight_decay=self.weight_decay,
|
| 182 |
+
)
|
| 183 |
+
elif self.optimizer_name == "adam":
|
| 184 |
+
optimizer = torch.optim.Adam(
|
| 185 |
+
self.parameters(),
|
| 186 |
+
lr=self.lr,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return optimizer
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class FlowNetTrainTrajectory(BranchFlowNetTrainBase):
|
| 193 |
+
def test_step(self, batch, batch_idx):
|
| 194 |
+
data_type = self.args.data_type
|
| 195 |
+
node = NeuralODE(
|
| 196 |
+
flow_model_torch_wrapper(self.flow_nets),
|
| 197 |
+
solver="euler",
|
| 198 |
+
sensitivity="adjoint",
|
| 199 |
+
atol=1e-5,
|
| 200 |
+
rtol=1e-5,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
t_exclude = self.skipped_time_points[0] if self.skipped_time_points else None
|
| 204 |
+
if t_exclude is not None:
|
| 205 |
+
traj = node.trajectory(
|
| 206 |
+
batch[t_exclude - 1],
|
| 207 |
+
t_span=torch.linspace(
|
| 208 |
+
self.timesteps[t_exclude - 1], self.timesteps[t_exclude], 101
|
| 209 |
+
),
|
| 210 |
+
)
|
| 211 |
+
X_mid_pred = traj[-1]
|
| 212 |
+
traj = node.trajectory(
|
| 213 |
+
batch[t_exclude - 1],
|
| 214 |
+
t_span=torch.linspace(
|
| 215 |
+
self.timesteps[t_exclude - 1],
|
| 216 |
+
self.timesteps[t_exclude + 1],
|
| 217 |
+
101,
|
| 218 |
+
),
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
EMD = wasserstein(X_mid_pred, batch[t_exclude], p=1)
|
| 222 |
+
self.final_EMD = EMD
|
| 223 |
+
|
| 224 |
+
self.log("test_EMD", EMD, on_step=False, on_epoch=True, prog_bar=True)
|
| 225 |
+
|
| 226 |
+
class FlowNetTrainCell(BranchFlowNetTrainBase):
|
| 227 |
+
def test_step(self, batch, batch_idx):
|
| 228 |
+
x0 = batch[0]["test_samples"][0]["x0"][0] # [B, D]
|
| 229 |
+
dataset_points = batch[0]["test_samples"][0]["dataset"][0] # full dataset, [N, D]
|
| 230 |
+
t_span = torch.linspace(0, 1, 101)
|
| 231 |
+
|
| 232 |
+
all_trajs = []
|
| 233 |
+
|
| 234 |
+
for i, flow_net in enumerate(self.flow_nets):
|
| 235 |
+
node = NeuralODE(
|
| 236 |
+
flow_model_torch_wrapper(flow_net),
|
| 237 |
+
solver="euler",
|
| 238 |
+
sensitivity="adjoint",
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
with torch.no_grad():
|
| 242 |
+
traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
|
| 243 |
+
|
| 244 |
+
if self.whiten:
|
| 245 |
+
traj_shape = traj.shape
|
| 246 |
+
traj = traj.reshape(-1, traj.shape[-1])
|
| 247 |
+
traj = self.trainer.datamodule.scaler.inverse_transform(
|
| 248 |
+
traj.cpu().detach().numpy()
|
| 249 |
+
).reshape(traj_shape)
|
| 250 |
+
dataset_points = self.trainer.datamodule.scaler.inverse_transform(
|
| 251 |
+
dataset_points.cpu().detach().numpy()
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
traj = torch.tensor(traj)
|
| 255 |
+
traj = torch.transpose(traj, 0, 1) # [B, T, D]
|
| 256 |
+
all_trajs.append(traj)
|
| 257 |
+
|
| 258 |
+
dataset_2d = dataset_points[:, :2] if isinstance(dataset_points, torch.Tensor) else dataset_points[:, :2]
|
| 259 |
+
|
| 260 |
+
# ===== Plot all 2D trajectories together with dataset and start/end points =====
|
| 261 |
+
fig, ax = plt.subplots(figsize=(6, 5))
|
| 262 |
+
dataset_2d = dataset_2d.cpu().numpy()
|
| 263 |
+
ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1)
|
| 264 |
+
for traj in all_trajs:
|
| 265 |
+
traj_2d = traj[..., :2] # [B, T, 2]
|
| 266 |
+
for i in range(traj_2d.shape[0]):
|
| 267 |
+
ax.plot(traj_2d[i, :, 0], traj_2d[i, :, 1], alpha=0.8, zorder=2)
|
| 268 |
+
ax.scatter(traj_2d[i, 0, 0], traj_2d[i, 0, 1], c='green', s=10, label="t=0" if i == 0 else "", zorder=3)
|
| 269 |
+
ax.scatter(traj_2d[i, -1, 0], traj_2d[i, -1, 1], c='red', s=10, label="t=1" if i == 0 else "", zorder=3)
|
| 270 |
+
|
| 271 |
+
ax.set_title("All Branch Trajectories (2D) with Dataset")
|
| 272 |
+
ax.set_xlabel("x")
|
| 273 |
+
ax.set_ylabel("y")
|
| 274 |
+
plt.axis("equal")
|
| 275 |
+
handles, labels = ax.get_legend_handles_labels()
|
| 276 |
+
if labels:
|
| 277 |
+
ax.legend()
|
| 278 |
+
|
| 279 |
+
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
|
| 280 |
+
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
|
| 281 |
+
save_path = os.path.join(results_dir, 'figures')
|
| 282 |
+
|
| 283 |
+
os.makedirs(save_path, exist_ok=True)
|
| 284 |
+
plt.savefig(f'{save_path}/{self.args.data_name}_all_branches.png', dpi=300)
|
| 285 |
+
plt.close()
|
| 286 |
+
|
| 287 |
+
# ===== Plot each 2D trajectory separately with dataset and endpoints =====
|
| 288 |
+
for i, traj in enumerate(all_trajs):
|
| 289 |
+
traj_2d = traj[..., :2]
|
| 290 |
+
fig, ax = plt.subplots(figsize=(6, 5))
|
| 291 |
+
ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1)
|
| 292 |
+
for j in range(traj_2d.shape[0]):
|
| 293 |
+
ax.plot(traj_2d[j, :, 0], traj_2d[j, :, 1], alpha=0.9, zorder=2)
|
| 294 |
+
ax.scatter(traj_2d[j, 0, 0], traj_2d[j, 0, 1], c='green', s=12, label="t=0" if j == 0 else "", zorder=3)
|
| 295 |
+
ax.scatter(traj_2d[j, -1, 0], traj_2d[j, -1, 1], c='red', s=12, label="t=1" if j == 0 else "", zorder=3)
|
| 296 |
+
|
| 297 |
+
ax.set_title(f"Branch {i + 1} Trajectories (2D) with Dataset")
|
| 298 |
+
ax.set_xlabel("x")
|
| 299 |
+
ax.set_ylabel("y")
|
| 300 |
+
plt.axis("equal")
|
| 301 |
+
handles, labels = ax.get_legend_handles_labels()
|
| 302 |
+
if labels:
|
| 303 |
+
ax.legend()
|
| 304 |
+
plt.savefig(f'{save_path}/{self.args.data_name}_branch_{i + 1}.png', dpi=300)
|
| 305 |
+
plt.close()
|
| 306 |
+
|
| 307 |
+
class FlowNetTrainLidar(BranchFlowNetTrainBase):
|
| 308 |
+
def test_step(self, batch, batch_idx):
|
| 309 |
+
# Handle both tuple and dict batch formats from CombinedLoader
|
| 310 |
+
if isinstance(batch, dict):
|
| 311 |
+
main_batch = batch["test_samples"][0]
|
| 312 |
+
metric_batch = batch["metric_samples"][0]
|
| 313 |
+
else:
|
| 314 |
+
# batch is a tuple: (test_samples, metric_samples)
|
| 315 |
+
main_batch = batch[0][0]
|
| 316 |
+
metric_batch = batch[1][0]
|
| 317 |
+
|
| 318 |
+
x0 = main_batch["x0"][0] # [B, D]
|
| 319 |
+
cloud_points = main_batch["dataset"][0] # full dataset, [N, D]
|
| 320 |
+
t_span = torch.linspace(0, 1, 101)
|
| 321 |
+
|
| 322 |
+
all_trajs = []
|
| 323 |
+
|
| 324 |
+
for i, flow_net in enumerate(self.flow_nets):
|
| 325 |
+
node = NeuralODE(
|
| 326 |
+
flow_model_torch_wrapper(flow_net),
|
| 327 |
+
solver="euler",
|
| 328 |
+
sensitivity="adjoint",
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
with torch.no_grad():
|
| 332 |
+
traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
|
| 333 |
+
|
| 334 |
+
if self.whiten:
|
| 335 |
+
traj_shape = traj.shape
|
| 336 |
+
traj = traj.reshape(-1, 3)
|
| 337 |
+
traj = self.trainer.datamodule.scaler.inverse_transform(
|
| 338 |
+
traj.cpu().detach().numpy()
|
| 339 |
+
).reshape(traj_shape)
|
| 340 |
+
|
| 341 |
+
traj = torch.tensor(traj)
|
| 342 |
+
traj = torch.transpose(traj, 0, 1) # [B, T, D]
|
| 343 |
+
all_trajs.append(traj)
|
| 344 |
+
|
| 345 |
+
# Inverse-transform the point cloud once
|
| 346 |
+
if self.whiten:
|
| 347 |
+
cloud_points = torch.tensor(
|
| 348 |
+
self.trainer.datamodule.scaler.inverse_transform(
|
| 349 |
+
cloud_points.cpu().detach().numpy()
|
| 350 |
+
)
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Create directory for saving figures
|
| 354 |
+
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
|
| 355 |
+
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
|
| 356 |
+
lidar_fig_dir = os.path.join(results_dir, 'figures')
|
| 357 |
+
os.makedirs(lidar_fig_dir, exist_ok=True)
|
| 358 |
+
|
| 359 |
+
# ===== Plot all trajectories together =====
|
| 360 |
+
fig = plt.figure(figsize=(6, 5))
|
| 361 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 362 |
+
ax.view_init(elev=30, azim=-115, roll=0)
|
| 363 |
+
for i, traj in enumerate(all_trajs):
|
| 364 |
+
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
|
| 365 |
+
plt.savefig(os.path.join(lidar_fig_dir, 'lidar_all_branches.png'), dpi=300)
|
| 366 |
+
plt.close()
|
| 367 |
+
|
| 368 |
+
# ===== Plot each trajectory separately =====
|
| 369 |
+
for i, traj in enumerate(all_trajs):
|
| 370 |
+
fig = plt.figure(figsize=(6, 5))
|
| 371 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 372 |
+
ax.view_init(elev=30, azim=-115, roll=0)
|
| 373 |
+
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
|
| 374 |
+
plt.savefig(os.path.join(lidar_fig_dir, f'lidar_branch_{i + 1}.png'), dpi=300)
|
| 375 |
+
plt.close()
|
src/branch_growth_net_train.py
ADDED
|
@@ -0,0 +1,994 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import wandb
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
from torch.optim import AdamW
|
| 8 |
+
from torchmetrics.functional import mean_squared_error
|
| 9 |
+
from torchdyn.core import NeuralODE
|
| 10 |
+
import numpy as np
|
| 11 |
+
import lpips
|
| 12 |
+
from .networks.utils import flow_model_torch_wrapper
|
| 13 |
+
from .utils import plot_lidar
|
| 14 |
+
from .ema import EMA
|
| 15 |
+
from torchdiffeq import odeint as odeint2
|
| 16 |
+
from .losses.energy_loss import EnergySolver, ReconsLoss
|
| 17 |
+
|
| 18 |
+
class GrowthNetTrain(pl.LightningModule):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
flow_nets,
|
| 22 |
+
growth_nets,
|
| 23 |
+
skipped_time_points=None,
|
| 24 |
+
ot_sampler=None,
|
| 25 |
+
args=None,
|
| 26 |
+
|
| 27 |
+
state_cost=None,
|
| 28 |
+
data_manifold_metric=None,
|
| 29 |
+
|
| 30 |
+
joint = False
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
#self.save_hyperparameters()
|
| 34 |
+
self.flow_nets = flow_nets
|
| 35 |
+
|
| 36 |
+
if not joint:
|
| 37 |
+
for param in self.flow_nets.parameters():
|
| 38 |
+
param.requires_grad = False
|
| 39 |
+
|
| 40 |
+
self.growth_nets = growth_nets # list of growth networks for each branch
|
| 41 |
+
|
| 42 |
+
self.ot_sampler = ot_sampler
|
| 43 |
+
self.skipped_time_points = skipped_time_points
|
| 44 |
+
|
| 45 |
+
self.optimizer_name = args.growth_optimizer
|
| 46 |
+
self.lr = args.growth_lr
|
| 47 |
+
self.weight_decay = args.growth_weight_decay
|
| 48 |
+
self.whiten = args.whiten
|
| 49 |
+
self.working_dir = args.working_dir
|
| 50 |
+
|
| 51 |
+
self.args = args
|
| 52 |
+
|
| 53 |
+
#branching
|
| 54 |
+
self.state_cost = state_cost
|
| 55 |
+
self.data_manifold_metric = data_manifold_metric
|
| 56 |
+
self.branches = len(growth_nets)
|
| 57 |
+
self.metric_clusters = args.metric_clusters
|
| 58 |
+
|
| 59 |
+
self.recons_loss = ReconsLoss()
|
| 60 |
+
|
| 61 |
+
# loss weights
|
| 62 |
+
self.lambda_energy = args.lambda_energy
|
| 63 |
+
self.lambda_mass = args.lambda_mass
|
| 64 |
+
self.lambda_match = args.lambda_match
|
| 65 |
+
self.lambda_recons = args.lambda_recons
|
| 66 |
+
|
| 67 |
+
self.joint = joint
|
| 68 |
+
|
| 69 |
+
def forward(self, t, xt, branch_idx):
|
| 70 |
+
# output growth rate given branch_idx
|
| 71 |
+
return self.growth_nets[branch_idx](t, xt)
|
| 72 |
+
|
| 73 |
+
def _compute_loss(self, main_batch, metric_samples_batch=None, validation=False):
|
| 74 |
+
x0s = main_batch["x0"][0]
|
| 75 |
+
w0s = main_batch["x0"][1]
|
| 76 |
+
x1s_list = []
|
| 77 |
+
w1s_list = []
|
| 78 |
+
|
| 79 |
+
if self.branches > 1:
|
| 80 |
+
for i in range(self.branches):
|
| 81 |
+
x1s_list.append([main_batch[f"x1_{i+1}"][0]])
|
| 82 |
+
w1s_list.append([main_batch[f"x1_{i+1}"][1]])
|
| 83 |
+
else:
|
| 84 |
+
x1s_list.append([main_batch["x1"][0]])
|
| 85 |
+
w1s_list.append([main_batch["x1"][1]])
|
| 86 |
+
|
| 87 |
+
if self.args.manifold:
|
| 88 |
+
#changed
|
| 89 |
+
if self.metric_clusters == 7 and self.branches == 6:
|
| 90 |
+
# Weinreb 6-branch scenario: cluster 0 (root) → clusters 1-6 (6 branches)
|
| 91 |
+
branch_sample_pairs = [
|
| 92 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 93 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 94 |
+
(metric_samples_batch[0], metric_samples_batch[3]), # x0 → x1_3 (branch 3)
|
| 95 |
+
(metric_samples_batch[0], metric_samples_batch[4]), # x0 → x1_4 (branch 4)
|
| 96 |
+
(metric_samples_batch[0], metric_samples_batch[5]), # x0 → x1_5 (branch 5)
|
| 97 |
+
(metric_samples_batch[0], metric_samples_batch[6]), # x0 → x1_6 (branch 6)
|
| 98 |
+
]
|
| 99 |
+
elif self.metric_clusters == 4:
|
| 100 |
+
branch_sample_pairs = [
|
| 101 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 102 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 103 |
+
(metric_samples_batch[0], metric_samples_batch[3]),
|
| 104 |
+
]
|
| 105 |
+
elif self.metric_clusters == 3:
|
| 106 |
+
branch_sample_pairs = [
|
| 107 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 108 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 109 |
+
]
|
| 110 |
+
elif self.metric_clusters == 2 and self.branches == 2:
|
| 111 |
+
branch_sample_pairs = [
|
| 112 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 113 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
|
| 114 |
+
]
|
| 115 |
+
elif self.metric_clusters == 2:
|
| 116 |
+
# For any number of branches with 2 metric clusters (initial vs remaining)
|
| 117 |
+
# All branches use the same metric cluster pair
|
| 118 |
+
branch_sample_pairs = [
|
| 119 |
+
(metric_samples_batch[0], metric_samples_batch[1]) # x0 → all branches
|
| 120 |
+
] * self.branches
|
| 121 |
+
else:
|
| 122 |
+
branch_sample_pairs = [
|
| 123 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
batch_size = x0s.shape[0]
|
| 127 |
+
|
| 128 |
+
assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches"
|
| 129 |
+
|
| 130 |
+
energy_loss = [0.] * self.branches
|
| 131 |
+
mass_loss = 0.
|
| 132 |
+
neg_weight_penalty = 0.
|
| 133 |
+
match_loss = [0.] * self.branches
|
| 134 |
+
recons_loss = [0.] * self.branches
|
| 135 |
+
|
| 136 |
+
dtype = x0s[0].dtype
|
| 137 |
+
#w0s = torch.zeros((batch_size, 1), dtype=dtype)
|
| 138 |
+
m0s = torch.zeros_like(w0s, dtype=dtype)
|
| 139 |
+
start_state = (x0s, w0s, m0s)
|
| 140 |
+
|
| 141 |
+
xt = [x0s.clone() for _ in range(self.branches)]
|
| 142 |
+
w0_branch = torch.zeros_like(w0s, dtype=dtype)
|
| 143 |
+
w0_branches = []
|
| 144 |
+
w0_branches.append(w0s)
|
| 145 |
+
for _ in range(self.branches - 1):
|
| 146 |
+
w0_branches.append(w0_branch)
|
| 147 |
+
#w0_branches = [w0_branch.clone() for _ in range(self.branches - 1)]
|
| 148 |
+
wt = w0_branches
|
| 149 |
+
|
| 150 |
+
mt = [m0s.clone() for _ in range(self.branches)]
|
| 151 |
+
|
| 152 |
+
# loop through timesteps
|
| 153 |
+
for step_idx, (s, t) in enumerate(zip(self.timesteps[:-1], self.timesteps[1:])):
|
| 154 |
+
time = torch.Tensor([s, t])
|
| 155 |
+
|
| 156 |
+
total_w_t = 0
|
| 157 |
+
# loop through branches
|
| 158 |
+
for i in range(self.branches):
|
| 159 |
+
|
| 160 |
+
if self.args.manifold:
|
| 161 |
+
start_samples, end_samples = branch_sample_pairs[i]
|
| 162 |
+
samples = torch.cat([start_samples, end_samples], dim=0)
|
| 163 |
+
else:
|
| 164 |
+
samples = None
|
| 165 |
+
|
| 166 |
+
# initialize weight and energy
|
| 167 |
+
start_state = (xt[i], wt[i], mt[i])
|
| 168 |
+
|
| 169 |
+
# loop over timesteps
|
| 170 |
+
xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples, timestep_idx=step_idx)
|
| 171 |
+
|
| 172 |
+
# placeholders for next state
|
| 173 |
+
xt_last = xt_next[-1]
|
| 174 |
+
wt_last = wt_next[-1]
|
| 175 |
+
mt_last = mt_next[-1]
|
| 176 |
+
|
| 177 |
+
total_w_t += wt_last
|
| 178 |
+
|
| 179 |
+
energy_loss[i] += (mt_last - mt[i])
|
| 180 |
+
neg_weight_penalty += torch.relu(-wt_last).sum()
|
| 181 |
+
|
| 182 |
+
# update branch state
|
| 183 |
+
xt[i] = xt_last.clone().detach()
|
| 184 |
+
wt[i] = wt_last.clone().detach()
|
| 185 |
+
mt[i] = mt_last.clone().detach()
|
| 186 |
+
|
| 187 |
+
# calculate mass loss from all branches
|
| 188 |
+
target = torch.ones_like(total_w_t)
|
| 189 |
+
mass_loss += mean_squared_error(total_w_t, target)
|
| 190 |
+
|
| 191 |
+
# calculate loss that matches final weights
|
| 192 |
+
for i in range(self.branches):
|
| 193 |
+
match_loss[i] = mean_squared_error(wt[i], w1s_list[i][0])
|
| 194 |
+
# compute reconstruction loss
|
| 195 |
+
recons_loss[i] = self.recons_loss(xt[i], x1s_list[i][0])
|
| 196 |
+
|
| 197 |
+
# average across time steps (loop runs len(timesteps)-1 times)
|
| 198 |
+
mass_loss = mass_loss / max(len(self.timesteps) - 1, 1)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# Weighted mean across branches (inversely weighted by cluster size)
|
| 202 |
+
# Get cluster sizes from datamodule if available
|
| 203 |
+
if hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'cluster_sizes'):
|
| 204 |
+
cluster_sizes = self.trainer.datamodule.cluster_sizes
|
| 205 |
+
max_size = max(cluster_sizes)
|
| 206 |
+
# Inverse weighting: smaller clusters get higher weight
|
| 207 |
+
branch_weights = torch.tensor([max_size / size for size in cluster_sizes],
|
| 208 |
+
dtype=energy_loss[0].dtype, device=energy_loss[0].device)
|
| 209 |
+
# Normalize weights to sum to num_branches for fair comparison
|
| 210 |
+
branch_weights = branch_weights * self.branches / branch_weights.sum()
|
| 211 |
+
|
| 212 |
+
energy_loss = torch.mean(torch.stack([e.mean() for e in energy_loss]) * branch_weights)
|
| 213 |
+
match_loss = torch.mean(torch.stack(match_loss) * branch_weights)
|
| 214 |
+
recons_loss = torch.mean(torch.stack(recons_loss) * branch_weights)
|
| 215 |
+
else:
|
| 216 |
+
# Fallback to uniform weighting
|
| 217 |
+
energy_loss = torch.mean(torch.stack([e.mean() for e in energy_loss]))
|
| 218 |
+
match_loss = torch.mean(torch.stack(match_loss))
|
| 219 |
+
recons_loss = torch.mean(torch.stack(recons_loss))
|
| 220 |
+
|
| 221 |
+
loss = (self.lambda_energy * energy_loss) + (self.lambda_mass * (mass_loss + neg_weight_penalty)) + (self.lambda_match * match_loss) \
|
| 222 |
+
+ (self.lambda_recons * recons_loss)
|
| 223 |
+
|
| 224 |
+
if self.joint:
|
| 225 |
+
if validation:
|
| 226 |
+
self.log("JointTrain/val_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 227 |
+
self.log("JointTrain/val_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
|
| 228 |
+
self.log("JointTrain/val_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 229 |
+
self.log("JointTrain/val_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 230 |
+
self.log("JointTrain/val_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 231 |
+
else:
|
| 232 |
+
self.log("JointTrain/train_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 233 |
+
self.log("JointTrain/train_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
|
| 234 |
+
self.log("JointTrain/train_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 235 |
+
self.log("JointTrain/train_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 236 |
+
self.log("JointTrain/train_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 237 |
+
else:
|
| 238 |
+
if validation:
|
| 239 |
+
self.log("GrowthNet/val_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 240 |
+
self.log("GrowthNet/val_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
|
| 241 |
+
self.log("GrowthNet/val_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 242 |
+
self.log("GrowthNet/val_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 243 |
+
self.log("GrowthNet/val_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 244 |
+
else:
|
| 245 |
+
self.log("GrowthNet/train_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 246 |
+
self.log("GrowthNet/train_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
|
| 247 |
+
self.log("GrowthNet/train_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 248 |
+
self.log("GrowthNet/train_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 249 |
+
self.log("GrowthNet/train_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 250 |
+
|
| 251 |
+
return loss
|
| 252 |
+
|
| 253 |
+
def take_step(self, t, start_state, branch_idx, samples=None, timestep_idx=0):
|
| 254 |
+
|
| 255 |
+
flow_net = self.flow_nets[branch_idx]
|
| 256 |
+
growth_net = self.growth_nets[branch_idx]
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
x_t, w_t, m_t = odeint2(EnergySolver(flow_net, growth_net, self.state_cost, self.data_manifold_metric, samples, timestep_idx), start_state, t, options=dict(step_size=0.1),method='euler')
|
| 260 |
+
|
| 261 |
+
return x_t, w_t, m_t
|
| 262 |
+
|
| 263 |
+
def training_step(self, batch, batch_idx):
|
| 264 |
+
if isinstance(batch, (list, tuple)):
|
| 265 |
+
batch = batch[0]
|
| 266 |
+
if isinstance(batch, dict) and "train_samples" in batch:
|
| 267 |
+
main_batch = batch["train_samples"]
|
| 268 |
+
metric_batch = batch["metric_samples"]
|
| 269 |
+
if isinstance(main_batch, tuple):
|
| 270 |
+
main_batch = main_batch[0]
|
| 271 |
+
if isinstance(metric_batch, tuple):
|
| 272 |
+
metric_batch = metric_batch[0]
|
| 273 |
+
else:
|
| 274 |
+
# Fallback
|
| 275 |
+
main_batch = batch.get("train_samples", batch)
|
| 276 |
+
metric_batch = batch.get("metric_samples", [])
|
| 277 |
+
|
| 278 |
+
self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
|
| 279 |
+
loss = self._compute_loss(main_batch, metric_batch, validation=False)
|
| 280 |
+
|
| 281 |
+
if self.joint:
|
| 282 |
+
self.log(
|
| 283 |
+
"JointTrain/train_loss",
|
| 284 |
+
loss,
|
| 285 |
+
on_step=False,
|
| 286 |
+
on_epoch=True,
|
| 287 |
+
prog_bar=True,
|
| 288 |
+
logger=True,
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
self.log(
|
| 292 |
+
"GrowthNet/train_loss",
|
| 293 |
+
loss,
|
| 294 |
+
on_step=False,
|
| 295 |
+
on_epoch=True,
|
| 296 |
+
prog_bar=True,
|
| 297 |
+
logger=True,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
return loss
|
| 301 |
+
|
| 302 |
+
def validation_step(self, batch, batch_idx):
|
| 303 |
+
if isinstance(batch, (list, tuple)):
|
| 304 |
+
batch = batch[0]
|
| 305 |
+
if isinstance(batch, dict) and "val_samples" in batch:
|
| 306 |
+
main_batch = batch["val_samples"]
|
| 307 |
+
metric_batch = batch["metric_samples"]
|
| 308 |
+
if isinstance(main_batch, tuple):
|
| 309 |
+
main_batch = main_batch[0]
|
| 310 |
+
if isinstance(metric_batch, tuple):
|
| 311 |
+
metric_batch = metric_batch[0]
|
| 312 |
+
else:
|
| 313 |
+
# Fallback
|
| 314 |
+
main_batch = batch.get("val_samples", batch)
|
| 315 |
+
metric_batch = batch.get("metric_samples", [])
|
| 316 |
+
|
| 317 |
+
self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
|
| 318 |
+
val_loss = self._compute_loss(main_batch, metric_batch, validation=True)
|
| 319 |
+
|
| 320 |
+
if self.joint:
|
| 321 |
+
self.log(
|
| 322 |
+
"JointTrain/val_loss",
|
| 323 |
+
val_loss,
|
| 324 |
+
on_step=False,
|
| 325 |
+
on_epoch=True,
|
| 326 |
+
prog_bar=True,
|
| 327 |
+
logger=True,
|
| 328 |
+
)
|
| 329 |
+
else:
|
| 330 |
+
self.log(
|
| 331 |
+
"GrowthNet/val_loss",
|
| 332 |
+
val_loss,
|
| 333 |
+
on_step=False,
|
| 334 |
+
on_epoch=True,
|
| 335 |
+
prog_bar=True,
|
| 336 |
+
logger=True,
|
| 337 |
+
)
|
| 338 |
+
return val_loss
|
| 339 |
+
|
| 340 |
+
def optimizer_step(self, *args, **kwargs):
|
| 341 |
+
super().optimizer_step(*args, **kwargs)
|
| 342 |
+
for net in self.growth_nets:
|
| 343 |
+
if isinstance(net, EMA):
|
| 344 |
+
net.update_ema()
|
| 345 |
+
if self.joint:
|
| 346 |
+
for net in self.flow_nets:
|
| 347 |
+
if isinstance(net, EMA):
|
| 348 |
+
net.update_ema()
|
| 349 |
+
|
| 350 |
+
def configure_optimizers(self):
|
| 351 |
+
params = []
|
| 352 |
+
for net in self.growth_nets:
|
| 353 |
+
params += list(net.parameters())
|
| 354 |
+
|
| 355 |
+
if self.joint:
|
| 356 |
+
for net in self.flow_nets:
|
| 357 |
+
params += list(net.parameters())
|
| 358 |
+
|
| 359 |
+
if self.optimizer_name == "adamw":
|
| 360 |
+
optimizer = AdamW(
|
| 361 |
+
params,
|
| 362 |
+
lr=self.lr,
|
| 363 |
+
weight_decay=self.weight_decay,
|
| 364 |
+
)
|
| 365 |
+
elif self.optimizer_name == "adam":
|
| 366 |
+
optimizer = torch.optim.Adam(
|
| 367 |
+
params,
|
| 368 |
+
lr=self.lr,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
return optimizer
|
| 372 |
+
|
| 373 |
+
@torch.no_grad()
|
| 374 |
+
def get_mass_and_position(self, main_batch, metric_samples_batch=None):
|
| 375 |
+
if isinstance(main_batch, dict):
|
| 376 |
+
main_batch = main_batch
|
| 377 |
+
else:
|
| 378 |
+
main_batch = main_batch[0]
|
| 379 |
+
|
| 380 |
+
x0s = main_batch["x0"][0]
|
| 381 |
+
w0s = main_batch["x0"][1]
|
| 382 |
+
|
| 383 |
+
if self.args.manifold:
|
| 384 |
+
if self.metric_clusters == 4:
|
| 385 |
+
branch_sample_pairs = [
|
| 386 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 387 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 388 |
+
(metric_samples_batch[0], metric_samples_batch[3]),
|
| 389 |
+
]
|
| 390 |
+
elif self.metric_clusters == 3:
|
| 391 |
+
branch_sample_pairs = [
|
| 392 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 393 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 394 |
+
]
|
| 395 |
+
elif self.metric_clusters == 2 and self.branches == 2:
|
| 396 |
+
branch_sample_pairs = [
|
| 397 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 398 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
|
| 399 |
+
]
|
| 400 |
+
elif self.metric_clusters == 2:
|
| 401 |
+
# For any number of branches with 2 metric clusters (initial vs remaining)
|
| 402 |
+
# All branches use the same metric cluster pair
|
| 403 |
+
branch_sample_pairs = [
|
| 404 |
+
(metric_samples_batch[0], metric_samples_batch[1]) # x0 → all branches
|
| 405 |
+
] * self.branches
|
| 406 |
+
else:
|
| 407 |
+
branch_sample_pairs = [
|
| 408 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 409 |
+
]
|
| 410 |
+
|
| 411 |
+
batch_size = x0s.shape[0]
|
| 412 |
+
dtype = x0s[0].dtype
|
| 413 |
+
|
| 414 |
+
m0s = torch.zeros_like(w0s, dtype=dtype)
|
| 415 |
+
xt = [x0s.clone() for _ in range(self.branches)]
|
| 416 |
+
|
| 417 |
+
w0_branch = torch.zeros_like(w0s, dtype=dtype)
|
| 418 |
+
w0_branches = []
|
| 419 |
+
w0_branches.append(w0s)
|
| 420 |
+
for _ in range(self.branches - 1):
|
| 421 |
+
w0_branches.append(w0_branch)
|
| 422 |
+
|
| 423 |
+
wt = w0_branches
|
| 424 |
+
mt = [m0s.clone() for _ in range(self.branches)]
|
| 425 |
+
|
| 426 |
+
time_points = []
|
| 427 |
+
mass_over_time = [[] for _ in range(self.branches)]
|
| 428 |
+
energy_over_time = [[] for _ in range(self.branches)]
|
| 429 |
+
# record per-sample weights at each time for each branch (to allow OT with per-sample masses)
|
| 430 |
+
weights_over_time = [[] for _ in range(self.branches)]
|
| 431 |
+
all_trajs = [[] for _ in range(self.branches)]
|
| 432 |
+
|
| 433 |
+
t_span = torch.linspace(0, 1, 101)
|
| 434 |
+
for step_idx, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])):
|
| 435 |
+
time_points.append(t.item())
|
| 436 |
+
time = torch.Tensor([s, t])
|
| 437 |
+
|
| 438 |
+
for i in range(self.branches):
|
| 439 |
+
if self.args.manifold:
|
| 440 |
+
start_samples, end_samples = branch_sample_pairs[i]
|
| 441 |
+
samples = torch.cat([start_samples, end_samples], dim=0)
|
| 442 |
+
else:
|
| 443 |
+
samples = None
|
| 444 |
+
|
| 445 |
+
start_state = (xt[i], wt[i], mt[i])
|
| 446 |
+
xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples, timestep_idx=step_idx)
|
| 447 |
+
|
| 448 |
+
xt[i] = xt_next[-1].clone().detach()
|
| 449 |
+
wt[i] = wt_next[-1].clone().detach()
|
| 450 |
+
mt[i] = mt_next[-1].clone().detach()
|
| 451 |
+
|
| 452 |
+
all_trajs[i].append(xt[i].clone().detach())
|
| 453 |
+
mass_over_time[i].append(wt[i].mean().item())
|
| 454 |
+
energy_over_time[i].append(mt[i].mean().item())
|
| 455 |
+
# store per-sample weights (clone to detach from graph)
|
| 456 |
+
try:
|
| 457 |
+
weights_over_time[i].append(wt[i].clone().detach())
|
| 458 |
+
except Exception:
|
| 459 |
+
# fallback: store mean as singleton tensor
|
| 460 |
+
weights_over_time[i].append(torch.tensor(wt[i].mean().item()).unsqueeze(0))
|
| 461 |
+
|
| 462 |
+
return time_points, xt, all_trajs, mass_over_time, energy_over_time, weights_over_time
|
| 463 |
+
|
| 464 |
+
@torch.no_grad()
|
| 465 |
+
def _plot_mass_and_energy(self, main_batch, metric_samples_batch=None, save_dir=None):
|
| 466 |
+
x0s = main_batch["x0"][0]
|
| 467 |
+
w0s = main_batch["x0"][1]
|
| 468 |
+
|
| 469 |
+
if self.args.manifold:
|
| 470 |
+
if self.metric_clusters == 7 and self.branches == 6:
|
| 471 |
+
# Weinreb 6-branch scenario: cluster 0 (root) → clusters 1-6 (6 branches)
|
| 472 |
+
branch_sample_pairs = [
|
| 473 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 474 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 475 |
+
(metric_samples_batch[0], metric_samples_batch[3]), # x0 → x1_3 (branch 3)
|
| 476 |
+
(metric_samples_batch[0], metric_samples_batch[4]), # x0 → x1_4 (branch 4)
|
| 477 |
+
(metric_samples_batch[0], metric_samples_batch[5]), # x0 → x1_5 (branch 5)
|
| 478 |
+
(metric_samples_batch[0], metric_samples_batch[6]), # x0 → x1_6 (branch 6)
|
| 479 |
+
]
|
| 480 |
+
elif self.metric_clusters == 4:
|
| 481 |
+
branch_sample_pairs = [
|
| 482 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 483 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 484 |
+
(metric_samples_batch[0], metric_samples_batch[3]),
|
| 485 |
+
]
|
| 486 |
+
elif self.metric_clusters == 3:
|
| 487 |
+
branch_sample_pairs = [
|
| 488 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 489 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 490 |
+
]
|
| 491 |
+
elif self.metric_clusters == 2 and self.branches == 2:
|
| 492 |
+
branch_sample_pairs = [
|
| 493 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 494 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
|
| 495 |
+
]
|
| 496 |
+
else:
|
| 497 |
+
branch_sample_pairs = [
|
| 498 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 499 |
+
]
|
| 500 |
+
|
| 501 |
+
batch_size = x0s.shape[0]
|
| 502 |
+
dtype = x0s[0].dtype
|
| 503 |
+
|
| 504 |
+
m0s = torch.zeros_like(w0s, dtype=dtype)
|
| 505 |
+
xt = [x0s.clone() for _ in range(self.branches)]
|
| 506 |
+
|
| 507 |
+
w0_branch = torch.zeros_like(w0s, dtype=dtype)
|
| 508 |
+
w0_branches = []
|
| 509 |
+
w0_branches.append(w0s)
|
| 510 |
+
for _ in range(self.branches - 1):
|
| 511 |
+
w0_branches.append(w0_branch)
|
| 512 |
+
|
| 513 |
+
wt = w0_branches
|
| 514 |
+
mt = [m0s.clone() for _ in range(self.branches)]
|
| 515 |
+
|
| 516 |
+
time_points = []
|
| 517 |
+
mass_over_time = [[] for _ in range(self.branches)]
|
| 518 |
+
energy_over_time = [[] for _ in range(self.branches)]
|
| 519 |
+
|
| 520 |
+
t_span = torch.linspace(0, 1, 101)
|
| 521 |
+
for step_idx, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])):
|
| 522 |
+
time_points.append(t.item())
|
| 523 |
+
time = torch.Tensor([s, t])
|
| 524 |
+
|
| 525 |
+
for i in range(self.branches):
|
| 526 |
+
if self.args.manifold:
|
| 527 |
+
start_samples, end_samples = branch_sample_pairs[i]
|
| 528 |
+
samples = torch.cat([start_samples, end_samples], dim=0)
|
| 529 |
+
else:
|
| 530 |
+
samples = None
|
| 531 |
+
|
| 532 |
+
start_state = (xt[i], wt[i], mt[i])
|
| 533 |
+
xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples, timestep_idx=step_idx)
|
| 534 |
+
|
| 535 |
+
xt[i] = xt_next[-1].clone().detach()
|
| 536 |
+
wt[i] = wt_next[-1].clone().detach()
|
| 537 |
+
mt[i] = mt_next[-1].clone().detach()
|
| 538 |
+
|
| 539 |
+
mass_over_time[i].append(wt[i].mean().item())
|
| 540 |
+
energy_over_time[i].append(mt[i].mean().item())
|
| 541 |
+
|
| 542 |
+
if save_dir is None:
|
| 543 |
+
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
|
| 544 |
+
save_dir = os.path.join(self.args.working_dir, 'results', run_name, 'figures')
|
| 545 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 546 |
+
|
| 547 |
+
# Use tab10 colormap to get visually distinct colors
|
| 548 |
+
if self.args.branches == 3:
|
| 549 |
+
branch_colors = ['#9793F8', '#50B2D7', '#D577FF'] # tuple of RGBs
|
| 550 |
+
else:
|
| 551 |
+
branch_colors = ['#50B2D7', '#D577FF'] # tuple of RGBs
|
| 552 |
+
|
| 553 |
+
# --- Plot Mass ---
|
| 554 |
+
plt.figure(figsize=(8, 5))
|
| 555 |
+
for i in range(self.branches):
|
| 556 |
+
color = branch_colors[i]
|
| 557 |
+
plt.plot(time_points, mass_over_time[i], color=color, linewidth=2.5, label=f"Mass Branch {i}")
|
| 558 |
+
plt.xlabel("Time")
|
| 559 |
+
plt.ylabel("Mass")
|
| 560 |
+
plt.title("Mass Evolution per Branch")
|
| 561 |
+
plt.legend()
|
| 562 |
+
plt.grid(True)
|
| 563 |
+
if self.joint:
|
| 564 |
+
mass_path = os.path.join(save_dir, f"{self.args.data_name}_joint_mass.png")
|
| 565 |
+
else:
|
| 566 |
+
mass_path = os.path.join(save_dir, f"{self.args.data_name}_growth_mass.png")
|
| 567 |
+
plt.savefig(mass_path, dpi=300, bbox_inches="tight")
|
| 568 |
+
plt.close()
|
| 569 |
+
|
| 570 |
+
# --- Plot Energy ---
|
| 571 |
+
plt.figure(figsize=(8, 5))
|
| 572 |
+
for i in range(self.branches):
|
| 573 |
+
color = branch_colors[i]
|
| 574 |
+
plt.plot(time_points, energy_over_time[i], color=color, linewidth=2.5, label=f"Energy Branch {i}")
|
| 575 |
+
plt.xlabel("Time")
|
| 576 |
+
plt.ylabel("Energy")
|
| 577 |
+
plt.title("Energy Evolution per Branch")
|
| 578 |
+
plt.legend()
|
| 579 |
+
plt.grid(True)
|
| 580 |
+
if self.joint:
|
| 581 |
+
energy_path = os.path.join(save_dir, f"{self.args.data_name}_joint_energy.png")
|
| 582 |
+
else:
|
| 583 |
+
energy_path = os.path.join(save_dir, f"{self.args.data_name}_growth_energy.png")
|
| 584 |
+
plt.savefig(energy_path, dpi=300, bbox_inches="tight")
|
| 585 |
+
plt.close()
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
class GrowthNetTrainLidar(GrowthNetTrain):
|
| 589 |
+
def test_step(self, batch, batch_idx):
|
| 590 |
+
# Handle both tuple and dict batch formats from CombinedLoader
|
| 591 |
+
if isinstance(batch, dict):
|
| 592 |
+
main_batch = batch["test_samples"][0]
|
| 593 |
+
metric_batch = batch["metric_samples"][0]
|
| 594 |
+
else:
|
| 595 |
+
# batch is a tuple: (test_samples, metric_samples)
|
| 596 |
+
main_batch = batch[0][0]
|
| 597 |
+
metric_batch = batch[1][0]
|
| 598 |
+
|
| 599 |
+
self._plot_mass_and_energy(main_batch, metric_batch)
|
| 600 |
+
|
| 601 |
+
x0 = main_batch["x0"][0] # [B, D]
|
| 602 |
+
cloud_points = main_batch["dataset"][0] # full dataset, [N, D]
|
| 603 |
+
t_span = torch.linspace(0, 1, 101)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
all_trajs = []
|
| 607 |
+
|
| 608 |
+
for i, flow_net in enumerate(self.flow_nets):
|
| 609 |
+
node = NeuralODE(
|
| 610 |
+
flow_model_torch_wrapper(flow_net),
|
| 611 |
+
solver="euler",
|
| 612 |
+
sensitivity="adjoint",
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
with torch.no_grad():
|
| 616 |
+
traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
|
| 617 |
+
|
| 618 |
+
if self.whiten:
|
| 619 |
+
traj_shape = traj.shape
|
| 620 |
+
traj = traj.reshape(-1, 3)
|
| 621 |
+
traj = self.trainer.datamodule.scaler.inverse_transform(
|
| 622 |
+
traj.cpu().detach().numpy()
|
| 623 |
+
).reshape(traj_shape)
|
| 624 |
+
|
| 625 |
+
traj = torch.tensor(traj)
|
| 626 |
+
traj = torch.transpose(traj, 0, 1) # [B, T, D]
|
| 627 |
+
all_trajs.append(traj)
|
| 628 |
+
|
| 629 |
+
# Inverse-transform the point cloud once
|
| 630 |
+
if self.whiten:
|
| 631 |
+
cloud_points = torch.tensor(
|
| 632 |
+
self.trainer.datamodule.scaler.inverse_transform(
|
| 633 |
+
cloud_points.cpu().detach().numpy()
|
| 634 |
+
)
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
# ===== Plot all trajectories together =====
|
| 638 |
+
fig = plt.figure(figsize=(6, 5))
|
| 639 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 640 |
+
ax.view_init(elev=30, azim=-115, roll=0)
|
| 641 |
+
for i, traj in enumerate(all_trajs):
|
| 642 |
+
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
|
| 643 |
+
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
|
| 644 |
+
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
|
| 645 |
+
lidar_fig_dir = os.path.join(results_dir, 'figures')
|
| 646 |
+
os.makedirs(lidar_fig_dir, exist_ok=True)
|
| 647 |
+
if self.joint:
|
| 648 |
+
plt.savefig(os.path.join(lidar_fig_dir, 'joint_lidar_all_branches.png'), dpi=300)
|
| 649 |
+
else:
|
| 650 |
+
plt.savefig(os.path.join(lidar_fig_dir, 'growth_lidar_all_branches.png'), dpi=300)
|
| 651 |
+
plt.close()
|
| 652 |
+
|
| 653 |
+
# ===== Plot each trajectory separately =====
|
| 654 |
+
for i, traj in enumerate(all_trajs):
|
| 655 |
+
fig = plt.figure(figsize=(6, 5))
|
| 656 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 657 |
+
ax.view_init(elev=30, azim=-115, roll=0)
|
| 658 |
+
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
|
| 659 |
+
if self.joint:
|
| 660 |
+
plt.savefig(os.path.join(lidar_fig_dir, f'joint_lidar_branch_{i + 1}.png'), dpi=300)
|
| 661 |
+
else:
|
| 662 |
+
plt.savefig(os.path.join(lidar_fig_dir, f'growth_lidar_branch_{i + 1}.png'), dpi=300)
|
| 663 |
+
plt.close()
|
| 664 |
+
|
| 665 |
+
class GrowthNetTrainCell(GrowthNetTrain):
|
| 666 |
+
def test_step(self, batch, batch_idx):
|
| 667 |
+
if self.args.data_type in ["scrna", "tahoe"]:
|
| 668 |
+
main_batch = batch[0]["test_samples"][0]
|
| 669 |
+
metric_batch = batch[0]["metric_samples"][0]
|
| 670 |
+
else:
|
| 671 |
+
main_batch = batch["test_samples"][0]
|
| 672 |
+
metric_batch = batch["metric_samples"][0]
|
| 673 |
+
|
| 674 |
+
self._plot_mass_and_energy(main_batch, metric_batch)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
class SequentialGrowthNetTrain(pl.LightningModule):
|
| 678 |
+
"""
|
| 679 |
+
Sequential growth network training for multi-timepoint data.
|
| 680 |
+
Learns growth rates for transitions between consecutive timepoints.
|
| 681 |
+
"""
|
| 682 |
+
def __init__(
|
| 683 |
+
self,
|
| 684 |
+
flow_nets,
|
| 685 |
+
growth_nets,
|
| 686 |
+
skipped_time_points=None,
|
| 687 |
+
ot_sampler=None,
|
| 688 |
+
args=None,
|
| 689 |
+
data_manifold_metric=None,
|
| 690 |
+
joint=False
|
| 691 |
+
):
|
| 692 |
+
super().__init__()
|
| 693 |
+
self.flow_nets = flow_nets
|
| 694 |
+
|
| 695 |
+
if not joint:
|
| 696 |
+
for param in self.flow_nets.parameters():
|
| 697 |
+
param.requires_grad = False
|
| 698 |
+
|
| 699 |
+
self.growth_nets = growth_nets
|
| 700 |
+
self.ot_sampler = ot_sampler
|
| 701 |
+
self.skipped_time_points = skipped_time_points
|
| 702 |
+
|
| 703 |
+
self.optimizer_name = args.growth_optimizer
|
| 704 |
+
self.lr = args.growth_lr
|
| 705 |
+
self.weight_decay = args.growth_weight_decay
|
| 706 |
+
self.whiten = args.whiten
|
| 707 |
+
self.working_dir = args.working_dir
|
| 708 |
+
|
| 709 |
+
self.args = args
|
| 710 |
+
self.data_manifold_metric = data_manifold_metric
|
| 711 |
+
self.branches = len(growth_nets)
|
| 712 |
+
self.metric_clusters = args.metric_clusters
|
| 713 |
+
|
| 714 |
+
self.recons_loss = ReconsLoss()
|
| 715 |
+
|
| 716 |
+
# loss weights
|
| 717 |
+
self.lambda_energy = args.lambda_energy
|
| 718 |
+
self.lambda_mass = args.lambda_mass
|
| 719 |
+
self.lambda_match = args.lambda_match
|
| 720 |
+
self.lambda_recons = args.lambda_recons
|
| 721 |
+
|
| 722 |
+
self.joint = joint
|
| 723 |
+
self.num_timepoints = None
|
| 724 |
+
self.timepoint_keys = None
|
| 725 |
+
|
| 726 |
+
def forward(self, t, xt, branch_idx):
|
| 727 |
+
return self.growth_nets[branch_idx](t, xt)
|
| 728 |
+
|
| 729 |
+
def setup(self, stage=None):
|
| 730 |
+
"""Initialize timepoint keys before training/validation starts."""
|
| 731 |
+
if self.timepoint_keys is None:
|
| 732 |
+
timepoint_data = self.trainer.datamodule.get_timepoint_data()
|
| 733 |
+
self.timepoint_keys = [k for k in sorted(timepoint_data.keys())
|
| 734 |
+
if not any(x in k for x in ['_', 'time_labels'])]
|
| 735 |
+
self.num_timepoints = len(self.timepoint_keys)
|
| 736 |
+
print(f"Training sequential growth for {self.num_timepoints} timepoints: {self.timepoint_keys}")
|
| 737 |
+
|
| 738 |
+
def _compute_loss(self, main_batch, metric_samples_batch=None, validation=False):
|
| 739 |
+
"""Compute loss for sequential growth between timepoints."""
|
| 740 |
+
x0s = main_batch["x0"][0]
|
| 741 |
+
w0s = main_batch["x0"][1]
|
| 742 |
+
|
| 743 |
+
# Setup metric sample pairs
|
| 744 |
+
if self.args.manifold:
|
| 745 |
+
if self.metric_clusters == 2:
|
| 746 |
+
branch_sample_pairs = [
|
| 747 |
+
(metric_samples_batch[0], metric_samples_batch[1])
|
| 748 |
+
] * self.branches
|
| 749 |
+
else:
|
| 750 |
+
branch_sample_pairs = []
|
| 751 |
+
for b in range(self.branches):
|
| 752 |
+
if b + 1 < len(metric_samples_batch):
|
| 753 |
+
branch_sample_pairs.append(
|
| 754 |
+
(metric_samples_batch[0], metric_samples_batch[b + 1])
|
| 755 |
+
)
|
| 756 |
+
else:
|
| 757 |
+
branch_sample_pairs.append(
|
| 758 |
+
(metric_samples_batch[0], metric_samples_batch[1])
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
total_loss = 0
|
| 762 |
+
total_energy_loss = 0
|
| 763 |
+
total_mass_loss = 0
|
| 764 |
+
total_match_loss = 0
|
| 765 |
+
total_recons_loss = 0
|
| 766 |
+
num_transitions = 0
|
| 767 |
+
|
| 768 |
+
# Process each consecutive timepoint transition
|
| 769 |
+
for i in range(len(self.timepoint_keys) - 1):
|
| 770 |
+
t_curr_key = self.timepoint_keys[i]
|
| 771 |
+
t_next_key = self.timepoint_keys[i + 1]
|
| 772 |
+
|
| 773 |
+
batch_curr_key = f"x{t_curr_key.replace('t', '').replace('final', '1')}"
|
| 774 |
+
x_curr = main_batch[batch_curr_key][0]
|
| 775 |
+
w_curr = main_batch[batch_curr_key][1]
|
| 776 |
+
|
| 777 |
+
if i == len(self.timepoint_keys) - 2:
|
| 778 |
+
# Final transition to branches
|
| 779 |
+
# Get cluster size weights if available
|
| 780 |
+
if hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'cluster_sizes'):
|
| 781 |
+
cluster_sizes = self.trainer.datamodule.cluster_sizes
|
| 782 |
+
max_size = max(cluster_sizes)
|
| 783 |
+
# Inverse weighting: smaller clusters get higher weight
|
| 784 |
+
branch_weights = [max_size / size for size in cluster_sizes]
|
| 785 |
+
else:
|
| 786 |
+
branch_weights = [1.0] * self.branches
|
| 787 |
+
|
| 788 |
+
for b in range(self.branches):
|
| 789 |
+
x_next = main_batch[f"x1_{b+1}"][0]
|
| 790 |
+
w_next = main_batch[f"x1_{b+1}"][1]
|
| 791 |
+
|
| 792 |
+
# Compute growth-based loss for this transition
|
| 793 |
+
loss, energy_l, mass_l, match_l, recons_l = self._compute_transition_loss(
|
| 794 |
+
x_curr, w_curr, x_next, w_next, b, i,
|
| 795 |
+
branch_sample_pairs[b] if self.args.manifold else None
|
| 796 |
+
)
|
| 797 |
+
# Apply branch weight
|
| 798 |
+
total_loss += loss * branch_weights[b]
|
| 799 |
+
total_energy_loss += energy_l * branch_weights[b]
|
| 800 |
+
total_mass_loss += mass_l * branch_weights[b]
|
| 801 |
+
total_match_loss += match_l * branch_weights[b]
|
| 802 |
+
total_recons_loss += recons_l * branch_weights[b]
|
| 803 |
+
num_transitions += 1
|
| 804 |
+
else:
|
| 805 |
+
# Regular consecutive timepoints
|
| 806 |
+
batch_next_key = f"x{t_next_key.replace('t', '').replace('final', '1')}"
|
| 807 |
+
x_next = main_batch[batch_next_key][0]
|
| 808 |
+
w_next = main_batch[batch_next_key][1]
|
| 809 |
+
|
| 810 |
+
for b in range(self.branches):
|
| 811 |
+
loss, energy_l, mass_l, match_l, recons_l = self._compute_transition_loss(
|
| 812 |
+
x_curr, w_curr, x_next, w_next, b, i,
|
| 813 |
+
branch_sample_pairs[b] if self.args.manifold else None
|
| 814 |
+
)
|
| 815 |
+
total_loss += loss
|
| 816 |
+
total_energy_loss += energy_l
|
| 817 |
+
total_mass_loss += mass_l
|
| 818 |
+
total_match_loss += match_l
|
| 819 |
+
total_recons_loss += recons_l
|
| 820 |
+
num_transitions += 1
|
| 821 |
+
|
| 822 |
+
# Average losses
|
| 823 |
+
avg_energy_loss = total_energy_loss / num_transitions if num_transitions > 0 else total_energy_loss
|
| 824 |
+
avg_mass_loss = total_mass_loss / num_transitions if num_transitions > 0 else total_mass_loss
|
| 825 |
+
avg_match_loss = total_match_loss / num_transitions if num_transitions > 0 else total_match_loss
|
| 826 |
+
avg_recons_loss = total_recons_loss / num_transitions if num_transitions > 0 else total_recons_loss
|
| 827 |
+
|
| 828 |
+
# Log individual components
|
| 829 |
+
if self.joint:
|
| 830 |
+
if validation:
|
| 831 |
+
self.log("JointTrain/val_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 832 |
+
self.log("JointTrain/val_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 833 |
+
self.log("JointTrain/val_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 834 |
+
self.log("JointTrain/val_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 835 |
+
else:
|
| 836 |
+
self.log("JointTrain/train_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 837 |
+
self.log("JointTrain/train_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 838 |
+
self.log("JointTrain/train_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 839 |
+
self.log("JointTrain/train_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 840 |
+
else:
|
| 841 |
+
if validation:
|
| 842 |
+
self.log("GrowthNet/val_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 843 |
+
self.log("GrowthNet/val_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 844 |
+
self.log("GrowthNet/val_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 845 |
+
self.log("GrowthNet/val_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 846 |
+
else:
|
| 847 |
+
self.log("GrowthNet/train_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 848 |
+
self.log("GrowthNet/train_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 849 |
+
self.log("GrowthNet/train_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 850 |
+
self.log("GrowthNet/train_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 851 |
+
|
| 852 |
+
return total_loss
|
| 853 |
+
|
| 854 |
+
def _compute_transition_loss(self, x0, w0, x1, w1, branch_idx, transition_idx, metric_pair):
|
| 855 |
+
"""Compute loss for a single timepoint transition."""
|
| 856 |
+
if self.ot_sampler is not None:
|
| 857 |
+
x0, x1 = self.ot_sampler.sample_plan(x0, x1, replace=True)
|
| 858 |
+
|
| 859 |
+
# Simulate trajectory using flow network
|
| 860 |
+
t_span = torch.linspace(0, 1, 10, device=x0.device)
|
| 861 |
+
|
| 862 |
+
flow_model = flow_model_torch_wrapper(self.flow_nets[branch_idx])
|
| 863 |
+
node = NeuralODE(flow_model, solver="euler", sensitivity="adjoint")
|
| 864 |
+
|
| 865 |
+
with torch.no_grad():
|
| 866 |
+
traj = node.trajectory(x0, t_span)
|
| 867 |
+
|
| 868 |
+
# Compute energy and mass losses
|
| 869 |
+
energy_loss = 0
|
| 870 |
+
mass_loss = 0
|
| 871 |
+
neg_weight_penalty = 0
|
| 872 |
+
|
| 873 |
+
for t_idx in range(len(t_span)):
|
| 874 |
+
t = t_span[t_idx]
|
| 875 |
+
xt = traj[t_idx]
|
| 876 |
+
|
| 877 |
+
# Growth rate
|
| 878 |
+
growth = self.growth_nets[branch_idx](t.unsqueeze(0).expand(xt.shape[0]), xt)
|
| 879 |
+
|
| 880 |
+
# Energy loss
|
| 881 |
+
if self.args.manifold and metric_pair is not None:
|
| 882 |
+
start_samples, end_samples = metric_pair
|
| 883 |
+
samples = torch.cat([start_samples, end_samples], dim=0)
|
| 884 |
+
_, kinetic, potential = self.data_manifold_metric.calculate_velocity(
|
| 885 |
+
xt, torch.zeros_like(xt), samples, transition_idx
|
| 886 |
+
)
|
| 887 |
+
energy = kinetic + potential
|
| 888 |
+
else:
|
| 889 |
+
energy = (growth ** 2).sum(dim=-1)
|
| 890 |
+
|
| 891 |
+
energy_loss += energy.mean()
|
| 892 |
+
|
| 893 |
+
# Mass conservation
|
| 894 |
+
growth_sum = growth.sum(dim=-1, keepdim=True) # Keep dimension for proper broadcasting
|
| 895 |
+
wt = w0 * torch.exp(growth_sum)
|
| 896 |
+
mass = wt.sum()
|
| 897 |
+
mass_loss += (mass - w1.sum()).abs()
|
| 898 |
+
neg_weight_penalty += torch.relu(-wt).sum()
|
| 899 |
+
|
| 900 |
+
# Match and reconstruction losses (computed at final time)
|
| 901 |
+
xt_final = traj[-1]
|
| 902 |
+
match_loss = mean_squared_error(wt, w1)
|
| 903 |
+
recons_loss = self.recons_loss(xt_final, x1)
|
| 904 |
+
|
| 905 |
+
total_loss = (
|
| 906 |
+
self.lambda_energy * energy_loss +
|
| 907 |
+
self.lambda_mass * (mass_loss + neg_weight_penalty) +
|
| 908 |
+
self.lambda_match * match_loss +
|
| 909 |
+
self.lambda_recons * recons_loss
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
return total_loss, energy_loss, mass_loss + neg_weight_penalty, match_loss, recons_loss
|
| 913 |
+
|
| 914 |
+
def training_step(self, batch, batch_idx):
|
| 915 |
+
if isinstance(batch, (list, tuple)):
|
| 916 |
+
batch = batch[0]
|
| 917 |
+
main_batch = batch["train_samples"]
|
| 918 |
+
metric_batch = batch["metric_samples"]
|
| 919 |
+
if isinstance(main_batch, tuple):
|
| 920 |
+
main_batch = main_batch[0]
|
| 921 |
+
if isinstance(metric_batch, tuple):
|
| 922 |
+
metric_batch = metric_batch[0]
|
| 923 |
+
|
| 924 |
+
loss = self._compute_loss(main_batch, metric_batch)
|
| 925 |
+
|
| 926 |
+
if self.joint:
|
| 927 |
+
self.log(
|
| 928 |
+
"JointTrain/train_loss",
|
| 929 |
+
loss,
|
| 930 |
+
on_step=False,
|
| 931 |
+
on_epoch=True,
|
| 932 |
+
prog_bar=True,
|
| 933 |
+
logger=True,
|
| 934 |
+
)
|
| 935 |
+
else:
|
| 936 |
+
self.log(
|
| 937 |
+
"GrowthNet/train_loss",
|
| 938 |
+
loss,
|
| 939 |
+
on_step=False,
|
| 940 |
+
on_epoch=True,
|
| 941 |
+
prog_bar=True,
|
| 942 |
+
logger=True,
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
return loss
|
| 946 |
+
|
| 947 |
+
def validation_step(self, batch, batch_idx):
|
| 948 |
+
if isinstance(batch, (list, tuple)):
|
| 949 |
+
batch = batch[0]
|
| 950 |
+
main_batch = batch["val_samples"]
|
| 951 |
+
metric_batch = batch["metric_samples"]
|
| 952 |
+
if isinstance(main_batch, tuple):
|
| 953 |
+
main_batch = main_batch[0]
|
| 954 |
+
if isinstance(metric_batch, tuple):
|
| 955 |
+
metric_batch = metric_batch[0]
|
| 956 |
+
|
| 957 |
+
loss = self._compute_loss(main_batch, metric_batch, validation=True)
|
| 958 |
+
|
| 959 |
+
if self.joint:
|
| 960 |
+
self.log(
|
| 961 |
+
"JointTrain/val_loss",
|
| 962 |
+
loss,
|
| 963 |
+
on_step=False,
|
| 964 |
+
on_epoch=True,
|
| 965 |
+
prog_bar=True,
|
| 966 |
+
logger=True,
|
| 967 |
+
)
|
| 968 |
+
else:
|
| 969 |
+
self.log(
|
| 970 |
+
"GrowthNet/val_loss",
|
| 971 |
+
loss,
|
| 972 |
+
on_step=False,
|
| 973 |
+
on_epoch=True,
|
| 974 |
+
prog_bar=True,
|
| 975 |
+
logger=True,
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
return loss
|
| 979 |
+
|
| 980 |
+
def configure_optimizers(self):
|
| 981 |
+
import itertools
|
| 982 |
+
params = list(itertools.chain(*[net.parameters() for net in self.growth_nets]))
|
| 983 |
+
if self.joint:
|
| 984 |
+
params += list(itertools.chain(*[net.parameters() for net in self.flow_nets]))
|
| 985 |
+
|
| 986 |
+
if self.optimizer_name == "adam":
|
| 987 |
+
optimizer = torch.optim.Adam(params, lr=self.lr)
|
| 988 |
+
elif self.optimizer_name == "adamw":
|
| 989 |
+
optimizer = torch.optim.AdamW(
|
| 990 |
+
params,
|
| 991 |
+
lr=self.lr,
|
| 992 |
+
weight_decay=self.weight_decay,
|
| 993 |
+
)
|
| 994 |
+
return optimizer
|